mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-09 23:33:24 +00:00
Compare commits
286 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
65a87815e7 | ||
|
|
b80793ca82 | ||
|
|
601550f238 | ||
|
|
41b1cf2273 | ||
|
|
3b4f9f43db | ||
|
|
37a09ecb23 | ||
|
|
0da34d3c2d | ||
|
|
74bf7eda8f | ||
|
|
d3100085b0 | ||
|
|
f481d25133 | ||
|
|
8c6c90da74 | ||
|
|
24bcfd9c03 | ||
|
|
816fb4c5da | ||
|
|
c1bb77c7c9 | ||
|
|
6bcac3a55a | ||
|
|
fc346f4537 | ||
|
|
43e531a3b6 | ||
|
|
d24ea4ce2a | ||
|
|
2c30c981ae | ||
|
|
aa1da8a858 | ||
|
|
f1e9a787d7 | ||
|
|
4eeec297de | ||
|
|
77cc4ce3a0 | ||
|
|
37dfea1d3f | ||
|
|
e6626c672a | ||
|
|
c66cb0afd2 | ||
|
|
fb48eee973 | ||
|
|
bb44e5ec44 | ||
|
|
c785c1a3ca | ||
|
|
514ae341c8 | ||
|
|
0659ffab75 | ||
|
|
8ce07f38dd | ||
|
|
7cb398d167 | ||
|
|
c3e12c5e58 | ||
|
|
1825fc7503 | ||
|
|
48732ba05e | ||
|
|
acf483c9e6 | ||
|
|
3b3e0d1141 | ||
|
|
7acd428507 | ||
|
|
0aaf177640 | ||
|
|
450d1227bd | ||
|
|
492b9c46f0 | ||
|
|
6e634fe3f9 | ||
|
|
4e26182d14 | ||
|
|
eb7571936c | ||
|
|
5382764d8a | ||
|
|
49c8ec69d0 | ||
|
|
3b421c8181 | ||
|
|
21d2329947 | ||
|
|
0993413bab | ||
|
|
713388dd7b | ||
|
|
e6c7af0fa9 | ||
|
|
837aa6e3aa | ||
|
|
d210be06c2 | ||
|
|
afc8a0f9be | ||
|
|
af8e9ef458 | ||
|
|
cec6f993ad | ||
|
|
950de29f48 | ||
|
|
d6ec33e8e1 | ||
|
|
081cfe806e | ||
|
|
c1c62a6c04 | ||
|
|
a99522224f | ||
|
|
f5d46b9ca2 | ||
|
|
d693d7993b | ||
|
|
5936f9895c | ||
|
|
2fdf5d2793 | ||
|
|
b3da00d2ed | ||
|
|
740277a9f2 | ||
|
|
f91807b6b9 | ||
|
|
57d18bb226 | ||
|
|
10b9c6cb8a | ||
|
|
b24786f8a7 | ||
|
|
7b0eb41ebc | ||
|
|
70949929db | ||
|
|
7c9c89dace | ||
|
|
ef5901c81b | ||
|
|
d4829c82f7 | ||
|
|
a5f4166a9b | ||
|
|
0cbfe7f457 | ||
|
|
f2b1ec4f9e | ||
|
|
1cc21cc45b | ||
|
|
07cf616e2b | ||
|
|
2b8c466e88 | ||
|
|
ca2174ea48 | ||
|
|
c09fb2a79d | ||
|
|
4445a165e9 | ||
|
|
e92e2af71a | ||
|
|
a6bdd9a652 | ||
|
|
349a6349b3 | ||
|
|
00822770ec | ||
|
|
1a0ceda0fc | ||
|
|
b9ae4ab803 | ||
|
|
72add453d2 | ||
|
|
2789396435 | ||
|
|
61da7bd981 | ||
|
|
ae4c502792 | ||
|
|
ec6068060b | ||
|
|
ecb01d3dcd | ||
|
|
22c0c00bd4 | ||
|
|
9eb3e7a6c4 | ||
|
|
357c191510 | ||
|
|
5db244af76 | ||
|
|
dc375d1b74 | ||
|
|
9c040445af | ||
|
|
fff866424e | ||
|
|
2d12becfd6 | ||
|
|
252f7e0751 | ||
|
|
b2b17528cb | ||
|
|
55f938164b | ||
|
|
76294f0c59 | ||
|
|
2bcee78c6e | ||
|
|
7fe8246a9f | ||
|
|
93fe58e31e | ||
|
|
e5b5dc870f | ||
|
|
a54877c023 | ||
|
|
bb86a0c0c4 | ||
|
|
5fa23c7f41 | ||
|
|
f9a09b7f23 | ||
|
|
b0cde626fe | ||
|
|
e42ef9a95d | ||
|
|
abf1629ec7 | ||
|
|
73dc0b10b8 | ||
|
|
2ea95266e3 | ||
|
|
922d4141c0 | ||
|
|
1f8f198c45 | ||
|
|
c55275342c | ||
|
|
9261b0c20b | ||
|
|
7cc725496e | ||
|
|
5726a99c80 | ||
|
|
b5756bf729 | ||
|
|
709d999f9f | ||
|
|
24c18614f0 | ||
|
|
603f06a762 | ||
|
|
98f0a3e3bd | ||
|
|
e186ccb0d4 | ||
|
|
8fc0b08b70 | ||
|
|
52a257dc24 | ||
|
|
a12d907f55 | ||
|
|
453aaf8774 | ||
|
|
1b1ab1fb9b | ||
|
|
a9d0bb72da | ||
|
|
d328e54e4b | ||
|
|
5a7932cba4 | ||
|
|
1dbeb0827a | ||
|
|
2c8821891c | ||
|
|
0a2555b0f3 | ||
|
|
020df41efe | ||
|
|
f8f8cf17ce | ||
|
|
f31f7f701a | ||
|
|
b5fe78eb70 | ||
|
|
d1f667cf8d | ||
|
|
54ad7c1b6b | ||
|
|
d560c20c26 | ||
|
|
5abeca1f9e | ||
|
|
294eac3a88 | ||
|
|
a31104020c | ||
|
|
65bec4d734 | ||
|
|
edb2993838 | ||
|
|
c0d8e0dec7 | ||
|
|
795da13d5d | ||
|
|
55789df275 | ||
|
|
9e652a3540 | ||
|
|
46a6782065 | ||
|
|
c359f61859 | ||
|
|
908c8eab5b | ||
|
|
f5f2c69233 | ||
|
|
63d4de5eea | ||
|
|
af15083496 | ||
|
|
c4722e42b1 | ||
|
|
f9a991365f | ||
|
|
6df16bedba | ||
|
|
632a2fd2f2 | ||
|
|
5626637fbd | ||
|
|
2db89211a9 | ||
|
|
587371eb14 | ||
|
|
75818b1e25 | ||
|
|
a45c6defa7 | ||
|
|
cbe56955a9 | ||
|
|
8ea6ac913d | ||
|
|
ae1e8a5191 | ||
|
|
b3ccc55f09 | ||
|
|
40bee3e8d9 | ||
|
|
1ce56d7413 | ||
|
|
41a78be3a2 | ||
|
|
1ff5de9a31 | ||
|
|
46a6853046 | ||
|
|
4b2d40bd67 | ||
|
|
726f1a590c | ||
|
|
575881cb59 | ||
|
|
d02df0141b | ||
|
|
e4bc9da913 | ||
|
|
8c6be49625 | ||
|
|
c727e4251f | ||
|
|
99266be998 | ||
|
|
d0f3fd96f8 | ||
|
|
f361b2716d | ||
|
|
086d8d0d0b | ||
|
|
627dee1dac | ||
|
|
93147dddeb | ||
|
|
c0f9b15a58 | ||
|
|
6f2fbdcbae | ||
|
|
55c3197fb8 | ||
|
|
65debb874f | ||
|
|
3caadac003 | ||
|
|
6a9e3a6b84 | ||
|
|
269972440a | ||
|
|
cce13e6ad2 | ||
|
|
8a565dcad8 | ||
|
|
d536110404 | ||
|
|
48e957ddff | ||
|
|
94563d622c | ||
|
|
5a2cf0d53c | ||
|
|
2573358173 | ||
|
|
09cd3cff91 | ||
|
|
ab0bf1b517 | ||
|
|
58e09f8e5f | ||
|
|
2334a2b174 | ||
|
|
bc61bf36b2 | ||
|
|
7726a44ca2 | ||
|
|
dc55fb0ce3 | ||
|
|
a146c6c0aa | ||
|
|
4c133d3ea9 | ||
|
|
544238772a | ||
|
|
f3ccd85ba1 | ||
|
|
dc279de443 | ||
|
|
bf1634bda0 | ||
|
|
166d2d24d9 | ||
|
|
4cbcc835d1 | ||
|
|
b93026d83a | ||
|
|
5ed2133ff9 | ||
|
|
e9dd44e623 | ||
|
|
cc8c4ffb5f | ||
|
|
1510bfcb6f | ||
|
|
bcd2208b51 | ||
|
|
09b19f5c4e | ||
|
|
7b01ca0e2e | ||
|
|
9c65e17a21 | ||
|
|
fe6fc628ed | ||
|
|
8192eeabc8 | ||
|
|
c3f1cdd7e5 | ||
|
|
c6bd91b86b | ||
|
|
ce0c6aa82b | ||
|
|
349ddcaa89 | ||
|
|
bb9fe52f1e | ||
|
|
afe4c1bfb7 | ||
|
|
3c85d2a4d7 | ||
|
|
865af9f19e | ||
|
|
2b97cb98b5 | ||
|
|
938a799263 | ||
|
|
e17d4f8d98 | ||
|
|
c8cae1f74d | ||
|
|
0040d78496 | ||
|
|
896de027cc | ||
|
|
fc329ebf37 | ||
|
|
15bc99f6ea | ||
|
|
91841a5519 | ||
|
|
eaab1d6824 | ||
|
|
0cfe310df6 | ||
|
|
918b6955e4 | ||
|
|
3ec7991e5f | ||
|
|
532fbf00d4 | ||
|
|
45b6fffd7f | ||
|
|
5a3eb08739 | ||
|
|
0dff329162 | ||
|
|
49c1740b47 | ||
|
|
3fbee51e9f | ||
|
|
a3dc56d2a0 | ||
|
|
63643c44a1 | ||
|
|
1d93608dbe | ||
|
|
d125b7de92 | ||
|
|
d5654ee316 | ||
|
|
3b34521ad9 | ||
|
|
7197fb350b | ||
|
|
6e349bfcc7 | ||
|
|
234056072d | ||
|
|
76330f4bff | ||
|
|
d468eec6ec | ||
|
|
40e85a6759 | ||
|
|
9bc6cc5b41 | ||
|
|
d109be159c | ||
|
|
eddf31e55b | ||
|
|
7e9d0db6aa | ||
|
|
2f1874ede5 | ||
|
|
52364af5bf | ||
|
|
cc116ce67d | ||
|
|
40efc2ba43 |
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- run: git fetch --force --tags
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '>=1.24.0'
|
||||
go-version: '>=1.26.0'
|
||||
cache: true
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -3,10 +3,11 @@ cli-proxy-api
|
||||
cliproxy
|
||||
*.exe
|
||||
|
||||
|
||||
# Configuration
|
||||
config.yaml
|
||||
.env
|
||||
|
||||
.mcp.json
|
||||
# Generated content
|
||||
bin/*
|
||||
logs/*
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.24-alpine AS builder
|
||||
FROM golang:1.26-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 51 KiB |
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -26,6 +27,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
@@ -70,8 +72,10 @@ func main() {
|
||||
// Command-line flags to control the application's behavior.
|
||||
var login bool
|
||||
var codexLogin bool
|
||||
var codexDeviceLogin bool
|
||||
var claudeLogin bool
|
||||
var qwenLogin bool
|
||||
var kiloLogin bool
|
||||
var iflowLogin bool
|
||||
var iflowCookie bool
|
||||
var noBrowser bool
|
||||
@@ -88,14 +92,18 @@ func main() {
|
||||
var vertexImport string
|
||||
var configPath string
|
||||
var password string
|
||||
var tuiMode bool
|
||||
var standalone bool
|
||||
var noIncognito bool
|
||||
var useIncognito bool
|
||||
|
||||
// Define command-line flags for different operation modes.
|
||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
@@ -114,6 +122,8 @@ func main() {
|
||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
flag.StringVar(&password, "password", "", "")
|
||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||
|
||||
flag.CommandLine.Usage = func() {
|
||||
out := flag.CommandLine.Output()
|
||||
@@ -475,7 +485,7 @@ func main() {
|
||||
}
|
||||
|
||||
// Register built-in access providers before constructing services.
|
||||
configaccess.Register()
|
||||
configaccess.Register(&cfg.SDKConfig)
|
||||
|
||||
// Handle different command modes based on the provided flags.
|
||||
|
||||
@@ -494,11 +504,16 @@ func main() {
|
||||
} else if codexLogin {
|
||||
// Handle Codex login
|
||||
cmd.DoCodexLogin(cfg, options)
|
||||
} else if codexDeviceLogin {
|
||||
// Handle Codex device-code login
|
||||
cmd.DoCodexDeviceLogin(cfg, options)
|
||||
} else if claudeLogin {
|
||||
// Handle Claude login
|
||||
cmd.DoClaudeLogin(cfg, options)
|
||||
} else if qwenLogin {
|
||||
cmd.DoQwenLogin(cfg, options)
|
||||
} else if kiloLogin {
|
||||
cmd.DoKiloLogin(cfg, options)
|
||||
} else if iflowLogin {
|
||||
cmd.DoIFlowLogin(cfg, options)
|
||||
} else if iflowCookie {
|
||||
@@ -536,15 +551,89 @@ func main() {
|
||||
cmd.WaitForCloudDeploy()
|
||||
return
|
||||
}
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
if tuiMode {
|
||||
if standalone {
|
||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
hook := tui.NewLogHook(2000)
|
||||
hook.SetFormatter(&logging.LogFormatter{})
|
||||
log.AddHook(hook)
|
||||
|
||||
// 初始化并启动 Kiro token 后台刷新
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
defer kiro.StopGlobalRefreshManager()
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
origLogOutput := log.StandardLogger().Out
|
||||
log.SetOutput(io.Discard)
|
||||
|
||||
devNull, errOpenDevNull := os.Open(os.DevNull)
|
||||
if errOpenDevNull == nil {
|
||||
os.Stdout = devNull
|
||||
os.Stderr = devNull
|
||||
}
|
||||
|
||||
restoreIO := func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
log.SetOutput(origLogOutput)
|
||||
if devNull != nil {
|
||||
_ = devNull.Close()
|
||||
}
|
||||
}
|
||||
|
||||
localMgmtPassword := fmt.Sprintf("tui-%d-%d", os.Getpid(), time.Now().UnixNano())
|
||||
if password == "" {
|
||||
password = localMgmtPassword
|
||||
}
|
||||
|
||||
cancel, done := cmd.StartServiceBackground(cfg, configFilePath, password)
|
||||
|
||||
client := tui.NewClient(cfg.Port, password)
|
||||
ready := false
|
||||
backoff := 100 * time.Millisecond
|
||||
for i := 0; i < 30; i++ {
|
||||
if _, errGetConfig := client.GetConfig(); errGetConfig == nil {
|
||||
ready = true
|
||||
break
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
if backoff < time.Second {
|
||||
backoff = time.Duration(float64(backoff) * 1.5)
|
||||
}
|
||||
}
|
||||
|
||||
if !ready {
|
||||
restoreIO()
|
||||
cancel()
|
||||
<-done
|
||||
fmt.Fprintf(os.Stderr, "TUI error: embedded server is not ready\n")
|
||||
return
|
||||
}
|
||||
|
||||
if errRun := tui.Run(cfg.Port, password, hook, origStdout); errRun != nil {
|
||||
restoreIO()
|
||||
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
|
||||
} else {
|
||||
restoreIO()
|
||||
}
|
||||
|
||||
cancel()
|
||||
<-done
|
||||
} else {
|
||||
// Default TUI mode: pure management client.
|
||||
// The proxy server must already be running.
|
||||
if errRun := tui.Run(cfg.Port, password, nil, os.Stdout); errRun != nil {
|
||||
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
defer kiro.StopGlobalRefreshManager()
|
||||
}
|
||||
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
}
|
||||
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
|
||||
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
|
||||
host: ""
|
||||
host: ''
|
||||
|
||||
# Server port
|
||||
port: 8317
|
||||
@@ -8,8 +8,8 @@ port: 8317
|
||||
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
|
||||
tls:
|
||||
enable: false
|
||||
cert: ""
|
||||
key: ""
|
||||
cert: ''
|
||||
key: ''
|
||||
|
||||
# Management API settings
|
||||
remote-management:
|
||||
@@ -20,22 +20,22 @@ remote-management:
|
||||
# Management key. If a plaintext value is provided here, it will be hashed on startup.
|
||||
# All management requests (even from localhost) require this key.
|
||||
# Leave empty to disable the Management API entirely (404 for all /v0/management routes).
|
||||
secret-key: ""
|
||||
secret-key: ''
|
||||
|
||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||
disable-control-panel: false
|
||||
|
||||
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
||||
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
|
||||
|
||||
# Authentication directory (supports ~ for home directory)
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
auth-dir: '~/.cli-proxy-api'
|
||||
|
||||
# API keys for authentication
|
||||
api-keys:
|
||||
- "your-api-key-1"
|
||||
- "your-api-key-2"
|
||||
- "your-api-key-3"
|
||||
- 'your-api-key-1'
|
||||
- 'your-api-key-2'
|
||||
- 'your-api-key-3'
|
||||
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
@@ -43,7 +43,7 @@ debug: false
|
||||
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
|
||||
pprof:
|
||||
enable: false
|
||||
addr: "127.0.0.1:8316"
|
||||
addr: '127.0.0.1:8316'
|
||||
|
||||
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
|
||||
commercial-mode: false
|
||||
@@ -68,11 +68,15 @@ error-logs-max-files: 10
|
||||
usage-statistics-enabled: false
|
||||
|
||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
proxy-url: ''
|
||||
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
|
||||
# When true, forward filtered upstream response headers to downstream clients.
|
||||
# Default is false (disabled).
|
||||
passthrough-headers: false
|
||||
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
@@ -86,7 +90,7 @@ quota-exceeded:
|
||||
|
||||
# Routing strategy for selecting credentials when multiple match.
|
||||
routing:
|
||||
strategy: "round-robin" # round-robin (default), fill-first
|
||||
strategy: 'round-robin' # round-robin (default), fill-first
|
||||
|
||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||
ws-auth: false
|
||||
@@ -160,6 +164,15 @@ nonstream-keepalive-interval: 0
|
||||
# sensitive-words: # optional: words to obfuscate with zero-width characters
|
||||
# - "API"
|
||||
# - "proxy"
|
||||
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||
|
||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||
# These are used as fallbacks when the client does not send its own headers.
|
||||
# claude-header-defaults:
|
||||
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
||||
# package-version: "0.74.0"
|
||||
# runtime-version: "v24.3.0"
|
||||
# timeout: "600"
|
||||
|
||||
# Kiro (AWS CodeWhisperer) configuration
|
||||
# Note: Kiro API currently only operates in us-east-1 region
|
||||
@@ -171,6 +184,21 @@ nonstream-keepalive-interval: 0
|
||||
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
|
||||
|
||||
# Kilocode (OAuth-based code assistant)
|
||||
# Note: Kilocode uses OAuth device flow authentication.
|
||||
# Use the CLI command: ./server --kilo-login
|
||||
# This will save credentials to the auth directory (default: ~/.cli-proxy-api/)
|
||||
# oauth-model-alias:
|
||||
# kilo:
|
||||
# - name: "minimax/minimax-m2.5:free"
|
||||
# alias: "minimax-m2.5"
|
||||
# - name: "z-ai/glm-5:free"
|
||||
# alias: "glm-5"
|
||||
# oauth-excluded-models:
|
||||
# kilo:
|
||||
# - "kilo-claude-opus-4-6" # exclude specific models (exact match)
|
||||
# - "*:free" # wildcard matching suffix (e.g. all free models)
|
||||
|
||||
# OpenAI compatibility providers
|
||||
# openai-compatibility:
|
||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||
@@ -239,7 +267,7 @@ nonstream-keepalive-interval: 0
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||
#oauth-model-alias:
|
||||
# oauth-model-alias:
|
||||
# antigravity:
|
||||
# - name: "rev19-uic3-1p"
|
||||
# alias: "gemini-2.5-computer-use-preview-10-2025"
|
||||
@@ -265,9 +293,6 @@ nonstream-keepalive-interval: 0
|
||||
# aistudio:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# antigravity:
|
||||
# - name: "gemini-3-pro-preview"
|
||||
# alias: "g3p"
|
||||
# claude:
|
||||
# - name: "claude-sonnet-4-5-20250929"
|
||||
# alias: "cs4.5"
|
||||
|
||||
@@ -7,80 +7,71 @@ The `github.com/router-for-me/CLIProxyAPI/v6/sdk/access` package centralizes inb
|
||||
```go
|
||||
import (
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
```
|
||||
|
||||
Add the module with `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access`.
|
||||
|
||||
## Provider Registry
|
||||
|
||||
Providers are registered globally and then attached to a `Manager` as a snapshot:
|
||||
|
||||
- `RegisterProvider(type, provider)` installs a pre-initialized provider instance.
|
||||
- Registration order is preserved the first time each `type` is seen.
|
||||
- `RegisteredProviders()` returns the providers in that order.
|
||||
|
||||
## Manager Lifecycle
|
||||
|
||||
```go
|
||||
manager := sdkaccess.NewManager()
|
||||
providers, err := sdkaccess.BuildProviders(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manager.SetProviders(providers)
|
||||
manager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
* `NewManager` constructs an empty manager.
|
||||
* `SetProviders` replaces the provider slice using a defensive copy.
|
||||
* `Providers` retrieves a snapshot that can be iterated safely from other goroutines.
|
||||
* `BuildProviders` translates `config.Config` access declarations into runnable providers. When the config omits explicit providers but defines inline API keys, the helper auto-installs the built-in `config-api-key` provider.
|
||||
|
||||
If the manager itself is `nil` or no providers are configured, the call returns `nil, nil`, allowing callers to treat access control as disabled.
|
||||
|
||||
## Authenticating Requests
|
||||
|
||||
```go
|
||||
result, err := manager.Authenticate(ctx, req)
|
||||
result, authErr := manager.Authenticate(ctx, req)
|
||||
switch {
|
||||
case err == nil:
|
||||
case authErr == nil:
|
||||
// Authentication succeeded; result describes the provider and principal.
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
|
||||
// No recognizable credentials were supplied.
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
|
||||
// Supplied credentials were present but rejected.
|
||||
default:
|
||||
// Transport-level failure was returned by a provider.
|
||||
// Internal/transport failure was returned by a provider.
|
||||
}
|
||||
```
|
||||
|
||||
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that surface `ErrNotHandled`, and tracks whether any provider reported `ErrNoCredentials` or `ErrInvalidCredential` for downstream error reporting.
|
||||
|
||||
If the manager itself is `nil` or no providers are registered, the call returns `nil, nil`, allowing callers to treat access control as disabled without branching on errors.
|
||||
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that return `AuthErrorCodeNotHandled`, and aggregates `AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` for a final result.
|
||||
|
||||
Each `Result` includes the provider identifier, the resolved principal, and optional metadata (for example, which header carried the credential).
|
||||
|
||||
## Configuration Layout
|
||||
## Built-in `config-api-key` Provider
|
||||
|
||||
The manager expects access providers under the `auth.providers` key inside `config.yaml`:
|
||||
The proxy includes one built-in access provider:
|
||||
|
||||
- `config-api-key`: Validates API keys declared under top-level `api-keys`.
|
||||
- Credential sources: `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, `?key=`, `?auth_token=`
|
||||
- Metadata: `Result.Metadata["source"]` is set to the matched source label.
|
||||
|
||||
In the CLI server and `sdk/cliproxy`, this provider is registered automatically based on the loaded configuration.
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: inline-api
|
||||
type: config-api-key
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
```
|
||||
|
||||
Fields map directly to `config.AccessProvider`: `name` labels the provider, `type` selects the registered factory, `sdk` can name an external module, `api-keys` seeds inline credentials, and `config` passes provider-specific options.
|
||||
## Loading Providers from External Go Modules
|
||||
|
||||
### Loading providers from external SDK modules
|
||||
|
||||
To consume a provider shipped in another Go module, point the `sdk` field at the module path and import it for its registration side effect:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: partner-auth
|
||||
type: partner-token
|
||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
||||
config:
|
||||
region: us-west-2
|
||||
audience: cli-proxy
|
||||
```
|
||||
To consume a provider shipped in another Go module, import it for its registration side effect:
|
||||
|
||||
```go
|
||||
import (
|
||||
@@ -89,19 +80,11 @@ import (
|
||||
)
|
||||
```
|
||||
|
||||
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before `BuildProviders` is called.
|
||||
|
||||
## Built-in Providers
|
||||
|
||||
The SDK ships with one provider out of the box:
|
||||
|
||||
- `config-api-key`: Validates API keys declared inline or under top-level `api-keys`. It accepts the key from `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, or the `?key=` query string and reports `ErrInvalidCredential` when no match is found.
|
||||
|
||||
Additional providers can be delivered by third-party packages. When a provider package is imported, it registers itself with `sdkaccess.RegisterProvider`.
|
||||
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before you call `RegisteredProviders()` (or before `cliproxy.NewBuilder().Build()`).
|
||||
|
||||
### Metadata and auditing
|
||||
|
||||
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, or `query-key`). Populate this map in custom providers to enrich logs and downstream auditing.
|
||||
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, `query-key`, `query-auth-token`). Populate this map in custom providers to enrich logs and downstream auditing.
|
||||
|
||||
## Writing Custom Providers
|
||||
|
||||
@@ -110,13 +93,13 @@ type customProvider struct{}
|
||||
|
||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
token := r.Header.Get("X-Custom")
|
||||
if token == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if token != "expected" {
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
return &sdkaccess.Result{
|
||||
Provider: p.Identifier(),
|
||||
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
|
||||
}
|
||||
|
||||
func init() {
|
||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
||||
return &customProvider{}, nil
|
||||
})
|
||||
sdkaccess.RegisterProvider("custom", &customProvider{})
|
||||
}
|
||||
```
|
||||
|
||||
A provider must implement `Identifier()` and `Authenticate()`. To expose it to configuration, call `RegisterProvider` inside `init`. Provider factories receive the specific `AccessProvider` block plus the full root configuration for contextual needs.
|
||||
A provider must implement `Identifier()` and `Authenticate()`. To make it available to the access manager, call `RegisterProvider` inside `init` with an initialized provider instance.
|
||||
|
||||
## Error Semantics
|
||||
|
||||
- `ErrNoCredentials`: no credentials were present or recognized by any provider.
|
||||
- `ErrInvalidCredential`: at least one provider processed the credentials but rejected them.
|
||||
- `ErrNotHandled`: instructs the manager to fall through to the next provider without affecting aggregate error reporting.
|
||||
- `NewNoCredentialsError()` (`AuthErrorCodeNoCredentials`): no credentials were present or recognized. (HTTP 401)
|
||||
- `NewInvalidCredentialError()` (`AuthErrorCodeInvalidCredential`): credentials were present but rejected. (HTTP 401)
|
||||
- `NewNotHandledError()` (`AuthErrorCodeNotHandled`): fall through to the next provider.
|
||||
- `NewInternalAuthError(message, cause)` (`AuthErrorCodeInternal`): transport/system failure. (HTTP 500)
|
||||
|
||||
Return custom errors to surface transport failures; they propagate immediately to the caller instead of being masked.
|
||||
Errors propagate immediately to the caller unless they are classified as `not_handled` / `no_credentials` / `invalid_credential` and can be aggregated by the manager.
|
||||
|
||||
## Integration with cliproxy Service
|
||||
|
||||
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a preconfigured manager allows you to extend or override the default providers:
|
||||
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a manager lets you reuse the same instance in your host process:
|
||||
|
||||
```go
|
||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
||||
manager := sdkaccess.NewManager()
|
||||
manager.SetProviders(providers)
|
||||
accessManager := sdkaccess.NewManager()
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(coreCfg).
|
||||
WithAccessManager(manager).
|
||||
WithConfigPath("config.yaml").
|
||||
WithRequestAccessManager(accessManager).
|
||||
Build()
|
||||
```
|
||||
|
||||
The service reuses the manager for every inbound request, ensuring consistent authentication across embedded deployments and the canonical CLI binary.
|
||||
Register any custom providers (typically via blank imports) before calling `Build()` so they are present in the global registry snapshot.
|
||||
|
||||
### Hot reloading providers
|
||||
### Hot reloading
|
||||
|
||||
When configuration changes, rebuild providers and swap them into the manager:
|
||||
When configuration changes, refresh any config-backed providers and then reset the manager's provider chain:
|
||||
|
||||
```go
|
||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
||||
if err != nil {
|
||||
log.Errorf("reload auth providers failed: %v", err)
|
||||
return
|
||||
}
|
||||
accessManager.SetProviders(providers)
|
||||
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
This mirrors the behaviour in `cliproxy.Service.refreshAccessProviders` and `api.Server.applyAccessConfig`, enabling runtime updates without restarting the process.
|
||||
This mirrors the behaviour in `internal/access.ApplyAccessProviders`, enabling runtime updates without restarting the process.
|
||||
|
||||
@@ -7,80 +7,71 @@
|
||||
```go
|
||||
import (
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
```
|
||||
|
||||
通过 `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access` 添加依赖。
|
||||
|
||||
## Provider Registry
|
||||
|
||||
访问提供者是全局注册,然后以快照形式挂到 `Manager` 上:
|
||||
|
||||
- `RegisterProvider(type, provider)` 注册一个已经初始化好的 provider 实例。
|
||||
- 每个 `type` 第一次出现时会记录其注册顺序。
|
||||
- `RegisteredProviders()` 会按该顺序返回 provider 列表。
|
||||
|
||||
## 管理器生命周期
|
||||
|
||||
```go
|
||||
manager := sdkaccess.NewManager()
|
||||
providers, err := sdkaccess.BuildProviders(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manager.SetProviders(providers)
|
||||
manager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
- `NewManager` 创建空管理器。
|
||||
- `SetProviders` 替换提供者切片并做防御性拷贝。
|
||||
- `Providers` 返回适合并发读取的快照。
|
||||
- `BuildProviders` 将 `config.Config` 中的访问配置转换成可运行的提供者。当配置没有显式声明但包含顶层 `api-keys` 时,会自动挂载内建的 `config-api-key` 提供者。
|
||||
|
||||
如果管理器本身为 `nil` 或未配置任何 provider,调用会返回 `nil, nil`,可视为关闭访问控制。
|
||||
|
||||
## 认证请求
|
||||
|
||||
```go
|
||||
result, err := manager.Authenticate(ctx, req)
|
||||
result, authErr := manager.Authenticate(ctx, req)
|
||||
switch {
|
||||
case err == nil:
|
||||
case authErr == nil:
|
||||
// Authentication succeeded; result carries provider and principal.
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
|
||||
// No recognizable credentials were supplied.
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
|
||||
// Credentials were present but rejected.
|
||||
default:
|
||||
// Provider surfaced a transport-level failure.
|
||||
}
|
||||
```
|
||||
|
||||
`Manager.Authenticate` 按配置顺序遍历提供者。遇到成功立即返回,`ErrNotHandled` 会继续尝试下一个;若发现 `ErrNoCredentials` 或 `ErrInvalidCredential`,会在遍历结束后汇总给调用方。
|
||||
|
||||
若管理器本身为 `nil` 或尚未注册提供者,调用会返回 `nil, nil`,让调用方无需针对错误做额外分支即可关闭访问控制。
|
||||
`Manager.Authenticate` 会按顺序遍历 provider:遇到成功立即返回,`AuthErrorCodeNotHandled` 会继续尝试下一个;`AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` 会在遍历结束后汇总给调用方。
|
||||
|
||||
`Result` 提供认证提供者标识、解析出的主体以及可选元数据(例如凭证来源)。
|
||||
|
||||
## 配置结构
|
||||
## 内建 `config-api-key` Provider
|
||||
|
||||
在 `config.yaml` 的 `auth.providers` 下定义访问提供者:
|
||||
代理内置一个访问提供者:
|
||||
|
||||
- `config-api-key`:校验 `config.yaml` 顶层的 `api-keys`。
|
||||
- 凭证来源:`Authorization: Bearer`、`X-Goog-Api-Key`、`X-Api-Key`、`?key=`、`?auth_token=`
|
||||
- 元数据:`Result.Metadata["source"]` 会写入匹配到的来源标识
|
||||
|
||||
在 CLI 服务端与 `sdk/cliproxy` 中,该 provider 会根据加载到的配置自动注册。
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: inline-api
|
||||
type: config-api-key
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
```
|
||||
|
||||
条目映射到 `config.AccessProvider`:`name` 指定实例名,`type` 选择注册的工厂,`sdk` 可引用第三方模块,`api-keys` 提供内联凭证,`config` 用于传递特定选项。
|
||||
## 引入外部 Go 模块提供者
|
||||
|
||||
### 引入外部 SDK 提供者
|
||||
|
||||
若要消费其它 Go 模块输出的访问提供者,可在配置里填写 `sdk` 字段并在代码中引入该包,利用其 `init` 注册过程:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: partner-auth
|
||||
type: partner-token
|
||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
||||
config:
|
||||
region: us-west-2
|
||||
audience: cli-proxy
|
||||
```
|
||||
若要消费其它 Go 模块输出的访问提供者,直接用空白标识符导入以触发其 `init` 注册即可:
|
||||
|
||||
```go
|
||||
import (
|
||||
@@ -89,19 +80,11 @@ import (
|
||||
)
|
||||
```
|
||||
|
||||
通过空白标识符导入即可确保 `init` 调用,先于 `BuildProviders` 完成 `sdkaccess.RegisterProvider`。
|
||||
|
||||
## 内建提供者
|
||||
|
||||
当前 SDK 默认内置:
|
||||
|
||||
- `config-api-key`:校验配置中的 API Key。它从 `Authorization: Bearer`、`X-Goog-Api-Key`、`X-Api-Key` 以及查询参数 `?key=` 提取凭证,不匹配时抛出 `ErrInvalidCredential`。
|
||||
|
||||
导入第三方包即可通过 `sdkaccess.RegisterProvider` 注册更多类型。
|
||||
空白导入可确保 `init` 先执行,从而在你调用 `RegisteredProviders()`(或 `cliproxy.NewBuilder().Build()`)之前完成 `sdkaccess.RegisterProvider`。
|
||||
|
||||
### 元数据与审计
|
||||
|
||||
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization`、`x-goog-api-key`、`x-api-key` 或 `query-key`)。自定义提供者同样可以填充该 Map,以便丰富日志与审计场景。
|
||||
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization`、`x-goog-api-key`、`x-api-key`、`query-key`、`query-auth-token`)。自定义提供者同样可以填充该 Map,以便丰富日志与审计场景。
|
||||
|
||||
## 编写自定义提供者
|
||||
|
||||
@@ -110,13 +93,13 @@ type customProvider struct{}
|
||||
|
||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
token := r.Header.Get("X-Custom")
|
||||
if token == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if token != "expected" {
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
return &sdkaccess.Result{
|
||||
Provider: p.Identifier(),
|
||||
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
|
||||
}
|
||||
|
||||
func init() {
|
||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
||||
return &customProvider{}, nil
|
||||
})
|
||||
sdkaccess.RegisterProvider("custom", &customProvider{})
|
||||
}
|
||||
```
|
||||
|
||||
自定义提供者需要实现 `Identifier()` 与 `Authenticate()`。在 `init` 中调用 `RegisterProvider` 暴露给配置层,工厂函数既能读取当前条目,也能访问完整根配置。
|
||||
自定义提供者需要实现 `Identifier()` 与 `Authenticate()`。在 `init` 中用已初始化实例调用 `RegisterProvider` 注册到全局 registry。
|
||||
|
||||
## 错误语义
|
||||
|
||||
- `ErrNoCredentials`:任何提供者都未识别到凭证。
|
||||
- `ErrInvalidCredential`:至少一个提供者处理了凭证但判定无效。
|
||||
- `ErrNotHandled`:告诉管理器跳到下一个提供者,不影响最终错误统计。
|
||||
- `NewNoCredentialsError()`(`AuthErrorCodeNoCredentials`):未提供或未识别到凭证。(HTTP 401)
|
||||
- `NewInvalidCredentialError()`(`AuthErrorCodeInvalidCredential`):凭证存在但校验失败。(HTTP 401)
|
||||
- `NewNotHandledError()`(`AuthErrorCodeNotHandled`):告诉管理器跳到下一个 provider。
|
||||
- `NewInternalAuthError(message, cause)`(`AuthErrorCodeInternal`):网络/系统错误。(HTTP 500)
|
||||
|
||||
自定义错误(例如网络异常)会马上冒泡返回。
|
||||
除可汇总的 `not_handled` / `no_credentials` / `invalid_credential` 外,其它错误会立即冒泡返回。
|
||||
|
||||
## 与 cliproxy 集成
|
||||
|
||||
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果需要扩展内置行为,可传入自定义管理器:
|
||||
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果希望在宿主进程里复用同一个 `Manager` 实例,可传入自定义管理器:
|
||||
|
||||
```go
|
||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
||||
manager := sdkaccess.NewManager()
|
||||
manager.SetProviders(providers)
|
||||
accessManager := sdkaccess.NewManager()
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(coreCfg).
|
||||
WithAccessManager(manager).
|
||||
WithConfigPath("config.yaml").
|
||||
WithRequestAccessManager(accessManager).
|
||||
Build()
|
||||
```
|
||||
|
||||
服务会复用该管理器处理每一个入站请求,实现与 CLI 二进制一致的访问控制体验。
|
||||
请在调用 `Build()` 之前完成自定义 provider 的注册(通常通过空白导入触发 `init`),以确保它们被包含在全局 registry 的快照中。
|
||||
|
||||
### 动态热更新提供者
|
||||
|
||||
当配置发生变化时,可以重新构建提供者并替换当前列表:
|
||||
当配置发生变化时,刷新依赖配置的 provider,然后重置 manager 的 provider 链:
|
||||
|
||||
```go
|
||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
||||
if err != nil {
|
||||
log.Errorf("reload auth providers failed: %v", err)
|
||||
return
|
||||
}
|
||||
accessManager.SetProviders(providers)
|
||||
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
这一流程与 `cliproxy.Service.refreshAccessProviders` 和 `api.Server.applyAccessConfig` 保持一致,避免为更新访问策略而重启进程。
|
||||
这一流程与 `internal/access.ApplyAccessProviders` 保持一致,避免为更新访问策略而重启进程。
|
||||
|
||||
@@ -159,13 +159,13 @@ func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request,
|
||||
return clipexec.Response{}, errors.New("count tokens not implemented")
|
||||
}
|
||||
|
||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) {
|
||||
ch := make(chan clipexec.StreamChunk, 1)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
|
||||
}()
|
||||
return ch, nil
|
||||
return &clipexec.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
|
||||
@@ -58,7 +58,7 @@ func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, c
|
||||
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) {
|
||||
return nil, errors.New("echo executor: ExecuteStream not implemented")
|
||||
}
|
||||
|
||||
|
||||
23
go.mod
23
go.mod
@@ -1,9 +1,13 @@
|
||||
module github.com/router-for-me/CLIProxyAPI/v6
|
||||
|
||||
go 1.24.0
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.0.6
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/charmbracelet/bubbles v1.0.0
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/fxamacker/cbor/v2 v2.9.0
|
||||
github.com/gin-gonic/gin v1.10.1
|
||||
@@ -33,8 +37,16 @@ require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.6 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
||||
github.com/cloudflare/circl v1.6.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
@@ -42,6 +54,7 @@ require (
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-git/gcfg/v2 v2.0.2 // indirect
|
||||
@@ -58,19 +71,27 @@ require (
|
||||
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/sergi/go-diff v1.4.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
|
||||
45
go.sum
45
go.sum
@@ -10,10 +10,34 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
|
||||
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
|
||||
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
|
||||
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
|
||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
||||
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
@@ -33,6 +57,8 @@ github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o
|
||||
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
|
||||
@@ -101,8 +127,14 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
|
||||
@@ -114,6 +146,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||
@@ -124,6 +162,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
@@ -161,6 +201,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
@@ -168,12 +210,15 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -4,19 +4,28 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
var registerOnce sync.Once
|
||||
|
||||
// Register ensures the config-access provider is available to the access manager.
|
||||
func Register() {
|
||||
registerOnce.Do(func() {
|
||||
sdkaccess.RegisterProvider(sdkconfig.AccessProviderTypeConfigAPIKey, newProvider)
|
||||
})
|
||||
func Register(cfg *sdkconfig.SDKConfig) {
|
||||
if cfg == nil {
|
||||
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
|
||||
return
|
||||
}
|
||||
|
||||
keys := normalizeKeys(cfg.APIKeys)
|
||||
if len(keys) == 0 {
|
||||
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
|
||||
return
|
||||
}
|
||||
|
||||
sdkaccess.RegisterProvider(
|
||||
sdkaccess.AccessProviderTypeConfigAPIKey,
|
||||
newProvider(sdkaccess.DefaultAccessProviderName, keys),
|
||||
)
|
||||
}
|
||||
|
||||
type provider struct {
|
||||
@@ -24,34 +33,31 @@ type provider struct {
|
||||
keys map[string]struct{}
|
||||
}
|
||||
|
||||
func newProvider(cfg *sdkconfig.AccessProvider, _ *sdkconfig.SDKConfig) (sdkaccess.Provider, error) {
|
||||
name := cfg.Name
|
||||
if name == "" {
|
||||
name = sdkconfig.DefaultAccessProviderName
|
||||
func newProvider(name string, keys []string) *provider {
|
||||
providerName := strings.TrimSpace(name)
|
||||
if providerName == "" {
|
||||
providerName = sdkaccess.DefaultAccessProviderName
|
||||
}
|
||||
keys := make(map[string]struct{}, len(cfg.APIKeys))
|
||||
for _, key := range cfg.APIKeys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
keys[key] = struct{}{}
|
||||
keySet := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
return &provider{name: name, keys: keys}, nil
|
||||
return &provider{name: providerName, keys: keySet}
|
||||
}
|
||||
|
||||
func (p *provider) Identifier() string {
|
||||
if p == nil || p.name == "" {
|
||||
return sdkconfig.DefaultAccessProviderName
|
||||
return sdkaccess.DefaultAccessProviderName
|
||||
}
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
if p == nil {
|
||||
return nil, sdkaccess.ErrNotHandled
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if len(p.keys) == 0 {
|
||||
return nil, sdkaccess.ErrNotHandled
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
||||
@@ -63,7 +69,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
||||
queryAuthToken = r.URL.Query().Get("auth_token")
|
||||
}
|
||||
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNoCredentialsError()
|
||||
}
|
||||
|
||||
apiKey := extractBearerToken(authHeader)
|
||||
@@ -94,7 +100,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
||||
}
|
||||
}
|
||||
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
|
||||
func extractBearerToken(header string) string {
|
||||
@@ -110,3 +116,26 @@ func extractBearerToken(header string) string {
|
||||
}
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
func normalizeKeys(keys []string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]string, 0, len(keys))
|
||||
seen := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
trimmedKey := strings.TrimSpace(key)
|
||||
if trimmedKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[trimmedKey]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmedKey] = struct{}{}
|
||||
normalized = append(normalized, trimmedKey)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkConfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -17,26 +17,26 @@ import (
|
||||
// ordered provider slice along with the identifiers of providers that were added, updated, or
|
||||
// removed compared to the previous configuration.
|
||||
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Provider) (result []sdkaccess.Provider, added, updated, removed []string, err error) {
|
||||
_ = oldCfg
|
||||
if newCfg == nil {
|
||||
return nil, nil, nil, nil, nil
|
||||
}
|
||||
|
||||
result = sdkaccess.RegisteredProviders()
|
||||
|
||||
existingMap := make(map[string]sdkaccess.Provider, len(existing))
|
||||
for _, provider := range existing {
|
||||
if provider == nil {
|
||||
providerID := identifierFromProvider(provider)
|
||||
if providerID == "" {
|
||||
continue
|
||||
}
|
||||
existingMap[provider.Identifier()] = provider
|
||||
existingMap[providerID] = provider
|
||||
}
|
||||
|
||||
oldCfgMap := accessProviderMap(oldCfg)
|
||||
newEntries := collectProviderEntries(newCfg)
|
||||
|
||||
result = make([]sdkaccess.Provider, 0, len(newEntries))
|
||||
finalIDs := make(map[string]struct{}, len(newEntries))
|
||||
finalIDs := make(map[string]struct{}, len(result))
|
||||
|
||||
isInlineProvider := func(id string) bool {
|
||||
return strings.EqualFold(id, sdkConfig.DefaultAccessProviderName)
|
||||
return strings.EqualFold(id, sdkaccess.DefaultAccessProviderName)
|
||||
}
|
||||
appendChange := func(list *[]string, id string) {
|
||||
if isInlineProvider(id) {
|
||||
@@ -45,85 +45,28 @@ func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Prov
|
||||
*list = append(*list, id)
|
||||
}
|
||||
|
||||
for _, providerCfg := range newEntries {
|
||||
key := providerIdentifier(providerCfg)
|
||||
if key == "" {
|
||||
for _, provider := range result {
|
||||
providerID := identifierFromProvider(provider)
|
||||
if providerID == "" {
|
||||
continue
|
||||
}
|
||||
finalIDs[providerID] = struct{}{}
|
||||
|
||||
forceRebuild := strings.EqualFold(strings.TrimSpace(providerCfg.Type), sdkConfig.AccessProviderTypeConfigAPIKey)
|
||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||
isAliased := oldCfgProvider == providerCfg
|
||||
if !forceRebuild && !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
|
||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||
result = append(result, existingProvider)
|
||||
finalIDs[key] = struct{}{}
|
||||
continue
|
||||
}
|
||||
}
|
||||
existingProvider, exists := existingMap[providerID]
|
||||
if !exists {
|
||||
appendChange(&added, providerID)
|
||||
continue
|
||||
}
|
||||
|
||||
provider, buildErr := sdkaccess.BuildProvider(providerCfg, &newCfg.SDKConfig)
|
||||
if buildErr != nil {
|
||||
return nil, nil, nil, nil, buildErr
|
||||
}
|
||||
if _, ok := oldCfgMap[key]; ok {
|
||||
if _, existed := existingMap[key]; existed {
|
||||
appendChange(&updated, key)
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
result = append(result, provider)
|
||||
finalIDs[key] = struct{}{}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(newCfg.APIKeys); inline != nil {
|
||||
key := providerIdentifier(inline)
|
||||
if key != "" {
|
||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||
if providerConfigEqual(oldCfgProvider, inline) {
|
||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||
result = append(result, existingProvider)
|
||||
finalIDs[key] = struct{}{}
|
||||
goto inlineDone
|
||||
}
|
||||
}
|
||||
}
|
||||
provider, buildErr := sdkaccess.BuildProvider(inline, &newCfg.SDKConfig)
|
||||
if buildErr != nil {
|
||||
return nil, nil, nil, nil, buildErr
|
||||
}
|
||||
if _, existed := existingMap[key]; existed {
|
||||
appendChange(&updated, key)
|
||||
} else if _, hadOld := oldCfgMap[key]; hadOld {
|
||||
appendChange(&updated, key)
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
result = append(result, provider)
|
||||
finalIDs[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
inlineDone:
|
||||
}
|
||||
|
||||
removedSet := make(map[string]struct{})
|
||||
for id := range existingMap {
|
||||
if _, ok := finalIDs[id]; !ok {
|
||||
if isInlineProvider(id) {
|
||||
continue
|
||||
}
|
||||
removedSet[id] = struct{}{}
|
||||
if !providerInstanceEqual(existingProvider, provider) {
|
||||
appendChange(&updated, providerID)
|
||||
}
|
||||
}
|
||||
|
||||
removed = make([]string, 0, len(removedSet))
|
||||
for id := range removedSet {
|
||||
removed = append(removed, id)
|
||||
for providerID := range existingMap {
|
||||
if _, exists := finalIDs[providerID]; exists {
|
||||
continue
|
||||
}
|
||||
appendChange(&removed, providerID)
|
||||
}
|
||||
|
||||
sort.Strings(added)
|
||||
@@ -142,6 +85,7 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
|
||||
}
|
||||
|
||||
existing := manager.Providers()
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
providers, added, updated, removed, err := ReconcileProviders(oldCfg, newCfg, existing)
|
||||
if err != nil {
|
||||
log.Errorf("failed to reconcile request auth providers: %v", err)
|
||||
@@ -160,111 +104,24 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func accessProviderMap(cfg *config.Config) map[string]*sdkConfig.AccessProvider {
|
||||
result := make(map[string]*sdkConfig.AccessProvider)
|
||||
if cfg == nil {
|
||||
return result
|
||||
}
|
||||
for i := range cfg.Access.Providers {
|
||||
providerCfg := &cfg.Access.Providers[i]
|
||||
if providerCfg.Type == "" {
|
||||
continue
|
||||
}
|
||||
key := providerIdentifier(providerCfg)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
result[key] = providerCfg
|
||||
}
|
||||
if len(result) == 0 && len(cfg.APIKeys) > 0 {
|
||||
if provider := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); provider != nil {
|
||||
if key := providerIdentifier(provider); key != "" {
|
||||
result[key] = provider
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func collectProviderEntries(cfg *config.Config) []*sdkConfig.AccessProvider {
|
||||
entries := make([]*sdkConfig.AccessProvider, 0, len(cfg.Access.Providers))
|
||||
for i := range cfg.Access.Providers {
|
||||
providerCfg := &cfg.Access.Providers[i]
|
||||
if providerCfg.Type == "" {
|
||||
continue
|
||||
}
|
||||
if key := providerIdentifier(providerCfg); key != "" {
|
||||
entries = append(entries, providerCfg)
|
||||
}
|
||||
}
|
||||
if len(entries) == 0 && len(cfg.APIKeys) > 0 {
|
||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); inline != nil {
|
||||
entries = append(entries, inline)
|
||||
}
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func providerIdentifier(provider *sdkConfig.AccessProvider) string {
|
||||
func identifierFromProvider(provider sdkaccess.Provider) string {
|
||||
if provider == nil {
|
||||
return ""
|
||||
}
|
||||
if name := strings.TrimSpace(provider.Name); name != "" {
|
||||
return name
|
||||
}
|
||||
typ := strings.TrimSpace(provider.Type)
|
||||
if typ == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.EqualFold(typ, sdkConfig.AccessProviderTypeConfigAPIKey) {
|
||||
return sdkConfig.DefaultAccessProviderName
|
||||
}
|
||||
return typ
|
||||
return strings.TrimSpace(provider.Identifier())
|
||||
}
|
||||
|
||||
func providerConfigEqual(a, b *sdkConfig.AccessProvider) bool {
|
||||
func providerInstanceEqual(a, b sdkaccess.Provider) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == nil && b == nil
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) {
|
||||
if reflect.TypeOf(a) != reflect.TypeOf(b) {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(a.SDK) != strings.TrimSpace(b.SDK) {
|
||||
return false
|
||||
valueA := reflect.ValueOf(a)
|
||||
valueB := reflect.ValueOf(b)
|
||||
if valueA.Kind() == reflect.Pointer && valueB.Kind() == reflect.Pointer {
|
||||
return valueA.Pointer() == valueB.Pointer()
|
||||
}
|
||||
if !stringSetEqual(a.APIKeys, b.APIKeys) {
|
||||
return false
|
||||
}
|
||||
if len(a.Config) != len(b.Config) {
|
||||
return false
|
||||
}
|
||||
if len(a.Config) > 0 && !reflect.DeepEqual(a.Config, b.Config) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func stringSetEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
seen := make(map[string]int, len(a))
|
||||
for _, val := range a {
|
||||
seen[val]++
|
||||
}
|
||||
for _, val := range b {
|
||||
count := seen[val]
|
||||
if count == 0 {
|
||||
return false
|
||||
}
|
||||
if count == 1 {
|
||||
delete(seen, val)
|
||||
} else {
|
||||
seen[val] = count - 1
|
||||
}
|
||||
}
|
||||
return len(seen) == 0
|
||||
return reflect.DeepEqual(a, b)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -189,9 +190,21 @@ func (h *Handler) APICall(c *gin.Context) {
|
||||
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
|
||||
}
|
||||
|
||||
// When caller indicates CBOR in request headers, convert JSON string payload to CBOR bytes.
|
||||
useCBORPayload := headerContainsValue(reqHeaders, "Content-Type", "application/cbor")
|
||||
|
||||
var requestBody io.Reader
|
||||
if body.Data != "" {
|
||||
requestBody = strings.NewReader(body.Data)
|
||||
if useCBORPayload {
|
||||
cborPayload, errEncode := encodeJSONStringToCBOR(body.Data)
|
||||
if errEncode != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json data for cbor content-type"})
|
||||
return
|
||||
}
|
||||
requestBody = bytes.NewReader(cborPayload)
|
||||
} else {
|
||||
requestBody = strings.NewReader(body.Data)
|
||||
}
|
||||
}
|
||||
|
||||
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
|
||||
@@ -234,10 +247,18 @@ func (h *Handler) APICall(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// For CBOR upstream responses, decode into plain text or JSON string before returning.
|
||||
responseBodyText := string(respBody)
|
||||
if headerContainsValue(reqHeaders, "Accept", "application/cbor") || strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "application/cbor") {
|
||||
if decodedBody, errDecode := decodeCBORBodyToTextOrJSON(respBody); errDecode == nil {
|
||||
responseBodyText = decodedBody
|
||||
}
|
||||
}
|
||||
|
||||
response := apiCallResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header,
|
||||
Body: string(respBody),
|
||||
Body: responseBodyText,
|
||||
}
|
||||
|
||||
// If this is a GitHub Copilot token endpoint response, try to enrich with quota information
|
||||
@@ -747,6 +768,83 @@ func buildProxyTransport(proxyStr string) *http.Transport {
|
||||
return nil
|
||||
}
|
||||
|
||||
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
|
||||
func headerContainsValue(headers map[string]string, targetKey, targetValue string) bool {
|
||||
if len(headers) == 0 {
|
||||
return false
|
||||
}
|
||||
for key, value := range headers {
|
||||
if !strings.EqualFold(strings.TrimSpace(key), strings.TrimSpace(targetKey)) {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(strings.ToLower(value), strings.ToLower(strings.TrimSpace(targetValue))) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// encodeJSONStringToCBOR converts a JSON string payload into CBOR bytes.
|
||||
func encodeJSONStringToCBOR(jsonString string) ([]byte, error) {
|
||||
var payload any
|
||||
if errUnmarshal := json.Unmarshal([]byte(jsonString), &payload); errUnmarshal != nil {
|
||||
return nil, errUnmarshal
|
||||
}
|
||||
return cbor.Marshal(payload)
|
||||
}
|
||||
|
||||
// decodeCBORBodyToTextOrJSON decodes CBOR bytes to plain text (for string payloads) or JSON string.
|
||||
func decodeCBORBodyToTextOrJSON(raw []byte) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var payload any
|
||||
if errUnmarshal := cbor.Unmarshal(raw, &payload); errUnmarshal != nil {
|
||||
return "", errUnmarshal
|
||||
}
|
||||
|
||||
jsonCompatible := cborValueToJSONCompatible(payload)
|
||||
switch typed := jsonCompatible.(type) {
|
||||
case string:
|
||||
return typed, nil
|
||||
case []byte:
|
||||
return string(typed), nil
|
||||
default:
|
||||
jsonBytes, errMarshal := json.Marshal(jsonCompatible)
|
||||
if errMarshal != nil {
|
||||
return "", errMarshal
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
}
|
||||
|
||||
// cborValueToJSONCompatible recursively converts CBOR-decoded values into JSON-marshalable values.
|
||||
func cborValueToJSONCompatible(value any) any {
|
||||
switch typed := value.(type) {
|
||||
case map[any]any:
|
||||
out := make(map[string]any, len(typed))
|
||||
for key, item := range typed {
|
||||
out[fmt.Sprint(key)] = cborValueToJSONCompatible(item)
|
||||
}
|
||||
return out
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(typed))
|
||||
for key, item := range typed {
|
||||
out[key] = cborValueToJSONCompatible(item)
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]any, len(typed))
|
||||
for i, item := range typed {
|
||||
out[i] = cborValueToJSONCompatible(item)
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return typed
|
||||
}
|
||||
}
|
||||
|
||||
// QuotaDetail represents quota information for a specific resource type
|
||||
type QuotaDetail struct {
|
||||
Entitlement float64 `json:"entitlement"`
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
@@ -411,6 +412,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
||||
if !auth.LastRefreshedAt.IsZero() {
|
||||
entry["last_refresh"] = auth.LastRefreshedAt
|
||||
}
|
||||
if !auth.NextRetryAfter.IsZero() {
|
||||
entry["next_retry_after"] = auth.NextRetryAfter
|
||||
}
|
||||
if path != "" {
|
||||
entry["path"] = path
|
||||
entry["source"] = "file"
|
||||
@@ -813,6 +817,87 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||
}
|
||||
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
|
||||
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
if h.authManager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Prefix *string `json:"prefix"`
|
||||
ProxyURL *string `json:"proxy_url"`
|
||||
Priority *int `json:"priority"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(req.Name)
|
||||
if name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Find auth by name or ID
|
||||
var targetAuth *coreauth.Auth
|
||||
if auth, ok := h.authManager.GetByID(name); ok {
|
||||
targetAuth = auth
|
||||
} else {
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth.FileName == name {
|
||||
targetAuth = auth
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if targetAuth == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
|
||||
return
|
||||
}
|
||||
|
||||
changed := false
|
||||
if req.Prefix != nil {
|
||||
targetAuth.Prefix = *req.Prefix
|
||||
changed = true
|
||||
}
|
||||
if req.ProxyURL != nil {
|
||||
targetAuth.ProxyURL = *req.ProxyURL
|
||||
changed = true
|
||||
}
|
||||
if req.Priority != nil {
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if *req.Priority == 0 {
|
||||
delete(targetAuth.Metadata, "priority")
|
||||
} else {
|
||||
targetAuth.Metadata["priority"] = *req.Priority
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
|
||||
if !changed {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"})
|
||||
return
|
||||
}
|
||||
|
||||
targetAuth.UpdatedAt = time.Now()
|
||||
|
||||
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||
if h == nil || h.authManager == nil {
|
||||
return
|
||||
@@ -869,11 +954,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
||||
if store == nil {
|
||||
return "", fmt.Errorf("token store unavailable")
|
||||
}
|
||||
if h.postAuthHook != nil {
|
||||
if err := h.postAuthHook(ctx, record); err != nil {
|
||||
return "", fmt.Errorf("post-auth hook failed: %w", err)
|
||||
}
|
||||
}
|
||||
return store.Save(ctx, record)
|
||||
}
|
||||
|
||||
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Claude authentication...")
|
||||
|
||||
@@ -1018,6 +1109,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
||||
|
||||
@@ -1193,6 +1285,30 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
ts.Checked = true
|
||||
} else if strings.EqualFold(requestedProjectID, "GOOGLE_ONE") {
|
||||
ts.Auto = false
|
||||
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
|
||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||
SetOAuthSessionError(state, "Google One auto-discovery failed")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
log.Error("Google One auto-discovery returned empty project ID")
|
||||
SetOAuthSessionError(state, "Google One auto-discovery returned empty project ID")
|
||||
return
|
||||
}
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the auto-discovered project")
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
@@ -1252,6 +1368,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Codex authentication...")
|
||||
|
||||
@@ -1397,6 +1514,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Antigravity authentication...")
|
||||
|
||||
@@ -1561,6 +1679,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Qwen authentication...")
|
||||
|
||||
@@ -1616,6 +1735,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Kimi authentication...")
|
||||
|
||||
@@ -1692,6 +1812,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing iFlow authentication...")
|
||||
|
||||
@@ -1811,8 +1932,6 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
|
||||
|
||||
// Initialize Copilot auth service
|
||||
// We need to import "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" first if not present
|
||||
// Assuming copilot package is imported as "copilot"
|
||||
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
|
||||
|
||||
// Initiate device flow
|
||||
@@ -1826,7 +1945,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
authURL := deviceCode.VerificationURI
|
||||
userCode := deviceCode.UserCode
|
||||
|
||||
RegisterOAuthSession(state, "github")
|
||||
RegisterOAuthSession(state, "github-copilot")
|
||||
|
||||
go func() {
|
||||
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
|
||||
@@ -1838,9 +1957,13 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
username, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||
userInfo, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||
if errUser != nil {
|
||||
log.Warnf("Failed to fetch user info: %v", errUser)
|
||||
}
|
||||
|
||||
username := userInfo.Login
|
||||
if username == "" {
|
||||
username = "github-user"
|
||||
}
|
||||
|
||||
@@ -1849,18 +1972,26 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
TokenType: tokenData.TokenType,
|
||||
Scope: tokenData.Scope,
|
||||
Username: username,
|
||||
Email: userInfo.Email,
|
||||
Name: userInfo.Name,
|
||||
Type: "github-copilot",
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("github-%s.json", username)
|
||||
fileName := fmt.Sprintf("github-copilot-%s.json", username)
|
||||
label := userInfo.Email
|
||||
if label == "" {
|
||||
label = username
|
||||
}
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "github",
|
||||
Provider: "github-copilot",
|
||||
Label: label,
|
||||
FileName: fileName,
|
||||
Storage: tokenStorage,
|
||||
Metadata: map[string]any{
|
||||
"email": username,
|
||||
"email": userInfo.Email,
|
||||
"username": username,
|
||||
"name": userInfo.Name,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1874,7 +2005,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use GitHub Copilot services through this CLI")
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("github")
|
||||
CompleteOAuthSessionsByProvider("github-copilot")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
@@ -2124,7 +2255,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
}
|
||||
}
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
// Auto-discovery: try onboardUser without specifying a project
|
||||
// to let Google auto-provision one (matches Gemini CLI headless behavior
|
||||
// and Antigravity's FetchProjectID pattern).
|
||||
autoOnboardReq := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer autoCancel()
|
||||
for attempt := 1; ; attempt++ {
|
||||
var onboardResp map[string]any
|
||||
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
|
||||
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
|
||||
}
|
||||
|
||||
if done, okDone := onboardResp["done"].(bool); okDone && done {
|
||||
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
|
||||
switch v := resp["cloudaicompanionProject"].(type) {
|
||||
case string:
|
||||
projectID = strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if id, okID := v["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
|
||||
select {
|
||||
case <-autoCtx.Done():
|
||||
return &projectSelectionRequiredError{}
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
}
|
||||
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
|
||||
}
|
||||
|
||||
onboardReqBody := map[string]any{
|
||||
@@ -2374,6 +2546,14 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||
}
|
||||
|
||||
// PopulateAuthContext extracts request info and adds it to the context
|
||||
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
|
||||
info := &coreauth.RequestInfo{
|
||||
Query: c.Request.URL.Query(),
|
||||
Headers: c.Request.Header,
|
||||
}
|
||||
return coreauth.WithRequestInfo(ctx, info)
|
||||
}
|
||||
const kiroCallbackPort = 9876
|
||||
|
||||
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||
@@ -2668,3 +2848,88 @@ func generateKiroPKCE() (verifier, challenge string, err error) {
|
||||
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
func (h *Handler) RequestKiloToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
fmt.Println("Initializing Kilo authentication...")
|
||||
|
||||
state := fmt.Sprintf("kil-%d", time.Now().UnixNano())
|
||||
kilocodeAuth := kilo.NewKiloAuth()
|
||||
|
||||
resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initiate device flow: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "kilo")
|
||||
|
||||
go func() {
|
||||
fmt.Printf("Please visit %s and enter code: %s\n", resp.VerificationURL, resp.Code)
|
||||
|
||||
status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
|
||||
if err != nil {
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := kilocodeAuth.GetProfile(ctx, status.Token)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to fetch profile: %v", err)
|
||||
profile = &kilo.Profile{Email: status.UserEmail}
|
||||
}
|
||||
|
||||
var orgID string
|
||||
if len(profile.Orgs) > 0 {
|
||||
orgID = profile.Orgs[0].ID
|
||||
}
|
||||
|
||||
defaults, err := kilocodeAuth.GetDefaults(ctx, status.Token, orgID)
|
||||
if err != nil {
|
||||
defaults = &kilo.Defaults{}
|
||||
}
|
||||
|
||||
ts := &kilo.KiloTokenStorage{
|
||||
Token: status.Token,
|
||||
OrganizationID: orgID,
|
||||
Model: defaults.Model,
|
||||
Email: status.UserEmail,
|
||||
Type: "kilo",
|
||||
}
|
||||
|
||||
fileName := kilo.CredentialFileName(status.UserEmail)
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kilo",
|
||||
FileName: fileName,
|
||||
Storage: ts,
|
||||
Metadata: map[string]any{
|
||||
"email": status.UserEmail,
|
||||
"organization_id": orgID,
|
||||
"model": defaults.Model,
|
||||
},
|
||||
}
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("kilo")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": "ok",
|
||||
"url": resp.VerificationURL,
|
||||
"state": state,
|
||||
"user_code": resp.Code,
|
||||
"verification_uri": resp.VerificationURL,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -28,8 +28,7 @@ func (h *Handler) GetConfig(c *gin.Context) {
|
||||
c.JSON(200, gin.H{})
|
||||
return
|
||||
}
|
||||
cfgCopy := *h.cfg
|
||||
c.JSON(200, &cfgCopy)
|
||||
c.JSON(200, new(*h.cfg))
|
||||
}
|
||||
|
||||
type releaseInfo struct {
|
||||
|
||||
@@ -109,14 +109,13 @@ func (h *Handler) GetAPIKeys(c *gin.Context) { c.JSON(200, gin.H{"api-keys": h.c
|
||||
func (h *Handler) PutAPIKeys(c *gin.Context) {
|
||||
h.putStringList(c, func(v []string) {
|
||||
h.cfg.APIKeys = append([]string(nil), v...)
|
||||
h.cfg.Access.Providers = nil
|
||||
}, nil)
|
||||
}
|
||||
func (h *Handler) PatchAPIKeys(c *gin.Context) {
|
||||
h.patchStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
||||
h.patchStringList(c, &h.cfg.APIKeys, func() {})
|
||||
}
|
||||
func (h *Handler) DeleteAPIKeys(c *gin.Context) {
|
||||
h.deleteFromStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
||||
h.deleteFromStringList(c, &h.cfg.APIKeys, func() {})
|
||||
}
|
||||
|
||||
// gemini-api-key: []GeminiKey
|
||||
@@ -797,10 +796,10 @@ func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
|
||||
c.JSON(404, gin.H{"error": "channel not found"})
|
||||
return
|
||||
}
|
||||
delete(h.cfg.OAuthModelAlias, channel)
|
||||
if len(h.cfg.OAuthModelAlias) == 0 {
|
||||
h.cfg.OAuthModelAlias = nil
|
||||
}
|
||||
// Set to nil instead of deleting the key so that the "explicitly disabled"
|
||||
// marker survives config reload and prevents SanitizeOAuthModelAlias from
|
||||
// re-injecting default aliases (fixes #222).
|
||||
h.cfg.OAuthModelAlias[channel] = nil
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ type Handler struct {
|
||||
allowRemoteOverride bool
|
||||
envSecret string
|
||||
logDir string
|
||||
postAuthHook coreauth.PostAuthHook
|
||||
}
|
||||
|
||||
// NewHandler creates a new management handler instance.
|
||||
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
|
||||
h.logDir = dir
|
||||
}
|
||||
|
||||
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
|
||||
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
|
||||
h.postAuthHook = hook
|
||||
}
|
||||
|
||||
// Middleware enforces access control for management endpoints.
|
||||
// All requests (local and remote) require a valid management key.
|
||||
// Additionally, remote access requires allow-remote-management=true.
|
||||
|
||||
@@ -15,10 +15,12 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
|
||||
|
||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||
// It captures detailed information about the request and response, including headers and body,
|
||||
// and uses the provided RequestLogger to record this data. When logging is disabled in the
|
||||
// logger, it still captures data so that upstream errors can be persisted.
|
||||
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
|
||||
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
|
||||
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if logger == nil {
|
||||
@@ -26,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.Method == http.MethodGet {
|
||||
if shouldSkipMethodForRequestLogging(c.Request) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -37,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
loggerEnabled := logger.IsEnabled()
|
||||
|
||||
// Capture request information
|
||||
requestInfo, err := captureRequestInfo(c)
|
||||
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
|
||||
if err != nil {
|
||||
// Log error but continue processing
|
||||
// In a real implementation, you might want to use a proper logger here
|
||||
@@ -48,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
|
||||
// Create response writer wrapper
|
||||
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
||||
if !logger.IsEnabled() {
|
||||
if !loggerEnabled {
|
||||
wrapper.logOnErrorOnly = true
|
||||
}
|
||||
c.Writer = wrapper
|
||||
@@ -64,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
|
||||
if req == nil {
|
||||
return true
|
||||
}
|
||||
if req.Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
return !isResponsesWebsocketUpgrade(req)
|
||||
}
|
||||
|
||||
func isResponsesWebsocketUpgrade(req *http.Request) bool {
|
||||
if req == nil || req.URL == nil {
|
||||
return false
|
||||
}
|
||||
if req.URL.Path != "/v1/responses" {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
|
||||
}
|
||||
|
||||
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
|
||||
if loggerEnabled {
|
||||
return true
|
||||
}
|
||||
if req == nil || req.Body == nil {
|
||||
return false
|
||||
}
|
||||
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return false
|
||||
}
|
||||
if req.ContentLength <= 0 {
|
||||
return false
|
||||
}
|
||||
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
|
||||
}
|
||||
|
||||
// captureRequestInfo extracts relevant information from the incoming HTTP request.
|
||||
// It captures the URL, method, headers, and body. The request body is read and then
|
||||
// restored so that it can be processed by subsequent handlers.
|
||||
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
|
||||
// Capture URL with sensitive query parameters masked
|
||||
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||
url := c.Request.URL.Path
|
||||
@@ -86,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
|
||||
// Capture request body
|
||||
var body []byte
|
||||
if c.Request.Body != nil {
|
||||
if captureBody && c.Request.Body != nil {
|
||||
// Read the body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
|
||||
138
internal/api/middleware/request_logging_test.go
Normal file
138
internal/api/middleware/request_logging_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *http.Request
|
||||
skip bool
|
||||
}{
|
||||
{
|
||||
name: "nil request",
|
||||
req: nil,
|
||||
skip: true,
|
||||
},
|
||||
{
|
||||
name: "post request should not skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodPost,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
},
|
||||
skip: false,
|
||||
},
|
||||
{
|
||||
name: "plain get should skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/models"},
|
||||
Header: http.Header{},
|
||||
},
|
||||
skip: true,
|
||||
},
|
||||
{
|
||||
name: "responses websocket upgrade should not skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
Header: http.Header{"Upgrade": []string{"websocket"}},
|
||||
},
|
||||
skip: false,
|
||||
},
|
||||
{
|
||||
name: "responses get without upgrade should skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
Header: http.Header{},
|
||||
},
|
||||
skip: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range tests {
|
||||
got := shouldSkipMethodForRequestLogging(tests[i].req)
|
||||
if got != tests[i].skip {
|
||||
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldCaptureRequestBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
loggerEnabled bool
|
||||
req *http.Request
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "logger enabled always captures",
|
||||
loggerEnabled: true,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
ContentLength: -1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nil request",
|
||||
loggerEnabled: false,
|
||||
req: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "small known size json in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
ContentLength: 2,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "large known size skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "unknown size skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: -1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "multipart skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: 1,
|
||||
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range tests {
|
||||
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
|
||||
if got != tests[i].want {
|
||||
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
)
|
||||
|
||||
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
||||
|
||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||
type RequestInfo struct {
|
||||
URL string // URL is the request URL.
|
||||
@@ -223,8 +225,8 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
||||
|
||||
// Only fall back to request payload hints when Content-Type is not set yet.
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
bodyStr := string(w.requestInfo.Body)
|
||||
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
|
||||
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
|
||||
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
|
||||
}
|
||||
|
||||
return false
|
||||
@@ -310,7 +312,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||
@@ -361,16 +363,32 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||
if c != nil {
|
||||
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
|
||||
switch value := bodyOverride.(type) {
|
||||
case []byte:
|
||||
if len(value) > 0 {
|
||||
return bytes.Clone(value)
|
||||
}
|
||||
case string:
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return []byte(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
return w.requestInfo.Body
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
if w.requestInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if len(w.requestInfo.Body) > 0 {
|
||||
requestBody = w.requestInfo.Body
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||
}); ok {
|
||||
|
||||
43
internal/api/middleware/response_writer_test.go
Normal file
43
internal/api/middleware/response_writer_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{
|
||||
requestInfo: &RequestInfo{Body: []byte("original-body")},
|
||||
}
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
if string(body) != "original-body" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "original-body")
|
||||
}
|
||||
|
||||
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
|
||||
body = wrapper.extractRequestBody(c)
|
||||
if string(body) != "override-body" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{}
|
||||
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
if string(body) != "override-as-string" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
||||
}
|
||||
}
|
||||
@@ -127,8 +127,7 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
m.modelMapper = NewModelMapper(settings.ModelMappings)
|
||||
|
||||
// Store initial config for partial reload comparison
|
||||
settingsCopy := settings
|
||||
m.lastConfig = &settingsCopy
|
||||
m.lastConfig = new(settings)
|
||||
|
||||
// Initialize localhost restriction setting (hot-reloadable)
|
||||
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)
|
||||
|
||||
@@ -215,7 +215,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
|
||||
// Don't log as error for context canceled - it's usually client closing connection
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
|
||||
return
|
||||
} else {
|
||||
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
|
||||
}
|
||||
|
||||
@@ -493,6 +493,30 @@ func TestReverseProxy_ErrorHandler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_ErrorHandler_ContextCanceled(t *testing.T) {
|
||||
// Test that context.Canceled errors return 499 without generic error response
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource(""))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a canceled context to trigger the cancellation path
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil).WithContext(ctx)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Directly invoke the ErrorHandler with context.Canceled
|
||||
proxy.ErrorHandler(rr, req, context.Canceled)
|
||||
|
||||
// Body should be empty for canceled requests (no JSON error response)
|
||||
body := rr.Body.Bytes()
|
||||
if len(body) > 0 {
|
||||
t.Fatalf("expected empty body for canceled context, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
|
||||
// Upstream returns gzipped JSON without Content-Encoding header
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -122,7 +122,7 @@ func (rw *ResponseRewriter) Flush() {
|
||||
}
|
||||
|
||||
// modelFieldPaths lists all JSON paths where model name may appear
|
||||
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||
|
||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
||||
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
||||
|
||||
110
internal/api/modules/amp/response_rewriter_test.go
Normal file
110
internal/api/modules/amp/response_rewriter_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRewriteModelInResponse_TopLevel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"id":"resp_1","model":"gpt-5.3-codex","output":[]}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"id":"resp_1","model":"gpt-5.2-codex","output":[]}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_ResponseModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"completed"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"completed"}}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_ResponseCreated(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"in_progress"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"in_progress"}}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_NoModelField(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.output_item.added","item":{"id":"item_1","type":"message"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_EmptyOriginalModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: ""}
|
||||
|
||||
input := []byte(`{"model":"gpt-5.3-codex"}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification when originalModel is empty, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_SSEWithResponseModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
chunk := []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.3-codex\",\"status\":\"completed\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
expected := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.2-codex\",\"status\":\"completed\"}}\n\n"
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_MultipleEvents(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
chunk := []byte("data: {\"type\":\"response.created\",\"response\":{\"model\":\"gpt-5.3-codex\"}}\n\ndata: {\"type\":\"response.output_item.added\",\"item\":{\"id\":\"item_1\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
if string(result) == string(chunk) {
|
||||
t.Error("expected response.model to be rewritten in SSE stream")
|
||||
}
|
||||
if !contains(result, []byte(`"model":"gpt-5.2-codex"`)) {
|
||||
t.Errorf("expected rewritten model in output, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_MessageModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "claude-opus-4.5"}
|
||||
|
||||
chunk := []byte("data: {\"message\":{\"model\":\"claude-sonnet-4\",\"role\":\"assistant\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
expected := "data: {\"message\":{\"model\":\"claude-opus-4.5\",\"role\":\"assistant\"}}\n\n"
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func contains(data, substr []byte) bool {
|
||||
for i := 0; i <= len(data)-len(substr); i++ {
|
||||
if string(data[i:i+len(substr)]) == string(substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -52,6 +52,7 @@ type serverOptionConfig struct {
|
||||
keepAliveEnabled bool
|
||||
keepAliveTimeout time.Duration
|
||||
keepAliveOnTimeout func()
|
||||
postAuthHook auth.PostAuthHook
|
||||
}
|
||||
|
||||
// ServerOption customises HTTP server construction.
|
||||
@@ -112,6 +113,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
|
||||
}
|
||||
}
|
||||
|
||||
// WithPostAuthHook registers a hook to be called after auth record creation.
|
||||
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
|
||||
return func(cfg *serverOptionConfig) {
|
||||
cfg.postAuthHook = hook
|
||||
}
|
||||
}
|
||||
|
||||
// Server represents the main API server.
|
||||
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||
type Server struct {
|
||||
@@ -263,6 +271,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
}
|
||||
logDir := logging.ResolveLogDirectory(cfg)
|
||||
s.mgmt.SetLogDirectory(logDir)
|
||||
if optionState.postAuthHook != nil {
|
||||
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
|
||||
}
|
||||
s.localPassword = optionState.localPassword
|
||||
|
||||
// Setup routes
|
||||
@@ -285,8 +296,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
optionState.routerConfigurator(engine, s.handlers, cfg)
|
||||
}
|
||||
|
||||
// Register management routes when configuration or environment secrets are available.
|
||||
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
|
||||
// Register management routes when configuration or environment secrets are available,
|
||||
// or when a local management password is provided (e.g. TUI mode).
|
||||
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
|
||||
s.managementRoutesEnabled.Store(hasManagementSecret)
|
||||
if hasManagementSecret {
|
||||
s.registerManagementRoutes()
|
||||
@@ -329,6 +341,7 @@ func (s *Server) setupRoutes() {
|
||||
v1.POST("/completions", openaiHandlers.Completions)
|
||||
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
||||
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
|
||||
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
||||
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
|
||||
}
|
||||
@@ -642,6 +655,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
|
||||
mgmt.PATCH("/auth-files/fields", s.mgmt.PatchAuthFileFields)
|
||||
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
||||
|
||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||
@@ -649,6 +663,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
|
||||
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
@@ -683,14 +698,17 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
||||
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
// Synchronously ensure management.html is available with a detached context.
|
||||
// Control panel bootstrap should not be canceled by client disconnects.
|
||||
if !managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository) {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.WithError(err).Error("failed to stat management control panel asset")
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithError(err).Error("failed to stat management control panel asset")
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.File(filePath)
|
||||
@@ -980,10 +998,6 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
|
||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||
|
||||
if !cfg.RemoteManagement.DisableControlPanel {
|
||||
staticDir := managementasset.StaticDir(s.configFilePath)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
if s.mgmt != nil {
|
||||
s.mgmt.SetConfig(cfg)
|
||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||
@@ -1062,14 +1076,10 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing API key"})
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
|
||||
default:
|
||||
statusCode := err.HTTPStatusCode()
|
||||
if statusCode >= http.StatusInternalServerError {
|
||||
log.Errorf("authentication middleware error: %v", err)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authentication service error"})
|
||||
}
|
||||
c.AbortWithStatusJSON(statusCode, gin.H{"error": err.Message})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
// OAuth configuration constants for Claude/Anthropic
|
||||
const (
|
||||
AuthURL = "https://claude.ai/oauth/authorize"
|
||||
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
TokenURL = "https://api.anthropic.com/v1/oauth/token"
|
||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
RedirectURI = "http://localhost:54545/callback"
|
||||
)
|
||||
|
||||
@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
|
||||
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
// Encode and write the token data as JSON
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
||||
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
||||
// authorization code and PKCE verifier.
|
||||
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
|
||||
}
|
||||
|
||||
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
|
||||
// a caller-provided redirect URI. This supports alternate auth flows such as device
|
||||
// login while preserving the existing token parsing and storage behavior.
|
||||
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||
if pkceCodes == nil {
|
||||
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
||||
}
|
||||
if strings.TrimSpace(redirectURI) == "" {
|
||||
return nil, fmt.Errorf("redirect URI is required for token exchange")
|
||||
}
|
||||
|
||||
// Prepare token exchange request
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {ClientID},
|
||||
"code": {code},
|
||||
"redirect_uri": {RedirectURI},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"code_verifier": {pkceCodes.CodeVerifier},
|
||||
}
|
||||
|
||||
@@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
||||
if err == nil {
|
||||
return tokenData, nil
|
||||
}
|
||||
if isNonRetryableRefreshErr(err) {
|
||||
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||
@@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func isNonRetryableRefreshErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
raw := strings.ToLower(err.Error())
|
||||
return strings.Contains(raw, "refresh_token_reused")
|
||||
}
|
||||
|
||||
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
||||
// This is typically called after a successful token refresh to persist the new credentials.
|
||||
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
||||
|
||||
44
internal/auth/codex/openai_auth_test.go
Normal file
44
internal/auth/codex/openai_auth_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
|
||||
var calls int32
|
||||
auth := &CodexAuth{
|
||||
httpClient: &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
|
||||
Header: make(http.Header),
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for non-retryable refresh failure")
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
|
||||
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
|
||||
}
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
||||
}
|
||||
}
|
||||
@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -82,15 +82,21 @@ func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *Devi
|
||||
}
|
||||
|
||||
// Fetch the GitHub username
|
||||
username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||
userInfo, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||
if err != nil {
|
||||
log.Warnf("copilot: failed to fetch user info: %v", err)
|
||||
username = "unknown"
|
||||
}
|
||||
|
||||
username := userInfo.Login
|
||||
if username == "" {
|
||||
username = "github-user"
|
||||
}
|
||||
|
||||
return &CopilotAuthBundle{
|
||||
TokenData: tokenData,
|
||||
Username: username,
|
||||
Email: userInfo.Email,
|
||||
Name: userInfo.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -150,12 +156,12 @@ func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bo
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
username, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
||||
userInfo, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
return true, username, nil
|
||||
return true, userInfo.Login, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
|
||||
@@ -165,6 +171,8 @@ func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotToke
|
||||
TokenType: bundle.TokenData.TokenType,
|
||||
Scope: bundle.TokenData.Scope,
|
||||
Username: bundle.Username,
|
||||
Email: bundle.Email,
|
||||
Name: bundle.Name,
|
||||
Type: "github-copilot",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
||||
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", copilotClientID)
|
||||
data.Set("scope", "user:email")
|
||||
data.Set("scope", "read:user user:email")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
@@ -211,15 +211,25 @@ func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode st
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FetchUserInfo retrieves the GitHub username for the authenticated user.
|
||||
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
|
||||
// GitHubUserInfo holds GitHub user profile information.
|
||||
type GitHubUserInfo struct {
|
||||
// Login is the GitHub username.
|
||||
Login string
|
||||
// Email is the primary email address (may be empty if not public).
|
||||
Email string
|
||||
// Name is the display name.
|
||||
Name string
|
||||
}
|
||||
|
||||
// FetchUserInfo retrieves the GitHub user profile for the authenticated user.
|
||||
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (GitHubUserInfo, error) {
|
||||
if accessToken == "" {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
|
||||
if err != nil {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
@@ -227,7 +237,7 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
@@ -237,19 +247,25 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
|
||||
|
||||
if !isHTTPSuccess(resp.StatusCode) {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||
}
|
||||
|
||||
var userInfo struct {
|
||||
var raw struct {
|
||||
Login string `json:"login"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
if err = json.NewDecoder(resp.Body).Decode(&raw); err != nil {
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
}
|
||||
|
||||
if userInfo.Login == "" {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
||||
if raw.Login == "" {
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
||||
}
|
||||
|
||||
return userInfo.Login, nil
|
||||
return GitHubUserInfo{
|
||||
Login: raw.Login,
|
||||
Email: raw.Email,
|
||||
Name: raw.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
213
internal/auth/copilot/oauth_test.go
Normal file
213
internal/auth/copilot/oauth_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// roundTripFunc lets us inject a custom transport for testing.
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
||||
|
||||
// newTestClient returns an *http.Client whose requests are redirected to the given test server,
|
||||
// regardless of the original URL host.
|
||||
func newTestClient(srv *httptest.Server) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
req2 := req.Clone(req.Context())
|
||||
req2.URL.Scheme = "http"
|
||||
req2.URL.Host = strings.TrimPrefix(srv.URL, "http://")
|
||||
return srv.Client().Transport.RoundTrip(req2)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_FullProfile verifies that FetchUserInfo returns login, email, and name.
|
||||
func TestFetchUserInfo_FullProfile(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"login": "octocat",
|
||||
"email": "octocat@github.com",
|
||||
"name": "The Octocat",
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info.Login != "octocat" {
|
||||
t.Errorf("Login: got %q, want %q", info.Login, "octocat")
|
||||
}
|
||||
if info.Email != "octocat@github.com" {
|
||||
t.Errorf("Email: got %q, want %q", info.Email, "octocat@github.com")
|
||||
}
|
||||
if info.Name != "The Octocat" {
|
||||
t.Errorf("Name: got %q, want %q", info.Name, "The Octocat")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_EmptyEmail verifies graceful handling when email is absent (private account).
|
||||
func TestFetchUserInfo_EmptyEmail(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// GitHub returns null for private emails.
|
||||
_, _ = w.Write([]byte(`{"login":"privateuser","email":null,"name":"Private User"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info.Login != "privateuser" {
|
||||
t.Errorf("Login: got %q, want %q", info.Login, "privateuser")
|
||||
}
|
||||
if info.Email != "" {
|
||||
t.Errorf("Email: got %q, want empty string", info.Email)
|
||||
}
|
||||
if info.Name != "Private User" {
|
||||
t.Errorf("Name: got %q, want %q", info.Name, "Private User")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_EmptyToken verifies error is returned for empty access token.
|
||||
func TestFetchUserInfo_EmptyToken(t *testing.T) {
|
||||
client := &DeviceFlowClient{httpClient: http.DefaultClient}
|
||||
_, err := client.FetchUserInfo(context.Background(), "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty token, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_EmptyLogin verifies error is returned when API returns no login.
|
||||
func TestFetchUserInfo_EmptyLogin(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"email":"someone@example.com","name":"No Login"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
_, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty login, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_HTTPError verifies error is returned on non-2xx response.
|
||||
func TestFetchUserInfo_HTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
_, err := client.FetchUserInfo(context.Background(), "bad-token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 401 response, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopilotTokenStorage_EmailNameFields verifies Email and Name serialise correctly.
|
||||
func TestCopilotTokenStorage_EmailNameFields(t *testing.T) {
|
||||
ts := &CopilotTokenStorage{
|
||||
AccessToken: "ghu_abc",
|
||||
TokenType: "bearer",
|
||||
Scope: "read:user user:email",
|
||||
Username: "octocat",
|
||||
Email: "octocat@github.com",
|
||||
Name: "The Octocat",
|
||||
Type: "github-copilot",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(ts)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal error: %v", err)
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err = json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
for _, key := range []string{"access_token", "username", "email", "name", "type"} {
|
||||
if _, ok := out[key]; !ok {
|
||||
t.Errorf("expected key %q in JSON output, not found", key)
|
||||
}
|
||||
}
|
||||
if out["email"] != "octocat@github.com" {
|
||||
t.Errorf("email: got %v, want %q", out["email"], "octocat@github.com")
|
||||
}
|
||||
if out["name"] != "The Octocat" {
|
||||
t.Errorf("name: got %v, want %q", out["name"], "The Octocat")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopilotTokenStorage_OmitEmptyEmailName verifies email/name are omitted when empty (omitempty).
|
||||
func TestCopilotTokenStorage_OmitEmptyEmailName(t *testing.T) {
|
||||
ts := &CopilotTokenStorage{
|
||||
AccessToken: "ghu_abc",
|
||||
Username: "octocat",
|
||||
Type: "github-copilot",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(ts)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal error: %v", err)
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err = json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := out["email"]; ok {
|
||||
t.Error("email key should be omitted when empty (omitempty), but was present")
|
||||
}
|
||||
if _, ok := out["name"]; ok {
|
||||
t.Error("name key should be omitted when empty (omitempty), but was present")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopilotAuthBundle_EmailNameFields verifies bundle carries email and name through the pipeline.
|
||||
func TestCopilotAuthBundle_EmailNameFields(t *testing.T) {
|
||||
bundle := &CopilotAuthBundle{
|
||||
TokenData: &CopilotTokenData{AccessToken: "ghu_abc"},
|
||||
Username: "octocat",
|
||||
Email: "octocat@github.com",
|
||||
Name: "The Octocat",
|
||||
}
|
||||
if bundle.Email != "octocat@github.com" {
|
||||
t.Errorf("bundle.Email: got %q, want %q", bundle.Email, "octocat@github.com")
|
||||
}
|
||||
if bundle.Name != "The Octocat" {
|
||||
t.Errorf("bundle.Name: got %q, want %q", bundle.Name, "The Octocat")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGitHubUserInfo_Struct verifies the exported GitHubUserInfo struct fields are accessible.
|
||||
func TestGitHubUserInfo_Struct(t *testing.T) {
|
||||
info := GitHubUserInfo{
|
||||
Login: "octocat",
|
||||
Email: "octocat@github.com",
|
||||
Name: "The Octocat",
|
||||
}
|
||||
if info.Login == "" || info.Email == "" || info.Name == "" {
|
||||
t.Error("GitHubUserInfo fields should not be empty")
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,10 @@ type CopilotTokenStorage struct {
|
||||
ExpiresAt string `json:"expires_at,omitempty"`
|
||||
// Username is the GitHub username associated with this token.
|
||||
Username string `json:"username"`
|
||||
// Email is the GitHub email address associated with this token.
|
||||
Email string `json:"email,omitempty"`
|
||||
// Name is the GitHub display name associated with this token.
|
||||
Name string `json:"name,omitempty"`
|
||||
// Type indicates the authentication provider type, always "github-copilot" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
@@ -46,6 +50,10 @@ type CopilotAuthBundle struct {
|
||||
TokenData *CopilotTokenData
|
||||
// Username is the GitHub username.
|
||||
Username string
|
||||
// Email is the GitHub email address.
|
||||
Email string
|
||||
// Name is the GitHub display name.
|
||||
Name string
|
||||
}
|
||||
|
||||
// DeviceCodeResponse represents GitHub's device code response.
|
||||
|
||||
@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
|
||||
|
||||
// Type indicates the authentication provider type, always "gemini" for this storage.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
|
||||
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "gemini"
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
}
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
|
||||
Scope string `json:"scope"`
|
||||
Cookie string `json:"cookie"`
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serialises the token storage to disk.
|
||||
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
168
internal/auth/kilo/kilo_auth.go
Normal file
168
internal/auth/kilo/kilo_auth.go
Normal file
@@ -0,0 +1,168 @@
|
||||
// Package kilo provides authentication and token management functionality
|
||||
// for Kilo AI services.
|
||||
package kilo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// BaseURL is the base URL for the Kilo AI API.
|
||||
BaseURL = "https://api.kilo.ai/api"
|
||||
)
|
||||
|
||||
// DeviceAuthResponse represents the response from initiating device flow.
|
||||
type DeviceAuthResponse struct {
|
||||
Code string `json:"code"`
|
||||
VerificationURL string `json:"verificationUrl"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
// DeviceStatusResponse represents the response when polling for device flow status.
|
||||
type DeviceStatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Token string `json:"token"`
|
||||
UserEmail string `json:"userEmail"`
|
||||
}
|
||||
|
||||
// Profile represents the user profile from Kilo AI.
|
||||
type Profile struct {
|
||||
Email string `json:"email"`
|
||||
Orgs []Organization `json:"organizations"`
|
||||
}
|
||||
|
||||
// Organization represents a Kilo AI organization.
|
||||
type Organization struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Defaults represents default settings for an organization or user.
|
||||
type Defaults struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// KiloAuth provides methods for handling the Kilo AI authentication flow.
|
||||
type KiloAuth struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewKiloAuth creates a new instance of KiloAuth.
|
||||
func NewKiloAuth() *KiloAuth {
|
||||
return &KiloAuth{
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateDeviceFlow starts the device authentication flow.
|
||||
func (k *KiloAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceAuthResponse, error) {
|
||||
resp, err := k.client.Post(BaseURL+"/device-auth/codes", "application/json", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to initiate device flow: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var data DeviceAuthResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// PollForToken polls for the device flow completion.
|
||||
func (k *KiloAuth) PollForToken(ctx context.Context, code string) (*DeviceStatusResponse, error) {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
resp, err := k.client.Get(BaseURL + "/device-auth/codes/" + code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var data DeviceStatusResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch data.Status {
|
||||
case "approved":
|
||||
return &data, nil
|
||||
case "denied", "expired":
|
||||
return nil, fmt.Errorf("device flow %s", data.Status)
|
||||
case "pending":
|
||||
continue
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown status: %s", data.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetProfile fetches the user's profile.
|
||||
func (k *KiloAuth) GetProfile(ctx context.Context, token string) (*Profile, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", BaseURL+"/profile", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create get profile request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := k.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get profile: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var profile Profile
|
||||
if err := json.NewDecoder(resp.Body).Decode(&profile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
// GetDefaults fetches default settings for an organization.
|
||||
func (k *KiloAuth) GetDefaults(ctx context.Context, token, orgID string) (*Defaults, error) {
|
||||
url := BaseURL + "/defaults"
|
||||
if orgID != "" {
|
||||
url = BaseURL + "/organizations/" + orgID + "/defaults"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create get defaults request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := k.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get defaults: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var defaults Defaults
|
||||
if err := json.NewDecoder(resp.Body).Decode(&defaults); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &defaults, nil
|
||||
}
|
||||
60
internal/auth/kilo/kilo_token.go
Normal file
60
internal/auth/kilo/kilo_token.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Package kilo provides authentication and token management functionality
|
||||
// for Kilo AI services.
|
||||
package kilo
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// KiloTokenStorage stores token information for Kilo AI authentication.
|
||||
type KiloTokenStorage struct {
|
||||
// Token is the Kilo access token.
|
||||
Token string `json:"kilocodeToken"`
|
||||
|
||||
// OrganizationID is the Kilo organization ID.
|
||||
OrganizationID string `json:"kilocodeOrganizationId"`
|
||||
|
||||
// Model is the default model to use.
|
||||
Model string `json:"kilocodeModel"`
|
||||
|
||||
// Email is the email address of the authenticated user.
|
||||
Email string `json:"email"`
|
||||
|
||||
// Type indicates the authentication provider type, always "kilo" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Kilo token storage to a JSON file.
|
||||
func (ts *KiloTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "kilo"
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := f.Close(); errClose != nil {
|
||||
log.Errorf("failed to close file: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CredentialFileName returns the filename used to persist Kilo credentials.
|
||||
func CredentialFileName(email string) string {
|
||||
return fmt.Sprintf("kilo-%s.json", email)
|
||||
}
|
||||
@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
|
||||
Expired string `json:"expired,omitempty"`
|
||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(f)
|
||||
encoder.SetIndent("", " ")
|
||||
if err = encoder.Encode(ts); err != nil {
|
||||
if err = encoder.Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -40,8 +40,7 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
||||
if err != nil {
|
||||
var authErr *claude.AuthenticationError
|
||||
if errors.As(err, &authErr) {
|
||||
if authErr, ok := errors.AsType[*claude.AuthenticationError](err); ok {
|
||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == claude.ErrPortInUse.Type {
|
||||
os.Exit(claude.ErrPortInUse.Code)
|
||||
|
||||
@@ -22,6 +22,7 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewKimiAuthenticator(),
|
||||
sdkAuth.NewKiroAuthenticator(),
|
||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||
sdkAuth.NewKiloAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
|
||||
@@ -32,8 +32,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
|
||||
if err != nil {
|
||||
var emailErr *sdkAuth.EmailRequiredError
|
||||
if errors.As(err, &emailErr) {
|
||||
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
|
||||
log.Error(emailErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
54
internal/cmd/kilo_login.go
Normal file
54
internal/cmd/kilo_login.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
)
|
||||
|
||||
// DoKiloLogin handles the Kilo device flow using the shared authentication manager.
|
||||
// It initiates the device-based authentication process for Kilo AI services and saves
|
||||
// the authentication tokens to the configured auth directory.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration
|
||||
// - options: Login options including browser behavior and prompts
|
||||
func DoKiloLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = func(prompt string) (string, error) {
|
||||
fmt.Print(prompt)
|
||||
var value string
|
||||
fmt.Scanln(&value)
|
||||
return strings.TrimSpace(value), nil
|
||||
}
|
||||
}
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "kilo", cfg, authOpts)
|
||||
if err != nil {
|
||||
fmt.Printf("Kilo authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
|
||||
fmt.Println("Kilo authentication successful!")
|
||||
}
|
||||
@@ -100,49 +100,74 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
|
||||
log.Info("Authentication successful.")
|
||||
|
||||
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
||||
if errProjects != nil {
|
||||
log.Errorf("Failed to get project list: %v", errProjects)
|
||||
return
|
||||
var activatedProjects []string
|
||||
|
||||
useGoogleOne := false
|
||||
if trimmedProjectID == "" && promptFn != nil {
|
||||
fmt.Println("\nSelect login mode:")
|
||||
fmt.Println(" 1. Code Assist (GCP project, manual selection)")
|
||||
fmt.Println(" 2. Google One (personal account, auto-discover project)")
|
||||
choice, errPrompt := promptFn("Enter choice [1/2] (default: 1): ")
|
||||
if errPrompt == nil && strings.TrimSpace(choice) == "2" {
|
||||
useGoogleOne = true
|
||||
}
|
||||
}
|
||||
|
||||
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Errorf("Invalid project selection: %v", errSelection)
|
||||
return
|
||||
}
|
||||
if len(projectSelections) == 0 {
|
||||
log.Error("No project selected; aborting login.")
|
||||
return
|
||||
}
|
||||
|
||||
activatedProjects := make([]string, 0, len(projectSelections))
|
||||
seenProjects := make(map[string]bool)
|
||||
for _, candidateID := range projectSelections {
|
||||
log.Infof("Activating project %s", candidateID)
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||
var projectErr *projectSelectionRequiredError
|
||||
if errors.As(errSetup, &projectErr) {
|
||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||
showProjectSelectionHelp(storage.Email, projects)
|
||||
return
|
||||
}
|
||||
log.Errorf("Failed to complete user setup: %v", errSetup)
|
||||
if useGoogleOne {
|
||||
log.Info("Google One mode: auto-discovering project...")
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, ""); errSetup != nil {
|
||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||
return
|
||||
}
|
||||
finalID := strings.TrimSpace(storage.ProjectID)
|
||||
if finalID == "" {
|
||||
finalID = candidateID
|
||||
autoProject := strings.TrimSpace(storage.ProjectID)
|
||||
if autoProject == "" {
|
||||
log.Error("Google One auto-discovery returned empty project ID")
|
||||
return
|
||||
}
|
||||
log.Infof("Auto-discovered project: %s", autoProject)
|
||||
activatedProjects = []string{autoProject}
|
||||
} else {
|
||||
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
||||
if errProjects != nil {
|
||||
log.Errorf("Failed to get project list: %v", errProjects)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip duplicates
|
||||
if seenProjects[finalID] {
|
||||
log.Infof("Project %s already activated, skipping", finalID)
|
||||
continue
|
||||
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Errorf("Invalid project selection: %v", errSelection)
|
||||
return
|
||||
}
|
||||
if len(projectSelections) == 0 {
|
||||
log.Error("No project selected; aborting login.")
|
||||
return
|
||||
}
|
||||
|
||||
seenProjects := make(map[string]bool)
|
||||
for _, candidateID := range projectSelections {
|
||||
log.Infof("Activating project %s", candidateID)
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||
if _, ok := errors.AsType[*projectSelectionRequiredError](errSetup); ok {
|
||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||
showProjectSelectionHelp(storage.Email, projects)
|
||||
return
|
||||
}
|
||||
log.Errorf("Failed to complete user setup: %v", errSetup)
|
||||
return
|
||||
}
|
||||
finalID := strings.TrimSpace(storage.ProjectID)
|
||||
if finalID == "" {
|
||||
finalID = candidateID
|
||||
}
|
||||
|
||||
if seenProjects[finalID] {
|
||||
log.Infof("Project %s already activated, skipping", finalID)
|
||||
continue
|
||||
}
|
||||
seenProjects[finalID] = true
|
||||
activatedProjects = append(activatedProjects, finalID)
|
||||
}
|
||||
seenProjects[finalID] = true
|
||||
activatedProjects = append(activatedProjects, finalID)
|
||||
}
|
||||
|
||||
storage.Auto = false
|
||||
@@ -235,7 +260,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
}
|
||||
}
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
// Auto-discovery: try onboardUser without specifying a project
|
||||
// to let Google auto-provision one (matches Gemini CLI headless behavior
|
||||
// and Antigravity's FetchProjectID pattern).
|
||||
autoOnboardReq := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer autoCancel()
|
||||
for attempt := 1; ; attempt++ {
|
||||
var onboardResp map[string]any
|
||||
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
|
||||
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
|
||||
}
|
||||
|
||||
if done, okDone := onboardResp["done"].(bool); okDone && done {
|
||||
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
|
||||
switch v := resp["cloudaicompanionProject"].(type) {
|
||||
case string:
|
||||
projectID = strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if id, okID := v["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
|
||||
select {
|
||||
case <-autoCtx.Done():
|
||||
return &projectSelectionRequiredError{}
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
}
|
||||
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
|
||||
}
|
||||
|
||||
onboardReqBody := map[string]any{
|
||||
@@ -617,7 +683,7 @@ func updateAuthRecord(record *cliproxyauth.Auth, storage *gemini.GeminiTokenStor
|
||||
return
|
||||
}
|
||||
|
||||
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, false)
|
||||
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, true)
|
||||
|
||||
if record.Metadata == nil {
|
||||
record.Metadata = make(map[string]any)
|
||||
|
||||
60
internal/cmd/openai_device_login.go
Normal file
60
internal/cmd/openai_device_login.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
codexLoginModeMetadataKey = "codex_login_mode"
|
||||
codexLoginModeDevice = "device"
|
||||
)
|
||||
|
||||
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
|
||||
// existing codex-login OAuth callback flow intact.
|
||||
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{
|
||||
codexLoginModeMetadataKey: codexLoginModeDevice,
|
||||
},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
if err != nil {
|
||||
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == codex.ErrPortInUse.Type {
|
||||
os.Exit(codex.ErrPortInUse.Code)
|
||||
}
|
||||
return
|
||||
}
|
||||
fmt.Printf("Codex device authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
fmt.Println("Codex device authentication successful!")
|
||||
}
|
||||
@@ -54,8 +54,7 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
if err != nil {
|
||||
var authErr *codex.AuthenticationError
|
||||
if errors.As(err, &authErr) {
|
||||
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == codex.ErrPortInUse.Type {
|
||||
os.Exit(codex.ErrPortInUse.Code)
|
||||
|
||||
@@ -44,8 +44,7 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
||||
if err != nil {
|
||||
var emailErr *sdkAuth.EmailRequiredError
|
||||
if errors.As(err, &emailErr) {
|
||||
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
|
||||
log.Error(emailErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -55,6 +55,34 @@ func StartService(cfg *config.Config, configPath string, localPassword string) {
|
||||
}
|
||||
}
|
||||
|
||||
// StartServiceBackground starts the proxy service in a background goroutine
|
||||
// and returns a cancel function for shutdown and a done channel.
|
||||
func StartServiceBackground(cfg *config.Config, configPath string, localPassword string) (cancel func(), done <-chan struct{}) {
|
||||
builder := cliproxy.NewBuilder().
|
||||
WithConfig(cfg).
|
||||
WithConfigPath(configPath).
|
||||
WithLocalManagementPassword(localPassword)
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
doneCh := make(chan struct{})
|
||||
|
||||
service, err := builder.Build()
|
||||
if err != nil {
|
||||
log.Errorf("failed to build proxy service: %v", err)
|
||||
close(doneCh)
|
||||
return cancelFn, doneCh
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(doneCh)
|
||||
if err := service.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("proxy service exited with error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return cancelFn, doneCh
|
||||
}
|
||||
|
||||
// WaitForCloudDeploy waits indefinitely for shutdown signals in cloud deploy mode
|
||||
// when no configuration file is available.
|
||||
func WaitForCloudDeploy() {
|
||||
|
||||
@@ -97,6 +97,10 @@ type Config struct {
|
||||
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
||||
|
||||
// ClaudeHeaderDefaults configures default header values for Claude API requests.
|
||||
// These are used as fallbacks when the client does not send its own headers.
|
||||
ClaudeHeaderDefaults ClaudeHeaderDefaults `yaml:"claude-header-defaults" json:"claude-header-defaults"`
|
||||
|
||||
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
|
||||
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"`
|
||||
|
||||
@@ -130,6 +134,15 @@ type Config struct {
|
||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
|
||||
// when the client does not send them. Update these when Claude Code releases a new version.
|
||||
type ClaudeHeaderDefaults struct {
|
||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
// TLSConfig holds HTTPS server settings.
|
||||
type TLSConfig struct {
|
||||
// Enable toggles HTTPS server mode.
|
||||
@@ -301,6 +314,10 @@ type CloakConfig struct {
|
||||
// SensitiveWords is a list of words to obfuscate with zero-width characters.
|
||||
// This can help bypass certain content filters.
|
||||
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
|
||||
|
||||
// CacheUserID controls whether Claude user_id values are cached per API key.
|
||||
// When false, a fresh random user_id is generated for every request.
|
||||
CacheUserID *bool `yaml:"cache-user-id,omitempty" json:"cache-user-id,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeKey represents the configuration for a Claude API key,
|
||||
@@ -368,6 +385,9 @@ type CodexKey struct {
|
||||
// If empty, the default Codex API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
|
||||
// Websockets enables the Responses API websocket transport for this credential.
|
||||
Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"`
|
||||
|
||||
// ProxyURL overrides the global proxy setting for this API key if provided.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
@@ -632,9 +652,6 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
}
|
||||
|
||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
||||
syncInlineAccessProvider(&cfg)
|
||||
|
||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||
cfg.SanitizeGeminiKeys()
|
||||
|
||||
@@ -739,14 +756,46 @@ func payloadRawString(value any) ([]byte, bool) {
|
||||
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||
// It also injects default aliases for channels that have built-in defaults (e.g., kiro)
|
||||
// when no user-configured aliases exist for those channels.
|
||||
func (cfg *Config) SanitizeOAuthModelAlias() {
|
||||
if cfg == nil || len(cfg.OAuthModelAlias) == 0 {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Inject channel defaults when the channel is absent in user config.
|
||||
// Presence is checked case-insensitively and includes explicit nil/empty markers.
|
||||
if cfg.OAuthModelAlias == nil {
|
||||
cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
|
||||
}
|
||||
hasChannel := func(channel string) bool {
|
||||
for k := range cfg.OAuthModelAlias {
|
||||
if strings.EqualFold(strings.TrimSpace(k), channel) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
if !hasChannel("kiro") {
|
||||
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
|
||||
}
|
||||
if !hasChannel("github-copilot") {
|
||||
cfg.OAuthModelAlias["github-copilot"] = defaultGitHubCopilotAliases()
|
||||
}
|
||||
|
||||
if len(cfg.OAuthModelAlias) == 0 {
|
||||
return
|
||||
}
|
||||
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
|
||||
for rawChannel, aliases := range cfg.OAuthModelAlias {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(aliases) == 0 {
|
||||
if channel == "" {
|
||||
continue
|
||||
}
|
||||
// Preserve channels that were explicitly set to empty/nil – they act
|
||||
// as "disabled" markers so default injection won't re-add them (#222).
|
||||
if len(aliases) == 0 {
|
||||
out[channel] = nil
|
||||
continue
|
||||
}
|
||||
seenAlias := make(map[string]struct{}, len(aliases))
|
||||
@@ -888,18 +937,6 @@ func normalizeModelPrefix(prefix string) string {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func syncInlineAccessProvider(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
if len(cfg.APIKeys) == 0 {
|
||||
if provider := cfg.ConfigAPIKeyProvider(); provider != nil && len(provider.APIKeys) > 0 {
|
||||
cfg.APIKeys = append([]string(nil), provider.APIKeys...)
|
||||
}
|
||||
}
|
||||
cfg.Access.Providers = nil
|
||||
}
|
||||
|
||||
// looksLikeBcrypt returns true if the provided string appears to be a bcrypt hash.
|
||||
func looksLikeBcrypt(s string) bool {
|
||||
return len(s) > 4 && (s[:4] == "$2a$" || s[:4] == "$2b$" || s[:4] == "$2y$")
|
||||
@@ -987,7 +1024,7 @@ func hashSecret(secret string) (string, error) {
|
||||
// SaveConfigPreserveComments writes the config back to YAML while preserving existing comments
|
||||
// and key ordering by loading the original file into a yaml.Node tree and updating values in-place.
|
||||
func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||
persistCfg := sanitizeConfigForPersist(cfg)
|
||||
persistCfg := cfg
|
||||
// Load original YAML as a node tree to preserve comments and ordering.
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
@@ -1055,16 +1092,6 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func sanitizeConfigForPersist(cfg *Config) *Config {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *cfg
|
||||
clone.SDKConfig = cfg.SDKConfig
|
||||
clone.SDKConfig.Access = AccessConfig{}
|
||||
return &clone
|
||||
}
|
||||
|
||||
// SaveConfigPreserveCommentsUpdateNestedScalar updates a nested scalar key path like ["a","b"]
|
||||
// while preserving comments and positions.
|
||||
func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []string, value string) error {
|
||||
@@ -1161,8 +1188,13 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
|
||||
|
||||
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
|
||||
// key order and comments of existing keys in dst. New keys are only added if their
|
||||
// value is non-zero to avoid polluting the config with defaults.
|
||||
func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
// value is non-zero and not a known default to avoid polluting the config with defaults.
|
||||
func mergeMappingPreserve(dst, src *yaml.Node, path ...[]string) {
|
||||
var currentPath []string
|
||||
if len(path) > 0 {
|
||||
currentPath = path[0]
|
||||
}
|
||||
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
@@ -1176,16 +1208,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
sk := src.Content[i]
|
||||
sv := src.Content[i+1]
|
||||
idx := findMapKeyIndex(dst, sk.Value)
|
||||
childPath := appendPath(currentPath, sk.Value)
|
||||
if idx >= 0 {
|
||||
// Merge into existing value node (always update, even to zero values)
|
||||
dv := dst.Content[idx+1]
|
||||
mergeNodePreserve(dv, sv)
|
||||
mergeNodePreserve(dv, sv, childPath)
|
||||
} else {
|
||||
// New key: only add if value is non-zero to avoid polluting config with defaults
|
||||
if isZeroValueNode(sv) {
|
||||
// New key: only add if value is non-zero and not a known default
|
||||
candidate := deepCopyNode(sv)
|
||||
pruneKnownDefaultsInNewNode(childPath, candidate)
|
||||
if isKnownDefaultValue(childPath, candidate) {
|
||||
continue
|
||||
}
|
||||
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
|
||||
dst.Content = append(dst.Content, deepCopyNode(sk), candidate)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1193,7 +1228,12 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
// mergeNodePreserve merges src into dst for scalars, mappings and sequences while
|
||||
// reusing destination nodes to keep comments and anchors. For sequences, it updates
|
||||
// in-place by index.
|
||||
func mergeNodePreserve(dst, src *yaml.Node) {
|
||||
func mergeNodePreserve(dst, src *yaml.Node, path ...[]string) {
|
||||
var currentPath []string
|
||||
if len(path) > 0 {
|
||||
currentPath = path[0]
|
||||
}
|
||||
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
@@ -1202,7 +1242,7 @@ func mergeNodePreserve(dst, src *yaml.Node) {
|
||||
if dst.Kind != yaml.MappingNode {
|
||||
copyNodeShallow(dst, src)
|
||||
}
|
||||
mergeMappingPreserve(dst, src)
|
||||
mergeMappingPreserve(dst, src, currentPath)
|
||||
case yaml.SequenceNode:
|
||||
// Preserve explicit null style if dst was null and src is empty sequence
|
||||
if dst.Kind == yaml.ScalarNode && dst.Tag == "!!null" && len(src.Content) == 0 {
|
||||
@@ -1225,7 +1265,7 @@ func mergeNodePreserve(dst, src *yaml.Node) {
|
||||
dst.Content[i] = deepCopyNode(src.Content[i])
|
||||
continue
|
||||
}
|
||||
mergeNodePreserve(dst.Content[i], src.Content[i])
|
||||
mergeNodePreserve(dst.Content[i], src.Content[i], currentPath)
|
||||
if dst.Content[i] != nil && src.Content[i] != nil &&
|
||||
dst.Content[i].Kind == yaml.MappingNode && src.Content[i].Kind == yaml.MappingNode {
|
||||
pruneMissingMapKeys(dst.Content[i], src.Content[i])
|
||||
@@ -1267,6 +1307,94 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
|
||||
return -1
|
||||
}
|
||||
|
||||
// appendPath appends a key to the path, returning a new slice to avoid modifying the original.
|
||||
func appendPath(path []string, key string) []string {
|
||||
if len(path) == 0 {
|
||||
return []string{key}
|
||||
}
|
||||
newPath := make([]string, len(path)+1)
|
||||
copy(newPath, path)
|
||||
newPath[len(path)] = key
|
||||
return newPath
|
||||
}
|
||||
|
||||
// isKnownDefaultValue returns true if the given node at the specified path
|
||||
// represents a known default value that should not be written to the config file.
|
||||
// This prevents non-zero defaults from polluting the config.
|
||||
func isKnownDefaultValue(path []string, node *yaml.Node) bool {
|
||||
// First check if it's a zero value
|
||||
if isZeroValueNode(node) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Match known non-zero defaults by exact dotted path.
|
||||
if len(path) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
fullPath := strings.Join(path, ".")
|
||||
|
||||
// Check string defaults
|
||||
if node.Kind == yaml.ScalarNode && node.Tag == "!!str" {
|
||||
switch fullPath {
|
||||
case "pprof.addr":
|
||||
return node.Value == DefaultPprofAddr
|
||||
case "remote-management.panel-github-repository":
|
||||
return node.Value == DefaultPanelGitHubRepository
|
||||
case "routing.strategy":
|
||||
return node.Value == "round-robin"
|
||||
}
|
||||
}
|
||||
|
||||
// Check integer defaults
|
||||
if node.Kind == yaml.ScalarNode && node.Tag == "!!int" {
|
||||
switch fullPath {
|
||||
case "error-logs-max-files":
|
||||
return node.Value == "10"
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// pruneKnownDefaultsInNewNode removes default-valued descendants from a new node
|
||||
// before it is appended into the destination YAML tree.
|
||||
func pruneKnownDefaultsInNewNode(path []string, node *yaml.Node) {
|
||||
if node == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch node.Kind {
|
||||
case yaml.MappingNode:
|
||||
filtered := make([]*yaml.Node, 0, len(node.Content))
|
||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||
keyNode := node.Content[i]
|
||||
valueNode := node.Content[i+1]
|
||||
if keyNode == nil || valueNode == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
childPath := appendPath(path, keyNode.Value)
|
||||
if isKnownDefaultValue(childPath, valueNode) {
|
||||
continue
|
||||
}
|
||||
|
||||
pruneKnownDefaultsInNewNode(childPath, valueNode)
|
||||
if (valueNode.Kind == yaml.MappingNode || valueNode.Kind == yaml.SequenceNode) &&
|
||||
len(valueNode.Content) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
filtered = append(filtered, keyNode, valueNode)
|
||||
}
|
||||
node.Content = filtered
|
||||
case yaml.SequenceNode:
|
||||
for _, child := range node.Content {
|
||||
pruneKnownDefaultsInNewNode(path, child)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isZeroValueNode returns true if the YAML node represents a zero/default value
|
||||
// that should not be written as a new key to preserve config cleanliness.
|
||||
// For mappings and sequences, recursively checks if all children are zero values.
|
||||
|
||||
@@ -20,6 +20,43 @@ var antigravityModelConversionTable = map[string]string{
|
||||
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||
}
|
||||
|
||||
// defaultKiroAliases returns the default oauth-model-alias configuration
|
||||
// for the kiro channel. Maps kiro-prefixed model names to standard Claude model
|
||||
// names so that clients like Claude Code can use standard names directly.
|
||||
func defaultKiroAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
// Sonnet 4.5
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||
// Sonnet 4
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
|
||||
// Opus 4.6
|
||||
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
|
||||
// Opus 4.5
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
|
||||
// Haiku 4.5
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
|
||||
}
|
||||
}
|
||||
|
||||
// defaultGitHubCopilotAliases returns default oauth-model-alias entries that
|
||||
// expose Claude hyphen-style IDs for GitHub Copilot Claude models.
|
||||
// This keeps compatibility with clients (e.g. Claude Code) that use
|
||||
// Anthropic-style model IDs like "claude-opus-4-6".
|
||||
func defaultGitHubCopilotAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
|
||||
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
|
||||
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
|
||||
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
|
||||
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||
}
|
||||
}
|
||||
|
||||
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
||||
// for the antigravity channel when neither field exists.
|
||||
func defaultAntigravityAliases() []OAuthModelAlias {
|
||||
|
||||
@@ -54,3 +54,208 @@ func TestSanitizeOAuthModelAlias_AllowsMultipleAliasesForSameName(t *testing.T)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroAliases(t *testing.T) {
|
||||
// When no kiro aliases are configured, defaults should be injected
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"codex": {
|
||||
{Name: "gpt-5", Alias: "g5"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) == 0 {
|
||||
t.Fatal("expected default kiro aliases to be injected")
|
||||
}
|
||||
|
||||
// Check that standard Claude model names are present
|
||||
aliasSet := make(map[string]bool)
|
||||
for _, a := range kiroAliases {
|
||||
aliasSet[a.Alias] = true
|
||||
}
|
||||
expectedAliases := []string{
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-sonnet-4",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-haiku-4-5",
|
||||
}
|
||||
for _, expected := range expectedAliases {
|
||||
if !aliasSet[expected] {
|
||||
t.Fatalf("expected default kiro alias %q to be present", expected)
|
||||
}
|
||||
}
|
||||
|
||||
// All should have fork=true
|
||||
for _, a := range kiroAliases {
|
||||
if !a.Fork {
|
||||
t.Fatalf("expected all default kiro aliases to have fork=true, got fork=false for %q", a.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
// Codex aliases should still be preserved
|
||||
if len(cfg.OAuthModelAlias["codex"]) != 1 {
|
||||
t.Fatal("expected codex aliases to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_InjectsDefaultGitHubCopilotAliases(t *testing.T) {
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"codex": {
|
||||
{Name: "gpt-5", Alias: "g5"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
|
||||
if len(copilotAliases) == 0 {
|
||||
t.Fatal("expected default github-copilot aliases to be injected")
|
||||
}
|
||||
|
||||
aliasSet := make(map[string]bool, len(copilotAliases))
|
||||
for _, a := range copilotAliases {
|
||||
aliasSet[a.Alias] = true
|
||||
if !a.Fork {
|
||||
t.Fatalf("expected all default github-copilot aliases to have fork=true, got fork=false for %q", a.Alias)
|
||||
}
|
||||
}
|
||||
expectedAliases := []string{
|
||||
"claude-haiku-4-5",
|
||||
"claude-opus-4-1",
|
||||
"claude-opus-4-5",
|
||||
"claude-opus-4-6",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4-6",
|
||||
}
|
||||
for _, expected := range expectedAliases {
|
||||
if !aliasSet[expected] {
|
||||
t.Fatalf("expected default github-copilot alias %q to be present", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
|
||||
// When user has configured kiro aliases, defaults should NOT be injected
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"kiro": {
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "my-custom-sonnet", Fork: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) != 1 {
|
||||
t.Fatalf("expected 1 user-configured kiro alias, got %d", len(kiroAliases))
|
||||
}
|
||||
if kiroAliases[0].Alias != "my-custom-sonnet" {
|
||||
t.Fatalf("expected user alias to be preserved, got %q", kiroAliases[0].Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserGitHubCopilotAliases(t *testing.T) {
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"github-copilot": {
|
||||
{Name: "claude-opus-4.6", Alias: "my-opus", Fork: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
|
||||
if len(copilotAliases) != 1 {
|
||||
t.Fatalf("expected 1 user-configured github-copilot alias, got %d", len(copilotAliases))
|
||||
}
|
||||
if copilotAliases[0].Alias != "my-opus" {
|
||||
t.Fatalf("expected user alias to be preserved, got %q", copilotAliases[0].Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) {
|
||||
// When user explicitly deletes kiro aliases (key exists with nil value),
|
||||
// defaults should NOT be re-injected on subsequent sanitize calls (#222).
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"kiro": nil, // explicitly deleted
|
||||
"codex": {{Name: "gpt-5", Alias: "g5"}},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) != 0 {
|
||||
t.Fatalf("expected kiro aliases to remain empty after explicit deletion, got %d aliases", len(kiroAliases))
|
||||
}
|
||||
// The key itself must still be present to prevent re-injection on next reload
|
||||
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
|
||||
t.Fatal("expected kiro key to be preserved as nil marker after sanitization")
|
||||
}
|
||||
// Other channels should be unaffected
|
||||
if len(cfg.OAuthModelAlias["codex"]) != 1 {
|
||||
t.Fatal("expected codex aliases to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_GitHubCopilotDoesNotReinjectAfterExplicitDeletion(t *testing.T) {
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"github-copilot": nil, // explicitly deleted
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
|
||||
if len(copilotAliases) != 0 {
|
||||
t.Fatalf("expected github-copilot aliases to remain empty after explicit deletion, got %d aliases", len(copilotAliases))
|
||||
}
|
||||
if _, exists := cfg.OAuthModelAlias["github-copilot"]; !exists {
|
||||
t.Fatal("expected github-copilot key to be preserved as nil marker after sanitization")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) {
|
||||
// Same as above but with empty slice instead of nil (PUT with empty body).
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"kiro": {}, // explicitly set to empty
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
if len(cfg.OAuthModelAlias["kiro"]) != 0 {
|
||||
t.Fatalf("expected kiro aliases to remain empty, got %d aliases", len(cfg.OAuthModelAlias["kiro"]))
|
||||
}
|
||||
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
|
||||
t.Fatal("expected kiro key to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroWhenEmpty(t *testing.T) {
|
||||
// When OAuthModelAlias is nil, kiro defaults should still be injected
|
||||
cfg := &Config{}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) == 0 {
|
||||
t.Fatal("expected default kiro aliases to be injected when OAuthModelAlias is nil")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,8 +20,9 @@ type SDKConfig struct {
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
|
||||
// Access holds request authentication provider configuration.
|
||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
|
||||
// Default is false (disabled).
|
||||
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
|
||||
|
||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||
@@ -42,65 +43,3 @@ type StreamingConfig struct {
|
||||
// <= 0 disables bootstrap retries. Default is 0.
|
||||
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
||||
}
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
type AccessConfig struct {
|
||||
// Providers lists configured authentication providers.
|
||||
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
||||
}
|
||||
|
||||
// AccessProvider describes a request authentication provider entry.
|
||||
type AccessProvider struct {
|
||||
// Name is the instance identifier for the provider.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Type selects the provider implementation registered via the SDK.
|
||||
Type string `yaml:"type" json:"type"`
|
||||
|
||||
// SDK optionally names a third-party SDK module providing this provider.
|
||||
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
||||
|
||||
// APIKeys lists inline keys for providers that require them.
|
||||
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
||||
|
||||
// Config passes provider-specific options to the implementation.
|
||||
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
||||
AccessProviderTypeConfigAPIKey = "config-api-key"
|
||||
|
||||
// DefaultAccessProviderName is applied when no provider name is supplied.
|
||||
DefaultAccessProviderName = "config-inline"
|
||||
)
|
||||
|
||||
// ConfigAPIKeyProvider returns the first inline API key provider if present.
|
||||
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range c.Access.Providers {
|
||||
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
|
||||
if c.Access.Providers[i].Name == "" {
|
||||
c.Access.Providers[i].Name = DefaultAccessProviderName
|
||||
}
|
||||
return &c.Access.Providers[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
||||
// It returns nil when no keys are supplied.
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
provider := &AccessProvider{
|
||||
Name: DefaultAccessProviderName,
|
||||
Type: AccessProviderTypeConfigAPIKey,
|
||||
APIKeys: append([]string(nil), keys...),
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
@@ -27,4 +27,7 @@ const (
|
||||
|
||||
// Kiro represents the AWS CodeWhisperer (Kiro) provider identifier.
|
||||
Kiro = "kiro"
|
||||
|
||||
// Kilo represents the Kilo AI provider identifier.
|
||||
Kilo = "kilo"
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,6 +29,7 @@ const (
|
||||
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
managementSyncMinInterval = 30 * time.Second
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
@@ -37,11 +39,10 @@ const ManagementFileName = managementAssetName
|
||||
var (
|
||||
lastUpdateCheckMu sync.Mutex
|
||||
lastUpdateCheckTime time.Time
|
||||
|
||||
currentConfigPtr atomic.Pointer[config.Config]
|
||||
disableControlPanel atomic.Bool
|
||||
schedulerOnce sync.Once
|
||||
schedulerConfigPath atomic.Value
|
||||
sfGroup singleflight.Group
|
||||
)
|
||||
|
||||
// SetCurrentConfig stores the latest configuration snapshot for management asset decisions.
|
||||
@@ -50,16 +51,7 @@ func SetCurrentConfig(cfg *config.Config) {
|
||||
currentConfigPtr.Store(nil)
|
||||
return
|
||||
}
|
||||
|
||||
prevDisabled := disableControlPanel.Load()
|
||||
currentConfigPtr.Store(cfg)
|
||||
disableControlPanel.Store(cfg.RemoteManagement.DisableControlPanel)
|
||||
|
||||
if prevDisabled && !cfg.RemoteManagement.DisableControlPanel {
|
||||
lastUpdateCheckMu.Lock()
|
||||
lastUpdateCheckTime = time.Time{}
|
||||
lastUpdateCheckMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// StartAutoUpdater launches a background goroutine that periodically ensures the management asset is up to date.
|
||||
@@ -92,7 +84,7 @@ func runAutoUpdater(ctx context.Context) {
|
||||
log.Debug("management asset auto-updater skipped: config not yet available")
|
||||
return
|
||||
}
|
||||
if disableControlPanel.Load() {
|
||||
if cfg.RemoteManagement.DisableControlPanel {
|
||||
log.Debug("management asset auto-updater skipped: control panel disabled")
|
||||
return
|
||||
}
|
||||
@@ -181,103 +173,106 @@ func FilePath(configFilePath string) string {
|
||||
}
|
||||
|
||||
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
|
||||
// The function is designed to run in a background goroutine and will never panic.
|
||||
// It enforces a 3-hour rate limit to avoid frequent checks on config/auth file changes.
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) {
|
||||
// It coalesces concurrent sync attempts and returns whether the asset exists after the sync attempt.
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) bool {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
if disableControlPanel.Load() {
|
||||
log.Debug("management asset sync skipped: control panel disabled by configuration")
|
||||
return
|
||||
}
|
||||
|
||||
staticDir = strings.TrimSpace(staticDir)
|
||||
if staticDir == "" {
|
||||
log.Debug("management asset sync skipped: empty static directory")
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localFileMissing := false
|
||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||
if errors.Is(errStat, os.ErrNotExist) {
|
||||
localFileMissing = true
|
||||
} else {
|
||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting: check only once every 3 hours
|
||||
lastUpdateCheckMu.Lock()
|
||||
now := time.Now()
|
||||
timeSinceLastCheck := now.Sub(lastUpdateCheckTime)
|
||||
if timeSinceLastCheck < updateCheckInterval {
|
||||
_, _, _ = sfGroup.Do(localPath, func() (interface{}, error) {
|
||||
lastUpdateCheckMu.Lock()
|
||||
now := time.Now()
|
||||
timeSinceLastAttempt := now.Sub(lastUpdateCheckTime)
|
||||
if !lastUpdateCheckTime.IsZero() && timeSinceLastAttempt < managementSyncMinInterval {
|
||||
lastUpdateCheckMu.Unlock()
|
||||
log.Debugf(
|
||||
"management asset sync skipped by throttle: last attempt %v ago (interval %v)",
|
||||
timeSinceLastAttempt.Round(time.Second),
|
||||
managementSyncMinInterval,
|
||||
)
|
||||
return nil, nil
|
||||
}
|
||||
lastUpdateCheckTime = now
|
||||
lastUpdateCheckMu.Unlock()
|
||||
log.Debugf("management asset update check skipped: last check was %v ago (interval: %v)", timeSinceLastCheck.Round(time.Second), updateCheckInterval)
|
||||
return
|
||||
}
|
||||
lastUpdateCheckTime = now
|
||||
lastUpdateCheckMu.Unlock()
|
||||
|
||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||
return
|
||||
}
|
||||
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localHash, err := fileSHA256(localPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.WithError(err).Debug("failed to read local management asset hash")
|
||||
}
|
||||
localHash = ""
|
||||
}
|
||||
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
localFileMissing := false
|
||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||
if errors.Is(errStat, os.ErrNotExist) {
|
||||
localFileMissing = true
|
||||
} else {
|
||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return
|
||||
}
|
||||
|
||||
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
|
||||
log.Debug("management asset is already up to date")
|
||||
return
|
||||
}
|
||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localHash, err := fileSHA256(localPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.WithError(err).Debug("failed to read local management asset hash")
|
||||
}
|
||||
return
|
||||
localHash = ""
|
||||
}
|
||||
log.WithError(err).Warn("failed to download management asset")
|
||||
return
|
||||
}
|
||||
|
||||
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
||||
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
||||
}
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to update management asset on disk")
|
||||
return
|
||||
}
|
||||
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
|
||||
log.Debug("management asset is already up to date")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
log.WithError(err).Warn("failed to download management asset")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
||||
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to update management asset on disk")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
_, err := os.Stat(localPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package misc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
|
||||
func LogCredentialSeparator() {
|
||||
log.Debug(credentialSeparator)
|
||||
}
|
||||
|
||||
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
|
||||
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
|
||||
var data map[string]any
|
||||
|
||||
// Fast path: if source is already a map, just copy it to avoid mutation of original
|
||||
if srcMap, ok := source.(map[string]any); ok {
|
||||
data = make(map[string]any, len(srcMap)+len(metadata))
|
||||
for k, v := range srcMap {
|
||||
data[k] = v
|
||||
}
|
||||
} else {
|
||||
// Slow path: marshal to JSON and back to map to respect JSON tags
|
||||
temp, err := json.Marshal(source)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal source: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(temp, &data); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Merge extra metadata
|
||||
if metadata != nil {
|
||||
if data == nil {
|
||||
data = make(map[string]any)
|
||||
}
|
||||
for k, v := range metadata {
|
||||
data[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
21
internal/registry/kilo_models.go
Normal file
21
internal/registry/kilo_models.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Package registry provides model definitions for various AI service providers.
|
||||
package registry
|
||||
|
||||
// GetKiloModels returns the Kilo model definitions
|
||||
func GetKiloModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
// --- Base Models ---
|
||||
{
|
||||
ID: "kilo/auto",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "kilo",
|
||||
Type: "kilo",
|
||||
DisplayName: "Kilo Auto",
|
||||
Description: "Automatic model selection by Kilo",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,9 @@ import (
|
||||
// - codex
|
||||
// - qwen
|
||||
// - iflow
|
||||
// - kimi
|
||||
// - kiro
|
||||
// - kilo
|
||||
// - github-copilot
|
||||
// - kiro
|
||||
// - amazonq
|
||||
@@ -43,10 +45,14 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
return GetQwenModels()
|
||||
case "iflow":
|
||||
return GetIFlowModels()
|
||||
case "kimi":
|
||||
return GetKimiModels()
|
||||
case "github-copilot":
|
||||
return GetGitHubCopilotModels()
|
||||
case "kiro":
|
||||
return GetKiroModels()
|
||||
case "kilo":
|
||||
return GetKiloModels()
|
||||
case "amazonq":
|
||||
return GetAmazonQModels()
|
||||
case "antigravity":
|
||||
@@ -93,8 +99,10 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
GetOpenAIModels(),
|
||||
GetQwenModels(),
|
||||
GetIFlowModels(),
|
||||
GetKimiModels(),
|
||||
GetGitHubCopilotModels(),
|
||||
GetKiroModels(),
|
||||
GetKiloModels(),
|
||||
GetAmazonQModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
@@ -121,7 +129,19 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
||||
func GetGitHubCopilotModels() []*ModelInfo {
|
||||
now := int64(1732752000) // 2024-11-27
|
||||
return []*ModelInfo{
|
||||
gpt4oEntries := []struct {
|
||||
ID string
|
||||
DisplayName string
|
||||
Description string
|
||||
}{
|
||||
{ID: "gpt-4o-2024-11-20", DisplayName: "GPT-4o (2024-11-20)", Description: "OpenAI GPT-4o 2024-11-20 via GitHub Copilot"},
|
||||
{ID: "gpt-4o-2024-08-06", DisplayName: "GPT-4o (2024-08-06)", Description: "OpenAI GPT-4o 2024-08-06 via GitHub Copilot"},
|
||||
{ID: "gpt-4o-2024-05-13", DisplayName: "GPT-4o (2024-05-13)", Description: "OpenAI GPT-4o 2024-05-13 via GitHub Copilot"},
|
||||
{ID: "gpt-4o", DisplayName: "GPT-4o", Description: "OpenAI GPT-4o via GitHub Copilot"},
|
||||
{ID: "gpt-4-o-preview", DisplayName: "GPT-4-o Preview", Description: "OpenAI GPT-4-o Preview via GitHub Copilot"},
|
||||
}
|
||||
|
||||
models := []*ModelInfo{
|
||||
{
|
||||
ID: "gpt-4.1",
|
||||
Object: "model",
|
||||
@@ -133,6 +153,23 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
},
|
||||
}
|
||||
|
||||
for _, entry := range gpt4oEntries {
|
||||
models = append(models, &ModelInfo{
|
||||
ID: entry.ID,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: entry.DisplayName,
|
||||
Description: entry.Description,
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
})
|
||||
}
|
||||
|
||||
return append(models, []*ModelInfo{
|
||||
{
|
||||
ID: "gpt-5",
|
||||
Object: "model",
|
||||
@@ -144,6 +181,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5-mini",
|
||||
@@ -156,6 +194,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5-codex",
|
||||
@@ -168,6 +207,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1",
|
||||
@@ -180,6 +220,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex",
|
||||
@@ -192,6 +233,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex-mini",
|
||||
@@ -204,6 +246,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex-max",
|
||||
@@ -216,6 +259,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2",
|
||||
@@ -228,6 +272,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2-codex",
|
||||
@@ -240,6 +285,20 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.3-codex",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5.3 Codex",
|
||||
Description: "OpenAI GPT-5.3 Codex via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-haiku-4.5",
|
||||
@@ -277,6 +336,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4.6",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "Claude Opus 4.6",
|
||||
Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4",
|
||||
Object: "model",
|
||||
@@ -301,6 +372,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4.6",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "Claude Sonnet 4.6",
|
||||
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "gemini-2.5-pro",
|
||||
Object: "model",
|
||||
@@ -323,6 +406,17 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 1048576,
|
||||
MaxCompletionTokens: 65536,
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "Gemini 3.1 Pro (Preview)",
|
||||
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
|
||||
ContextLength: 1048576,
|
||||
MaxCompletionTokens: 65536,
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
@@ -357,7 +451,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
},
|
||||
}
|
||||
}...)
|
||||
}
|
||||
|
||||
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
|
||||
@@ -388,6 +482,18 @@ func GetKiroModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4-6",
|
||||
Object: "model",
|
||||
Created: 1739836800, // 2025-02-18
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Claude Sonnet 4.6",
|
||||
Description: "Claude Sonnet 4.6 via Kiro (1.3x credit)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-opus-4-5",
|
||||
Object: "model",
|
||||
@@ -436,6 +542,87 @@ func GetKiroModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
// --- 第三方模型 (通过 Kiro 接入) ---
|
||||
{
|
||||
ID: "kiro-deepseek-3-2",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro DeepSeek 3.2",
|
||||
Description: "DeepSeek 3.2 via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-minimax-m2-1",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro MiniMax M2.1",
|
||||
Description: "MiniMax M2.1 via Kiro",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-qwen3-coder-next",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Qwen3 Coder Next",
|
||||
Description: "Qwen3 Coder Next via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-4o",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-4o",
|
||||
Description: "OpenAI GPT-4o via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-4",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-4",
|
||||
Description: "OpenAI GPT-4 via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 8192,
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-4-turbo",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-4 Turbo",
|
||||
Description: "OpenAI GPT-4 Turbo via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-3-5-turbo",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-3.5 Turbo",
|
||||
Description: "OpenAI GPT-3.5 Turbo via Kiro",
|
||||
ContextLength: 16384,
|
||||
MaxCompletionTokens: 4096,
|
||||
},
|
||||
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
|
||||
{
|
||||
ID: "kiro-claude-opus-4-6-agentic",
|
||||
@@ -449,6 +636,18 @@ func GetKiroModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4-6-agentic",
|
||||
Object: "model",
|
||||
Created: 1739836800, // 2025-02-18
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Claude Sonnet 4.6 (Agentic)",
|
||||
Description: "Claude Sonnet 4.6 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-opus-4-5-agentic",
|
||||
Object: "model",
|
||||
@@ -497,6 +696,42 @@ func GetKiroModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-deepseek-3-2-agentic",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro DeepSeek 3.2 (Agentic)",
|
||||
Description: "DeepSeek 3.2 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-minimax-m2-1-agentic",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro MiniMax M2.1 (Agentic)",
|
||||
Description: "MiniMax M2.1 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-qwen3-coder-next-agentic",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Qwen3 Coder Next (Agentic)",
|
||||
Description: "Qwen3 Coder Next optimized for coding agents (chunked writes)",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,17 @@ func GetClaudeModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-6",
|
||||
Object: "model",
|
||||
Created: 1771372800, // 2026-02-17
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.6 Sonnet",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-6",
|
||||
Object: "model",
|
||||
@@ -40,6 +51,18 @@ func GetClaudeModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 128000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-6",
|
||||
Object: "model",
|
||||
Created: 1771286400, // 2026-02-17
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.6 Sonnet",
|
||||
Description: "Best combination of speed and intelligence",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-5-20251101",
|
||||
Object: "model",
|
||||
@@ -173,6 +196,21 @@ func GetGeminiModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-pro-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Pro Preview",
|
||||
Description: "Gemini 3.1 Pro Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
@@ -283,6 +321,21 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-pro-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Pro Preview",
|
||||
Description: "Gemini 3.1 Pro Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
Object: "model",
|
||||
@@ -425,6 +478,21 @@ func GetGeminiCLIModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-pro-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Pro Preview",
|
||||
Description: "Gemini 3.1 Pro Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
@@ -506,6 +574,21 @@ func GetAIStudioModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-pro-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Pro Preview",
|
||||
Description: "Gemini 3.1 Pro Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
@@ -742,6 +825,20 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.3-codex-spark",
|
||||
Object: "model",
|
||||
Created: 1770912000,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.3",
|
||||
DisplayName: "GPT 5.3 Codex Spark",
|
||||
Description: "Ultra-fast coding model.",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -774,6 +871,19 @@ func GetQwenModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 2048,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
{
|
||||
ID: "coder-model",
|
||||
Object: "model",
|
||||
Created: 1771171200,
|
||||
OwnedBy: "qwen",
|
||||
Type: "qwen",
|
||||
Version: "3.5",
|
||||
DisplayName: "Qwen 3.5 Plus",
|
||||
Description: "efficient hybrid model with leading coding performance",
|
||||
ContextLength: 1048576,
|
||||
MaxCompletionTokens: 65536,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
{
|
||||
ID: "vision-model",
|
||||
Object: "model",
|
||||
@@ -806,18 +916,12 @@ func GetIFlowModels() []*ModelInfo {
|
||||
Created int64
|
||||
Thinking *ThinkingSupport
|
||||
}{
|
||||
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
|
||||
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
|
||||
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
|
||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||
@@ -826,10 +930,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -863,11 +964,15 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-flash-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 128000},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
"tab_flash_lite_preview": {},
|
||||
}
|
||||
|
||||
@@ -601,8 +601,7 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
now := time.Now()
|
||||
registration.QuotaExceededClients[clientID] = &now
|
||||
registration.QuotaExceededClients[clientID] = new(time.Now())
|
||||
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,12 +164,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the AI Studio API.
|
||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -254,7 +254,6 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(first wsrelay.StreamEvent) {
|
||||
defer close(out)
|
||||
var param any
|
||||
@@ -318,7 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
}
|
||||
}(firstEvent)
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request using the AI Studio API.
|
||||
|
||||
@@ -54,8 +54,78 @@ const (
|
||||
var (
|
||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
randSourceMutex sync.Mutex
|
||||
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
|
||||
// from any antigravity auth. Empty fetches never overwrite this cache.
|
||||
antigravityPrimaryModelsCache struct {
|
||||
mu sync.RWMutex
|
||||
models []*registry.ModelInfo
|
||||
}
|
||||
)
|
||||
|
||||
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*registry.ModelInfo, 0, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil || strings.TrimSpace(model.ID) == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, cloneAntigravityModelInfo(model))
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
|
||||
if model == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *model
|
||||
if len(model.SupportedGenerationMethods) > 0 {
|
||||
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
|
||||
}
|
||||
if len(model.SupportedParameters) > 0 {
|
||||
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||
}
|
||||
if model.Thinking != nil {
|
||||
thinkingClone := *model.Thinking
|
||||
if len(model.Thinking.Levels) > 0 {
|
||||
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||
}
|
||||
clone.Thinking = &thinkingClone
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
|
||||
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
|
||||
cloned := cloneAntigravityModels(models)
|
||||
if len(cloned) == 0 {
|
||||
return false
|
||||
}
|
||||
antigravityPrimaryModelsCache.mu.Lock()
|
||||
antigravityPrimaryModelsCache.models = cloned
|
||||
antigravityPrimaryModelsCache.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||
antigravityPrimaryModelsCache.mu.RLock()
|
||||
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
|
||||
antigravityPrimaryModelsCache.mu.RUnlock()
|
||||
return cloned
|
||||
}
|
||||
|
||||
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||
models := loadAntigravityPrimaryModels()
|
||||
if len(models) > 0 {
|
||||
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||
type AntigravityExecutor struct {
|
||||
cfg *config.Config
|
||||
@@ -232,7 +302,7 @@ attemptLoop:
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
@@ -436,7 +506,7 @@ attemptLoop:
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
return resp, nil
|
||||
@@ -645,7 +715,7 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -775,7 +845,6 @@ attemptLoop:
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -820,7 +889,7 @@ attemptLoop:
|
||||
reporter.ensurePublished(ctx)
|
||||
}
|
||||
}(httpResp)
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
@@ -968,7 +1037,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
lastStatus = httpResp.StatusCode
|
||||
@@ -1008,8 +1077,8 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
exec := &AntigravityExecutor{cfg: cfg}
|
||||
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil || token == "" {
|
||||
return nil
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
auth = updatedAuth
|
||||
}
|
||||
@@ -1021,7 +1090,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
modelsURL := baseURL + antigravityModelsPath
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
|
||||
if errReq != nil {
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
@@ -1033,13 +1102,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
@@ -1051,19 +1120,27 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
return nil
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
@@ -1108,9 +1185,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
}
|
||||
models = append(models, modelInfo)
|
||||
}
|
||||
if len(models) == 0 {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
storeAntigravityPrimaryModels(models)
|
||||
return models
|
||||
}
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestAntigravityBuildRequest_SanitizesGeminiToolSchema(t *testing.T) {
|
||||
body := buildRequestBodyFromPayload(t, "gemini-2.5-pro")
|
||||
|
||||
decl := extractFirstFunctionDeclaration(t, body)
|
||||
if _, ok := decl["parametersJsonSchema"]; ok {
|
||||
t.Fatalf("parametersJsonSchema should be renamed to parameters")
|
||||
}
|
||||
|
||||
params, ok := decl["parameters"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parameters missing or invalid type")
|
||||
}
|
||||
assertSchemaSanitizedAndPropertyPreserved(t, params)
|
||||
}
|
||||
|
||||
func TestAntigravityBuildRequest_SanitizesAntigravityToolSchema(t *testing.T) {
|
||||
body := buildRequestBodyFromPayload(t, "claude-opus-4-6")
|
||||
|
||||
decl := extractFirstFunctionDeclaration(t, body)
|
||||
params, ok := decl["parameters"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parameters missing or invalid type")
|
||||
}
|
||||
assertSchemaSanitizedAndPropertyPreserved(t, params)
|
||||
}
|
||||
|
||||
func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
executor := &AntigravityExecutor{}
|
||||
auth := &cliproxyauth.Auth{}
|
||||
payload := []byte(`{
|
||||
"request": {
|
||||
"tools": [
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"name": "tool_1",
|
||||
"parametersJsonSchema": {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"$id": "root-schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"$id": {"type": "string"},
|
||||
"arg": {
|
||||
"type": "object",
|
||||
"prefill": "hello",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["a", "b"],
|
||||
"enumTitles": ["A", "B"]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"patternProperties": {
|
||||
"^x-": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
|
||||
req, err := executor.buildRequest(context.Background(), auth, "token", modelName, payload, false, "", "https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("buildRequest error: %v", err)
|
||||
}
|
||||
|
||||
raw, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body error: %v", err)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(raw, &body); err != nil {
|
||||
t.Fatalf("unmarshal request body error: %v, body=%s", err, string(raw))
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func extractFirstFunctionDeclaration(t *testing.T, body map[string]any) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
request, ok := body["request"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("request missing or invalid type")
|
||||
}
|
||||
tools, ok := request["tools"].([]any)
|
||||
if !ok || len(tools) == 0 {
|
||||
t.Fatalf("tools missing or empty")
|
||||
}
|
||||
tool, ok := tools[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("first tool invalid type")
|
||||
}
|
||||
decls, ok := tool["function_declarations"].([]any)
|
||||
if !ok || len(decls) == 0 {
|
||||
t.Fatalf("function_declarations missing or empty")
|
||||
}
|
||||
decl, ok := decls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("first function declaration invalid type")
|
||||
}
|
||||
return decl
|
||||
}
|
||||
|
||||
func assertSchemaSanitizedAndPropertyPreserved(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
|
||||
if _, ok := params["$id"]; ok {
|
||||
t.Fatalf("root $id should be removed from schema")
|
||||
}
|
||||
if _, ok := params["patternProperties"]; ok {
|
||||
t.Fatalf("patternProperties should be removed from schema")
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("properties missing or invalid type")
|
||||
}
|
||||
if _, ok := props["$id"]; !ok {
|
||||
t.Fatalf("property named $id should be preserved")
|
||||
}
|
||||
|
||||
arg, ok := props["arg"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("arg property missing or invalid type")
|
||||
}
|
||||
if _, ok := arg["prefill"]; ok {
|
||||
t.Fatalf("prefill should be removed from nested schema")
|
||||
}
|
||||
|
||||
argProps, ok := arg["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("arg.properties missing or invalid type")
|
||||
}
|
||||
mode, ok := argProps["mode"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("mode property missing or invalid type")
|
||||
}
|
||||
if _, ok := mode["enumTitles"]; ok {
|
||||
t.Fatalf("enumTitles should be removed from nested schema")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func resetAntigravityPrimaryModelsCacheForTest() {
|
||||
antigravityPrimaryModelsCache.mu.Lock()
|
||||
antigravityPrimaryModelsCache.models = nil
|
||||
antigravityPrimaryModelsCache.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
|
||||
resetAntigravityPrimaryModelsCacheForTest()
|
||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||
|
||||
seed := []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4-5"},
|
||||
{ID: "gemini-2.5-pro"},
|
||||
}
|
||||
if updated := storeAntigravityPrimaryModels(seed); !updated {
|
||||
t.Fatal("expected non-empty model list to update primary cache")
|
||||
}
|
||||
|
||||
if updated := storeAntigravityPrimaryModels(nil); updated {
|
||||
t.Fatal("expected nil model list not to overwrite primary cache")
|
||||
}
|
||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
|
||||
t.Fatal("expected empty model list not to overwrite primary cache")
|
||||
}
|
||||
|
||||
got := loadAntigravityPrimaryModels()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected cached model count 2, got %d", len(got))
|
||||
}
|
||||
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
|
||||
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
|
||||
resetAntigravityPrimaryModelsCacheForTest()
|
||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||
|
||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
|
||||
ID: "gpt-5",
|
||||
DisplayName: "GPT-5",
|
||||
SupportedGenerationMethods: []string{"generateContent"},
|
||||
SupportedParameters: []string{"temperature"},
|
||||
Thinking: ®istry.ThinkingSupport{
|
||||
Levels: []string{"high"},
|
||||
},
|
||||
}}); !updated {
|
||||
t.Fatal("expected model cache update")
|
||||
}
|
||||
|
||||
got := loadAntigravityPrimaryModels()
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one cached model, got %d", len(got))
|
||||
}
|
||||
got[0].ID = "mutated-id"
|
||||
if len(got[0].SupportedGenerationMethods) > 0 {
|
||||
got[0].SupportedGenerationMethods[0] = "mutated-method"
|
||||
}
|
||||
if len(got[0].SupportedParameters) > 0 {
|
||||
got[0].SupportedParameters[0] = "mutated-parameter"
|
||||
}
|
||||
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
|
||||
got[0].Thinking.Levels[0] = "mutated-level"
|
||||
}
|
||||
|
||||
again := loadAntigravityPrimaryModels()
|
||||
if len(again) != 1 {
|
||||
t.Fatalf("expected one cached model after mutation, got %d", len(again))
|
||||
}
|
||||
if again[0].ID != "gpt-5" {
|
||||
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
|
||||
}
|
||||
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
|
||||
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
|
||||
}
|
||||
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
|
||||
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
|
||||
}
|
||||
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
|
||||
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -116,7 +117,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
|
||||
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
||||
// based on client type and configuration.
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
@@ -134,7 +135,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
bodyForTranslation := body
|
||||
bodyForUpstream := body
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
@@ -143,7 +144,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -208,7 +209,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
} else {
|
||||
reporter.publish(ctx, parseClaudeUsage(data))
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
|
||||
}
|
||||
var param any
|
||||
@@ -222,11 +223,11 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
data,
|
||||
¶m,
|
||||
)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -257,7 +258,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
|
||||
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
||||
// based on client type and configuration.
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
@@ -275,7 +276,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
bodyForTranslation := body
|
||||
bodyForUpstream := body
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
@@ -284,7 +285,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -329,7 +330,6 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -348,7 +348,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
}
|
||||
// Forward the line as-is to preserve SSE format
|
||||
@@ -375,7 +375,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(
|
||||
@@ -398,7 +398,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
@@ -423,7 +423,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
// Extract betas from body and convert to header (for count_tokens too)
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
body = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
@@ -432,7 +432,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -487,7 +487,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "input_tokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
@@ -638,7 +638,49 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
|
||||
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
||||
func mapStainlessOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return "MacOS"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "freebsd":
|
||||
return "FreeBSD"
|
||||
default:
|
||||
return "Other::" + runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
||||
func mapStainlessArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "arm64":
|
||||
return "arm64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return "other::" + runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
|
||||
hdrDefault := func(cfgVal, fallback string) string {
|
||||
if cfgVal != "" {
|
||||
return cfgVal
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
var hd config.ClaudeHeaderDefaults
|
||||
if cfg != nil {
|
||||
hd = cfg.ClaudeHeaderDefaults
|
||||
}
|
||||
|
||||
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
|
||||
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
|
||||
if isAnthropicBase && useAPIKey {
|
||||
@@ -685,16 +727,17 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", "v24.3.0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", "0.55.1")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", "arm64")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", "MacOS")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", "60")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "claude-cli/1.0.83 (external, cli)")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
|
||||
r.Header.Set("Connection", "keep-alive")
|
||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||
if stream {
|
||||
@@ -702,6 +745,8 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
} else {
|
||||
r.Header.Set("Accept", "application/json")
|
||||
}
|
||||
// Keep OS/Arch mapping dynamic (not configurable).
|
||||
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
@@ -753,11 +798,21 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
// Collect built-in tool names (those with a non-empty "type" field) so we can
|
||||
// skip them consistently in both tools and message history.
|
||||
builtinTools := map[string]bool{}
|
||||
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
|
||||
builtinTools[name] = true
|
||||
}
|
||||
|
||||
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
||||
tools.ForEach(func(index, tool gjson.Result) bool {
|
||||
// Skip built-in tools (web_search, code_execution, etc.) which have
|
||||
// a "type" field and require their name to remain unchanged.
|
||||
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
|
||||
if n := tool.Get("name").String(); n != "" {
|
||||
builtinTools[n] = true
|
||||
}
|
||||
return true
|
||||
}
|
||||
name := tool.Get("name").String()
|
||||
@@ -772,7 +827,7 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
|
||||
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
|
||||
name := gjson.GetBytes(body, "tool_choice.name").String()
|
||||
if name != "" && !strings.HasPrefix(name, prefix) {
|
||||
if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] {
|
||||
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
|
||||
}
|
||||
}
|
||||
@@ -784,15 +839,38 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
return true
|
||||
}
|
||||
content.ForEach(func(contentIndex, part gjson.Result) bool {
|
||||
if part.Get("type").String() != "tool_use" {
|
||||
return true
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "tool_use":
|
||||
name := part.Get("name").String()
|
||||
if name == "" || strings.HasPrefix(name, prefix) || builtinTools[name] {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+name)
|
||||
case "tool_reference":
|
||||
toolName := part.Get("tool_name").String()
|
||||
if toolName == "" || strings.HasPrefix(toolName, prefix) || builtinTools[toolName] {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+toolName)
|
||||
case "tool_result":
|
||||
// Handle nested tool_reference blocks inside tool_result.content[]
|
||||
nestedContent := part.Get("content")
|
||||
if nestedContent.Exists() && nestedContent.IsArray() {
|
||||
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
|
||||
if nestedPart.Get("type").String() == "tool_reference" {
|
||||
nestedToolName := nestedPart.Get("tool_name").String()
|
||||
if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) && !builtinTools[nestedToolName] {
|
||||
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
name := part.Get("name").String()
|
||||
if name == "" || strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+name)
|
||||
return true
|
||||
})
|
||||
return true
|
||||
@@ -811,15 +889,38 @@ func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
|
||||
return body
|
||||
}
|
||||
content.ForEach(func(index, part gjson.Result) bool {
|
||||
if part.Get("type").String() != "tool_use" {
|
||||
return true
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "tool_use":
|
||||
name := part.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("content.%d.name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
|
||||
case "tool_reference":
|
||||
toolName := part.Get("tool_name").String()
|
||||
if !strings.HasPrefix(toolName, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("content.%d.tool_name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix))
|
||||
case "tool_result":
|
||||
// Handle nested tool_reference blocks inside tool_result.content[]
|
||||
nestedContent := part.Get("content")
|
||||
if nestedContent.Exists() && nestedContent.IsArray() {
|
||||
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
|
||||
if nestedPart.Get("type").String() == "tool_reference" {
|
||||
nestedToolName := nestedPart.Get("tool_name").String()
|
||||
if strings.HasPrefix(nestedToolName, prefix) {
|
||||
nestedPath := fmt.Sprintf("content.%d.content.%d.tool_name", index.Int(), nestedIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, nestedPath, strings.TrimPrefix(nestedToolName, prefix))
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
name := part.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("content.%d.name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
|
||||
return true
|
||||
})
|
||||
return body
|
||||
@@ -834,15 +935,34 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
|
||||
return line
|
||||
}
|
||||
contentBlock := gjson.GetBytes(payload, "content_block")
|
||||
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
|
||||
if !contentBlock.Exists() {
|
||||
return line
|
||||
}
|
||||
name := contentBlock.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return line
|
||||
}
|
||||
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
|
||||
if err != nil {
|
||||
|
||||
blockType := contentBlock.Get("type").String()
|
||||
var updated []byte
|
||||
var err error
|
||||
|
||||
switch blockType {
|
||||
case "tool_use":
|
||||
name := contentBlock.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return line
|
||||
}
|
||||
updated, err = sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
case "tool_reference":
|
||||
toolName := contentBlock.Get("tool_name").String()
|
||||
if !strings.HasPrefix(toolName, prefix) {
|
||||
return line
|
||||
}
|
||||
updated, err = sjson.SetBytes(payload, "content_block.tool_name", strings.TrimPrefix(toolName, prefix))
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
default:
|
||||
return line
|
||||
}
|
||||
|
||||
@@ -862,10 +982,10 @@ func getClientUserAgent(ctx context.Context) string {
|
||||
}
|
||||
|
||||
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
|
||||
// Returns (cloakMode, strictMode, sensitiveWords).
|
||||
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
|
||||
// Returns (cloakMode, strictMode, sensitiveWords, cacheUserID).
|
||||
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bool) {
|
||||
if auth == nil || auth.Attributes == nil {
|
||||
return "auto", false, nil
|
||||
return "auto", false, nil, false
|
||||
}
|
||||
|
||||
cloakMode := auth.Attributes["cloak_mode"]
|
||||
@@ -883,7 +1003,9 @@ func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
|
||||
}
|
||||
}
|
||||
|
||||
return cloakMode, strictMode, sensitiveWords
|
||||
cacheUserID := strings.EqualFold(strings.TrimSpace(auth.Attributes["cloak_cache_user_id"]), "true")
|
||||
|
||||
return cloakMode, strictMode, sensitiveWords, cacheUserID
|
||||
}
|
||||
|
||||
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
|
||||
@@ -916,16 +1038,24 @@ func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *c
|
||||
}
|
||||
|
||||
// injectFakeUserID generates and injects a fake user ID into the request metadata.
|
||||
func injectFakeUserID(payload []byte) []byte {
|
||||
// When useCache is false, a new user ID is generated for every call.
|
||||
func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
|
||||
generateID := func() string {
|
||||
if useCache {
|
||||
return cachedUserID(apiKey)
|
||||
}
|
||||
return generateFakeUserID()
|
||||
}
|
||||
|
||||
metadata := gjson.GetBytes(payload, "metadata")
|
||||
if !metadata.Exists() {
|
||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
|
||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
|
||||
return payload
|
||||
}
|
||||
|
||||
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||
if existingUserID == "" || !isValidUserID(existingUserID) {
|
||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
|
||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
|
||||
}
|
||||
return payload
|
||||
}
|
||||
@@ -962,7 +1092,7 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
|
||||
// applyCloaking applies cloaking transformations to the payload based on config and client.
|
||||
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
|
||||
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string) []byte {
|
||||
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string, apiKey string) []byte {
|
||||
clientUserAgent := getClientUserAgent(ctx)
|
||||
|
||||
// Get cloak config from ClaudeKey configuration
|
||||
@@ -972,16 +1102,20 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
||||
var cloakMode string
|
||||
var strictMode bool
|
||||
var sensitiveWords []string
|
||||
var cacheUserID bool
|
||||
|
||||
if cloakCfg != nil {
|
||||
cloakMode = cloakCfg.Mode
|
||||
strictMode = cloakCfg.StrictMode
|
||||
sensitiveWords = cloakCfg.SensitiveWords
|
||||
if cloakCfg.CacheUserID != nil {
|
||||
cacheUserID = *cloakCfg.CacheUserID
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to auth attributes if no config found
|
||||
if cloakMode == "" {
|
||||
attrMode, attrStrict, attrWords := getCloakConfigFromAuth(auth)
|
||||
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
|
||||
cloakMode = attrMode
|
||||
if !strictMode {
|
||||
strictMode = attrStrict
|
||||
@@ -989,6 +1123,12 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
||||
if len(sensitiveWords) == 0 {
|
||||
sensitiveWords = attrWords
|
||||
}
|
||||
if cloakCfg == nil || cloakCfg.CacheUserID == nil {
|
||||
cacheUserID = attrCache
|
||||
}
|
||||
} else if cloakCfg == nil || cloakCfg.CacheUserID == nil {
|
||||
_, _, _, attrCache := getCloakConfigFromAuth(auth)
|
||||
cacheUserID = attrCache
|
||||
}
|
||||
|
||||
// Determine if cloaking should be applied
|
||||
@@ -1002,7 +1142,7 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
||||
}
|
||||
|
||||
// Inject fake user ID
|
||||
payload = injectFakeUserID(payload)
|
||||
payload = injectFakeUserID(payload, apiKey, cacheUserID)
|
||||
|
||||
// Apply sensitive word obfuscation
|
||||
if len(sensitiveWords) > 0 {
|
||||
|
||||
@@ -2,9 +2,18 @@ package executor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||
@@ -25,6 +34,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_WithToolReference(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"name":"alpha"}],"messages":[{"role":"user","content":[{"type":"tool_reference","tool_name":"beta"},{"type":"tool_reference","tool_name":"proxy_gamma"}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.tool_name").String(); got != "proxy_beta" {
|
||||
t.Fatalf("messages.0.content.0.tool_name = %q, want %q", got, "proxy_beta")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != "proxy_gamma" {
|
||||
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, "proxy_gamma")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
@@ -37,6 +58,97 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_BuiltinToolSkipped(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"type": "web_search_20250305", "name": "web_search", "max_uses": 5},
|
||||
{"name": "Read"}
|
||||
],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}},
|
||||
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_KnownBuiltinInHistoryOnly(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"name": "Read"}
|
||||
],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_CustomToolsPrefixed(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [{"name": "Read"}, {"name": "Write"}],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}},
|
||||
{"type": "tool_use", "name": "Write", "id": "w1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Write" {
|
||||
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Write")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Write" {
|
||||
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"type": "web_search_20250305", "name": "web_search"},
|
||||
{"name": "Read"}
|
||||
],
|
||||
"tool_choice": {"type": "tool", "name": "web_search"}
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
@@ -49,6 +161,18 @@ func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse_WithToolReference(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_reference","tool_name":"proxy_alpha"},{"type":"tool_reference","tool_name":"bravo"}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "content.0.tool_name").String(); got != "alpha" {
|
||||
t.Fatalf("content.0.tool_name = %q, want %q", got, "alpha")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "content.1.tool_name").String(); got != "bravo" {
|
||||
t.Fatalf("content.1.tool_name = %q, want %q", got, "bravo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
|
||||
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||
@@ -61,3 +185,166 @@ func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromStreamLine_WithToolReference(t *testing.T) {
|
||||
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_reference","tool_name":"proxy_beta"},"index":0}`)
|
||||
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||
|
||||
payload := bytes.TrimSpace(out)
|
||||
if bytes.HasPrefix(payload, []byte("data:")) {
|
||||
payload = bytes.TrimSpace(payload[len("data:"):])
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "content_block.tool_name").String(); got != "beta" {
|
||||
t.Fatalf("content_block.tool_name = %q, want %q", got, "beta")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"mcp__nia__manage_resource"}]}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
|
||||
if got != "proxy_mcp__nia__manage_resource" {
|
||||
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "proxy_mcp__nia__manage_resource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
var userIDs []string
|
||||
var requestModels []string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
userID := gjson.GetBytes(body, "metadata.user_id").String()
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
userIDs = append(userIDs, userID)
|
||||
requestModels = append(requestModels, model)
|
||||
t.Logf("HTTP Server received request: model=%s, user_id=%s, url=%s", model, userID, r.URL.String())
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
t.Logf("End-to-end test: Fake HTTP server started at %s", server.URL)
|
||||
|
||||
cacheEnabled := true
|
||||
executor := NewClaudeExecutor(&config.Config{
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{
|
||||
APIKey: "key-123",
|
||||
BaseURL: server.URL,
|
||||
Cloak: &config.CloakConfig{
|
||||
CacheUserID: &cacheEnabled,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
models := []string{"claude-3-5-sonnet", "claude-3-5-haiku"}
|
||||
for _, model := range models {
|
||||
t.Logf("Sending request for model: %s", model)
|
||||
modelPayload, _ := sjson.SetBytes(payload, "model", model)
|
||||
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: model,
|
||||
Payload: modelPayload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
}); err != nil {
|
||||
t.Fatalf("Execute(%s) error: %v", model, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(userIDs) != 2 {
|
||||
t.Fatalf("expected 2 requests, got %d", len(userIDs))
|
||||
}
|
||||
if userIDs[0] == "" || userIDs[1] == "" {
|
||||
t.Fatal("expected user_id to be populated")
|
||||
}
|
||||
t.Logf("user_id[0] (model=%s): %s", requestModels[0], userIDs[0])
|
||||
t.Logf("user_id[1] (model=%s): %s", requestModels[1], userIDs[1])
|
||||
if userIDs[0] != userIDs[1] {
|
||||
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
|
||||
}
|
||||
if !isValidUserID(userIDs[0]) {
|
||||
t.Fatalf("user_id %q is not valid", userIDs[0])
|
||||
}
|
||||
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
var userIDs []string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
userIDs = append(userIDs, gjson.GetBytes(body, "metadata.user_id").String())
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
}); err != nil {
|
||||
t.Fatalf("Execute call %d error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(userIDs) != 2 {
|
||||
t.Fatalf("expected 2 requests, got %d", len(userIDs))
|
||||
}
|
||||
if userIDs[0] == "" || userIDs[1] == "" {
|
||||
t.Fatal("expected user_id to be populated")
|
||||
}
|
||||
if userIDs[0] == userIDs[1] {
|
||||
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
|
||||
}
|
||||
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
|
||||
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "content.0.content.0.tool_name").String()
|
||||
if got != "mcp__nia__manage_resource" {
|
||||
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "mcp__nia__manage_resource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T) {
|
||||
// tool_result.content can be a string - should not be processed
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"plain string result"}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content").String()
|
||||
if got != "plain string result" {
|
||||
t.Fatalf("string content should remain unchanged = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
|
||||
if got != "web_search" {
|
||||
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,8 +28,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
codexClientVersion = "0.98.0"
|
||||
codexUserAgent = "codex_cli_rs/0.98.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
codexClientVersion = "0.101.0"
|
||||
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
)
|
||||
|
||||
var dataTag = []byte("data:")
|
||||
@@ -156,7 +156,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
|
||||
@@ -260,7 +260,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
@@ -273,11 +273,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.ensurePublished(ctx)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
|
||||
}
|
||||
@@ -358,11 +358,10 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
err = newCodexStatusErr(httpResp.StatusCode, data)
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -397,7 +396,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
@@ -643,7 +642,6 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
}
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Openai-Beta", "responses=experimental")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
|
||||
|
||||
@@ -675,6 +673,35 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
}
|
||||
|
||||
func newCodexStatusErr(statusCode int, body []byte) statusErr {
|
||||
err := statusErr{code: statusCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
|
||||
err.retryAfter = retryAfter
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
|
||||
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(errorBody, "error.type").String()) != "usage_limit_reached" {
|
||||
return nil
|
||||
}
|
||||
if resetsAt := gjson.GetBytes(errorBody, "error.resets_at").Int(); resetsAt > 0 {
|
||||
resetAtTime := time.Unix(resetsAt, 0)
|
||||
if resetAtTime.After(now) {
|
||||
retryAfter := resetAtTime.Sub(now)
|
||||
return &retryAfter
|
||||
}
|
||||
}
|
||||
if resetsInSeconds := gjson.GetBytes(errorBody, "error.resets_in_seconds").Int(); resetsInSeconds > 0 {
|
||||
retryAfter := time.Duration(resetsInSeconds) * time.Second
|
||||
return &retryAfter
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
if a == nil {
|
||||
return "", ""
|
||||
|
||||
65
internal/runtime/executor/codex_executor_retry_test.go
Normal file
65
internal/runtime/executor/codex_executor_retry_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseCodexRetryAfter(t *testing.T) {
|
||||
now := time.Unix(1_700_000_000, 0)
|
||||
|
||||
t.Run("resets_in_seconds", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":123}}`)
|
||||
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||
if retryAfter == nil {
|
||||
t.Fatalf("expected retryAfter, got nil")
|
||||
}
|
||||
if *retryAfter != 123*time.Second {
|
||||
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 123*time.Second)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prefers resets_at", func(t *testing.T) {
|
||||
resetAt := now.Add(5 * time.Minute).Unix()
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":1}}`)
|
||||
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||
if retryAfter == nil {
|
||||
t.Fatalf("expected retryAfter, got nil")
|
||||
}
|
||||
if *retryAfter != 5*time.Minute {
|
||||
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 5*time.Minute)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fallback when resets_at is past", func(t *testing.T) {
|
||||
resetAt := now.Add(-1 * time.Minute).Unix()
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":77}}`)
|
||||
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||
if retryAfter == nil {
|
||||
t.Fatalf("expected retryAfter, got nil")
|
||||
}
|
||||
if *retryAfter != 77*time.Second {
|
||||
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 77*time.Second)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-429 status code", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":30}}`)
|
||||
if got := parseCodexRetryAfter(http.StatusBadRequest, body, now); got != nil {
|
||||
t.Fatalf("expected nil for non-429, got %v", *got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non usage_limit_reached error type", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"type":"server_error","resets_in_seconds":30}}`)
|
||||
if got := parseCodexRetryAfter(http.StatusTooManyRequests, body, now); got != nil {
|
||||
t.Fatalf("expected nil for non-usage_limit_reached, got %v", *got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func itoa(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
1408
internal/runtime/executor/codex_websockets_executor.go
Normal file
1408
internal/runtime/executor/codex_websockets_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -225,7 +225,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -256,7 +256,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Gemini CLI API.
|
||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -382,7 +382,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response, reqBody []byte, attemptModel string) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -441,7 +440,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
||||
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
if len(lastBody) > 0 {
|
||||
@@ -546,7 +545,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
lastStatus = resp.StatusCode
|
||||
lastBody = append([]byte(nil), data...)
|
||||
@@ -899,8 +898,7 @@ func parseRetryDelay(errorBody []byte) (*time.Duration, error) {
|
||||
if matches := re.FindStringSubmatch(message); len(matches) > 1 {
|
||||
seconds, err := strconv.Atoi(matches[1])
|
||||
if err == nil {
|
||||
duration := time.Duration(seconds) * time.Second
|
||||
return &duration, nil
|
||||
return new(time.Duration(seconds) * time.Second), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,12 +205,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Gemini API.
|
||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -298,7 +298,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -335,7 +334,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request using the Gemini API.
|
||||
@@ -416,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
||||
|
||||
@@ -253,7 +253,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Vertex AI API.
|
||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
to := sdktranslator.FromString("gemini")
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -524,12 +524,12 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -618,7 +618,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -650,11 +649,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -743,7 +742,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -775,7 +773,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// countTokensWithServiceAccount counts tokens using service account credentials.
|
||||
@@ -859,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||
@@ -943,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
333
internal/runtime/executor/github_copilot_executor_test.go
Normal file
333
internal/runtime/executor/github_copilot_executor_test.go
Normal file
@@ -0,0 +1,333 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "suffix stripped",
|
||||
model: "claude-opus-4.6(medium)",
|
||||
wantModel: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "no suffix unchanged",
|
||||
model: "claude-opus-4.6",
|
||||
wantModel: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "different suffix stripped",
|
||||
model: "gpt-4o(high)",
|
||||
wantModel: "gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "numeric suffix stripped",
|
||||
model: "gemini-2.5-pro(8192)",
|
||||
wantModel: "gemini-2.5-pro",
|
||||
},
|
||||
}
|
||||
|
||||
e := &GitHubCopilotExecutor{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"` + tt.model + `","messages":[]}`)
|
||||
got := e.normalizeModel(tt.model, body)
|
||||
|
||||
gotModel := gjson.GetBytes(got, "model").String()
|
||||
if gotModel != tt.wantModel {
|
||||
t.Fatalf("normalizeModel() model = %q, want %q", gotModel, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") {
|
||||
t.Fatal("expected openai-response source to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") {
|
||||
t.Fatal("expected codex model to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
|
||||
t.Fatal("expected default openai source with non-codex model to use /chat/completions")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`)
|
||||
got := normalizeGitHubCopilotChatTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("tools len = %d, want 1", len(tools))
|
||||
}
|
||||
if tools[0].Get("type").String() != "function" {
|
||||
t.Fatalf("tool type = %q, want function", tools[0].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`)
|
||||
got := normalizeGitHubCopilotChatTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
|
||||
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
in := gjson.GetBytes(got, "input")
|
||||
if !in.IsArray() {
|
||||
t.Fatalf("input type = %v, want array", in.Type)
|
||||
}
|
||||
raw := in.Raw
|
||||
if !strings.Contains(raw, "sys text") || !strings.Contains(raw, "user text") || !strings.Contains(raw, "assistant text") {
|
||||
t.Fatalf("input = %s, want structured array with all texts", raw)
|
||||
}
|
||||
if gjson.GetBytes(got, "messages").Exists() {
|
||||
t.Fatal("messages should be removed after conversion")
|
||||
}
|
||||
if gjson.GetBytes(got, "system").Exists() {
|
||||
t.Fatal("system should be removed after conversion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":{"foo":"bar"}}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
in := gjson.GetBytes(got, "input")
|
||||
if in.Type != gjson.String {
|
||||
t.Fatalf("input type = %v, want string", in.Type)
|
||||
}
|
||||
if !strings.Contains(in.String(), "foo") {
|
||||
t.Fatalf("input = %q, want stringified object", in.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("tools len = %d, want 1", len(tools))
|
||||
}
|
||||
if tools[0].Get("name").String() != "sum" {
|
||||
t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String())
|
||||
}
|
||||
if !tools[0].Get("parameters").Exists() {
|
||||
t.Fatal("expected parameters to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("tools len = %d, want 2", len(tools))
|
||||
}
|
||||
if tools[0].Get("type").String() != "function" {
|
||||
t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String())
|
||||
}
|
||||
if tools[0].Get("name").String() != "Bash" {
|
||||
t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String())
|
||||
}
|
||||
if tools[0].Get("description").String() != "Run commands" {
|
||||
t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String())
|
||||
}
|
||||
if !tools[0].Get("parameters").Exists() {
|
||||
t.Fatal("expected parameters to be set from input_schema")
|
||||
}
|
||||
if tools[0].Get("parameters.properties.command").Exists() != true {
|
||||
t.Fatal("expected parameters.properties.command to exist")
|
||||
}
|
||||
if tools[1].Get("name").String() != "Read" {
|
||||
t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice.type").String() != "function" {
|
||||
t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String())
|
||||
}
|
||||
if gjson.GetBytes(got, "tool_choice.name").String() != "sum" {
|
||||
t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tool_choice":{"type":"function"}}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
|
||||
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.Get(out, "type").String() != "message" {
|
||||
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.type").String() != "text" {
|
||||
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.text").String() != "hello" {
|
||||
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.Get(out, "content.0.type").String() != "tool_use" {
|
||||
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.name").String() != "sum" {
|
||||
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
|
||||
}
|
||||
if gjson.Get(out, "stop_reason").String() != "tool_use" {
|
||||
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
var param any
|
||||
|
||||
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), ¶m)
|
||||
if len(created) == 0 || !strings.Contains(created[0], "message_start") {
|
||||
t.Fatalf("created events = %#v, want message_start", created)
|
||||
}
|
||||
|
||||
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), ¶m)
|
||||
joinedDelta := strings.Join(delta, "")
|
||||
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
|
||||
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
|
||||
}
|
||||
|
||||
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), ¶m)
|
||||
joinedCompleted := strings.Join(completed, "")
|
||||
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
|
||||
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for X-Initiator detection logic (Problem L) ---
|
||||
|
||||
func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"messages":[{"role":"system","content":"sys"},{"role":"user","content":"hello"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_AgentWithAssistantAndUserToolResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// Claude Code typical flow: last message is user (tool result), but has assistant in history
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (assistant exists in messages)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (tool role exists)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for x-github-api-version header (Problem M) ---
|
||||
|
||||
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
e.applyHeaders(req, "token", nil)
|
||||
if got := req.Header.Get("X-Github-Api-Version"); got != "2025-04-01" {
|
||||
t.Fatalf("X-Github-Api-Version = %q, want 2025-04-01", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for vision detection (Problem P) ---
|
||||
|
||||
func TestDetectVisionContent_WithImageURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
|
||||
if !detectVisionContent(body) {
|
||||
t.Fatal("expected vision content to be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectVisionContent_WithImageType(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"image","source":{"data":"abc","media_type":"image/png"}}]}]}`)
|
||||
if !detectVisionContent(body) {
|
||||
t.Fatal("expected image type to be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectVisionContent_NoVision(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||
if detectVisionContent(body) {
|
||||
t.Fatal("expected no vision content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectVisionContent_NoMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
// After Responses API normalization, messages is removed — detection should return false
|
||||
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
|
||||
if detectVisionContent(body) {
|
||||
t.Fatal("expected no vision content when messages field is absent")
|
||||
}
|
||||
}
|
||||
@@ -4,12 +4,16 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
@@ -165,12 +169,12 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming chat completion request.
|
||||
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -258,7 +262,6 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -290,7 +293,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
@@ -453,6 +456,20 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
r.Header.Set("User-Agent", iflowUserAgent)
|
||||
|
||||
// Generate session-id
|
||||
sessionID := "session-" + generateUUID()
|
||||
r.Header.Set("session-id", sessionID)
|
||||
|
||||
// Generate timestamp and signature
|
||||
timestamp := time.Now().UnixMilli()
|
||||
r.Header.Set("x-iflow-timestamp", fmt.Sprintf("%d", timestamp))
|
||||
|
||||
signature := createIFlowSignature(iflowUserAgent, sessionID, timestamp, apiKey)
|
||||
if signature != "" {
|
||||
r.Header.Set("x-iflow-signature", signature)
|
||||
}
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
@@ -460,6 +477,23 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// createIFlowSignature generates HMAC-SHA256 signature for iFlow API requests.
|
||||
// The signature payload format is: userAgent:sessionId:timestamp
|
||||
func createIFlowSignature(userAgent, sessionID string, timestamp int64, apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return ""
|
||||
}
|
||||
payload := fmt.Sprintf("%s:%s:%d", userAgent, sessionID, timestamp)
|
||||
h := hmac.New(sha256.New, []byte(apiKey))
|
||||
h.Write([]byte(payload))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// generateUUID generates a random UUID v4 string.
|
||||
func generateUUID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
func iflowCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
if a == nil {
|
||||
return "", ""
|
||||
|
||||
460
internal/runtime/executor/kilo_executor.go
Normal file
460
internal/runtime/executor/kilo_executor.go
Normal file
@@ -0,0 +1,460 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// KiloExecutor handles requests to Kilo API.
|
||||
type KiloExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewKiloExecutor creates a new Kilo executor instance.
|
||||
func NewKiloExecutor(cfg *config.Config) *KiloExecutor {
|
||||
return &KiloExecutor{cfg: cfg}
|
||||
}
|
||||
|
||||
// Identifier returns the unique identifier for this executor.
|
||||
func (e *KiloExecutor) Identifier() string { return "kilo" }
|
||||
|
||||
// PrepareRequest prepares the HTTP request before execution.
|
||||
func (e *KiloExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
accessToken, _ := kiloCredentials(auth)
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return fmt.Errorf("kilo: missing access token")
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest executes a raw HTTP request.
|
||||
func (e *KiloExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("kilo executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming request.
|
||||
func (e *KiloExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, orgID := kiloCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return resp, fmt.Errorf("kilo: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
endpoint := "/api/openrouter/chat/completions"
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
url := "https://api.kilo.ai" + endpoint
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if orgID != "" {
|
||||
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, body)
|
||||
reporter.publish(ctx, parseOpenAIUsage(body))
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request.
|
||||
func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, orgID := kiloCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("kilo: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
endpoint := "/api/openrouter/chat/completions"
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := "https://api.kilo.ai" + endpoint
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if orgID != "" {
|
||||
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
httpReq.Header.Set("Cache-Control", "no-cache")
|
||||
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
httpResp.Body.Close()
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 52_428_800)
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
|
||||
return &cliproxyexecutor.StreamResult{
|
||||
Headers: httpResp.Header.Clone(),
|
||||
Chunks: out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refresh validates the Kilo token.
|
||||
func (e *KiloExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("missing auth")
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// CountTokens returns the token count for the given request.
|
||||
func (e *KiloExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("kilo: count tokens not supported")
|
||||
}
|
||||
|
||||
// kiloCredentials extracts access token and other info from auth.
|
||||
func kiloCredentials(auth *cliproxyauth.Auth) (accessToken, orgID string) {
|
||||
if auth == nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Prefer kilocode specific keys, then fall back to generic keys.
|
||||
// Check metadata first, then attributes.
|
||||
if auth.Metadata != nil {
|
||||
if token, ok := auth.Metadata["kilocodeToken"].(string); ok && token != "" {
|
||||
accessToken = token
|
||||
} else if token, ok := auth.Metadata["access_token"].(string); ok && token != "" {
|
||||
accessToken = token
|
||||
}
|
||||
|
||||
if org, ok := auth.Metadata["kilocodeOrganizationId"].(string); ok && org != "" {
|
||||
orgID = org
|
||||
} else if org, ok := auth.Metadata["organization_id"].(string); ok && org != "" {
|
||||
orgID = org
|
||||
}
|
||||
}
|
||||
|
||||
if accessToken == "" && auth.Attributes != nil {
|
||||
if token := auth.Attributes["kilocodeToken"]; token != "" {
|
||||
accessToken = token
|
||||
} else if token := auth.Attributes["access_token"]; token != "" {
|
||||
accessToken = token
|
||||
}
|
||||
}
|
||||
|
||||
if orgID == "" && auth.Attributes != nil {
|
||||
if org := auth.Attributes["kilocodeOrganizationId"]; org != "" {
|
||||
orgID = org
|
||||
} else if org := auth.Attributes["organization_id"]; org != "" {
|
||||
orgID = org
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, orgID
|
||||
}
|
||||
|
||||
// FetchKiloModels fetches models from Kilo API.
|
||||
func FetchKiloModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
||||
accessToken, orgID := kiloCredentials(auth)
|
||||
if accessToken == "" {
|
||||
log.Infof("kilo: no access token found, skipping dynamic model fetch (using static kilo/auto)")
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
log.Debugf("kilo: fetching dynamic models (orgID: %s)", orgID)
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.kilo.ai/api/openrouter/models", nil)
|
||||
if err != nil {
|
||||
log.Warnf("kilo: failed to create model fetch request: %v", err)
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if orgID != "" {
|
||||
req.Header.Set("X-Kilocode-OrganizationID", orgID)
|
||||
}
|
||||
req.Header.Set("User-Agent", "cli-proxy-kilo")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Warnf("kilo: fetch models canceled: %v", err)
|
||||
} else {
|
||||
log.Warnf("kilo: using static models (API fetch failed: %v)", err)
|
||||
}
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Warnf("kilo: failed to read models response: %v", err)
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("kilo: fetch models failed: status %d, body: %s", resp.StatusCode, string(body))
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(body, "data")
|
||||
if !result.Exists() {
|
||||
// Try root if data field is missing
|
||||
result = gjson.ParseBytes(body)
|
||||
if !result.IsArray() {
|
||||
log.Debugf("kilo: response body: %s", string(body))
|
||||
log.Warn("kilo: invalid API response format (expected array or data field with array)")
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
}
|
||||
|
||||
var dynamicModels []*registry.ModelInfo
|
||||
now := time.Now().Unix()
|
||||
count := 0
|
||||
totalCount := 0
|
||||
|
||||
result.ForEach(func(key, value gjson.Result) bool {
|
||||
totalCount++
|
||||
id := value.Get("id").String()
|
||||
pIdxResult := value.Get("preferredIndex")
|
||||
preferredIndex := pIdxResult.Int()
|
||||
|
||||
// Filter models where preferredIndex > 0 (Kilo-curated models)
|
||||
if preferredIndex <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if it's free. We look for :free suffix, is_free flag, or zero pricing.
|
||||
isFree := strings.HasSuffix(id, ":free") || id == "giga-potato" || value.Get("is_free").Bool()
|
||||
if !isFree {
|
||||
// Check pricing as fallback
|
||||
promptPricing := value.Get("pricing.prompt").String()
|
||||
if promptPricing == "0" || promptPricing == "0.0" {
|
||||
isFree = true
|
||||
}
|
||||
}
|
||||
|
||||
if !isFree {
|
||||
log.Debugf("kilo: skipping curated paid model: %s", id)
|
||||
return true
|
||||
}
|
||||
|
||||
log.Debugf("kilo: found curated model: %s (preferredIndex: %d)", id, preferredIndex)
|
||||
|
||||
dynamicModels = append(dynamicModels, ®istry.ModelInfo{
|
||||
ID: id,
|
||||
DisplayName: value.Get("name").String(),
|
||||
ContextLength: int(value.Get("context_length").Int()),
|
||||
OwnedBy: "kilo",
|
||||
Type: "kilo",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
})
|
||||
count++
|
||||
return true
|
||||
})
|
||||
|
||||
log.Infof("kilo: fetched %d models from API, %d curated free (preferredIndex > 0)", totalCount, count)
|
||||
if count == 0 && totalCount > 0 {
|
||||
log.Warn("kilo: no curated free models found (check API response fields)")
|
||||
}
|
||||
|
||||
staticModels := registry.GetKiloModels()
|
||||
// Always include kilo/auto (first static model)
|
||||
allModels := append(staticModels[:1], dynamicModels...)
|
||||
|
||||
return allModels
|
||||
}
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
@@ -101,6 +102,10 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, err = normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
@@ -156,12 +161,12 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming chat completion request to Kimi.
|
||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
from := opts.SourceFormat
|
||||
if from.String() == "claude" {
|
||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||
@@ -201,6 +206,10 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, err = normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
@@ -244,7 +253,6 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -276,7 +284,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens estimates token count for Kimi requests.
|
||||
@@ -285,6 +293,150 @@ func (e *KimiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
return e.ClaudeExecutor.CountTokens(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
out := body
|
||||
pending := make([]string, 0)
|
||||
patched := 0
|
||||
patchedReasoning := 0
|
||||
ambiguous := 0
|
||||
latestReasoning := ""
|
||||
hasLatestReasoning := false
|
||||
|
||||
removePending := func(id string) {
|
||||
for idx := range pending {
|
||||
if pending[idx] != id {
|
||||
continue
|
||||
}
|
||||
pending = append(pending[:idx], pending[idx+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
msgs := messages.Array()
|
||||
for msgIdx := range msgs {
|
||||
msg := msgs[msgIdx]
|
||||
role := strings.TrimSpace(msg.Get("role").String())
|
||||
switch role {
|
||||
case "assistant":
|
||||
reasoning := msg.Get("reasoning_content")
|
||||
if reasoning.Exists() {
|
||||
reasoningText := reasoning.String()
|
||||
if strings.TrimSpace(reasoningText) != "" {
|
||||
latestReasoning = reasoningText
|
||||
hasLatestReasoning = true
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := msg.Get("tool_calls")
|
||||
if !toolCalls.Exists() || !toolCalls.IsArray() || len(toolCalls.Array()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if !reasoning.Exists() || strings.TrimSpace(reasoning.String()) == "" {
|
||||
reasoningText := fallbackAssistantReasoning(msg, hasLatestReasoning, latestReasoning)
|
||||
path := fmt.Sprintf("messages.%d.reasoning_content", msgIdx)
|
||||
next, err := sjson.SetBytes(out, path, reasoningText)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi executor: failed to set assistant reasoning_content: %w", err)
|
||||
}
|
||||
out = next
|
||||
patchedReasoning++
|
||||
}
|
||||
|
||||
for _, tc := range toolCalls.Array() {
|
||||
id := strings.TrimSpace(tc.Get("id").String())
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
pending = append(pending, id)
|
||||
}
|
||||
case "tool":
|
||||
toolCallID := strings.TrimSpace(msg.Get("tool_call_id").String())
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(msg.Get("call_id").String())
|
||||
if toolCallID != "" {
|
||||
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
|
||||
next, err := sjson.SetBytes(out, path, toolCallID)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi executor: failed to set tool_call_id from call_id: %w", err)
|
||||
}
|
||||
out = next
|
||||
patched++
|
||||
}
|
||||
}
|
||||
if toolCallID == "" {
|
||||
if len(pending) == 1 {
|
||||
toolCallID = pending[0]
|
||||
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
|
||||
next, err := sjson.SetBytes(out, path, toolCallID)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi executor: failed to infer tool_call_id: %w", err)
|
||||
}
|
||||
out = next
|
||||
patched++
|
||||
} else if len(pending) > 1 {
|
||||
ambiguous++
|
||||
}
|
||||
}
|
||||
if toolCallID != "" {
|
||||
removePending(toolCallID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if patched > 0 || patchedReasoning > 0 {
|
||||
log.WithFields(log.Fields{
|
||||
"patched_tool_messages": patched,
|
||||
"patched_reasoning_messages": patchedReasoning,
|
||||
}).Debug("kimi executor: normalized tool message fields")
|
||||
}
|
||||
if ambiguous > 0 {
|
||||
log.WithFields(log.Fields{
|
||||
"ambiguous_tool_messages": ambiguous,
|
||||
"pending_tool_calls": len(pending),
|
||||
}).Warn("kimi executor: tool messages missing tool_call_id with ambiguous candidates")
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func fallbackAssistantReasoning(msg gjson.Result, hasLatest bool, latest string) string {
|
||||
if hasLatest && strings.TrimSpace(latest) != "" {
|
||||
return latest
|
||||
}
|
||||
|
||||
content := msg.Get("content")
|
||||
if content.Type == gjson.String {
|
||||
if text := strings.TrimSpace(content.String()); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
if content.IsArray() {
|
||||
parts := make([]string, 0, len(content.Array()))
|
||||
for _, item := range content.Array() {
|
||||
text := strings.TrimSpace(item.Get("text").String())
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, text)
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
return "[reasoning unavailable]"
|
||||
}
|
||||
|
||||
// Refresh refreshes the Kimi token using the refresh token.
|
||||
func (e *KimiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
log.Debugf("kimi executor: refresh called")
|
||||
|
||||
205
internal/runtime/executor/kimi_executor_test.go
Normal file
205
internal/runtime/executor/kimi_executor_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_UsesCallIDFallback(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"list_directory:1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
|
||||
{"role":"tool","call_id":"list_directory:1","content":"[]"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||
if got != "list_directory:1" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "list_directory:1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_InferSinglePendingID(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_123","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
|
||||
{"role":"tool","content":"file-content"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||
if got != "call_123" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_AmbiguousMissingIDIsNotInferred(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[
|
||||
{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}},
|
||||
{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}
|
||||
]},
|
||||
{"role":"tool","content":"result-without-id"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
if gjson.GetBytes(out, "messages.1.tool_call_id").Exists() {
|
||||
t.Fatalf("messages.1.tool_call_id should be absent for ambiguous case, got %q", gjson.GetBytes(out, "messages.1.tool_call_id").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_PreservesExistingToolCallID(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
|
||||
{"role":"tool","tool_call_id":"call_1","call_id":"different-id","content":"result"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||
if got != "call_1" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_InheritsPreviousReasoningForAssistantToolCalls(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":"plan","reasoning_content":"previous reasoning"},
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.reasoning_content").String()
|
||||
if got != "previous reasoning" {
|
||||
t.Fatalf("messages.1.reasoning_content = %q, want %q", got, "previous reasoning")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_InsertsFallbackReasoningWhenMissing(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
reasoning := gjson.GetBytes(out, "messages.0.reasoning_content")
|
||||
if !reasoning.Exists() {
|
||||
t.Fatalf("messages.0.reasoning_content should exist")
|
||||
}
|
||||
if reasoning.String() != "[reasoning unavailable]" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", reasoning.String(), "[reasoning unavailable]")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_UsesContentAsReasoningFallback(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":[{"type":"text","text":"first line"},{"type":"text","text":"second line"}],"tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||
if got != "first line\nsecond line" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "first line\nsecond line")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_ReplacesEmptyReasoningContent(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":"assistant summary","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":""}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||
if got != "assistant summary" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "assistant summary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_PreservesExistingAssistantReasoning(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"keep me"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||
if got != "keep me" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "keep me")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_RepairsIDsAndReasoningTogether(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"r1"},
|
||||
{"role":"tool","call_id":"call_1","content":"[]"},
|
||||
{"role":"assistant","tool_calls":[{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
|
||||
{"role":"tool","call_id":"call_2","content":"file"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.1.tool_call_id").String(); got != "call_1" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.3.tool_call_id").String(); got != "call_2" {
|
||||
t.Fatalf("messages.3.tool_call_id = %q, want %q", got, "call_2")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.2.reasoning_content").String(); got != "r1" {
|
||||
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "r1")
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -385,6 +386,35 @@ func buildKiroEndpointConfigs(region string) []kiroEndpointConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// resolveKiroAPIRegion determines the AWS region for Kiro API calls.
|
||||
// Region priority:
|
||||
// 1. auth.Metadata["api_region"] - explicit API region override
|
||||
// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource
|
||||
// 3. kiroDefaultRegion (us-east-1) - fallback
|
||||
// Note: OIDC "region" is NOT used - it's for token refresh, not API calls
|
||||
func resolveKiroAPIRegion(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return kiroDefaultRegion
|
||||
}
|
||||
// Priority 1: Explicit api_region override
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
log.Debugf("kiro: using region %s (source: api_region)", r)
|
||||
return r
|
||||
}
|
||||
// Priority 2: Extract from ProfileARN
|
||||
if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||
if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" {
|
||||
log.Debugf("kiro: using region %s (source: profile_arn)", arnRegion)
|
||||
return arnRegion
|
||||
}
|
||||
}
|
||||
// Note: OIDC "region" field is NOT used for API endpoint
|
||||
// Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2)
|
||||
// Using OIDC region for API calls causes DNS failures
|
||||
log.Debugf("kiro: using region %s (source: default)", kiroDefaultRegion)
|
||||
return kiroDefaultRegion
|
||||
}
|
||||
|
||||
// kiroEndpointConfigs is kept for backward compatibility with default us-east-1 region.
|
||||
// Prefer using buildKiroEndpointConfigs(region) for dynamic region support.
|
||||
var kiroEndpointConfigs = buildKiroEndpointConfigs(kiroDefaultRegion)
|
||||
@@ -403,30 +433,8 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
|
||||
return kiroEndpointConfigs
|
||||
}
|
||||
|
||||
// Determine API region with priority: api_region > profile_arn > region > default
|
||||
region := kiroDefaultRegion
|
||||
regionSource := "default"
|
||||
|
||||
if auth.Metadata != nil {
|
||||
// Priority 1: Explicit api_region override
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
region = r
|
||||
regionSource = "api_region"
|
||||
} else {
|
||||
// Priority 2: Extract from ProfileARN
|
||||
if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||
if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" {
|
||||
region = arnRegion
|
||||
regionSource = "profile_arn"
|
||||
}
|
||||
}
|
||||
// Note: OIDC "region" field is NOT used for API endpoint
|
||||
// Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2)
|
||||
// Using OIDC region for API calls causes DNS failures
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro: using region %s (source: %s)", region, regionSource)
|
||||
// Determine API region using shared resolution logic
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
|
||||
// Build endpoint configs for the specified region
|
||||
endpointConfigs := buildKiroEndpointConfigs(region)
|
||||
@@ -519,8 +527,12 @@ func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string,
|
||||
case "openai":
|
||||
log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String())
|
||||
return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
|
||||
case "kiro":
|
||||
// Body is already in Kiro format — pass through directly
|
||||
log.Debugf("kiro: body already in Kiro format, passing through directly")
|
||||
return body, false
|
||||
default:
|
||||
// Default to Claude format (also handles "claude", "kiro", etc.)
|
||||
// Default to Claude format
|
||||
log.Debugf("kiro: using Claude payload builder for source format: %s", sourceFormat.String())
|
||||
return kiroclaude.BuildKiroPayload(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
|
||||
}
|
||||
@@ -636,10 +648,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
// Check if token is expired before making request
|
||||
// Check if token is expired before making request (covers both normal and web_search paths)
|
||||
if e.isTokenExpired(accessToken) {
|
||||
log.Infof("kiro: access token expired, attempting recovery")
|
||||
|
||||
@@ -668,6 +677,16 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
}
|
||||
}
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint")
|
||||
return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -1014,8 +1033,9 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
|
||||
// Build response in Claude format for Kiro translator
|
||||
// stopReason is extracted from upstream response by parseEventStream
|
||||
kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, requestedModel, usageInfo, stopReason)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, requestedModel, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -1033,7 +1053,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
|
||||
// ExecuteStream handles streaming requests to Kiro API.
|
||||
// Supports automatic token refresh on 401/403 errors and quota fallback on 429.
|
||||
func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
accessToken, profileArn := kiroCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("kiro: access token not found in auth")
|
||||
@@ -1057,10 +1077,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
// Check if token is expired before making request
|
||||
// Check if token is expired before making request (covers both normal and web_search paths)
|
||||
if e.isTokenExpired(accessToken) {
|
||||
log.Infof("kiro: access token expired, attempting recovery before stream request")
|
||||
|
||||
@@ -1089,6 +1106,20 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
}
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
|
||||
streamWebSearch, errWebSearch := e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
if errWebSearch != nil {
|
||||
return nil, errWebSearch
|
||||
}
|
||||
return &cliproxyexecutor.StreamResult{Chunks: streamWebSearch}, nil
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -1101,7 +1132,11 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
// Execute stream with retry on 401/403 and 429 (quota exhausted)
|
||||
// Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint
|
||||
return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
||||
streamKiro, errStreamKiro := e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
||||
if errStreamKiro != nil {
|
||||
return nil, errStreamKiro
|
||||
}
|
||||
return &cliproxyexecutor.StreamResult{Chunks: streamKiro}, nil
|
||||
}
|
||||
|
||||
// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors.
|
||||
@@ -1405,7 +1440,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
// So we always enable thinking parsing for Kiro responses
|
||||
log.Debugf("kiro: stream thinkingEnabled = %v (always true for Kiro)", thinkingEnabled)
|
||||
|
||||
e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter, thinkingEnabled)
|
||||
e.streamToChannel(ctx, resp.Body, out, from, payloadRequestedModel(opts, req.Model), opts.OriginalRequest, body, reporter, thinkingEnabled)
|
||||
}(httpResp, thinkingEnabled)
|
||||
|
||||
return out, nil
|
||||
@@ -1682,6 +1717,7 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
// Amazon Q format (amazonq- prefix) - same API as Kiro
|
||||
"amazonq-auto": "auto",
|
||||
"amazonq-claude-opus-4-6": "claude-opus-4.6",
|
||||
"amazonq-claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"amazonq-claude-opus-4-5": "claude-opus-4.5",
|
||||
"amazonq-claude-sonnet-4-5": "claude-sonnet-4.5",
|
||||
"amazonq-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||
@@ -1690,6 +1726,7 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
"amazonq-claude-haiku-4-5": "claude-haiku-4.5",
|
||||
// Kiro format (kiro- prefix) - valid model names that should be preserved
|
||||
"kiro-claude-opus-4-6": "claude-opus-4.6",
|
||||
"kiro-claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"kiro-claude-opus-4-5": "claude-opus-4.5",
|
||||
"kiro-claude-sonnet-4-5": "claude-sonnet-4.5",
|
||||
"kiro-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||
@@ -1700,6 +1737,8 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
// Native format (no prefix) - used by Kiro IDE directly
|
||||
"claude-opus-4-6": "claude-opus-4.6",
|
||||
"claude-opus-4.6": "claude-opus-4.6",
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"claude-sonnet-4.6": "claude-sonnet-4.6",
|
||||
"claude-opus-4-5": "claude-opus-4.5",
|
||||
"claude-opus-4.5": "claude-opus-4.5",
|
||||
"claude-haiku-4-5": "claude-haiku-4.5",
|
||||
@@ -1712,11 +1751,13 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
"auto": "auto",
|
||||
// Agentic variants (same backend model IDs, but with special system prompt)
|
||||
"claude-opus-4.6-agentic": "claude-opus-4.6",
|
||||
"claude-sonnet-4.6-agentic": "claude-sonnet-4.6",
|
||||
"claude-opus-4.5-agentic": "claude-opus-4.5",
|
||||
"claude-sonnet-4.5-agentic": "claude-sonnet-4.5",
|
||||
"claude-sonnet-4-agentic": "claude-sonnet-4",
|
||||
"claude-haiku-4.5-agentic": "claude-haiku-4.5",
|
||||
"kiro-claude-opus-4-6-agentic": "claude-opus-4.6",
|
||||
"kiro-claude-sonnet-4-6-agentic": "claude-sonnet-4.6",
|
||||
"kiro-claude-opus-4-5-agentic": "claude-opus-4.5",
|
||||
"kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5",
|
||||
"kiro-claude-sonnet-4-agentic": "claude-sonnet-4",
|
||||
@@ -1742,6 +1783,10 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
log.Debugf("kiro: unknown Sonnet 3.7 model '%s', mapping to claude-3-7-sonnet-20250219", model)
|
||||
return "claude-3-7-sonnet-20250219"
|
||||
}
|
||||
if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") {
|
||||
log.Debugf("kiro: unknown Sonnet 4.6 model '%s', mapping to claude-sonnet-4.6", model)
|
||||
return "claude-sonnet-4.6"
|
||||
}
|
||||
if strings.Contains(modelLower, "4-5") || strings.Contains(modelLower, "4.5") {
|
||||
log.Debugf("kiro: unknown Sonnet 4.5 model '%s', mapping to claude-sonnet-4.5", model)
|
||||
return "claude-sonnet-4.5"
|
||||
@@ -1753,6 +1798,10 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
|
||||
// Check for Opus variants
|
||||
if strings.Contains(modelLower, "opus") {
|
||||
if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") {
|
||||
log.Debugf("kiro: unknown Opus 4.6 model '%s', mapping to claude-opus-4.6", model)
|
||||
return "claude-opus-4.6"
|
||||
}
|
||||
log.Debugf("kiro: unknown Opus model '%s', mapping to claude-opus-4.5", model)
|
||||
return "claude-opus-4.5"
|
||||
}
|
||||
@@ -2502,6 +2551,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
isThinkingBlockOpen := false // Track if thinking content block SSE event is open
|
||||
thinkingBlockIndex := -1 // Index of the thinking content block
|
||||
var accumulatedThinkingContent strings.Builder // Accumulate thinking content for token counting
|
||||
hasOfficialReasoningEvent := false // Disable tag parsing after official reasoning events appear
|
||||
|
||||
// Buffer for handling partial tag matches at chunk boundaries
|
||||
var pendingContent strings.Builder // Buffer content that might be part of a tag
|
||||
@@ -2937,6 +2987,31 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
lastUsageUpdateTime = time.Now()
|
||||
}
|
||||
|
||||
if hasOfficialReasoningEvent {
|
||||
processText := strings.TrimSpace(strings.ReplaceAll(strings.ReplaceAll(contentDelta, kirocommon.ThinkingStartTag, ""), kirocommon.ThinkingEndTag, ""))
|
||||
if processText != "" {
|
||||
if !isTextBlockOpen {
|
||||
contentBlockIndex++
|
||||
isTextBlockOpen = true
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
}
|
||||
}
|
||||
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processText, contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// TAG-BASED THINKING PARSING: Parse <thinking> tags from content
|
||||
// Combine pending content with new content for processing
|
||||
pendingContent.WriteString(contentDelta)
|
||||
@@ -3215,6 +3290,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
}
|
||||
|
||||
if thinkingText != "" {
|
||||
hasOfficialReasoningEvent = true
|
||||
// Close text block if open before starting thinking block
|
||||
if isTextBlockOpen && contentBlockIndex >= 0 {
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
@@ -4096,6 +4172,659 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool {
|
||||
return isExpired
|
||||
}
|
||||
|
||||
// NOTE: Message merging functions moved to internal/translator/kiro/common/message_merge.go
|
||||
// NOTE: Tool calling support functions moved to internal/translator/kiro/claude/kiro_claude_tools.go
|
||||
// The executor now uses kiroclaude.* and kirocommon.* functions instead
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// Web Search Handler (MCP API)
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
// fetchToolDescription caching:
|
||||
// Uses a mutex + fetched flag to ensure only one goroutine fetches at a time,
|
||||
// with automatic retry on failure:
|
||||
// - On failure, fetched stays false so subsequent calls will retry
|
||||
// - On success, fetched is set to true — subsequent calls skip immediately (mutex-free fast path)
|
||||
// The cached description is stored in the translator package via kiroclaude.SetWebSearchDescription(),
|
||||
// enabling the translator's convertClaudeToolsToKiro to read it when building Kiro requests.
|
||||
var (
|
||||
toolDescMu sync.Mutex
|
||||
toolDescFetched atomic.Bool
|
||||
)
|
||||
|
||||
// fetchToolDescription calls MCP tools/list to get the web_search tool description
|
||||
// and caches it. Safe to call concurrently — only one goroutine fetches at a time.
|
||||
// If the fetch fails, subsequent calls will retry. On success, no further fetches occur.
|
||||
// The httpClient parameter allows reusing a shared pooled HTTP client.
|
||||
func fetchToolDescription(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) {
|
||||
// Fast path: already fetched successfully, no lock needed
|
||||
if toolDescFetched.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
toolDescMu.Lock()
|
||||
defer toolDescMu.Unlock()
|
||||
|
||||
// Double-check after acquiring lock
|
||||
if toolDescFetched.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, authToken, httpClient, auth, authAttrs)
|
||||
reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`)
|
||||
log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", mcpEndpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to create tools/list request: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Reuse same headers as callMcpAPI
|
||||
handler.setMcpHeaders(req)
|
||||
|
||||
resp, err := handler.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: tools/list request failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil || resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body))
|
||||
|
||||
// Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}}
|
||||
var result struct {
|
||||
Result *struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
} `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
|
||||
log.Warnf("kiro/websearch: failed to parse tools/list response")
|
||||
return
|
||||
}
|
||||
|
||||
for _, tool := range result.Result.Tools {
|
||||
if tool.Name == "web_search" && tool.Description != "" {
|
||||
kiroclaude.SetWebSearchDescription(tool.Description)
|
||||
toolDescFetched.Store(true) // success — no more fetches
|
||||
log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// web_search tool not found in response
|
||||
log.Warnf("kiro/websearch: web_search tool not found in tools/list response")
|
||||
}
|
||||
|
||||
// webSearchHandler handles web search requests via Kiro MCP API
|
||||
type webSearchHandler struct {
|
||||
ctx context.Context
|
||||
mcpEndpoint string
|
||||
httpClient *http.Client
|
||||
authToken string
|
||||
auth *cliproxyauth.Auth // for applyDynamicFingerprint
|
||||
authAttrs map[string]string // optional, for custom headers from auth.Attributes
|
||||
}
|
||||
|
||||
// newWebSearchHandler creates a new webSearchHandler.
|
||||
// If httpClient is nil, a default client with 30s timeout is used.
|
||||
// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse.
|
||||
func newWebSearchHandler(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) *webSearchHandler {
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
return &webSearchHandler{
|
||||
ctx: ctx,
|
||||
mcpEndpoint: mcpEndpoint,
|
||||
httpClient: httpClient,
|
||||
authToken: authToken,
|
||||
auth: auth,
|
||||
authAttrs: authAttrs,
|
||||
}
|
||||
}
|
||||
|
||||
// setMcpHeaders sets standard MCP API headers on the request,
|
||||
// aligned with the GAR request pattern.
|
||||
func (h *webSearchHandler) setMcpHeaders(req *http.Request) {
|
||||
// 1. Content-Type & Accept (aligned with GAR)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
|
||||
// 2. Kiro-specific headers (aligned with GAR)
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
|
||||
// 3. User-Agent: Reuse applyDynamicFingerprint for consistency
|
||||
applyDynamicFingerprint(req, h.auth)
|
||||
|
||||
// 4. AWS SDK identifiers
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
|
||||
// 5. Authentication
|
||||
req.Header.Set("Authorization", "Bearer "+h.authToken)
|
||||
|
||||
// 6. Custom headers from auth attributes
|
||||
util.ApplyCustomHeadersFromAttrs(req, h.authAttrs)
|
||||
}
|
||||
|
||||
// mcpMaxRetries is the maximum number of retries for MCP API calls.
|
||||
const mcpMaxRetries = 2
|
||||
|
||||
// callMcpAPI calls the Kiro MCP API with the given request.
|
||||
// Includes retry logic with exponential backoff for retryable errors.
|
||||
func (h *webSearchHandler) callMcpAPI(request *kiroclaude.McpRequest) (*kiroclaude.McpResponse, error) {
|
||||
requestBody, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal MCP request: %w", err)
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.mcpEndpoint, len(requestBody))
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= mcpMaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(1<<attempt) * time.Second
|
||||
if backoff > 10*time.Second {
|
||||
backoff = 10 * time.Second
|
||||
}
|
||||
log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr)
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
return nil, h.ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(h.ctx, "POST", h.mcpEndpoint, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
h.setMcpHeaders(req)
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("MCP API request failed: %w", err)
|
||||
continue // network error → retry
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read MCP response: %w", err)
|
||||
continue // read error → retry
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body))
|
||||
|
||||
// Retryable HTTP status codes (aligned with GAR: 502, 503, 504)
|
||||
if resp.StatusCode >= 502 && resp.StatusCode <= 504 {
|
||||
lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body))
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var mcpResponse kiroclaude.McpResponse
|
||||
if err := json.Unmarshal(body, &mcpResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse MCP response: %w", err)
|
||||
}
|
||||
|
||||
if mcpResponse.Error != nil {
|
||||
code := -1
|
||||
if mcpResponse.Error.Code != nil {
|
||||
code = *mcpResponse.Error.Code
|
||||
}
|
||||
msg := "Unknown error"
|
||||
if mcpResponse.Error.Message != nil {
|
||||
msg = *mcpResponse.Error.Message
|
||||
}
|
||||
return nil, fmt.Errorf("MCP error %d: %s", code, msg)
|
||||
}
|
||||
|
||||
return &mcpResponse, nil
|
||||
}
|
||||
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// webSearchAuthAttrs extracts auth attributes for MCP calls.
|
||||
// Used by handleWebSearch and handleWebSearchStream to pass custom headers.
|
||||
func webSearchAuthAttrs(auth *cliproxyauth.Auth) map[string]string {
|
||||
if auth != nil {
|
||||
return auth.Attributes
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxWebSearchIterations = 5
|
||||
|
||||
// handleWebSearchStream handles web_search requests:
|
||||
// Step 1: tools/list (sync) → fetch/cache tool description
|
||||
// Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop
|
||||
// Note: We skip the "model decides to search" step because Claude Code already
|
||||
// decided to use web_search. The Kiro tool description restricts non-coding
|
||||
// topics, so asking the model again would cause it to refuse valid searches.
|
||||
func (e *KiroExecutor) handleWebSearchStream(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
// Extract search query from Claude Code's web_search tool_use
|
||||
query := kiroclaude.ExtractSearchQuery(req.Payload)
|
||||
if query == "" {
|
||||
log.Warnf("kiro/websearch: failed to extract search query, falling back to normal flow")
|
||||
return e.callKiroDirectStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
|
||||
|
||||
// ── Step 1: tools/list (SYNC) — cache tool description ──
|
||||
{
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
}
|
||||
|
||||
// Create output channel
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
|
||||
// Usage reporting: track web search requests like normal streaming requests
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
|
||||
go func() {
|
||||
var wsErr error
|
||||
defer reporter.trackFailure(ctx, &wsErr)
|
||||
defer close(out)
|
||||
|
||||
// Estimate input tokens using tokenizer (matching streamToChannel pattern)
|
||||
var totalUsage usage.Detail
|
||||
if enc, tokErr := getTokenizer(req.Model); tokErr == nil {
|
||||
if inp, e := countClaudeChatTokens(enc, req.Payload); e == nil && inp > 0 {
|
||||
totalUsage.InputTokens = inp
|
||||
} else {
|
||||
totalUsage.InputTokens = int64(len(req.Payload) / 4)
|
||||
}
|
||||
} else {
|
||||
totalUsage.InputTokens = int64(len(req.Payload) / 4)
|
||||
}
|
||||
if totalUsage.InputTokens == 0 && len(req.Payload) > 0 {
|
||||
totalUsage.InputTokens = 1
|
||||
}
|
||||
var accumulatedOutputLen int
|
||||
defer func() {
|
||||
if wsErr != nil {
|
||||
return // let trackFailure handle failure reporting
|
||||
}
|
||||
totalUsage.OutputTokens = int64(accumulatedOutputLen / 4)
|
||||
if accumulatedOutputLen > 0 && totalUsage.OutputTokens == 0 {
|
||||
totalUsage.OutputTokens = 1
|
||||
}
|
||||
reporter.publish(ctx, totalUsage)
|
||||
}()
|
||||
|
||||
// Send message_start event to client (aligned with streamToChannel pattern)
|
||||
// Use payloadRequestedModel to return user's original model alias
|
||||
msgStart := kiroclaude.BuildClaudeMessageStartEvent(
|
||||
payloadRequestedModel(opts, req.Model),
|
||||
totalUsage.InputTokens,
|
||||
)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: append(msgStart, '\n', '\n')}:
|
||||
}
|
||||
|
||||
// ── Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop ──
|
||||
contentBlockIndex := 0
|
||||
currentQuery := query
|
||||
|
||||
// Replace web_search tool description with a minimal one that allows re-search.
|
||||
// The original tools/list description from Kiro restricts non-coding topics,
|
||||
// but we've already decided to search. We keep the tool so the model can
|
||||
// request additional searches when results are insufficient.
|
||||
simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload))
|
||||
if simplifyErr != nil {
|
||||
log.Warnf("kiro/websearch: failed to simplify web_search tool: %v, using original payload", simplifyErr)
|
||||
simplifiedPayload = bytes.Clone(req.Payload)
|
||||
}
|
||||
|
||||
currentClaudePayload := simplifiedPayload
|
||||
totalSearches := 0
|
||||
|
||||
// Generate toolUseId for the first iteration (Claude Code already decided to search)
|
||||
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
|
||||
|
||||
for iteration := 0; iteration < maxWebSearchIterations; iteration++ {
|
||||
log.Infof("kiro/websearch: search iteration %d/%d",
|
||||
iteration+1, maxWebSearchIterations)
|
||||
|
||||
// MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery)
|
||||
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
log.Warnf("kiro/websearch: MCP API call failed: %v, continuing with empty results", mcpErr)
|
||||
} else {
|
||||
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
|
||||
}
|
||||
|
||||
resultCount := 0
|
||||
if searchResults != nil {
|
||||
resultCount = len(searchResults.Results)
|
||||
}
|
||||
totalSearches++
|
||||
log.Infof("kiro/websearch: iteration %d — got %d search results", iteration+1, resultCount)
|
||||
|
||||
// Send search indicator events to client
|
||||
searchEvents := kiroclaude.GenerateSearchIndicatorEvents(currentQuery, currentToolUseId, searchResults, contentBlockIndex)
|
||||
for _, event := range searchEvents {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: event}:
|
||||
}
|
||||
}
|
||||
contentBlockIndex += 2
|
||||
|
||||
// Inject tool_use + tool_result into Claude payload, then call GAR
|
||||
var err error
|
||||
currentClaudePayload, err = kiroclaude.InjectToolResultsClaude(currentClaudePayload, currentToolUseId, currentQuery, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to inject tool results: %v", err)
|
||||
wsErr = fmt.Errorf("failed to inject tool results: %w", err)
|
||||
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
|
||||
return
|
||||
}
|
||||
|
||||
// Call GAR with modified Claude payload (full translation pipeline)
|
||||
modifiedReq := req
|
||||
modifiedReq.Payload = currentClaudePayload
|
||||
kiroChunks, kiroErr := e.callKiroAndBuffer(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if kiroErr != nil {
|
||||
log.Warnf("kiro/websearch: Kiro API failed at iteration %d: %v", iteration+1, kiroErr)
|
||||
wsErr = fmt.Errorf("Kiro API failed at iteration %d: %w", iteration+1, kiroErr)
|
||||
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
|
||||
return
|
||||
}
|
||||
|
||||
// Analyze response
|
||||
analysis := kiroclaude.AnalyzeBufferedStream(kiroChunks)
|
||||
log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v",
|
||||
iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse)
|
||||
|
||||
if analysis.HasWebSearchToolUse && analysis.WebSearchQuery != "" && iteration+1 < maxWebSearchIterations {
|
||||
// Model wants another search
|
||||
filteredChunks := kiroclaude.FilterChunksForClient(kiroChunks, analysis.WebSearchToolUseIndex, contentBlockIndex)
|
||||
for _, chunk := range filteredChunks {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
|
||||
}
|
||||
}
|
||||
|
||||
currentQuery = analysis.WebSearchQuery
|
||||
currentToolUseId = analysis.WebSearchToolUseId
|
||||
continue
|
||||
}
|
||||
|
||||
// Model returned final response — stream to client
|
||||
for _, chunk := range kiroChunks {
|
||||
if contentBlockIndex > 0 && len(chunk) > 0 {
|
||||
adjusted, shouldForward := kiroclaude.AdjustSSEChunk(chunk, contentBlockIndex)
|
||||
if !shouldForward {
|
||||
continue
|
||||
}
|
||||
accumulatedOutputLen += len(adjusted)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: adjusted}:
|
||||
}
|
||||
} else {
|
||||
accumulatedOutputLen += len(chunk)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Infof("kiro/websearch: completed after %d search iteration(s), total searches: %d", iteration+1, totalSearches)
|
||||
return
|
||||
}
|
||||
|
||||
log.Warnf("kiro/websearch: reached max iterations (%d), stopping search loop", maxWebSearchIterations)
|
||||
}()
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// handleWebSearch handles web_search requests for non-streaming Execute path.
|
||||
// Performs MCP search synchronously, injects results into the request payload,
|
||||
// then calls the normal non-streaming Kiro API path which returns a proper
|
||||
// Claude JSON response (not SSE chunks).
|
||||
func (e *KiroExecutor) handleWebSearch(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (cliproxyexecutor.Response, error) {
|
||||
// Extract search query from Claude Code's web_search tool_use
|
||||
query := kiroclaude.ExtractSearchQuery(req.Payload)
|
||||
if query == "" {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute")
|
||||
// Fall through to normal non-streaming path
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
|
||||
|
||||
// Step 1: Fetch/cache tool description (sync)
|
||||
{
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
}
|
||||
|
||||
// Step 2: Perform MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(query)
|
||||
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr)
|
||||
} else {
|
||||
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
|
||||
}
|
||||
|
||||
resultCount := 0
|
||||
if searchResults != nil {
|
||||
resultCount = len(searchResults.Results)
|
||||
}
|
||||
log.Infof("kiro/websearch: non-stream: got %d search results", resultCount)
|
||||
|
||||
// Step 3: Replace restrictive web_search tool description (align with streaming path)
|
||||
simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload))
|
||||
if simplifyErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to simplify web_search tool: %v, using original payload", simplifyErr)
|
||||
simplifiedPayload = bytes.Clone(req.Payload)
|
||||
}
|
||||
|
||||
// Step 4: Inject search tool_use + tool_result into Claude payload
|
||||
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
|
||||
modifiedPayload, err := kiroclaude.InjectToolResultsClaude(simplifiedPayload, currentToolUseId, query, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err)
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Step 5: Call Kiro API via the normal non-streaming path (executeWithRetry)
|
||||
// This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream
|
||||
// to produce a proper Claude JSON response
|
||||
modifiedReq := req
|
||||
modifiedReq.Payload = modifiedPayload
|
||||
|
||||
resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Step 6: Inject server_tool_use + web_search_tool_result into response
|
||||
// so Claude Code can display "Did X searches in Ys"
|
||||
indicators := []kiroclaude.SearchIndicator{
|
||||
{
|
||||
ToolUseID: currentToolUseId,
|
||||
Query: query,
|
||||
Results: searchResults,
|
||||
},
|
||||
}
|
||||
injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators)
|
||||
if injErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr)
|
||||
} else {
|
||||
resp.Payload = injectedPayload
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// callKiroAndBuffer calls the Kiro API and buffers all response chunks.
|
||||
// Returns the buffered chunks for analysis before forwarding to client.
|
||||
// Usage reporting is NOT done here — the caller (handleWebSearchStream) manages its own reporter.
|
||||
func (e *KiroExecutor) callKiroAndBuffer(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) ([][]byte, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
log.Debugf("kiro/websearch GAR request: %d bytes", len(body))
|
||||
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
kiroStream, err := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
nil, body, from, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Buffer all chunks
|
||||
var chunks [][]byte
|
||||
for chunk := range kiroStream {
|
||||
if chunk.Err != nil {
|
||||
return chunks, chunk.Err
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
chunks = append(chunks, bytes.Clone(chunk.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro/websearch GAR response: %d chunks buffered", len(chunks))
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// callKiroDirectStream creates a direct streaming channel to Kiro API without search.
|
||||
func (e *KiroExecutor) callKiroDirectStream(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
var streamErr error
|
||||
defer reporter.trackFailure(ctx, &streamErr)
|
||||
|
||||
stream, streamErr := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
)
|
||||
return stream, streamErr
|
||||
}
|
||||
|
||||
// sendFallbackText sends a simple text response when the Kiro API fails during the search loop.
|
||||
// Delegates SSE event construction to kiroclaude.BuildFallbackTextEvents() for alignment
|
||||
// with how streamToChannel() uses BuildClaude*Event() functions.
|
||||
func (e *KiroExecutor) sendFallbackText(
|
||||
ctx context.Context,
|
||||
out chan<- cliproxyexecutor.StreamChunk,
|
||||
contentBlockIndex int,
|
||||
query string,
|
||||
searchResults *kiroclaude.WebSearchResults,
|
||||
) {
|
||||
events := kiroclaude.BuildFallbackTextEvents(contentBlockIndex, query, searchResults)
|
||||
for _, event := range events {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: append(event, '\n', '\n')}:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// executeNonStreamFallback runs the standard non-streaming Execute path for a request.
|
||||
// Used by handleWebSearch after injecting search results, or as a fallback.
|
||||
func (e *KiroExecutor) executeNonStreamFallback(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (cliproxyexecutor.Response, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
var err error
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
resp, err := e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
@@ -172,11 +172,11 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
// Translate response back to source format when needed
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -258,7 +258,6 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -298,7 +297,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
// Ensure we record the request if no usage chunk was ever seen
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
@@ -22,11 +23,151 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
qwenUserAgent = "google-api-nodejs-client/9.15.1"
|
||||
qwenXGoogAPIClient = "gl-node/22.17.0"
|
||||
qwenClientMetadataValue = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
|
||||
qwenRateLimitWindow = time.Minute // sliding window duration
|
||||
)
|
||||
|
||||
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
|
||||
var qwenBeijingLoc = func() *time.Location {
|
||||
loc, err := time.LoadLocation("Asia/Shanghai")
|
||||
if err != nil || loc == nil {
|
||||
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
|
||||
return time.FixedZone("CST", 8*3600)
|
||||
}
|
||||
return loc
|
||||
}()
|
||||
|
||||
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
|
||||
var qwenQuotaCodes = map[string]struct{}{
|
||||
"insufficient_quota": {},
|
||||
"quota_exceeded": {},
|
||||
}
|
||||
|
||||
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
|
||||
// Qwen has a limit of 60 requests per minute per account.
|
||||
var qwenRateLimiter = struct {
|
||||
sync.Mutex
|
||||
requests map[string][]time.Time // authID -> request timestamps
|
||||
}{
|
||||
requests: make(map[string][]time.Time),
|
||||
}
|
||||
|
||||
// redactAuthID returns a redacted version of the auth ID for safe logging.
|
||||
// Keeps a small prefix/suffix to allow correlation across events.
|
||||
func redactAuthID(id string) string {
|
||||
if id == "" {
|
||||
return ""
|
||||
}
|
||||
if len(id) <= 8 {
|
||||
return id
|
||||
}
|
||||
return id[:4] + "..." + id[len(id)-4:]
|
||||
}
|
||||
|
||||
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
|
||||
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
|
||||
func checkQwenRateLimit(authID string) error {
|
||||
if authID == "" {
|
||||
// Empty authID should not bypass rate limiting in production
|
||||
// Use debug level to avoid log spam for certain auth flows
|
||||
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-qwenRateLimitWindow)
|
||||
|
||||
qwenRateLimiter.Lock()
|
||||
defer qwenRateLimiter.Unlock()
|
||||
|
||||
// Get and filter timestamps within the window
|
||||
timestamps := qwenRateLimiter.requests[authID]
|
||||
var validTimestamps []time.Time
|
||||
for _, ts := range timestamps {
|
||||
if ts.After(windowStart) {
|
||||
validTimestamps = append(validTimestamps, ts)
|
||||
}
|
||||
}
|
||||
|
||||
// Always prune expired entries to prevent memory leak
|
||||
// Delete empty entries, otherwise update with pruned slice
|
||||
if len(validTimestamps) == 0 {
|
||||
delete(qwenRateLimiter.requests, authID)
|
||||
}
|
||||
|
||||
// Check if rate limit exceeded
|
||||
if len(validTimestamps) >= qwenRateLimitPerMin {
|
||||
// Calculate when the oldest request will expire
|
||||
oldestInWindow := validTimestamps[0]
|
||||
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
|
||||
if retryAfter < time.Second {
|
||||
retryAfter = time.Second
|
||||
}
|
||||
retryAfterSec := int(retryAfter.Seconds())
|
||||
return statusErr{
|
||||
code: http.StatusTooManyRequests,
|
||||
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
|
||||
retryAfter: &retryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
// Record this request and update the map with pruned timestamps
|
||||
validTimestamps = append(validTimestamps, now)
|
||||
qwenRateLimiter.requests[authID] = validTimestamps
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
|
||||
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
|
||||
func isQwenQuotaError(body []byte) bool {
|
||||
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
|
||||
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
|
||||
|
||||
// Primary check: exact match on error.code or error.type (most reliable)
|
||||
if _, ok := qwenQuotaCodes[code]; ok {
|
||||
return true
|
||||
}
|
||||
if _, ok := qwenQuotaCodes[errType]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
// Fallback: check message only if code/type don't match (less reliable)
|
||||
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
|
||||
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
|
||||
strings.Contains(msg, "free allocated quota exceeded") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
|
||||
// Returns the appropriate status code and retryAfter duration for statusErr.
|
||||
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
|
||||
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
|
||||
errCode = httpCode
|
||||
// Only check quota errors for expected status codes to avoid false positives
|
||||
// Qwen returns 403 for quota errors, 429 for rate limits
|
||||
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
|
||||
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
|
||||
cooldown := timeUntilNextDay()
|
||||
retryAfter = &cooldown
|
||||
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
|
||||
}
|
||||
return errCode, retryAfter
|
||||
}
|
||||
|
||||
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
|
||||
// Qwen's daily quota resets at 00:00 Beijing time.
|
||||
func timeUntilNextDay() time.Duration {
|
||||
now := time.Now()
|
||||
nowLocal := now.In(qwenBeijingLoc)
|
||||
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
|
||||
return tomorrow.Sub(now)
|
||||
}
|
||||
|
||||
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
||||
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
||||
type QwenExecutor struct {
|
||||
@@ -69,6 +210,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
if opts.Alt == "responses/compact" {
|
||||
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
|
||||
// Check rate limit before proceeding
|
||||
var authID string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
}
|
||||
if err := checkQwenRateLimit(authID); err != nil {
|
||||
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||
return resp, err
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
token, baseURL := qwenCreds(auth)
|
||||
@@ -104,9 +256,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
return resp, err
|
||||
}
|
||||
applyQwenHeaders(httpReq, token, false)
|
||||
var authID, authLabel, authType, authValue string
|
||||
var authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
@@ -137,8 +288,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
|
||||
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
@@ -152,14 +305,25 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
|
||||
// Check rate limit before proceeding
|
||||
var authID string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
}
|
||||
if err := checkQwenRateLimit(authID); err != nil {
|
||||
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
token, baseURL := qwenCreds(auth)
|
||||
@@ -202,9 +366,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
applyQwenHeaders(httpReq, token, true)
|
||||
var authID, authLabel, authType, authValue string
|
||||
var authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
@@ -230,15 +393,16 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
|
||||
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -270,7 +434,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
@@ -344,8 +508,18 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
r.Header.Set("User-Agent", qwenUserAgent)
|
||||
r.Header.Set("X-Goog-Api-Client", qwenXGoogAPIClient)
|
||||
r.Header.Set("Client-Metadata", qwenClientMetadataValue)
|
||||
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
|
||||
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
|
||||
r.Header.Set("Sec-Fetch-Mode", "cors")
|
||||
r.Header.Set("X-Stainless-Lang", "js")
|
||||
r.Header.Set("X-Stainless-Arch", "arm64")
|
||||
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
|
||||
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
|
||||
r.Header.Set("X-Stainless-Retry-Count", "0")
|
||||
r.Header.Set("X-Stainless-Os", "MacOS")
|
||||
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
|
||||
r.Header.Set("X-Stainless-Runtime", "node")
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
return
|
||||
|
||||
89
internal/runtime/executor/user_id_cache.go
Normal file
89
internal/runtime/executor/user_id_cache.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type userIDCacheEntry struct {
|
||||
value string
|
||||
expire time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
userIDCache = make(map[string]userIDCacheEntry)
|
||||
userIDCacheMu sync.RWMutex
|
||||
userIDCacheCleanupOnce sync.Once
|
||||
)
|
||||
|
||||
const (
|
||||
userIDTTL = time.Hour
|
||||
userIDCacheCleanupPeriod = 15 * time.Minute
|
||||
)
|
||||
|
||||
func startUserIDCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(userIDCacheCleanupPeriod)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredUserIDs()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func purgeExpiredUserIDs() {
|
||||
now := time.Now()
|
||||
userIDCacheMu.Lock()
|
||||
for key, entry := range userIDCache {
|
||||
if !entry.expire.After(now) {
|
||||
delete(userIDCache, key)
|
||||
}
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func userIDCacheKey(apiKey string) string {
|
||||
sum := sha256.Sum256([]byte(apiKey))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func cachedUserID(apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return generateFakeUserID()
|
||||
}
|
||||
|
||||
userIDCacheCleanupOnce.Do(startUserIDCacheCleanup)
|
||||
|
||||
key := userIDCacheKey(apiKey)
|
||||
now := time.Now()
|
||||
|
||||
userIDCacheMu.RLock()
|
||||
entry, ok := userIDCache[key]
|
||||
valid := ok && entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value)
|
||||
userIDCacheMu.RUnlock()
|
||||
if valid {
|
||||
userIDCacheMu.Lock()
|
||||
entry = userIDCache[key]
|
||||
if entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value) {
|
||||
entry.expire = now.Add(userIDTTL)
|
||||
userIDCache[key] = entry
|
||||
userIDCacheMu.Unlock()
|
||||
return entry.value
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
}
|
||||
|
||||
newID := generateFakeUserID()
|
||||
|
||||
userIDCacheMu.Lock()
|
||||
entry, ok = userIDCache[key]
|
||||
if !ok || entry.value == "" || !entry.expire.After(now) || !isValidUserID(entry.value) {
|
||||
entry.value = newID
|
||||
}
|
||||
entry.expire = now.Add(userIDTTL)
|
||||
userIDCache[key] = entry
|
||||
userIDCacheMu.Unlock()
|
||||
return entry.value
|
||||
}
|
||||
86
internal/runtime/executor/user_id_cache_test.go
Normal file
86
internal/runtime/executor/user_id_cache_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func resetUserIDCache() {
|
||||
userIDCacheMu.Lock()
|
||||
userIDCache = make(map[string]userIDCacheEntry)
|
||||
userIDCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
first := cachedUserID("api-key-1")
|
||||
second := cachedUserID("api-key-1")
|
||||
|
||||
if first == "" {
|
||||
t.Fatal("expected generated user_id to be non-empty")
|
||||
}
|
||||
if first != second {
|
||||
t.Fatalf("expected cached user_id to be reused, got %q and %q", first, second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
expiredID := cachedUserID("api-key-expired")
|
||||
cacheKey := userIDCacheKey("api-key-expired")
|
||||
userIDCacheMu.Lock()
|
||||
userIDCache[cacheKey] = userIDCacheEntry{
|
||||
value: expiredID,
|
||||
expire: time.Now().Add(-time.Minute),
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
|
||||
newID := cachedUserID("api-key-expired")
|
||||
if newID == expiredID {
|
||||
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
|
||||
}
|
||||
if newID == "" {
|
||||
t.Fatal("expected regenerated user_id to be non-empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
first := cachedUserID("api-key-1")
|
||||
second := cachedUserID("api-key-2")
|
||||
|
||||
if first == second {
|
||||
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
key := "api-key-renew"
|
||||
id := cachedUserID(key)
|
||||
cacheKey := userIDCacheKey(key)
|
||||
|
||||
soon := time.Now()
|
||||
userIDCacheMu.Lock()
|
||||
userIDCache[cacheKey] = userIDCacheEntry{
|
||||
value: id,
|
||||
expire: soon.Add(2 * time.Second),
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
|
||||
if refreshed := cachedUserID(key); refreshed != id {
|
||||
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
|
||||
}
|
||||
|
||||
userIDCacheMu.RLock()
|
||||
entry := userIDCache[cacheKey]
|
||||
userIDCacheMu.RUnlock()
|
||||
|
||||
if entry.expire.Sub(soon) < 30*time.Minute {
|
||||
t.Fatalf("expected TTL to renew, got %v remaining", entry.expire.Sub(soon))
|
||||
}
|
||||
}
|
||||
@@ -223,16 +223,71 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
||||
} else if functionResponseResult.IsArray() {
|
||||
frResults := functionResponseResult.Array()
|
||||
if len(frResults) == 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw)
|
||||
nonImageCount := 0
|
||||
lastNonImageRaw := ""
|
||||
filteredJSON := "[]"
|
||||
imagePartsJSON := "[]"
|
||||
for _, fr := range frResults {
|
||||
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := fr.Get("source.data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
imagePartJSON := `{}`
|
||||
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||
continue
|
||||
}
|
||||
|
||||
nonImageCount++
|
||||
lastNonImageRaw = fr.Raw
|
||||
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
|
||||
}
|
||||
|
||||
if nonImageCount == 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
|
||||
} else if nonImageCount > 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
}
|
||||
|
||||
// Place image data inside functionResponse.parts as inlineData
|
||||
// instead of as sibling parts in the outer content, to avoid
|
||||
// base64 data bloating the text context.
|
||||
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||
}
|
||||
|
||||
} else if functionResponseResult.IsObject() {
|
||||
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := functionResponseResult.Get("source.data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
imagePartJSON := `{}`
|
||||
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON := "[]"
|
||||
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
}
|
||||
} else if functionResponseResult.Raw != "" {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
// Content field is missing entirely — .Raw is empty which
|
||||
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
@@ -244,7 +299,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if sourceResult.Get("type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType)
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := sourceResult.Get("data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
@@ -344,7 +399,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// Inject interleaved thinking hint when both tools and thinking are active
|
||||
hasTools := toolDeclCount > 0
|
||||
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && thinkingResult.Get("type").String() == "enabled"
|
||||
thinkingType := thinkingResult.Get("type").String()
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive")
|
||||
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
|
||||
|
||||
if hasTools && hasThinking && isClaudeThinking {
|
||||
@@ -377,12 +433,18 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
switch t.Get("type").String() {
|
||||
case "enabled":
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||
|
||||
@@ -413,8 +413,8 @@ func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
|
||||
if !inlineData.Exists() {
|
||||
t.Error("inlineData should exist")
|
||||
}
|
||||
if inlineData.Get("mime_type").String() != "image/png" {
|
||||
t.Error("mime_type mismatch")
|
||||
if inlineData.Get("mimeType").String() != "image/png" {
|
||||
t.Error("mimeType mismatch")
|
||||
}
|
||||
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||
t.Error("data mismatch")
|
||||
@@ -661,6 +661,508 @@ func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultNoContent(t *testing.T) {
|
||||
// Bug repro: tool_result with no content field produces invalid JSON
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-opus-4-6-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "MyTool-123-456",
|
||||
"name": "MyTool",
|
||||
"input": {"key": "value"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "MyTool-123-456"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Errorf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
// Verify the functionResponse has a valid result value
|
||||
fr := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse.response.result")
|
||||
if !fr.Exists() {
|
||||
t.Error("functionResponse.response.result should exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
|
||||
// Bug repro: tool_result with null content produces invalid JSON
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-opus-4-6-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "MyTool-123-456",
|
||||
"name": "MyTool",
|
||||
"input": {"key": "value"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "MyTool-123-456",
|
||||
"content": null
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Errorf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultWithImage(t *testing.T) {
|
||||
// tool_result with array content containing text + image should place
|
||||
// image data inside functionResponse.parts as inlineData, not as a
|
||||
// sibling part in the outer content (to avoid base64 context bloat).
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "Read-123-456",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "File content here"
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUg=="
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
// Image should be inside functionResponse.parts, not as outer sibling part
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// Text content should be in response.result
|
||||
resultText := funcResp.Get("response.result.text").String()
|
||||
if resultText != "File content here" {
|
||||
t.Errorf("Expected response.result.text = 'File content here', got '%s'", resultText)
|
||||
}
|
||||
|
||||
// Image should be in functionResponse.parts[0].inlineData
|
||||
inlineData := funcResp.Get("parts.0.inlineData")
|
||||
if !inlineData.Exists() {
|
||||
t.Fatal("functionResponse.parts[0].inlineData should exist")
|
||||
}
|
||||
if inlineData.Get("mimeType").String() != "image/png" {
|
||||
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineData.Get("mimeType").String())
|
||||
}
|
||||
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||
t.Error("data mismatch")
|
||||
}
|
||||
|
||||
// Image should NOT be in outer parts (only functionResponse part should exist)
|
||||
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
|
||||
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
|
||||
t.Errorf("Expected only 1 outer part (functionResponse), got %d", len(outerParts.Array()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultWithSingleImage(t *testing.T) {
|
||||
// tool_result with single image object as content should place
|
||||
// image data inside functionResponse.parts, not as outer sibling part.
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "Read-789-012",
|
||||
"content": {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": "/9j/4AAQSkZJRgABAQ=="
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// response.result should be empty (image only)
|
||||
if funcResp.Get("response.result").String() != "" {
|
||||
t.Errorf("Expected empty response.result for image-only content, got '%s'", funcResp.Get("response.result").String())
|
||||
}
|
||||
|
||||
// Image should be in functionResponse.parts[0].inlineData
|
||||
inlineData := funcResp.Get("parts.0.inlineData")
|
||||
if !inlineData.Exists() {
|
||||
t.Fatal("functionResponse.parts[0].inlineData should exist")
|
||||
}
|
||||
if inlineData.Get("mimeType").String() != "image/jpeg" {
|
||||
t.Errorf("Expected mimeType 'image/jpeg', got '%s'", inlineData.Get("mimeType").String())
|
||||
}
|
||||
|
||||
// Image should NOT be in outer parts
|
||||
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
|
||||
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
|
||||
t.Errorf("Expected only 1 outer part, got %d", len(outerParts.Array()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultWithMultipleImagesAndTexts(t *testing.T) {
|
||||
// tool_result with array content: 2 text items + 2 images
|
||||
// All images go into functionResponse.parts, texts into response.result array
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "Multi-001",
|
||||
"content": [
|
||||
{"type": "text", "text": "First text"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/png", "data": "AAAA"}
|
||||
},
|
||||
{"type": "text", "text": "Second text"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/jpeg", "data": "BBBB"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// Multiple text items => response.result is an array
|
||||
resultArr := funcResp.Get("response.result")
|
||||
if !resultArr.IsArray() {
|
||||
t.Fatalf("Expected response.result to be an array, got: %s", resultArr.Raw)
|
||||
}
|
||||
results := resultArr.Array()
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("Expected 2 result items, got %d", len(results))
|
||||
}
|
||||
|
||||
// Both images should be in functionResponse.parts
|
||||
imgParts := funcResp.Get("parts").Array()
|
||||
if len(imgParts) != 2 {
|
||||
t.Fatalf("Expected 2 image parts in functionResponse.parts, got %d", len(imgParts))
|
||||
}
|
||||
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||
t.Errorf("Expected first image mimeType 'image/png', got '%s'", imgParts[0].Get("inlineData.mimeType").String())
|
||||
}
|
||||
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
|
||||
t.Errorf("Expected first image data 'AAAA', got '%s'", imgParts[0].Get("inlineData.data").String())
|
||||
}
|
||||
if imgParts[1].Get("inlineData.mimeType").String() != "image/jpeg" {
|
||||
t.Errorf("Expected second image mimeType 'image/jpeg', got '%s'", imgParts[1].Get("inlineData.mimeType").String())
|
||||
}
|
||||
if imgParts[1].Get("inlineData.data").String() != "BBBB" {
|
||||
t.Errorf("Expected second image data 'BBBB', got '%s'", imgParts[1].Get("inlineData.data").String())
|
||||
}
|
||||
|
||||
// Only 1 outer part (the functionResponse itself)
|
||||
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(outerParts) != 1 {
|
||||
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultWithOnlyMultipleImages(t *testing.T) {
|
||||
// tool_result with only images (no text) — response.result should be empty string
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "ImgOnly-001",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/png", "data": "PNG1"}
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/gif", "data": "GIF1"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// No text => response.result should be empty string
|
||||
if funcResp.Get("response.result").String() != "" {
|
||||
t.Errorf("Expected empty response.result, got '%s'", funcResp.Get("response.result").String())
|
||||
}
|
||||
|
||||
// Both images in functionResponse.parts
|
||||
imgParts := funcResp.Get("parts").Array()
|
||||
if len(imgParts) != 2 {
|
||||
t.Fatalf("Expected 2 image parts, got %d", len(imgParts))
|
||||
}
|
||||
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||
t.Error("first image mimeType mismatch")
|
||||
}
|
||||
if imgParts[1].Get("inlineData.mimeType").String() != "image/gif" {
|
||||
t.Error("second image mimeType mismatch")
|
||||
}
|
||||
|
||||
// Only 1 outer part
|
||||
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(outerParts) != 1 {
|
||||
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultImageNotBase64(t *testing.T) {
|
||||
// image with source.type != "base64" should be treated as non-image (falls through)
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "NotB64-001",
|
||||
"content": [
|
||||
{"type": "text", "text": "some output"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "url", "url": "https://example.com/img.png"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// Non-base64 image is treated as non-image, so it goes into the filtered results
|
||||
// along with the text item. Since there are 2 non-image items, result is array.
|
||||
resultArr := funcResp.Get("response.result")
|
||||
if !resultArr.IsArray() {
|
||||
t.Fatalf("Expected response.result to be an array (2 non-image items), got: %s", resultArr.Raw)
|
||||
}
|
||||
results := resultArr.Array()
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("Expected 2 result items, got %d", len(results))
|
||||
}
|
||||
|
||||
// No functionResponse.parts (no base64 images collected)
|
||||
if funcResp.Get("parts").Exists() {
|
||||
t.Error("functionResponse.parts should NOT exist when no base64 images")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingData(t *testing.T) {
|
||||
// image with source.type=base64 but missing data field
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "NoData-001",
|
||||
"content": [
|
||||
{"type": "text", "text": "output"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": "image/png"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// The image is still classified as base64 image (type check passes),
|
||||
// but data field is missing => inlineData has mimeType but no data
|
||||
imgParts := funcResp.Get("parts").Array()
|
||||
if len(imgParts) != 1 {
|
||||
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
|
||||
}
|
||||
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||
t.Error("mimeType should still be set")
|
||||
}
|
||||
if imgParts[0].Get("inlineData.data").Exists() {
|
||||
t.Error("data should not exist when source.data is missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingMediaType(t *testing.T) {
|
||||
// image with source.type=base64 but missing media_type field
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "NoMime-001",
|
||||
"content": [
|
||||
{"type": "text", "text": "output"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "data": "AAAA"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
|
||||
// The image is still classified as base64 image,
|
||||
// but media_type is missing => inlineData has data but no mimeType
|
||||
imgParts := funcResp.Get("parts").Array()
|
||||
if len(imgParts) != 1 {
|
||||
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
|
||||
}
|
||||
if imgParts[0].Get("inlineData.mimeType").Exists() {
|
||||
t.Error("mimeType should not exist when media_type is missing")
|
||||
}
|
||||
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
|
||||
t.Error("data should still be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
||||
// When tools + thinking but no system instruction, should create one with hint
|
||||
inputJSON := []byte(`{
|
||||
|
||||
@@ -93,3 +93,81 @@ func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
|
||||
// When functionResponse contains a "parts" field with inlineData (from Claude
|
||||
// translator's image embedding), fixCLIToolResponse should preserve it as-is.
|
||||
// parseFunctionResponseRaw returns response.Raw for valid JSON objects,
|
||||
// so extra fields like "parts" survive the pipeline.
|
||||
input := `{
|
||||
"model": "claude-opus-4-6-thinking",
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"functionCall": {"name": "screenshot", "args": {}}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{
|
||||
"functionResponse": {
|
||||
"id": "tool-001",
|
||||
"name": "screenshot",
|
||||
"response": {"result": "Screenshot taken"},
|
||||
"parts": [
|
||||
{"inlineData": {"mimeType": "image/png", "data": "iVBOR"}}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := fixCLIToolResponse(input)
|
||||
if err != nil {
|
||||
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||
}
|
||||
|
||||
// Find the function response content (role=function)
|
||||
contents := gjson.Get(result, "request.contents").Array()
|
||||
var funcContent gjson.Result
|
||||
for _, c := range contents {
|
||||
if c.Get("role").String() == "function" {
|
||||
funcContent = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if !funcContent.Exists() {
|
||||
t.Fatal("function role content should exist in output")
|
||||
}
|
||||
|
||||
// The functionResponse should be preserved with its parts field
|
||||
funcResp := funcContent.Get("parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist in output")
|
||||
}
|
||||
|
||||
// Verify the parts field with inlineData is preserved
|
||||
inlineParts := funcResp.Get("parts").Array()
|
||||
if len(inlineParts) != 1 {
|
||||
t.Fatalf("Expected 1 inlineData part in functionResponse.parts, got %d", len(inlineParts))
|
||||
}
|
||||
if inlineParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineParts[0].Get("inlineData.mimeType").String())
|
||||
}
|
||||
if inlineParts[0].Get("inlineData.data").String() != "iVBOR" {
|
||||
t.Errorf("Expected data 'iVBOR', got '%s'", inlineParts[0].Get("inlineData.data").String())
|
||||
}
|
||||
|
||||
// Verify response.result is also preserved
|
||||
if funcResp.Get("response.result").String() != "Screenshot taken" {
|
||||
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,7 +187,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||
mime := pieces[0]
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
@@ -201,7 +201,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
ext = sp[len(sp)-1]
|
||||
}
|
||||
if mimeType, ok := misc.MimeTypes[ext]; ok {
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mimeType)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
|
||||
p++
|
||||
} else {
|
||||
@@ -235,7 +235,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||
mime := pieces[0]
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
|
||||
@@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
|
||||
@@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
||||
}
|
||||
}
|
||||
|
||||
case "file":
|
||||
fileData := part.Get("file.file_data").String()
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
semicolonIdx := strings.Index(fileData, ";")
|
||||
commaIdx := strings.Index(fileData, ",")
|
||||
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||
data := fileData[commaIdx+1:]
|
||||
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
var textAggregate strings.Builder
|
||||
var partsJSON []string
|
||||
hasImage := false
|
||||
hasFile := false
|
||||
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
ptype := part.Get("type").String()
|
||||
@@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
hasImage = true
|
||||
}
|
||||
}
|
||||
case "input_file":
|
||||
fileData := part.Get("file_data").String()
|
||||
if fileData != "" {
|
||||
mediaType := "application/octet-stream"
|
||||
data := fileData
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
trimmed := strings.TrimPrefix(fileData, "data:")
|
||||
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
|
||||
if len(mediaAndData) == 2 {
|
||||
if mediaAndData[0] != "" {
|
||||
mediaType = mediaAndData[0]
|
||||
}
|
||||
data = mediaAndData[1]
|
||||
}
|
||||
}
|
||||
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
|
||||
contentPart, _ = sjson.Set(contentPart, "source.data", data)
|
||||
partsJSON = append(partsJSON, contentPart)
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
hasFile = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
if len(partsJSON) > 0 {
|
||||
msg := `{"role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
if len(partsJSON) == 1 && !hasImage {
|
||||
if len(partsJSON) == 1 && !hasImage && !hasFile {
|
||||
// Preserve legacy behavior for single text content
|
||||
msg, _ = sjson.Delete(msg, "content")
|
||||
textPart := gjson.Parse(partsJSON[0])
|
||||
|
||||
@@ -222,6 +222,10 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
case "adaptive":
|
||||
// Claude adaptive means "enable with max capacity"; keep it as highest level
|
||||
// and let ApplyThinking normalize per target model capability.
|
||||
reasoningEffort = string(thinking.LevelXHigh)
|
||||
case "disabled":
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
|
||||
@@ -22,8 +22,9 @@ var (
|
||||
|
||||
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
|
||||
type ConvertCodexResponseToClaudeParams struct {
|
||||
HasToolCall bool
|
||||
BlockIndex int
|
||||
HasToolCall bool
|
||||
BlockIndex int
|
||||
HasReceivedArgumentsDelta bool
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
|
||||
@@ -137,6 +138,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
itemType := itemResult.Get("type").String()
|
||||
if itemType == "function_call" {
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
|
||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
|
||||
@@ -171,12 +173,29 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
}
|
||||
} else if typeStr == "response.function_call_arguments.delta" {
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
|
||||
|
||||
output += "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
} else if typeStr == "response.function_call_arguments.done" {
|
||||
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
|
||||
// in a single "done" event without preceding "delta" events.
|
||||
// Emit the full arguments as a single input_json_delta so the
|
||||
// downstream Claude client receives the complete tool input.
|
||||
// When delta events were already received, skip to avoid duplicating arguments.
|
||||
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
|
||||
if args := rootResult.Get("arguments").String(); args != "" {
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.partial_json", args)
|
||||
|
||||
output += "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return []string{output}
|
||||
|
||||
@@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
}
|
||||
case "file":
|
||||
// Files are not specified in examples; skip for now
|
||||
if role == "user" {
|
||||
fileData := it.Get("file.file_data").String()
|
||||
filename := it.Get("file.filename").String()
|
||||
if fileData != "" {
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", "input_file")
|
||||
part, _ = sjson.Set(part, "file_data", fileData)
|
||||
if filename != "" {
|
||||
part, _ = sjson.Set(part, "filename", filename)
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,10 +20,12 @@ var (
|
||||
|
||||
// ConvertCliToOpenAIParams holds parameters for response conversion.
|
||||
type ConvertCliToOpenAIParams struct {
|
||||
ResponseID string
|
||||
CreatedAt int64
|
||||
Model string
|
||||
FunctionCallIndex int
|
||||
ResponseID string
|
||||
CreatedAt int64
|
||||
Model string
|
||||
FunctionCallIndex int
|
||||
HasReceivedArgumentsDelta bool
|
||||
HasToolCallAnnounced bool
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToOpenAI translates a single chunk of a streaming response from the
|
||||
@@ -43,10 +45,12 @@ type ConvertCliToOpenAIParams struct {
|
||||
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
if *param == nil {
|
||||
*param = &ConvertCliToOpenAIParams{
|
||||
Model: modelName,
|
||||
CreatedAt: 0,
|
||||
ResponseID: "",
|
||||
FunctionCallIndex: -1,
|
||||
Model: modelName,
|
||||
CreatedAt: 0,
|
||||
ResponseID: "",
|
||||
FunctionCallIndex: -1,
|
||||
HasReceivedArgumentsDelta: false,
|
||||
HasToolCallAnnounced: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +94,9 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
}
|
||||
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
}
|
||||
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
}
|
||||
@@ -115,35 +122,93 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||
} else if dataType == "response.output_item.done" {
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
} else if dataType == "response.output_item.added" {
|
||||
itemResult := rootResult.Get("item")
|
||||
if itemResult.Exists() {
|
||||
if itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// set the index
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened
|
||||
name := itemResult.Get("name").String()
|
||||
// Build reverse map on demand from original request tools
|
||||
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Increment index for this new function call item.
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
|
||||
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
|
||||
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened.
|
||||
name := itemResult.Get("name").String()
|
||||
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
|
||||
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.function_call_arguments.delta" {
|
||||
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
|
||||
|
||||
deltaValue := rootResult.Get("delta").String()
|
||||
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.function_call_arguments.done" {
|
||||
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
|
||||
// Arguments were already streamed via delta events; nothing to emit.
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Fallback: no delta events were received, emit the full arguments as a single chunk.
|
||||
fullArgs := rootResult.Get("arguments").String()
|
||||
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.output_item.done" {
|
||||
itemResult := rootResult.Get("item")
|
||||
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
|
||||
// Tool call was already announced via output_item.added; skip emission.
|
||||
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Fallback path: model skipped output_item.added, so emit complete tool call now.
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened.
|
||||
name := itemResult.Get("name").String()
|
||||
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else {
|
||||
return []string{}
|
||||
}
|
||||
@@ -205,6 +270,9 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
}
|
||||
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
}
|
||||
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
}
|
||||
|
||||
@@ -26,6 +26,11 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
|
||||
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
|
||||
|
||||
// Delete the user field as it is not supported by the Codex upstream.
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
||||
|
||||
// Convert role "system" to "developer" in input array to comply with Codex API requirements.
|
||||
rawJSON = convertSystemRoleToDeveloper(rawJSON)
|
||||
@@ -33,6 +38,23 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction
|
||||
// for Codex upstream compatibility.
|
||||
//
|
||||
// Codex /responses currently rejects context_management with:
|
||||
// {"detail":"Unsupported parameter: context_management"}.
|
||||
//
|
||||
// Compatibility strategy:
|
||||
// 1) Remove context_management before forwarding to Codex upstream.
|
||||
func applyResponsesCompactionCompatibility(rawJSON []byte) []byte {
|
||||
if !gjson.GetBytes(rawJSON, "context_management").Exists() {
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management")
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
// convertSystemRoleToDeveloper traverses the input array and converts any message items
|
||||
// with role "system" to role "developer". This is necessary because Codex API does not
|
||||
// accept "system" role in the input array.
|
||||
|
||||
@@ -263,3 +263,58 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) {
|
||||
t.Errorf("Expected third role 'assistant', got '%s'", thirdRole.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserFieldDeletion(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
"user": "test-user",
|
||||
"input": [{"role": "user", "content": "Hello"}]
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Verify user field is deleted
|
||||
userField := gjson.Get(outputStr, "user")
|
||||
if userField.Exists() {
|
||||
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextManagementCompactionCompatibility(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
"context_management": [
|
||||
{
|
||||
"type": "compaction",
|
||||
"compact_threshold": 12000
|
||||
}
|
||||
],
|
||||
"input": [{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if gjson.Get(outputStr, "context_management").Exists() {
|
||||
t.Fatalf("context_management should be removed for Codex compatibility")
|
||||
}
|
||||
if gjson.Get(outputStr, "truncation").Exists() {
|
||||
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncationRemovedForCodexCompatibility(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
"truncation": "disabled",
|
||||
"input": [{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if gjson.Get(outputStr, "truncation").Exists() {
|
||||
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user