mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-10 07:43:07 +00:00
Compare commits
238 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b8c466e88 | ||
|
|
ca2174ea48 | ||
|
|
c09fb2a79d | ||
|
|
4445a165e9 | ||
|
|
e92e2af71a | ||
|
|
a6bdd9a652 | ||
|
|
349a6349b3 | ||
|
|
00822770ec | ||
|
|
72add453d2 | ||
|
|
2789396435 | ||
|
|
61da7bd981 | ||
|
|
ae4c502792 | ||
|
|
ec6068060b | ||
|
|
ecb01d3dcd | ||
|
|
22c0c00bd4 | ||
|
|
9eb3e7a6c4 | ||
|
|
357c191510 | ||
|
|
5db244af76 | ||
|
|
dc375d1b74 | ||
|
|
9c040445af | ||
|
|
fff866424e | ||
|
|
2d12becfd6 | ||
|
|
252f7e0751 | ||
|
|
b2b17528cb | ||
|
|
55f938164b | ||
|
|
76294f0c59 | ||
|
|
2bcee78c6e | ||
|
|
7fe8246a9f | ||
|
|
93fe58e31e | ||
|
|
e5b5dc870f | ||
|
|
a54877c023 | ||
|
|
bb86a0c0c4 | ||
|
|
5fa23c7f41 | ||
|
|
f9a09b7f23 | ||
|
|
b0cde626fe | ||
|
|
e42ef9a95d | ||
|
|
abf1629ec7 | ||
|
|
73dc0b10b8 | ||
|
|
2ea95266e3 | ||
|
|
922d4141c0 | ||
|
|
1f8f198c45 | ||
|
|
c55275342c | ||
|
|
9261b0c20b | ||
|
|
7cc725496e | ||
|
|
5726a99c80 | ||
|
|
b5756bf729 | ||
|
|
709d999f9f | ||
|
|
24c18614f0 | ||
|
|
603f06a762 | ||
|
|
98f0a3e3bd | ||
|
|
e186ccb0d4 | ||
|
|
8fc0b08b70 | ||
|
|
52a257dc24 | ||
|
|
a12d907f55 | ||
|
|
453aaf8774 | ||
|
|
1b1ab1fb9b | ||
|
|
a9d0bb72da | ||
|
|
d328e54e4b | ||
|
|
5a7932cba4 | ||
|
|
1dbeb0827a | ||
|
|
2c8821891c | ||
|
|
0a2555b0f3 | ||
|
|
020df41efe | ||
|
|
f8f8cf17ce | ||
|
|
f31f7f701a | ||
|
|
b5fe78eb70 | ||
|
|
d1f667cf8d | ||
|
|
54ad7c1b6b | ||
|
|
d560c20c26 | ||
|
|
5abeca1f9e | ||
|
|
294eac3a88 | ||
|
|
a31104020c | ||
|
|
65bec4d734 | ||
|
|
edb2993838 | ||
|
|
c0d8e0dec7 | ||
|
|
795da13d5d | ||
|
|
55789df275 | ||
|
|
9e652a3540 | ||
|
|
46a6782065 | ||
|
|
c359f61859 | ||
|
|
908c8eab5b | ||
|
|
f5f2c69233 | ||
|
|
63d4de5eea | ||
|
|
af15083496 | ||
|
|
c4722e42b1 | ||
|
|
f9a991365f | ||
|
|
6df16bedba | ||
|
|
632a2fd2f2 | ||
|
|
5626637fbd | ||
|
|
2db89211a9 | ||
|
|
587371eb14 | ||
|
|
75818b1e25 | ||
|
|
cbe56955a9 | ||
|
|
8ea6ac913d | ||
|
|
ae1e8a5191 | ||
|
|
b3ccc55f09 | ||
|
|
1ce56d7413 | ||
|
|
41a78be3a2 | ||
|
|
1ff5de9a31 | ||
|
|
46a6853046 | ||
|
|
4b2d40bd67 | ||
|
|
726f1a590c | ||
|
|
575881cb59 | ||
|
|
d02df0141b | ||
|
|
e4bc9da913 | ||
|
|
8c6be49625 | ||
|
|
c727e4251f | ||
|
|
99266be998 | ||
|
|
d0f3fd96f8 | ||
|
|
f361b2716d | ||
|
|
086d8d0d0b | ||
|
|
627dee1dac | ||
|
|
55c3197fb8 | ||
|
|
5a2cf0d53c | ||
|
|
2573358173 | ||
|
|
09cd3cff91 | ||
|
|
ab0bf1b517 | ||
|
|
58e09f8e5f | ||
|
|
2334a2b174 | ||
|
|
bc61bf36b2 | ||
|
|
7726a44ca2 | ||
|
|
dc55fb0ce3 | ||
|
|
a146c6c0aa | ||
|
|
4c133d3ea9 | ||
|
|
544238772a | ||
|
|
f3ccd85ba1 | ||
|
|
dc279de443 | ||
|
|
bf1634bda0 | ||
|
|
166d2d24d9 | ||
|
|
4cbcc835d1 | ||
|
|
b93026d83a | ||
|
|
5ed2133ff9 | ||
|
|
e9dd44e623 | ||
|
|
cc8c4ffb5f | ||
|
|
1510bfcb6f | ||
|
|
bcd2208b51 | ||
|
|
09b19f5c4e | ||
|
|
7b01ca0e2e | ||
|
|
9c65e17a21 | ||
|
|
fe6fc628ed | ||
|
|
8192eeabc8 | ||
|
|
c3f1cdd7e5 | ||
|
|
c6bd91b86b | ||
|
|
349ddcaa89 | ||
|
|
bb9fe52f1e | ||
|
|
afe4c1bfb7 | ||
|
|
865af9f19e | ||
|
|
2b97cb98b5 | ||
|
|
938a799263 | ||
|
|
e17d4f8d98 | ||
|
|
c8cae1f74d | ||
|
|
0040d78496 | ||
|
|
896de027cc | ||
|
|
fc329ebf37 | ||
|
|
15bc99f6ea | ||
|
|
91841a5519 | ||
|
|
eaab1d6824 | ||
|
|
0cfe310df6 | ||
|
|
918b6955e4 | ||
|
|
3ec7991e5f | ||
|
|
532fbf00d4 | ||
|
|
45b6fffd7f | ||
|
|
5a3eb08739 | ||
|
|
0dff329162 | ||
|
|
49c1740b47 | ||
|
|
3fbee51e9f | ||
|
|
a3dc56d2a0 | ||
|
|
63643c44a1 | ||
|
|
1d93608dbe | ||
|
|
d125b7de92 | ||
|
|
d5654ee316 | ||
|
|
3b34521ad9 | ||
|
|
7197fb350b | ||
|
|
6e349bfcc7 | ||
|
|
234056072d | ||
|
|
76330f4bff | ||
|
|
d468eec6ec | ||
|
|
40e85a6759 | ||
|
|
9bc6cc5b41 | ||
|
|
d109be159c | ||
|
|
eddf31e55b | ||
|
|
7e9d0db6aa | ||
|
|
2f1874ede5 | ||
|
|
6b83585b53 | ||
|
|
78ef04fcf1 | ||
|
|
b7e4f00c5f | ||
|
|
c20507c15e | ||
|
|
f7d0019df7 | ||
|
|
52364af5bf | ||
|
|
f410dd0440 | ||
|
|
eb5582c17c | ||
|
|
1c6cb2bec3 | ||
|
|
80b5e79e75 | ||
|
|
d182e893b6 | ||
|
|
2e8d49a641 | ||
|
|
6abd7d27d9 | ||
|
|
8fa12af403 | ||
|
|
77586ed7d3 | ||
|
|
394497fb2f | ||
|
|
fc7b6ef086 | ||
|
|
98edcad39d | ||
|
|
cc116ce67d | ||
|
|
1187aa8222 | ||
|
|
a35d66443b | ||
|
|
40ad4a42ea | ||
|
|
dc9b4dd017 | ||
|
|
68cb81a258 | ||
|
|
16693053f5 | ||
|
|
40efc2ba43 | ||
|
|
4e3bad3907 | ||
|
|
c874f19f2a | ||
|
|
f5f26f0cbe | ||
|
|
e7e3ca1efb | ||
|
|
4b00312fef | ||
|
|
c5fd3db01e | ||
|
|
e35ffaa925 | ||
|
|
f870a9d2a7 | ||
|
|
165e03f3a7 | ||
|
|
86bdb7808c | ||
|
|
b4e034be1c | ||
|
|
84fcebf538 | ||
|
|
74d9a1ffed | ||
|
|
a5a25dec57 | ||
|
|
c71905e5e8 | ||
|
|
bc78d668ac | ||
|
|
e93eebc2e9 | ||
|
|
5bd0896ad7 | ||
|
|
09ecfbcaed | ||
|
|
f0bd14b64f | ||
|
|
14f044ce4f | ||
|
|
88872baffc | ||
|
|
dbecf5330e | ||
|
|
706590c62a | ||
|
|
49ef22ab78 | ||
|
|
ae4638712e | ||
|
|
233be6272a | ||
|
|
47cb52385e | ||
|
|
a406ca2d5a |
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- run: git fetch --force --tags
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '>=1.24.0'
|
||||
go-version: '>=1.26.0'
|
||||
cache: true
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -3,10 +3,11 @@ cli-proxy-api
|
||||
cliproxy
|
||||
*.exe
|
||||
|
||||
|
||||
# Configuration
|
||||
config.yaml
|
||||
.env
|
||||
|
||||
.mcp.json
|
||||
# Generated content
|
||||
bin/*
|
||||
logs/*
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.24-alpine AS builder
|
||||
FROM golang:1.26-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 51 KiB |
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -26,6 +27,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
@@ -72,11 +74,13 @@ func main() {
|
||||
var codexLogin bool
|
||||
var claudeLogin bool
|
||||
var qwenLogin bool
|
||||
var kiloLogin bool
|
||||
var iflowLogin bool
|
||||
var iflowCookie bool
|
||||
var noBrowser bool
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
var kimiLogin bool
|
||||
var kiroLogin bool
|
||||
var kiroGoogleLogin bool
|
||||
var kiroAWSLogin bool
|
||||
@@ -87,6 +91,8 @@ func main() {
|
||||
var vertexImport string
|
||||
var configPath string
|
||||
var password string
|
||||
var tuiMode bool
|
||||
var standalone bool
|
||||
var noIncognito bool
|
||||
var useIncognito bool
|
||||
|
||||
@@ -95,6 +101,7 @@ func main() {
|
||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
@@ -102,6 +109,7 @@ func main() {
|
||||
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
||||
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||
@@ -112,6 +120,8 @@ func main() {
|
||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
flag.StringVar(&password, "password", "", "")
|
||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||
|
||||
flag.CommandLine.Usage = func() {
|
||||
out := flag.CommandLine.Output()
|
||||
@@ -473,7 +483,7 @@ func main() {
|
||||
}
|
||||
|
||||
// Register built-in access providers before constructing services.
|
||||
configaccess.Register()
|
||||
configaccess.Register(&cfg.SDKConfig)
|
||||
|
||||
// Handle different command modes based on the provided flags.
|
||||
|
||||
@@ -497,10 +507,14 @@ func main() {
|
||||
cmd.DoClaudeLogin(cfg, options)
|
||||
} else if qwenLogin {
|
||||
cmd.DoQwenLogin(cfg, options)
|
||||
} else if kiloLogin {
|
||||
cmd.DoKiloLogin(cfg, options)
|
||||
} else if iflowLogin {
|
||||
cmd.DoIFlowLogin(cfg, options)
|
||||
} else if iflowCookie {
|
||||
cmd.DoIFlowCookieAuth(cfg, options)
|
||||
} else if kimiLogin {
|
||||
cmd.DoKimiLogin(cfg, options)
|
||||
} else if kiroLogin {
|
||||
// For Kiro auth, default to incognito mode for multi-account support
|
||||
// Users can explicitly override with --no-incognito
|
||||
@@ -532,15 +546,89 @@ func main() {
|
||||
cmd.WaitForCloudDeploy()
|
||||
return
|
||||
}
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
if tuiMode {
|
||||
if standalone {
|
||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
hook := tui.NewLogHook(2000)
|
||||
hook.SetFormatter(&logging.LogFormatter{})
|
||||
log.AddHook(hook)
|
||||
|
||||
// 初始化并启动 Kiro token 后台刷新
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
defer kiro.StopGlobalRefreshManager()
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
origLogOutput := log.StandardLogger().Out
|
||||
log.SetOutput(io.Discard)
|
||||
|
||||
devNull, errOpenDevNull := os.Open(os.DevNull)
|
||||
if errOpenDevNull == nil {
|
||||
os.Stdout = devNull
|
||||
os.Stderr = devNull
|
||||
}
|
||||
|
||||
restoreIO := func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
log.SetOutput(origLogOutput)
|
||||
if devNull != nil {
|
||||
_ = devNull.Close()
|
||||
}
|
||||
}
|
||||
|
||||
localMgmtPassword := fmt.Sprintf("tui-%d-%d", os.Getpid(), time.Now().UnixNano())
|
||||
if password == "" {
|
||||
password = localMgmtPassword
|
||||
}
|
||||
|
||||
cancel, done := cmd.StartServiceBackground(cfg, configFilePath, password)
|
||||
|
||||
client := tui.NewClient(cfg.Port, password)
|
||||
ready := false
|
||||
backoff := 100 * time.Millisecond
|
||||
for i := 0; i < 30; i++ {
|
||||
if _, errGetConfig := client.GetConfig(); errGetConfig == nil {
|
||||
ready = true
|
||||
break
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
if backoff < time.Second {
|
||||
backoff = time.Duration(float64(backoff) * 1.5)
|
||||
}
|
||||
}
|
||||
|
||||
if !ready {
|
||||
restoreIO()
|
||||
cancel()
|
||||
<-done
|
||||
fmt.Fprintf(os.Stderr, "TUI error: embedded server is not ready\n")
|
||||
return
|
||||
}
|
||||
|
||||
if errRun := tui.Run(cfg.Port, password, hook, origStdout); errRun != nil {
|
||||
restoreIO()
|
||||
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
|
||||
} else {
|
||||
restoreIO()
|
||||
}
|
||||
|
||||
cancel()
|
||||
<-done
|
||||
} else {
|
||||
// Default TUI mode: pure management client.
|
||||
// The proxy server must already be running.
|
||||
if errRun := tui.Run(cfg.Port, password, nil, os.Stdout); errRun != nil {
|
||||
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
defer kiro.StopGlobalRefreshManager()
|
||||
}
|
||||
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
}
|
||||
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
|
||||
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
|
||||
host: ""
|
||||
host: ''
|
||||
|
||||
# Server port
|
||||
port: 8317
|
||||
@@ -8,8 +8,8 @@ port: 8317
|
||||
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
|
||||
tls:
|
||||
enable: false
|
||||
cert: ""
|
||||
key: ""
|
||||
cert: ''
|
||||
key: ''
|
||||
|
||||
# Management API settings
|
||||
remote-management:
|
||||
@@ -20,22 +20,22 @@ remote-management:
|
||||
# Management key. If a plaintext value is provided here, it will be hashed on startup.
|
||||
# All management requests (even from localhost) require this key.
|
||||
# Leave empty to disable the Management API entirely (404 for all /v0/management routes).
|
||||
secret-key: ""
|
||||
secret-key: ''
|
||||
|
||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||
disable-control-panel: false
|
||||
|
||||
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
||||
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
|
||||
|
||||
# Authentication directory (supports ~ for home directory)
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
auth-dir: '~/.cli-proxy-api'
|
||||
|
||||
# API keys for authentication
|
||||
api-keys:
|
||||
- "your-api-key-1"
|
||||
- "your-api-key-2"
|
||||
- "your-api-key-3"
|
||||
- 'your-api-key-1'
|
||||
- 'your-api-key-2'
|
||||
- 'your-api-key-3'
|
||||
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
@@ -43,7 +43,7 @@ debug: false
|
||||
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
|
||||
pprof:
|
||||
enable: false
|
||||
addr: "127.0.0.1:8316"
|
||||
addr: '127.0.0.1:8316'
|
||||
|
||||
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
|
||||
commercial-mode: false
|
||||
@@ -68,11 +68,15 @@ error-logs-max-files: 10
|
||||
usage-statistics-enabled: false
|
||||
|
||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
proxy-url: ''
|
||||
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
|
||||
# When true, forward filtered upstream response headers to downstream clients.
|
||||
# Default is false (disabled).
|
||||
passthrough-headers: false
|
||||
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
@@ -86,7 +90,7 @@ quota-exceeded:
|
||||
|
||||
# Routing strategy for selecting credentials when multiple match.
|
||||
routing:
|
||||
strategy: "round-robin" # round-robin (default), fill-first
|
||||
strategy: 'round-robin' # round-robin (default), fill-first
|
||||
|
||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||
ws-auth: false
|
||||
@@ -161,6 +165,14 @@ nonstream-keepalive-interval: 0
|
||||
# - "API"
|
||||
# - "proxy"
|
||||
|
||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||
# These are used as fallbacks when the client does not send its own headers.
|
||||
# claude-header-defaults:
|
||||
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
||||
# package-version: "0.74.0"
|
||||
# runtime-version: "v24.3.0"
|
||||
# timeout: "600"
|
||||
|
||||
# Kiro (AWS CodeWhisperer) configuration
|
||||
# Note: Kiro API currently only operates in us-east-1 region
|
||||
#kiro:
|
||||
@@ -171,6 +183,21 @@ nonstream-keepalive-interval: 0
|
||||
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
|
||||
|
||||
# Kilocode (OAuth-based code assistant)
|
||||
# Note: Kilocode uses OAuth device flow authentication.
|
||||
# Use the CLI command: ./server --kilo-login
|
||||
# This will save credentials to the auth directory (default: ~/.cli-proxy-api/)
|
||||
# oauth-model-alias:
|
||||
# kilo:
|
||||
# - name: "minimax/minimax-m2.5:free"
|
||||
# alias: "minimax-m2.5"
|
||||
# - name: "z-ai/glm-5:free"
|
||||
# alias: "glm-5"
|
||||
# oauth-excluded-models:
|
||||
# kilo:
|
||||
# - "kilo-claude-opus-4-6" # exclude specific models (exact match)
|
||||
# - "*:free" # wildcard matching suffix (e.g. all free models)
|
||||
|
||||
# OpenAI compatibility providers
|
||||
# openai-compatibility:
|
||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||
@@ -236,10 +263,10 @@ nonstream-keepalive-interval: 0
|
||||
|
||||
# Global OAuth model name aliases (per channel)
|
||||
# These aliases rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||
#oauth-model-alias:
|
||||
# oauth-model-alias:
|
||||
# antigravity:
|
||||
# - name: "rev19-uic3-1p"
|
||||
# alias: "gemini-2.5-computer-use-preview-10-2025"
|
||||
@@ -265,9 +292,6 @@ nonstream-keepalive-interval: 0
|
||||
# aistudio:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# antigravity:
|
||||
# - name: "gemini-3-pro-preview"
|
||||
# alias: "g3p"
|
||||
# claude:
|
||||
# - name: "claude-sonnet-4-5-20250929"
|
||||
# alias: "cs4.5"
|
||||
@@ -280,6 +304,9 @@ nonstream-keepalive-interval: 0
|
||||
# iflow:
|
||||
# - name: "glm-4.7"
|
||||
# alias: "glm-god"
|
||||
# kimi:
|
||||
# - name: "kimi-k2.5"
|
||||
# alias: "k2.5"
|
||||
# kiro:
|
||||
# - name: "kiro-claude-opus-4-5"
|
||||
# alias: "op45"
|
||||
@@ -309,6 +336,8 @@ nonstream-keepalive-interval: 0
|
||||
# - "vision-model"
|
||||
# iflow:
|
||||
# - "tstars2.0"
|
||||
# kimi:
|
||||
# - "kimi-k2-thinking"
|
||||
# kiro:
|
||||
# - "kiro-claude-haiku-4-5"
|
||||
# github-copilot:
|
||||
|
||||
@@ -7,80 +7,71 @@ The `github.com/router-for-me/CLIProxyAPI/v6/sdk/access` package centralizes inb
|
||||
```go
|
||||
import (
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
```
|
||||
|
||||
Add the module with `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access`.
|
||||
|
||||
## Provider Registry
|
||||
|
||||
Providers are registered globally and then attached to a `Manager` as a snapshot:
|
||||
|
||||
- `RegisterProvider(type, provider)` installs a pre-initialized provider instance.
|
||||
- Registration order is preserved the first time each `type` is seen.
|
||||
- `RegisteredProviders()` returns the providers in that order.
|
||||
|
||||
## Manager Lifecycle
|
||||
|
||||
```go
|
||||
manager := sdkaccess.NewManager()
|
||||
providers, err := sdkaccess.BuildProviders(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manager.SetProviders(providers)
|
||||
manager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
* `NewManager` constructs an empty manager.
|
||||
* `SetProviders` replaces the provider slice using a defensive copy.
|
||||
* `Providers` retrieves a snapshot that can be iterated safely from other goroutines.
|
||||
* `BuildProviders` translates `config.Config` access declarations into runnable providers. When the config omits explicit providers but defines inline API keys, the helper auto-installs the built-in `config-api-key` provider.
|
||||
|
||||
If the manager itself is `nil` or no providers are configured, the call returns `nil, nil`, allowing callers to treat access control as disabled.
|
||||
|
||||
## Authenticating Requests
|
||||
|
||||
```go
|
||||
result, err := manager.Authenticate(ctx, req)
|
||||
result, authErr := manager.Authenticate(ctx, req)
|
||||
switch {
|
||||
case err == nil:
|
||||
case authErr == nil:
|
||||
// Authentication succeeded; result describes the provider and principal.
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
|
||||
// No recognizable credentials were supplied.
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
|
||||
// Supplied credentials were present but rejected.
|
||||
default:
|
||||
// Transport-level failure was returned by a provider.
|
||||
// Internal/transport failure was returned by a provider.
|
||||
}
|
||||
```
|
||||
|
||||
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that surface `ErrNotHandled`, and tracks whether any provider reported `ErrNoCredentials` or `ErrInvalidCredential` for downstream error reporting.
|
||||
|
||||
If the manager itself is `nil` or no providers are registered, the call returns `nil, nil`, allowing callers to treat access control as disabled without branching on errors.
|
||||
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that return `AuthErrorCodeNotHandled`, and aggregates `AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` for a final result.
|
||||
|
||||
Each `Result` includes the provider identifier, the resolved principal, and optional metadata (for example, which header carried the credential).
|
||||
|
||||
## Configuration Layout
|
||||
## Built-in `config-api-key` Provider
|
||||
|
||||
The manager expects access providers under the `auth.providers` key inside `config.yaml`:
|
||||
The proxy includes one built-in access provider:
|
||||
|
||||
- `config-api-key`: Validates API keys declared under top-level `api-keys`.
|
||||
- Credential sources: `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, `?key=`, `?auth_token=`
|
||||
- Metadata: `Result.Metadata["source"]` is set to the matched source label.
|
||||
|
||||
In the CLI server and `sdk/cliproxy`, this provider is registered automatically based on the loaded configuration.
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: inline-api
|
||||
type: config-api-key
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
```
|
||||
|
||||
Fields map directly to `config.AccessProvider`: `name` labels the provider, `type` selects the registered factory, `sdk` can name an external module, `api-keys` seeds inline credentials, and `config` passes provider-specific options.
|
||||
## Loading Providers from External Go Modules
|
||||
|
||||
### Loading providers from external SDK modules
|
||||
|
||||
To consume a provider shipped in another Go module, point the `sdk` field at the module path and import it for its registration side effect:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: partner-auth
|
||||
type: partner-token
|
||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
||||
config:
|
||||
region: us-west-2
|
||||
audience: cli-proxy
|
||||
```
|
||||
To consume a provider shipped in another Go module, import it for its registration side effect:
|
||||
|
||||
```go
|
||||
import (
|
||||
@@ -89,19 +80,11 @@ import (
|
||||
)
|
||||
```
|
||||
|
||||
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before `BuildProviders` is called.
|
||||
|
||||
## Built-in Providers
|
||||
|
||||
The SDK ships with one provider out of the box:
|
||||
|
||||
- `config-api-key`: Validates API keys declared inline or under top-level `api-keys`. It accepts the key from `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, or the `?key=` query string and reports `ErrInvalidCredential` when no match is found.
|
||||
|
||||
Additional providers can be delivered by third-party packages. When a provider package is imported, it registers itself with `sdkaccess.RegisterProvider`.
|
||||
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before you call `RegisteredProviders()` (or before `cliproxy.NewBuilder().Build()`).
|
||||
|
||||
### Metadata and auditing
|
||||
|
||||
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, or `query-key`). Populate this map in custom providers to enrich logs and downstream auditing.
|
||||
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, `query-key`, `query-auth-token`). Populate this map in custom providers to enrich logs and downstream auditing.
|
||||
|
||||
## Writing Custom Providers
|
||||
|
||||
@@ -110,13 +93,13 @@ type customProvider struct{}
|
||||
|
||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
token := r.Header.Get("X-Custom")
|
||||
if token == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if token != "expected" {
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
return &sdkaccess.Result{
|
||||
Provider: p.Identifier(),
|
||||
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
|
||||
}
|
||||
|
||||
func init() {
|
||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
||||
return &customProvider{}, nil
|
||||
})
|
||||
sdkaccess.RegisterProvider("custom", &customProvider{})
|
||||
}
|
||||
```
|
||||
|
||||
A provider must implement `Identifier()` and `Authenticate()`. To expose it to configuration, call `RegisterProvider` inside `init`. Provider factories receive the specific `AccessProvider` block plus the full root configuration for contextual needs.
|
||||
A provider must implement `Identifier()` and `Authenticate()`. To make it available to the access manager, call `RegisterProvider` inside `init` with an initialized provider instance.
|
||||
|
||||
## Error Semantics
|
||||
|
||||
- `ErrNoCredentials`: no credentials were present or recognized by any provider.
|
||||
- `ErrInvalidCredential`: at least one provider processed the credentials but rejected them.
|
||||
- `ErrNotHandled`: instructs the manager to fall through to the next provider without affecting aggregate error reporting.
|
||||
- `NewNoCredentialsError()` (`AuthErrorCodeNoCredentials`): no credentials were present or recognized. (HTTP 401)
|
||||
- `NewInvalidCredentialError()` (`AuthErrorCodeInvalidCredential`): credentials were present but rejected. (HTTP 401)
|
||||
- `NewNotHandledError()` (`AuthErrorCodeNotHandled`): fall through to the next provider.
|
||||
- `NewInternalAuthError(message, cause)` (`AuthErrorCodeInternal`): transport/system failure. (HTTP 500)
|
||||
|
||||
Return custom errors to surface transport failures; they propagate immediately to the caller instead of being masked.
|
||||
Errors propagate immediately to the caller unless they are classified as `not_handled` / `no_credentials` / `invalid_credential` and can be aggregated by the manager.
|
||||
|
||||
## Integration with cliproxy Service
|
||||
|
||||
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a preconfigured manager allows you to extend or override the default providers:
|
||||
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a manager lets you reuse the same instance in your host process:
|
||||
|
||||
```go
|
||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
||||
manager := sdkaccess.NewManager()
|
||||
manager.SetProviders(providers)
|
||||
accessManager := sdkaccess.NewManager()
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(coreCfg).
|
||||
WithAccessManager(manager).
|
||||
WithConfigPath("config.yaml").
|
||||
WithRequestAccessManager(accessManager).
|
||||
Build()
|
||||
```
|
||||
|
||||
The service reuses the manager for every inbound request, ensuring consistent authentication across embedded deployments and the canonical CLI binary.
|
||||
Register any custom providers (typically via blank imports) before calling `Build()` so they are present in the global registry snapshot.
|
||||
|
||||
### Hot reloading providers
|
||||
### Hot reloading
|
||||
|
||||
When configuration changes, rebuild providers and swap them into the manager:
|
||||
When configuration changes, refresh any config-backed providers and then reset the manager's provider chain:
|
||||
|
||||
```go
|
||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
||||
if err != nil {
|
||||
log.Errorf("reload auth providers failed: %v", err)
|
||||
return
|
||||
}
|
||||
accessManager.SetProviders(providers)
|
||||
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
This mirrors the behaviour in `cliproxy.Service.refreshAccessProviders` and `api.Server.applyAccessConfig`, enabling runtime updates without restarting the process.
|
||||
This mirrors the behaviour in `internal/access.ApplyAccessProviders`, enabling runtime updates without restarting the process.
|
||||
|
||||
@@ -7,80 +7,71 @@
|
||||
```go
|
||||
import (
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
```
|
||||
|
||||
通过 `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access` 添加依赖。
|
||||
|
||||
## Provider Registry
|
||||
|
||||
访问提供者是全局注册,然后以快照形式挂到 `Manager` 上:
|
||||
|
||||
- `RegisterProvider(type, provider)` 注册一个已经初始化好的 provider 实例。
|
||||
- 每个 `type` 第一次出现时会记录其注册顺序。
|
||||
- `RegisteredProviders()` 会按该顺序返回 provider 列表。
|
||||
|
||||
## 管理器生命周期
|
||||
|
||||
```go
|
||||
manager := sdkaccess.NewManager()
|
||||
providers, err := sdkaccess.BuildProviders(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manager.SetProviders(providers)
|
||||
manager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
- `NewManager` 创建空管理器。
|
||||
- `SetProviders` 替换提供者切片并做防御性拷贝。
|
||||
- `Providers` 返回适合并发读取的快照。
|
||||
- `BuildProviders` 将 `config.Config` 中的访问配置转换成可运行的提供者。当配置没有显式声明但包含顶层 `api-keys` 时,会自动挂载内建的 `config-api-key` 提供者。
|
||||
|
||||
如果管理器本身为 `nil` 或未配置任何 provider,调用会返回 `nil, nil`,可视为关闭访问控制。
|
||||
|
||||
## 认证请求
|
||||
|
||||
```go
|
||||
result, err := manager.Authenticate(ctx, req)
|
||||
result, authErr := manager.Authenticate(ctx, req)
|
||||
switch {
|
||||
case err == nil:
|
||||
case authErr == nil:
|
||||
// Authentication succeeded; result carries provider and principal.
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
|
||||
// No recognizable credentials were supplied.
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
|
||||
// Credentials were present but rejected.
|
||||
default:
|
||||
// Provider surfaced a transport-level failure.
|
||||
}
|
||||
```
|
||||
|
||||
`Manager.Authenticate` 按配置顺序遍历提供者。遇到成功立即返回,`ErrNotHandled` 会继续尝试下一个;若发现 `ErrNoCredentials` 或 `ErrInvalidCredential`,会在遍历结束后汇总给调用方。
|
||||
|
||||
若管理器本身为 `nil` 或尚未注册提供者,调用会返回 `nil, nil`,让调用方无需针对错误做额外分支即可关闭访问控制。
|
||||
`Manager.Authenticate` 会按顺序遍历 provider:遇到成功立即返回,`AuthErrorCodeNotHandled` 会继续尝试下一个;`AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` 会在遍历结束后汇总给调用方。
|
||||
|
||||
`Result` 提供认证提供者标识、解析出的主体以及可选元数据(例如凭证来源)。
|
||||
|
||||
## 配置结构
|
||||
## 内建 `config-api-key` Provider
|
||||
|
||||
在 `config.yaml` 的 `auth.providers` 下定义访问提供者:
|
||||
代理内置一个访问提供者:
|
||||
|
||||
- `config-api-key`:校验 `config.yaml` 顶层的 `api-keys`。
|
||||
- 凭证来源:`Authorization: Bearer`、`X-Goog-Api-Key`、`X-Api-Key`、`?key=`、`?auth_token=`
|
||||
- 元数据:`Result.Metadata["source"]` 会写入匹配到的来源标识
|
||||
|
||||
在 CLI 服务端与 `sdk/cliproxy` 中,该 provider 会根据加载到的配置自动注册。
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: inline-api
|
||||
type: config-api-key
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
```
|
||||
|
||||
条目映射到 `config.AccessProvider`:`name` 指定实例名,`type` 选择注册的工厂,`sdk` 可引用第三方模块,`api-keys` 提供内联凭证,`config` 用于传递特定选项。
|
||||
## 引入外部 Go 模块提供者
|
||||
|
||||
### 引入外部 SDK 提供者
|
||||
|
||||
若要消费其它 Go 模块输出的访问提供者,可在配置里填写 `sdk` 字段并在代码中引入该包,利用其 `init` 注册过程:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: partner-auth
|
||||
type: partner-token
|
||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
||||
config:
|
||||
region: us-west-2
|
||||
audience: cli-proxy
|
||||
```
|
||||
若要消费其它 Go 模块输出的访问提供者,直接用空白标识符导入以触发其 `init` 注册即可:
|
||||
|
||||
```go
|
||||
import (
|
||||
@@ -89,19 +80,11 @@ import (
|
||||
)
|
||||
```
|
||||
|
||||
通过空白标识符导入即可确保 `init` 调用,先于 `BuildProviders` 完成 `sdkaccess.RegisterProvider`。
|
||||
|
||||
## 内建提供者
|
||||
|
||||
当前 SDK 默认内置:
|
||||
|
||||
- `config-api-key`:校验配置中的 API Key。它从 `Authorization: Bearer`、`X-Goog-Api-Key`、`X-Api-Key` 以及查询参数 `?key=` 提取凭证,不匹配时抛出 `ErrInvalidCredential`。
|
||||
|
||||
导入第三方包即可通过 `sdkaccess.RegisterProvider` 注册更多类型。
|
||||
空白导入可确保 `init` 先执行,从而在你调用 `RegisteredProviders()`(或 `cliproxy.NewBuilder().Build()`)之前完成 `sdkaccess.RegisterProvider`。
|
||||
|
||||
### 元数据与审计
|
||||
|
||||
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization`、`x-goog-api-key`、`x-api-key` 或 `query-key`)。自定义提供者同样可以填充该 Map,以便丰富日志与审计场景。
|
||||
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization`、`x-goog-api-key`、`x-api-key`、`query-key`、`query-auth-token`)。自定义提供者同样可以填充该 Map,以便丰富日志与审计场景。
|
||||
|
||||
## 编写自定义提供者
|
||||
|
||||
@@ -110,13 +93,13 @@ type customProvider struct{}
|
||||
|
||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
token := r.Header.Get("X-Custom")
|
||||
if token == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if token != "expected" {
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
return &sdkaccess.Result{
|
||||
Provider: p.Identifier(),
|
||||
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
|
||||
}
|
||||
|
||||
func init() {
|
||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
||||
return &customProvider{}, nil
|
||||
})
|
||||
sdkaccess.RegisterProvider("custom", &customProvider{})
|
||||
}
|
||||
```
|
||||
|
||||
自定义提供者需要实现 `Identifier()` 与 `Authenticate()`。在 `init` 中调用 `RegisterProvider` 暴露给配置层,工厂函数既能读取当前条目,也能访问完整根配置。
|
||||
自定义提供者需要实现 `Identifier()` 与 `Authenticate()`。在 `init` 中用已初始化实例调用 `RegisterProvider` 注册到全局 registry。
|
||||
|
||||
## 错误语义
|
||||
|
||||
- `ErrNoCredentials`:任何提供者都未识别到凭证。
|
||||
- `ErrInvalidCredential`:至少一个提供者处理了凭证但判定无效。
|
||||
- `ErrNotHandled`:告诉管理器跳到下一个提供者,不影响最终错误统计。
|
||||
- `NewNoCredentialsError()`(`AuthErrorCodeNoCredentials`):未提供或未识别到凭证。(HTTP 401)
|
||||
- `NewInvalidCredentialError()`(`AuthErrorCodeInvalidCredential`):凭证存在但校验失败。(HTTP 401)
|
||||
- `NewNotHandledError()`(`AuthErrorCodeNotHandled`):告诉管理器跳到下一个 provider。
|
||||
- `NewInternalAuthError(message, cause)`(`AuthErrorCodeInternal`):网络/系统错误。(HTTP 500)
|
||||
|
||||
自定义错误(例如网络异常)会马上冒泡返回。
|
||||
除可汇总的 `not_handled` / `no_credentials` / `invalid_credential` 外,其它错误会立即冒泡返回。
|
||||
|
||||
## 与 cliproxy 集成
|
||||
|
||||
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果需要扩展内置行为,可传入自定义管理器:
|
||||
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果希望在宿主进程里复用同一个 `Manager` 实例,可传入自定义管理器:
|
||||
|
||||
```go
|
||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
||||
manager := sdkaccess.NewManager()
|
||||
manager.SetProviders(providers)
|
||||
accessManager := sdkaccess.NewManager()
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(coreCfg).
|
||||
WithAccessManager(manager).
|
||||
WithConfigPath("config.yaml").
|
||||
WithRequestAccessManager(accessManager).
|
||||
Build()
|
||||
```
|
||||
|
||||
服务会复用该管理器处理每一个入站请求,实现与 CLI 二进制一致的访问控制体验。
|
||||
请在调用 `Build()` 之前完成自定义 provider 的注册(通常通过空白导入触发 `init`),以确保它们被包含在全局 registry 的快照中。
|
||||
|
||||
### 动态热更新提供者
|
||||
|
||||
当配置发生变化时,可以重新构建提供者并替换当前列表:
|
||||
当配置发生变化时,刷新依赖配置的 provider,然后重置 manager 的 provider 链:
|
||||
|
||||
```go
|
||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
||||
if err != nil {
|
||||
log.Errorf("reload auth providers failed: %v", err)
|
||||
return
|
||||
}
|
||||
accessManager.SetProviders(providers)
|
||||
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
这一流程与 `cliproxy.Service.refreshAccessProviders` 和 `api.Server.applyAccessConfig` 保持一致,避免为更新访问策略而重启进程。
|
||||
这一流程与 `internal/access.ApplyAccessProviders` 保持一致,避免为更新访问策略而重启进程。
|
||||
|
||||
@@ -159,13 +159,13 @@ func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request,
|
||||
return clipexec.Response{}, errors.New("count tokens not implemented")
|
||||
}
|
||||
|
||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) {
|
||||
ch := make(chan clipexec.StreamChunk, 1)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
|
||||
}()
|
||||
return ch, nil
|
||||
return &clipexec.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
|
||||
@@ -58,7 +58,7 @@ func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, c
|
||||
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) {
|
||||
return nil, errors.New("echo executor: ExecuteStream not implemented")
|
||||
}
|
||||
|
||||
|
||||
23
go.mod
23
go.mod
@@ -1,9 +1,13 @@
|
||||
module github.com/router-for-me/CLIProxyAPI/v6
|
||||
|
||||
go 1.24.0
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.0.6
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/charmbracelet/bubbles v1.0.0
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/fxamacker/cbor/v2 v2.9.0
|
||||
github.com/gin-gonic/gin v1.10.1
|
||||
@@ -33,8 +37,16 @@ require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.6 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
||||
github.com/cloudflare/circl v1.6.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
@@ -42,6 +54,7 @@ require (
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-git/gcfg/v2 v2.0.2 // indirect
|
||||
@@ -58,19 +71,27 @@ require (
|
||||
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/sergi/go-diff v1.4.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
|
||||
45
go.sum
45
go.sum
@@ -10,10 +10,34 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
|
||||
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
|
||||
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
|
||||
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
|
||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
||||
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
@@ -33,6 +57,8 @@ github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o
|
||||
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
|
||||
@@ -101,8 +127,14 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
|
||||
@@ -114,6 +146,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||
@@ -124,6 +162,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
@@ -161,6 +201,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
@@ -168,12 +210,15 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -4,19 +4,28 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
var registerOnce sync.Once
|
||||
|
||||
// Register ensures the config-access provider is available to the access manager.
|
||||
func Register() {
|
||||
registerOnce.Do(func() {
|
||||
sdkaccess.RegisterProvider(sdkconfig.AccessProviderTypeConfigAPIKey, newProvider)
|
||||
})
|
||||
func Register(cfg *sdkconfig.SDKConfig) {
|
||||
if cfg == nil {
|
||||
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
|
||||
return
|
||||
}
|
||||
|
||||
keys := normalizeKeys(cfg.APIKeys)
|
||||
if len(keys) == 0 {
|
||||
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
|
||||
return
|
||||
}
|
||||
|
||||
sdkaccess.RegisterProvider(
|
||||
sdkaccess.AccessProviderTypeConfigAPIKey,
|
||||
newProvider(sdkaccess.DefaultAccessProviderName, keys),
|
||||
)
|
||||
}
|
||||
|
||||
type provider struct {
|
||||
@@ -24,34 +33,31 @@ type provider struct {
|
||||
keys map[string]struct{}
|
||||
}
|
||||
|
||||
func newProvider(cfg *sdkconfig.AccessProvider, _ *sdkconfig.SDKConfig) (sdkaccess.Provider, error) {
|
||||
name := cfg.Name
|
||||
if name == "" {
|
||||
name = sdkconfig.DefaultAccessProviderName
|
||||
func newProvider(name string, keys []string) *provider {
|
||||
providerName := strings.TrimSpace(name)
|
||||
if providerName == "" {
|
||||
providerName = sdkaccess.DefaultAccessProviderName
|
||||
}
|
||||
keys := make(map[string]struct{}, len(cfg.APIKeys))
|
||||
for _, key := range cfg.APIKeys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
keys[key] = struct{}{}
|
||||
keySet := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
return &provider{name: name, keys: keys}, nil
|
||||
return &provider{name: providerName, keys: keySet}
|
||||
}
|
||||
|
||||
func (p *provider) Identifier() string {
|
||||
if p == nil || p.name == "" {
|
||||
return sdkconfig.DefaultAccessProviderName
|
||||
return sdkaccess.DefaultAccessProviderName
|
||||
}
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
if p == nil {
|
||||
return nil, sdkaccess.ErrNotHandled
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if len(p.keys) == 0 {
|
||||
return nil, sdkaccess.ErrNotHandled
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
||||
@@ -63,7 +69,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
||||
queryAuthToken = r.URL.Query().Get("auth_token")
|
||||
}
|
||||
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNoCredentialsError()
|
||||
}
|
||||
|
||||
apiKey := extractBearerToken(authHeader)
|
||||
@@ -94,7 +100,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
||||
}
|
||||
}
|
||||
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
|
||||
func extractBearerToken(header string) string {
|
||||
@@ -110,3 +116,26 @@ func extractBearerToken(header string) string {
|
||||
}
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
func normalizeKeys(keys []string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]string, 0, len(keys))
|
||||
seen := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
trimmedKey := strings.TrimSpace(key)
|
||||
if trimmedKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[trimmedKey]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmedKey] = struct{}{}
|
||||
normalized = append(normalized, trimmedKey)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkConfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -17,26 +17,26 @@ import (
|
||||
// ordered provider slice along with the identifiers of providers that were added, updated, or
|
||||
// removed compared to the previous configuration.
|
||||
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Provider) (result []sdkaccess.Provider, added, updated, removed []string, err error) {
|
||||
_ = oldCfg
|
||||
if newCfg == nil {
|
||||
return nil, nil, nil, nil, nil
|
||||
}
|
||||
|
||||
result = sdkaccess.RegisteredProviders()
|
||||
|
||||
existingMap := make(map[string]sdkaccess.Provider, len(existing))
|
||||
for _, provider := range existing {
|
||||
if provider == nil {
|
||||
providerID := identifierFromProvider(provider)
|
||||
if providerID == "" {
|
||||
continue
|
||||
}
|
||||
existingMap[provider.Identifier()] = provider
|
||||
existingMap[providerID] = provider
|
||||
}
|
||||
|
||||
oldCfgMap := accessProviderMap(oldCfg)
|
||||
newEntries := collectProviderEntries(newCfg)
|
||||
|
||||
result = make([]sdkaccess.Provider, 0, len(newEntries))
|
||||
finalIDs := make(map[string]struct{}, len(newEntries))
|
||||
finalIDs := make(map[string]struct{}, len(result))
|
||||
|
||||
isInlineProvider := func(id string) bool {
|
||||
return strings.EqualFold(id, sdkConfig.DefaultAccessProviderName)
|
||||
return strings.EqualFold(id, sdkaccess.DefaultAccessProviderName)
|
||||
}
|
||||
appendChange := func(list *[]string, id string) {
|
||||
if isInlineProvider(id) {
|
||||
@@ -45,85 +45,28 @@ func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Prov
|
||||
*list = append(*list, id)
|
||||
}
|
||||
|
||||
for _, providerCfg := range newEntries {
|
||||
key := providerIdentifier(providerCfg)
|
||||
if key == "" {
|
||||
for _, provider := range result {
|
||||
providerID := identifierFromProvider(provider)
|
||||
if providerID == "" {
|
||||
continue
|
||||
}
|
||||
finalIDs[providerID] = struct{}{}
|
||||
|
||||
forceRebuild := strings.EqualFold(strings.TrimSpace(providerCfg.Type), sdkConfig.AccessProviderTypeConfigAPIKey)
|
||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||
isAliased := oldCfgProvider == providerCfg
|
||||
if !forceRebuild && !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
|
||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||
result = append(result, existingProvider)
|
||||
finalIDs[key] = struct{}{}
|
||||
continue
|
||||
}
|
||||
}
|
||||
existingProvider, exists := existingMap[providerID]
|
||||
if !exists {
|
||||
appendChange(&added, providerID)
|
||||
continue
|
||||
}
|
||||
|
||||
provider, buildErr := sdkaccess.BuildProvider(providerCfg, &newCfg.SDKConfig)
|
||||
if buildErr != nil {
|
||||
return nil, nil, nil, nil, buildErr
|
||||
}
|
||||
if _, ok := oldCfgMap[key]; ok {
|
||||
if _, existed := existingMap[key]; existed {
|
||||
appendChange(&updated, key)
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
result = append(result, provider)
|
||||
finalIDs[key] = struct{}{}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(newCfg.APIKeys); inline != nil {
|
||||
key := providerIdentifier(inline)
|
||||
if key != "" {
|
||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||
if providerConfigEqual(oldCfgProvider, inline) {
|
||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||
result = append(result, existingProvider)
|
||||
finalIDs[key] = struct{}{}
|
||||
goto inlineDone
|
||||
}
|
||||
}
|
||||
}
|
||||
provider, buildErr := sdkaccess.BuildProvider(inline, &newCfg.SDKConfig)
|
||||
if buildErr != nil {
|
||||
return nil, nil, nil, nil, buildErr
|
||||
}
|
||||
if _, existed := existingMap[key]; existed {
|
||||
appendChange(&updated, key)
|
||||
} else if _, hadOld := oldCfgMap[key]; hadOld {
|
||||
appendChange(&updated, key)
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
result = append(result, provider)
|
||||
finalIDs[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
inlineDone:
|
||||
}
|
||||
|
||||
removedSet := make(map[string]struct{})
|
||||
for id := range existingMap {
|
||||
if _, ok := finalIDs[id]; !ok {
|
||||
if isInlineProvider(id) {
|
||||
continue
|
||||
}
|
||||
removedSet[id] = struct{}{}
|
||||
if !providerInstanceEqual(existingProvider, provider) {
|
||||
appendChange(&updated, providerID)
|
||||
}
|
||||
}
|
||||
|
||||
removed = make([]string, 0, len(removedSet))
|
||||
for id := range removedSet {
|
||||
removed = append(removed, id)
|
||||
for providerID := range existingMap {
|
||||
if _, exists := finalIDs[providerID]; exists {
|
||||
continue
|
||||
}
|
||||
appendChange(&removed, providerID)
|
||||
}
|
||||
|
||||
sort.Strings(added)
|
||||
@@ -142,6 +85,7 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
|
||||
}
|
||||
|
||||
existing := manager.Providers()
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
providers, added, updated, removed, err := ReconcileProviders(oldCfg, newCfg, existing)
|
||||
if err != nil {
|
||||
log.Errorf("failed to reconcile request auth providers: %v", err)
|
||||
@@ -160,111 +104,24 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func accessProviderMap(cfg *config.Config) map[string]*sdkConfig.AccessProvider {
|
||||
result := make(map[string]*sdkConfig.AccessProvider)
|
||||
if cfg == nil {
|
||||
return result
|
||||
}
|
||||
for i := range cfg.Access.Providers {
|
||||
providerCfg := &cfg.Access.Providers[i]
|
||||
if providerCfg.Type == "" {
|
||||
continue
|
||||
}
|
||||
key := providerIdentifier(providerCfg)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
result[key] = providerCfg
|
||||
}
|
||||
if len(result) == 0 && len(cfg.APIKeys) > 0 {
|
||||
if provider := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); provider != nil {
|
||||
if key := providerIdentifier(provider); key != "" {
|
||||
result[key] = provider
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func collectProviderEntries(cfg *config.Config) []*sdkConfig.AccessProvider {
|
||||
entries := make([]*sdkConfig.AccessProvider, 0, len(cfg.Access.Providers))
|
||||
for i := range cfg.Access.Providers {
|
||||
providerCfg := &cfg.Access.Providers[i]
|
||||
if providerCfg.Type == "" {
|
||||
continue
|
||||
}
|
||||
if key := providerIdentifier(providerCfg); key != "" {
|
||||
entries = append(entries, providerCfg)
|
||||
}
|
||||
}
|
||||
if len(entries) == 0 && len(cfg.APIKeys) > 0 {
|
||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); inline != nil {
|
||||
entries = append(entries, inline)
|
||||
}
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func providerIdentifier(provider *sdkConfig.AccessProvider) string {
|
||||
func identifierFromProvider(provider sdkaccess.Provider) string {
|
||||
if provider == nil {
|
||||
return ""
|
||||
}
|
||||
if name := strings.TrimSpace(provider.Name); name != "" {
|
||||
return name
|
||||
}
|
||||
typ := strings.TrimSpace(provider.Type)
|
||||
if typ == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.EqualFold(typ, sdkConfig.AccessProviderTypeConfigAPIKey) {
|
||||
return sdkConfig.DefaultAccessProviderName
|
||||
}
|
||||
return typ
|
||||
return strings.TrimSpace(provider.Identifier())
|
||||
}
|
||||
|
||||
func providerConfigEqual(a, b *sdkConfig.AccessProvider) bool {
|
||||
func providerInstanceEqual(a, b sdkaccess.Provider) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == nil && b == nil
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) {
|
||||
if reflect.TypeOf(a) != reflect.TypeOf(b) {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(a.SDK) != strings.TrimSpace(b.SDK) {
|
||||
return false
|
||||
valueA := reflect.ValueOf(a)
|
||||
valueB := reflect.ValueOf(b)
|
||||
if valueA.Kind() == reflect.Pointer && valueB.Kind() == reflect.Pointer {
|
||||
return valueA.Pointer() == valueB.Pointer()
|
||||
}
|
||||
if !stringSetEqual(a.APIKeys, b.APIKeys) {
|
||||
return false
|
||||
}
|
||||
if len(a.Config) != len(b.Config) {
|
||||
return false
|
||||
}
|
||||
if len(a.Config) > 0 && !reflect.DeepEqual(a.Config, b.Config) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func stringSetEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
seen := make(map[string]int, len(a))
|
||||
for _, val := range a {
|
||||
seen[val]++
|
||||
}
|
||||
for _, val := range b {
|
||||
count := seen[val]
|
||||
if count == 0 {
|
||||
return false
|
||||
}
|
||||
if count == 1 {
|
||||
delete(seen, val)
|
||||
} else {
|
||||
seen[val] = count - 1
|
||||
}
|
||||
}
|
||||
return len(seen) == 0
|
||||
return reflect.DeepEqual(a, b)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -189,9 +190,21 @@ func (h *Handler) APICall(c *gin.Context) {
|
||||
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
|
||||
}
|
||||
|
||||
// When caller indicates CBOR in request headers, convert JSON string payload to CBOR bytes.
|
||||
useCBORPayload := headerContainsValue(reqHeaders, "Content-Type", "application/cbor")
|
||||
|
||||
var requestBody io.Reader
|
||||
if body.Data != "" {
|
||||
requestBody = strings.NewReader(body.Data)
|
||||
if useCBORPayload {
|
||||
cborPayload, errEncode := encodeJSONStringToCBOR(body.Data)
|
||||
if errEncode != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json data for cbor content-type"})
|
||||
return
|
||||
}
|
||||
requestBody = bytes.NewReader(cborPayload)
|
||||
} else {
|
||||
requestBody = strings.NewReader(body.Data)
|
||||
}
|
||||
}
|
||||
|
||||
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
|
||||
@@ -234,10 +247,18 @@ func (h *Handler) APICall(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// For CBOR upstream responses, decode into plain text or JSON string before returning.
|
||||
responseBodyText := string(respBody)
|
||||
if headerContainsValue(reqHeaders, "Accept", "application/cbor") || strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "application/cbor") {
|
||||
if decodedBody, errDecode := decodeCBORBodyToTextOrJSON(respBody); errDecode == nil {
|
||||
responseBodyText = decodedBody
|
||||
}
|
||||
}
|
||||
|
||||
response := apiCallResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header,
|
||||
Body: string(respBody),
|
||||
Body: responseBodyText,
|
||||
}
|
||||
|
||||
// If this is a GitHub Copilot token endpoint response, try to enrich with quota information
|
||||
@@ -747,6 +768,83 @@ func buildProxyTransport(proxyStr string) *http.Transport {
|
||||
return nil
|
||||
}
|
||||
|
||||
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
|
||||
func headerContainsValue(headers map[string]string, targetKey, targetValue string) bool {
|
||||
if len(headers) == 0 {
|
||||
return false
|
||||
}
|
||||
for key, value := range headers {
|
||||
if !strings.EqualFold(strings.TrimSpace(key), strings.TrimSpace(targetKey)) {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(strings.ToLower(value), strings.ToLower(strings.TrimSpace(targetValue))) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// encodeJSONStringToCBOR converts a JSON string payload into CBOR bytes.
|
||||
func encodeJSONStringToCBOR(jsonString string) ([]byte, error) {
|
||||
var payload any
|
||||
if errUnmarshal := json.Unmarshal([]byte(jsonString), &payload); errUnmarshal != nil {
|
||||
return nil, errUnmarshal
|
||||
}
|
||||
return cbor.Marshal(payload)
|
||||
}
|
||||
|
||||
// decodeCBORBodyToTextOrJSON decodes CBOR bytes to plain text (for string payloads) or JSON string.
|
||||
func decodeCBORBodyToTextOrJSON(raw []byte) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var payload any
|
||||
if errUnmarshal := cbor.Unmarshal(raw, &payload); errUnmarshal != nil {
|
||||
return "", errUnmarshal
|
||||
}
|
||||
|
||||
jsonCompatible := cborValueToJSONCompatible(payload)
|
||||
switch typed := jsonCompatible.(type) {
|
||||
case string:
|
||||
return typed, nil
|
||||
case []byte:
|
||||
return string(typed), nil
|
||||
default:
|
||||
jsonBytes, errMarshal := json.Marshal(jsonCompatible)
|
||||
if errMarshal != nil {
|
||||
return "", errMarshal
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
}
|
||||
|
||||
// cborValueToJSONCompatible recursively converts CBOR-decoded values into JSON-marshalable values.
|
||||
func cborValueToJSONCompatible(value any) any {
|
||||
switch typed := value.(type) {
|
||||
case map[any]any:
|
||||
out := make(map[string]any, len(typed))
|
||||
for key, item := range typed {
|
||||
out[fmt.Sprint(key)] = cborValueToJSONCompatible(item)
|
||||
}
|
||||
return out
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(typed))
|
||||
for key, item := range typed {
|
||||
out[key] = cborValueToJSONCompatible(item)
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]any, len(typed))
|
||||
for i, item := range typed {
|
||||
out[i] = cborValueToJSONCompatible(item)
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return typed
|
||||
}
|
||||
}
|
||||
|
||||
// QuotaDetail represents quota information for a specific resource type
|
||||
type QuotaDetail struct {
|
||||
Entitlement float64 `json:"entitlement"`
|
||||
|
||||
@@ -29,6 +29,8 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
@@ -812,6 +814,87 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||
}
|
||||
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
|
||||
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
if h.authManager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Prefix *string `json:"prefix"`
|
||||
ProxyURL *string `json:"proxy_url"`
|
||||
Priority *int `json:"priority"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(req.Name)
|
||||
if name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Find auth by name or ID
|
||||
var targetAuth *coreauth.Auth
|
||||
if auth, ok := h.authManager.GetByID(name); ok {
|
||||
targetAuth = auth
|
||||
} else {
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth.FileName == name {
|
||||
targetAuth = auth
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if targetAuth == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
|
||||
return
|
||||
}
|
||||
|
||||
changed := false
|
||||
if req.Prefix != nil {
|
||||
targetAuth.Prefix = *req.Prefix
|
||||
changed = true
|
||||
}
|
||||
if req.ProxyURL != nil {
|
||||
targetAuth.ProxyURL = *req.ProxyURL
|
||||
changed = true
|
||||
}
|
||||
if req.Priority != nil {
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if *req.Priority == 0 {
|
||||
delete(targetAuth.Metadata, "priority")
|
||||
} else {
|
||||
targetAuth.Metadata["priority"] = *req.Priority
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
|
||||
if !changed {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"})
|
||||
return
|
||||
}
|
||||
|
||||
targetAuth.UpdatedAt = time.Now()
|
||||
|
||||
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||
if h == nil || h.authManager == nil {
|
||||
return
|
||||
@@ -1192,6 +1275,30 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
ts.Checked = true
|
||||
} else if strings.EqualFold(requestedProjectID, "GOOGLE_ONE") {
|
||||
ts.Auto = false
|
||||
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
|
||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||
SetOAuthSessionError(state, "Google One auto-discovery failed")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
log.Error("Google One auto-discovery returned empty project ID")
|
||||
SetOAuthSessionError(state, "Google One auto-discovery returned empty project ID")
|
||||
return
|
||||
}
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the auto-discovered project")
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
@@ -1613,6 +1720,82 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
fmt.Println("Initializing Kimi authentication...")
|
||||
|
||||
state := fmt.Sprintf("kmi-%d", time.Now().UnixNano())
|
||||
// Initialize Kimi auth service
|
||||
kimiAuth := kimi.NewKimiAuth(h.cfg)
|
||||
|
||||
// Generate authorization URL
|
||||
deviceFlow, errStartDeviceFlow := kimiAuth.StartDeviceFlow(ctx)
|
||||
if errStartDeviceFlow != nil {
|
||||
log.Errorf("Failed to generate authorization URL: %v", errStartDeviceFlow)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||
return
|
||||
}
|
||||
authURL := deviceFlow.VerificationURIComplete
|
||||
if authURL == "" {
|
||||
authURL = deviceFlow.VerificationURI
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "kimi")
|
||||
|
||||
go func() {
|
||||
fmt.Println("Waiting for authentication...")
|
||||
authBundle, errWaitForAuthorization := kimiAuth.WaitForAuthorization(ctx, deviceFlow)
|
||||
if errWaitForAuthorization != nil {
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errWaitForAuthorization)
|
||||
return
|
||||
}
|
||||
|
||||
// Create token storage
|
||||
tokenStorage := kimiAuth.CreateTokenStorage(authBundle)
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "kimi",
|
||||
"access_token": authBundle.TokenData.AccessToken,
|
||||
"refresh_token": authBundle.TokenData.RefreshToken,
|
||||
"token_type": authBundle.TokenData.TokenType,
|
||||
"scope": authBundle.TokenData.Scope,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
if authBundle.TokenData.ExpiresAt > 0 {
|
||||
expired := time.Unix(authBundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
metadata["expired"] = expired
|
||||
}
|
||||
if strings.TrimSpace(authBundle.DeviceID) != "" {
|
||||
metadata["device_id"] = strings.TrimSpace(authBundle.DeviceID)
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("kimi-%d.json", time.Now().UnixMilli())
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kimi",
|
||||
FileName: fileName,
|
||||
Label: "Kimi User",
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use Kimi services through this CLI")
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("kimi")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -2047,7 +2230,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
}
|
||||
}
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
// Auto-discovery: try onboardUser without specifying a project
|
||||
// to let Google auto-provision one (matches Gemini CLI headless behavior
|
||||
// and Antigravity's FetchProjectID pattern).
|
||||
autoOnboardReq := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer autoCancel()
|
||||
for attempt := 1; ; attempt++ {
|
||||
var onboardResp map[string]any
|
||||
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
|
||||
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
|
||||
}
|
||||
|
||||
if done, okDone := onboardResp["done"].(bool); okDone && done {
|
||||
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
|
||||
switch v := resp["cloudaicompanionProject"].(type) {
|
||||
case string:
|
||||
projectID = strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if id, okID := v["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
|
||||
select {
|
||||
case <-autoCtx.Done():
|
||||
return &projectSelectionRequiredError{}
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
}
|
||||
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
|
||||
}
|
||||
|
||||
onboardReqBody := map[string]any{
|
||||
@@ -2591,3 +2815,88 @@ func generateKiroPKCE() (verifier, challenge string, err error) {
|
||||
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
func (h *Handler) RequestKiloToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
fmt.Println("Initializing Kilo authentication...")
|
||||
|
||||
state := fmt.Sprintf("kil-%d", time.Now().UnixNano())
|
||||
kilocodeAuth := kilo.NewKiloAuth()
|
||||
|
||||
resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initiate device flow: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "kilo")
|
||||
|
||||
go func() {
|
||||
fmt.Printf("Please visit %s and enter code: %s\n", resp.VerificationURL, resp.Code)
|
||||
|
||||
status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
|
||||
if err != nil {
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := kilocodeAuth.GetProfile(ctx, status.Token)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to fetch profile: %v", err)
|
||||
profile = &kilo.Profile{Email: status.UserEmail}
|
||||
}
|
||||
|
||||
var orgID string
|
||||
if len(profile.Orgs) > 0 {
|
||||
orgID = profile.Orgs[0].ID
|
||||
}
|
||||
|
||||
defaults, err := kilocodeAuth.GetDefaults(ctx, status.Token, orgID)
|
||||
if err != nil {
|
||||
defaults = &kilo.Defaults{}
|
||||
}
|
||||
|
||||
ts := &kilo.KiloTokenStorage{
|
||||
Token: status.Token,
|
||||
OrganizationID: orgID,
|
||||
Model: defaults.Model,
|
||||
Email: status.UserEmail,
|
||||
Type: "kilo",
|
||||
}
|
||||
|
||||
fileName := kilo.CredentialFileName(status.UserEmail)
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kilo",
|
||||
FileName: fileName,
|
||||
Storage: ts,
|
||||
Metadata: map[string]any{
|
||||
"email": status.UserEmail,
|
||||
"organization_id": orgID,
|
||||
"model": defaults.Model,
|
||||
},
|
||||
}
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("kilo")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": "ok",
|
||||
"url": resp.VerificationURL,
|
||||
"state": state,
|
||||
"user_code": resp.Code,
|
||||
"verification_uri": resp.VerificationURL,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -28,8 +28,7 @@ func (h *Handler) GetConfig(c *gin.Context) {
|
||||
c.JSON(200, gin.H{})
|
||||
return
|
||||
}
|
||||
cfgCopy := *h.cfg
|
||||
c.JSON(200, &cfgCopy)
|
||||
c.JSON(200, new(*h.cfg))
|
||||
}
|
||||
|
||||
type releaseInfo struct {
|
||||
|
||||
@@ -109,14 +109,13 @@ func (h *Handler) GetAPIKeys(c *gin.Context) { c.JSON(200, gin.H{"api-keys": h.c
|
||||
func (h *Handler) PutAPIKeys(c *gin.Context) {
|
||||
h.putStringList(c, func(v []string) {
|
||||
h.cfg.APIKeys = append([]string(nil), v...)
|
||||
h.cfg.Access.Providers = nil
|
||||
}, nil)
|
||||
}
|
||||
func (h *Handler) PatchAPIKeys(c *gin.Context) {
|
||||
h.patchStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
||||
h.patchStringList(c, &h.cfg.APIKeys, func() {})
|
||||
}
|
||||
func (h *Handler) DeleteAPIKeys(c *gin.Context) {
|
||||
h.deleteFromStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
||||
h.deleteFromStringList(c, &h.cfg.APIKeys, func() {})
|
||||
}
|
||||
|
||||
// gemini-api-key: []GeminiKey
|
||||
@@ -797,10 +796,10 @@ func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
|
||||
c.JSON(404, gin.H{"error": "channel not found"})
|
||||
return
|
||||
}
|
||||
delete(h.cfg.OAuthModelAlias, channel)
|
||||
if len(h.cfg.OAuthModelAlias) == 0 {
|
||||
h.cfg.OAuthModelAlias = nil
|
||||
}
|
||||
// Set to nil instead of deleting the key so that the "explicitly disabled"
|
||||
// marker survives config reload and prevents SanitizeOAuthModelAlias from
|
||||
// re-injecting default aliases (fixes #222).
|
||||
h.cfg.OAuthModelAlias[channel] = nil
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
|
||||
@@ -15,10 +15,12 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
|
||||
|
||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||
// It captures detailed information about the request and response, including headers and body,
|
||||
// and uses the provided RequestLogger to record this data. When logging is disabled in the
|
||||
// logger, it still captures data so that upstream errors can be persisted.
|
||||
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
|
||||
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
|
||||
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if logger == nil {
|
||||
@@ -26,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.Method == http.MethodGet {
|
||||
if shouldSkipMethodForRequestLogging(c.Request) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -37,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
loggerEnabled := logger.IsEnabled()
|
||||
|
||||
// Capture request information
|
||||
requestInfo, err := captureRequestInfo(c)
|
||||
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
|
||||
if err != nil {
|
||||
// Log error but continue processing
|
||||
// In a real implementation, you might want to use a proper logger here
|
||||
@@ -48,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
|
||||
// Create response writer wrapper
|
||||
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
||||
if !logger.IsEnabled() {
|
||||
if !loggerEnabled {
|
||||
wrapper.logOnErrorOnly = true
|
||||
}
|
||||
c.Writer = wrapper
|
||||
@@ -64,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
|
||||
if req == nil {
|
||||
return true
|
||||
}
|
||||
if req.Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
return !isResponsesWebsocketUpgrade(req)
|
||||
}
|
||||
|
||||
func isResponsesWebsocketUpgrade(req *http.Request) bool {
|
||||
if req == nil || req.URL == nil {
|
||||
return false
|
||||
}
|
||||
if req.URL.Path != "/v1/responses" {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
|
||||
}
|
||||
|
||||
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
|
||||
if loggerEnabled {
|
||||
return true
|
||||
}
|
||||
if req == nil || req.Body == nil {
|
||||
return false
|
||||
}
|
||||
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return false
|
||||
}
|
||||
if req.ContentLength <= 0 {
|
||||
return false
|
||||
}
|
||||
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
|
||||
}
|
||||
|
||||
// captureRequestInfo extracts relevant information from the incoming HTTP request.
|
||||
// It captures the URL, method, headers, and body. The request body is read and then
|
||||
// restored so that it can be processed by subsequent handlers.
|
||||
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
|
||||
// Capture URL with sensitive query parameters masked
|
||||
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||
url := c.Request.URL.Path
|
||||
@@ -86,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
|
||||
// Capture request body
|
||||
var body []byte
|
||||
if c.Request.Body != nil {
|
||||
if captureBody && c.Request.Body != nil {
|
||||
// Read the body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
|
||||
138
internal/api/middleware/request_logging_test.go
Normal file
138
internal/api/middleware/request_logging_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *http.Request
|
||||
skip bool
|
||||
}{
|
||||
{
|
||||
name: "nil request",
|
||||
req: nil,
|
||||
skip: true,
|
||||
},
|
||||
{
|
||||
name: "post request should not skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodPost,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
},
|
||||
skip: false,
|
||||
},
|
||||
{
|
||||
name: "plain get should skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/models"},
|
||||
Header: http.Header{},
|
||||
},
|
||||
skip: true,
|
||||
},
|
||||
{
|
||||
name: "responses websocket upgrade should not skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
Header: http.Header{"Upgrade": []string{"websocket"}},
|
||||
},
|
||||
skip: false,
|
||||
},
|
||||
{
|
||||
name: "responses get without upgrade should skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
Header: http.Header{},
|
||||
},
|
||||
skip: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range tests {
|
||||
got := shouldSkipMethodForRequestLogging(tests[i].req)
|
||||
if got != tests[i].skip {
|
||||
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldCaptureRequestBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
loggerEnabled bool
|
||||
req *http.Request
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "logger enabled always captures",
|
||||
loggerEnabled: true,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
ContentLength: -1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nil request",
|
||||
loggerEnabled: false,
|
||||
req: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "small known size json in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
ContentLength: 2,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "large known size skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "unknown size skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: -1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "multipart skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: 1,
|
||||
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range tests {
|
||||
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
|
||||
if got != tests[i].want {
|
||||
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
)
|
||||
|
||||
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
||||
|
||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||
type RequestInfo struct {
|
||||
URL string // URL is the request URL.
|
||||
@@ -223,8 +225,8 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
||||
|
||||
// Only fall back to request payload hints when Content-Type is not set yet.
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
bodyStr := string(w.requestInfo.Body)
|
||||
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
|
||||
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
|
||||
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
|
||||
}
|
||||
|
||||
return false
|
||||
@@ -310,7 +312,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||
@@ -361,16 +363,32 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||
if c != nil {
|
||||
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
|
||||
switch value := bodyOverride.(type) {
|
||||
case []byte:
|
||||
if len(value) > 0 {
|
||||
return bytes.Clone(value)
|
||||
}
|
||||
case string:
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return []byte(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
return w.requestInfo.Body
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
if w.requestInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if len(w.requestInfo.Body) > 0 {
|
||||
requestBody = w.requestInfo.Body
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||
}); ok {
|
||||
|
||||
43
internal/api/middleware/response_writer_test.go
Normal file
43
internal/api/middleware/response_writer_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{
|
||||
requestInfo: &RequestInfo{Body: []byte("original-body")},
|
||||
}
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
if string(body) != "original-body" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "original-body")
|
||||
}
|
||||
|
||||
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
|
||||
body = wrapper.extractRequestBody(c)
|
||||
if string(body) != "override-body" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{}
|
||||
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
if string(body) != "override-as-string" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
||||
}
|
||||
}
|
||||
@@ -127,8 +127,7 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
m.modelMapper = NewModelMapper(settings.ModelMappings)
|
||||
|
||||
// Store initial config for partial reload comparison
|
||||
settingsCopy := settings
|
||||
m.lastConfig = &settingsCopy
|
||||
m.lastConfig = new(settings)
|
||||
|
||||
// Initialize localhost restriction setting (hot-reloadable)
|
||||
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)
|
||||
|
||||
@@ -122,7 +122,7 @@ func (rw *ResponseRewriter) Flush() {
|
||||
}
|
||||
|
||||
// modelFieldPaths lists all JSON paths where model name may appear
|
||||
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||
|
||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
||||
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
||||
|
||||
110
internal/api/modules/amp/response_rewriter_test.go
Normal file
110
internal/api/modules/amp/response_rewriter_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRewriteModelInResponse_TopLevel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"id":"resp_1","model":"gpt-5.3-codex","output":[]}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"id":"resp_1","model":"gpt-5.2-codex","output":[]}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_ResponseModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"completed"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"completed"}}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_ResponseCreated(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"in_progress"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"in_progress"}}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_NoModelField(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.output_item.added","item":{"id":"item_1","type":"message"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_EmptyOriginalModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: ""}
|
||||
|
||||
input := []byte(`{"model":"gpt-5.3-codex"}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification when originalModel is empty, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_SSEWithResponseModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
chunk := []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.3-codex\",\"status\":\"completed\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
expected := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.2-codex\",\"status\":\"completed\"}}\n\n"
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_MultipleEvents(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
chunk := []byte("data: {\"type\":\"response.created\",\"response\":{\"model\":\"gpt-5.3-codex\"}}\n\ndata: {\"type\":\"response.output_item.added\",\"item\":{\"id\":\"item_1\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
if string(result) == string(chunk) {
|
||||
t.Error("expected response.model to be rewritten in SSE stream")
|
||||
}
|
||||
if !contains(result, []byte(`"model":"gpt-5.2-codex"`)) {
|
||||
t.Errorf("expected rewritten model in output, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_MessageModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "claude-opus-4.5"}
|
||||
|
||||
chunk := []byte("data: {\"message\":{\"model\":\"claude-sonnet-4\",\"role\":\"assistant\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
expected := "data: {\"message\":{\"model\":\"claude-opus-4.5\",\"role\":\"assistant\"}}\n\n"
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func contains(data, substr []byte) bool {
|
||||
for i := 0; i <= len(data)-len(substr); i++ {
|
||||
if string(data[i:i+len(substr)]) == string(substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -285,8 +285,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
optionState.routerConfigurator(engine, s.handlers, cfg)
|
||||
}
|
||||
|
||||
// Register management routes when configuration or environment secrets are available.
|
||||
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
|
||||
// Register management routes when configuration or environment secrets are available,
|
||||
// or when a local management password is provided (e.g. TUI mode).
|
||||
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
|
||||
s.managementRoutesEnabled.Store(hasManagementSecret)
|
||||
if hasManagementSecret {
|
||||
s.registerManagementRoutes()
|
||||
@@ -329,6 +330,7 @@ func (s *Server) setupRoutes() {
|
||||
v1.POST("/completions", openaiHandlers.Completions)
|
||||
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
||||
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
|
||||
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
||||
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
|
||||
}
|
||||
@@ -642,6 +644,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
|
||||
mgmt.PATCH("/auth-files/fields", s.mgmt.PatchAuthFileFields)
|
||||
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
||||
|
||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||
@@ -649,6 +652,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
|
||||
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||
@@ -682,14 +687,17 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
||||
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
// Synchronously ensure management.html is available with a detached context.
|
||||
// Control panel bootstrap should not be canceled by client disconnects.
|
||||
if !managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository) {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.WithError(err).Error("failed to stat management control panel asset")
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithError(err).Error("failed to stat management control panel asset")
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.File(filePath)
|
||||
@@ -979,10 +987,6 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
|
||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||
|
||||
if !cfg.RemoteManagement.DisableControlPanel {
|
||||
staticDir := managementasset.StaticDir(s.configFilePath)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
if s.mgmt != nil {
|
||||
s.mgmt.SetConfig(cfg)
|
||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||
@@ -1061,14 +1065,10 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing API key"})
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
|
||||
default:
|
||||
statusCode := err.HTTPStatusCode()
|
||||
if statusCode >= http.StatusInternalServerError {
|
||||
log.Errorf("authentication middleware error: %v", err)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authentication service error"})
|
||||
}
|
||||
c.AbortWithStatusJSON(statusCode, gin.H{"error": err.Message})
|
||||
}
|
||||
}
|
||||
|
||||
168
internal/auth/kilo/kilo_auth.go
Normal file
168
internal/auth/kilo/kilo_auth.go
Normal file
@@ -0,0 +1,168 @@
|
||||
// Package kilo provides authentication and token management functionality
|
||||
// for Kilo AI services.
|
||||
package kilo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// BaseURL is the base URL for the Kilo AI API.
|
||||
BaseURL = "https://api.kilo.ai/api"
|
||||
)
|
||||
|
||||
// DeviceAuthResponse represents the response from initiating device flow.
|
||||
type DeviceAuthResponse struct {
|
||||
Code string `json:"code"`
|
||||
VerificationURL string `json:"verificationUrl"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
// DeviceStatusResponse represents the response when polling for device flow status.
|
||||
type DeviceStatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Token string `json:"token"`
|
||||
UserEmail string `json:"userEmail"`
|
||||
}
|
||||
|
||||
// Profile represents the user profile from Kilo AI.
|
||||
type Profile struct {
|
||||
Email string `json:"email"`
|
||||
Orgs []Organization `json:"organizations"`
|
||||
}
|
||||
|
||||
// Organization represents a Kilo AI organization.
|
||||
type Organization struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Defaults represents default settings for an organization or user.
|
||||
type Defaults struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// KiloAuth provides methods for handling the Kilo AI authentication flow.
|
||||
type KiloAuth struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewKiloAuth creates a new instance of KiloAuth.
|
||||
func NewKiloAuth() *KiloAuth {
|
||||
return &KiloAuth{
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateDeviceFlow starts the device authentication flow.
|
||||
func (k *KiloAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceAuthResponse, error) {
|
||||
resp, err := k.client.Post(BaseURL+"/device-auth/codes", "application/json", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to initiate device flow: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var data DeviceAuthResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// PollForToken polls for the device flow completion.
|
||||
func (k *KiloAuth) PollForToken(ctx context.Context, code string) (*DeviceStatusResponse, error) {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
resp, err := k.client.Get(BaseURL + "/device-auth/codes/" + code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var data DeviceStatusResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch data.Status {
|
||||
case "approved":
|
||||
return &data, nil
|
||||
case "denied", "expired":
|
||||
return nil, fmt.Errorf("device flow %s", data.Status)
|
||||
case "pending":
|
||||
continue
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown status: %s", data.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetProfile fetches the user's profile.
|
||||
func (k *KiloAuth) GetProfile(ctx context.Context, token string) (*Profile, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", BaseURL+"/profile", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create get profile request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := k.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get profile: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var profile Profile
|
||||
if err := json.NewDecoder(resp.Body).Decode(&profile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
// GetDefaults fetches default settings for an organization.
|
||||
func (k *KiloAuth) GetDefaults(ctx context.Context, token, orgID string) (*Defaults, error) {
|
||||
url := BaseURL + "/defaults"
|
||||
if orgID != "" {
|
||||
url = BaseURL + "/organizations/" + orgID + "/defaults"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create get defaults request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := k.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get defaults: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var defaults Defaults
|
||||
if err := json.NewDecoder(resp.Body).Decode(&defaults); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &defaults, nil
|
||||
}
|
||||
60
internal/auth/kilo/kilo_token.go
Normal file
60
internal/auth/kilo/kilo_token.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Package kilo provides authentication and token management functionality
|
||||
// for Kilo AI services.
|
||||
package kilo
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// KiloTokenStorage stores token information for Kilo AI authentication.
|
||||
type KiloTokenStorage struct {
|
||||
// Token is the Kilo access token.
|
||||
Token string `json:"kilocodeToken"`
|
||||
|
||||
// OrganizationID is the Kilo organization ID.
|
||||
OrganizationID string `json:"kilocodeOrganizationId"`
|
||||
|
||||
// Model is the default model to use.
|
||||
Model string `json:"kilocodeModel"`
|
||||
|
||||
// Email is the email address of the authenticated user.
|
||||
Email string `json:"email"`
|
||||
|
||||
// Type indicates the authentication provider type, always "kilo" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Kilo token storage to a JSON file.
|
||||
func (ts *KiloTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "kilo"
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := f.Close(); errClose != nil {
|
||||
log.Errorf("failed to close file: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CredentialFileName returns the filename used to persist Kilo credentials.
|
||||
func CredentialFileName(email string) string {
|
||||
return fmt.Sprintf("kilo-%s.json", email)
|
||||
}
|
||||
396
internal/auth/kimi/kimi.go
Normal file
396
internal/auth/kimi/kimi.go
Normal file
@@ -0,0 +1,396 @@
|
||||
// Package kimi provides authentication and token management for Kimi (Moonshot AI) API.
|
||||
// It handles the RFC 8628 OAuth2 Device Authorization Grant flow for secure authentication.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// kimiClientID is Kimi Code's OAuth client ID.
|
||||
kimiClientID = "17e5f671-d194-4dfb-9706-5516cb48c098"
|
||||
// kimiOAuthHost is the OAuth server endpoint.
|
||||
kimiOAuthHost = "https://auth.kimi.com"
|
||||
// kimiDeviceCodeURL is the endpoint for requesting device codes.
|
||||
kimiDeviceCodeURL = kimiOAuthHost + "/api/oauth/device_authorization"
|
||||
// kimiTokenURL is the endpoint for exchanging device codes for tokens.
|
||||
kimiTokenURL = kimiOAuthHost + "/api/oauth/token"
|
||||
// KimiAPIBaseURL is the base URL for Kimi API requests.
|
||||
KimiAPIBaseURL = "https://api.kimi.com/coding"
|
||||
// defaultPollInterval is the default interval for polling token endpoint.
|
||||
defaultPollInterval = 5 * time.Second
|
||||
// maxPollDuration is the maximum time to wait for user authorization.
|
||||
maxPollDuration = 15 * time.Minute
|
||||
// refreshThresholdSeconds is when to refresh token before expiry (5 minutes).
|
||||
refreshThresholdSeconds = 300
|
||||
)
|
||||
|
||||
// KimiAuth handles Kimi authentication flow.
|
||||
type KimiAuth struct {
|
||||
deviceClient *DeviceFlowClient
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewKimiAuth creates a new KimiAuth service instance.
|
||||
func NewKimiAuth(cfg *config.Config) *KimiAuth {
|
||||
return &KimiAuth{
|
||||
deviceClient: NewDeviceFlowClient(cfg),
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// StartDeviceFlow initiates the device flow authentication.
|
||||
func (k *KimiAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
return k.deviceClient.RequestDeviceCode(ctx)
|
||||
}
|
||||
|
||||
// WaitForAuthorization polls for user authorization and returns the auth bundle.
|
||||
func (k *KimiAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiAuthBundle, error) {
|
||||
tokenData, err := k.deviceClient.PollForToken(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &KimiAuthBundle{
|
||||
TokenData: tokenData,
|
||||
DeviceID: k.deviceClient.deviceID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a new KimiTokenStorage from auth bundle.
|
||||
func (k *KimiAuth) CreateTokenStorage(bundle *KimiAuthBundle) *KimiTokenStorage {
|
||||
expired := ""
|
||||
if bundle.TokenData.ExpiresAt > 0 {
|
||||
expired = time.Unix(bundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
return &KimiTokenStorage{
|
||||
AccessToken: bundle.TokenData.AccessToken,
|
||||
RefreshToken: bundle.TokenData.RefreshToken,
|
||||
TokenType: bundle.TokenData.TokenType,
|
||||
Scope: bundle.TokenData.Scope,
|
||||
DeviceID: strings.TrimSpace(bundle.DeviceID),
|
||||
Expired: expired,
|
||||
Type: "kimi",
|
||||
}
|
||||
}
|
||||
|
||||
// DeviceFlowClient handles the OAuth2 device flow for Kimi.
|
||||
type DeviceFlowClient struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
deviceID string
|
||||
}
|
||||
|
||||
// NewDeviceFlowClient creates a new device flow client.
|
||||
func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
||||
return NewDeviceFlowClientWithDeviceID(cfg, "")
|
||||
}
|
||||
|
||||
// NewDeviceFlowClientWithDeviceID creates a new device flow client with the specified device ID.
|
||||
func NewDeviceFlowClientWithDeviceID(cfg *config.Config, deviceID string) *DeviceFlowClient {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
resolvedDeviceID := strings.TrimSpace(deviceID)
|
||||
if resolvedDeviceID == "" {
|
||||
resolvedDeviceID = getOrCreateDeviceID()
|
||||
}
|
||||
return &DeviceFlowClient{
|
||||
httpClient: client,
|
||||
cfg: cfg,
|
||||
deviceID: resolvedDeviceID,
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateDeviceID returns an in-memory device ID for the current authentication flow.
|
||||
func getOrCreateDeviceID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// getDeviceModel returns a device model string.
|
||||
func getDeviceModel() string {
|
||||
osName := runtime.GOOS
|
||||
arch := runtime.GOARCH
|
||||
|
||||
switch osName {
|
||||
case "darwin":
|
||||
return fmt.Sprintf("macOS %s", arch)
|
||||
case "windows":
|
||||
return fmt.Sprintf("Windows %s", arch)
|
||||
case "linux":
|
||||
return fmt.Sprintf("Linux %s", arch)
|
||||
default:
|
||||
return fmt.Sprintf("%s %s", osName, arch)
|
||||
}
|
||||
}
|
||||
|
||||
// getHostname returns the machine hostname.
|
||||
func getHostname() string {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
// commonHeaders returns headers required for Kimi API requests.
|
||||
func (c *DeviceFlowClient) commonHeaders() map[string]string {
|
||||
return map[string]string{
|
||||
"X-Msh-Platform": "cli-proxy-api",
|
||||
"X-Msh-Version": "1.0.0",
|
||||
"X-Msh-Device-Name": getHostname(),
|
||||
"X-Msh-Device-Model": getDeviceModel(),
|
||||
"X-Msh-Device-Id": c.deviceID,
|
||||
}
|
||||
}
|
||||
|
||||
// RequestDeviceCode initiates the device flow by requesting a device code from Kimi.
|
||||
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create device code request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: device code request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi device code: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read device code response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("kimi: device code request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var deviceCode DeviceCodeResponse
|
||||
if err = json.Unmarshal(bodyBytes, &deviceCode); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse device code response: %w", err)
|
||||
}
|
||||
|
||||
return &deviceCode, nil
|
||||
}
|
||||
|
||||
// PollForToken polls the token endpoint until the user authorizes or the device code expires.
|
||||
func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiTokenData, error) {
|
||||
if deviceCode == nil {
|
||||
return nil, fmt.Errorf("kimi: device code is nil")
|
||||
}
|
||||
|
||||
interval := time.Duration(deviceCode.Interval) * time.Second
|
||||
if interval < defaultPollInterval {
|
||||
interval = defaultPollInterval
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(maxPollDuration)
|
||||
if deviceCode.ExpiresIn > 0 {
|
||||
codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
|
||||
if codeDeadline.Before(deadline) {
|
||||
deadline = codeDeadline
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("kimi: context cancelled: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("kimi: device code expired")
|
||||
}
|
||||
|
||||
token, pollErr, shouldContinue := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode)
|
||||
if token != nil {
|
||||
return token, nil
|
||||
}
|
||||
if !shouldContinue {
|
||||
return nil, pollErr
|
||||
}
|
||||
// Continue polling
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// exchangeDeviceCode attempts to exchange the device code for an access token.
|
||||
// Returns (token, error, shouldContinue).
|
||||
func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*KimiTokenData, error, bool) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
data.Set("device_code", deviceCode)
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create token request: %w", err), false
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: token request failed: %w", err), false
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi token exchange: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read token response: %w", err), false
|
||||
}
|
||||
|
||||
// Parse response - Kimi returns 200 for both success and pending states
|
||||
var oauthResp struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn float64 `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse token response: %w", err), false
|
||||
}
|
||||
|
||||
if oauthResp.Error != "" {
|
||||
switch oauthResp.Error {
|
||||
case "authorization_pending":
|
||||
return nil, nil, true // Continue polling
|
||||
case "slow_down":
|
||||
return nil, nil, true // Continue polling (with increased interval handled by caller)
|
||||
case "expired_token":
|
||||
return nil, fmt.Errorf("kimi: device code expired"), false
|
||||
case "access_denied":
|
||||
return nil, fmt.Errorf("kimi: access denied by user"), false
|
||||
default:
|
||||
return nil, fmt.Errorf("kimi: OAuth error: %s - %s", oauthResp.Error, oauthResp.ErrorDescription), false
|
||||
}
|
||||
}
|
||||
|
||||
if oauthResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("kimi: empty access token in response"), false
|
||||
}
|
||||
|
||||
var expiresAt int64
|
||||
if oauthResp.ExpiresIn > 0 {
|
||||
expiresAt = time.Now().Unix() + int64(oauthResp.ExpiresIn)
|
||||
}
|
||||
|
||||
return &KimiTokenData{
|
||||
AccessToken: oauthResp.AccessToken,
|
||||
RefreshToken: oauthResp.RefreshToken,
|
||||
TokenType: oauthResp.TokenType,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: oauthResp.Scope,
|
||||
}, nil, false
|
||||
}
|
||||
|
||||
// RefreshToken exchanges a refresh token for a new access token.
|
||||
func (c *DeviceFlowClient) RefreshToken(ctx context.Context, refreshToken string) (*KimiTokenData, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
data.Set("grant_type", "refresh_token")
|
||||
data.Set("refresh_token", refreshToken)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create refresh request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: refresh request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi refresh token: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read refresh response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
return nil, fmt.Errorf("kimi: refresh token rejected (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("kimi: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn float64 `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("kimi: empty access token in refresh response")
|
||||
}
|
||||
|
||||
var expiresAt int64
|
||||
if tokenResp.ExpiresIn > 0 {
|
||||
expiresAt = time.Now().Unix() + int64(tokenResp.ExpiresIn)
|
||||
}
|
||||
|
||||
return &KimiTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: tokenResp.Scope,
|
||||
}, nil
|
||||
}
|
||||
116
internal/auth/kimi/token.go
Normal file
116
internal/auth/kimi/token.go
Normal file
@@ -0,0 +1,116 @@
|
||||
// Package kimi provides authentication and token management functionality
|
||||
// for Kimi (Moonshot AI) services. It handles OAuth2 device flow token storage,
|
||||
// serialization, and retrieval for maintaining authenticated sessions with the Kimi API.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
)
|
||||
|
||||
// KimiTokenStorage stores OAuth2 token information for Kimi API authentication.
|
||||
type KimiTokenStorage struct {
|
||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// TokenType is the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// Scope is the OAuth2 scope granted to the token.
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// DeviceID is the OAuth device flow identifier used for Kimi requests.
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
// Expired is the RFC3339 timestamp when the access token expires.
|
||||
Expired string `json:"expired,omitempty"`
|
||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||
type KimiTokenData struct {
|
||||
// AccessToken is the OAuth2 access token.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is the OAuth2 refresh token.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// TokenType is the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// ExpiresAt is the Unix timestamp when the token expires.
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
// Scope is the OAuth2 scope granted to the token.
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// KimiAuthBundle bundles authentication data for storage.
|
||||
type KimiAuthBundle struct {
|
||||
// TokenData contains the OAuth token information.
|
||||
TokenData *KimiTokenData
|
||||
// DeviceID is the device identifier used during OAuth device flow.
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
// DeviceCodeResponse represents Kimi's device code response.
|
||||
type DeviceCodeResponse struct {
|
||||
// DeviceCode is the device verification code.
|
||||
DeviceCode string `json:"device_code"`
|
||||
// UserCode is the code the user must enter at the verification URI.
|
||||
UserCode string `json:"user_code"`
|
||||
// VerificationURI is the URL where the user should enter the code.
|
||||
VerificationURI string `json:"verification_uri,omitempty"`
|
||||
// VerificationURIComplete is the URL with the code pre-filled.
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
// ExpiresIn is the number of seconds until the device code expires.
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
// Interval is the minimum number of seconds to wait between polling requests.
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Kimi token storage to a JSON file.
|
||||
func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "kimi"
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
encoder := json.NewEncoder(f)
|
||||
encoder.SetIndent("", " ")
|
||||
if err = encoder.Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsExpired checks if the token has expired.
|
||||
func (ts *KimiTokenStorage) IsExpired() bool {
|
||||
if ts.Expired == "" {
|
||||
return false // No expiry set, assume valid
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, ts.Expired)
|
||||
if err != nil {
|
||||
return true // Has expiry string but can't parse
|
||||
}
|
||||
// Consider expired if within refresh threshold
|
||||
return time.Now().Add(time.Duration(refreshThresholdSeconds) * time.Second).After(t)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the token should be refreshed.
|
||||
func (ts *KimiTokenStorage) NeedsRefresh() bool {
|
||||
if ts.RefreshToken == "" {
|
||||
return false // Can't refresh without refresh token
|
||||
}
|
||||
return ts.IsExpired()
|
||||
}
|
||||
@@ -238,7 +238,7 @@ func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroToken
|
||||
Description string `json:"description"`
|
||||
RateMultiplier float64 `json:"rateMultiplier"`
|
||||
RateUnit string `json:"rateUnit"`
|
||||
TokenLimits struct {
|
||||
TokenLimits *struct {
|
||||
MaxInputTokens int `json:"maxInputTokens"`
|
||||
} `json:"tokenLimits"`
|
||||
} `json:"models"`
|
||||
@@ -250,13 +250,17 @@ func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroToken
|
||||
|
||||
models := make([]*KiroModel, 0, len(result.Models))
|
||||
for _, m := range result.Models {
|
||||
maxInputTokens := 0
|
||||
if m.TokenLimits != nil {
|
||||
maxInputTokens = m.TokenLimits.MaxInputTokens
|
||||
}
|
||||
models = append(models, &KiroModel{
|
||||
ModelID: m.ModelID,
|
||||
ModelName: m.ModelName,
|
||||
Description: m.Description,
|
||||
RateMultiplier: m.RateMultiplier,
|
||||
RateUnit: m.RateUnit,
|
||||
MaxInputTokens: m.TokenLimits.MaxInputTokens,
|
||||
MaxInputTokens: maxInputTokens,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -40,8 +40,7 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
||||
if err != nil {
|
||||
var authErr *claude.AuthenticationError
|
||||
if errors.As(err, &authErr) {
|
||||
if authErr, ok := errors.AsType[*claude.AuthenticationError](err); ok {
|
||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == claude.ErrPortInUse.Type {
|
||||
os.Exit(claude.ErrPortInUse.Code)
|
||||
|
||||
@@ -19,8 +19,10 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewQwenAuthenticator(),
|
||||
sdkAuth.NewIFlowAuthenticator(),
|
||||
sdkAuth.NewAntigravityAuthenticator(),
|
||||
sdkAuth.NewKimiAuthenticator(),
|
||||
sdkAuth.NewKiroAuthenticator(),
|
||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||
sdkAuth.NewKiloAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
|
||||
@@ -32,8 +32,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
|
||||
if err != nil {
|
||||
var emailErr *sdkAuth.EmailRequiredError
|
||||
if errors.As(err, &emailErr) {
|
||||
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
|
||||
log.Error(emailErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
54
internal/cmd/kilo_login.go
Normal file
54
internal/cmd/kilo_login.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
)
|
||||
|
||||
// DoKiloLogin handles the Kilo device flow using the shared authentication manager.
|
||||
// It initiates the device-based authentication process for Kilo AI services and saves
|
||||
// the authentication tokens to the configured auth directory.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration
|
||||
// - options: Login options including browser behavior and prompts
|
||||
func DoKiloLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = func(prompt string) (string, error) {
|
||||
fmt.Print(prompt)
|
||||
var value string
|
||||
fmt.Scanln(&value)
|
||||
return strings.TrimSpace(value), nil
|
||||
}
|
||||
}
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "kilo", cfg, authOpts)
|
||||
if err != nil {
|
||||
fmt.Printf("Kilo authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
|
||||
fmt.Println("Kilo authentication successful!")
|
||||
}
|
||||
44
internal/cmd/kimi_login.go
Normal file
44
internal/cmd/kimi_login.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoKimiLogin triggers the OAuth device flow for Kimi (Moonshot AI) and saves tokens.
|
||||
// It initiates the device flow authentication, displays the verification URL for the user,
|
||||
// and waits for authorization before saving the tokens.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration containing proxy and auth directory settings
|
||||
// - options: Login options including browser behavior settings
|
||||
func DoKimiLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "kimi", cfg, authOpts)
|
||||
if err != nil {
|
||||
log.Errorf("Kimi authentication failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||
}
|
||||
fmt.Println("Kimi authentication successful!")
|
||||
}
|
||||
@@ -100,49 +100,74 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
|
||||
log.Info("Authentication successful.")
|
||||
|
||||
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
||||
if errProjects != nil {
|
||||
log.Errorf("Failed to get project list: %v", errProjects)
|
||||
return
|
||||
var activatedProjects []string
|
||||
|
||||
useGoogleOne := false
|
||||
if trimmedProjectID == "" && promptFn != nil {
|
||||
fmt.Println("\nSelect login mode:")
|
||||
fmt.Println(" 1. Code Assist (GCP project, manual selection)")
|
||||
fmt.Println(" 2. Google One (personal account, auto-discover project)")
|
||||
choice, errPrompt := promptFn("Enter choice [1/2] (default: 1): ")
|
||||
if errPrompt == nil && strings.TrimSpace(choice) == "2" {
|
||||
useGoogleOne = true
|
||||
}
|
||||
}
|
||||
|
||||
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Errorf("Invalid project selection: %v", errSelection)
|
||||
return
|
||||
}
|
||||
if len(projectSelections) == 0 {
|
||||
log.Error("No project selected; aborting login.")
|
||||
return
|
||||
}
|
||||
|
||||
activatedProjects := make([]string, 0, len(projectSelections))
|
||||
seenProjects := make(map[string]bool)
|
||||
for _, candidateID := range projectSelections {
|
||||
log.Infof("Activating project %s", candidateID)
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||
var projectErr *projectSelectionRequiredError
|
||||
if errors.As(errSetup, &projectErr) {
|
||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||
showProjectSelectionHelp(storage.Email, projects)
|
||||
return
|
||||
}
|
||||
log.Errorf("Failed to complete user setup: %v", errSetup)
|
||||
if useGoogleOne {
|
||||
log.Info("Google One mode: auto-discovering project...")
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, ""); errSetup != nil {
|
||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||
return
|
||||
}
|
||||
finalID := strings.TrimSpace(storage.ProjectID)
|
||||
if finalID == "" {
|
||||
finalID = candidateID
|
||||
autoProject := strings.TrimSpace(storage.ProjectID)
|
||||
if autoProject == "" {
|
||||
log.Error("Google One auto-discovery returned empty project ID")
|
||||
return
|
||||
}
|
||||
log.Infof("Auto-discovered project: %s", autoProject)
|
||||
activatedProjects = []string{autoProject}
|
||||
} else {
|
||||
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
||||
if errProjects != nil {
|
||||
log.Errorf("Failed to get project list: %v", errProjects)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip duplicates
|
||||
if seenProjects[finalID] {
|
||||
log.Infof("Project %s already activated, skipping", finalID)
|
||||
continue
|
||||
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Errorf("Invalid project selection: %v", errSelection)
|
||||
return
|
||||
}
|
||||
if len(projectSelections) == 0 {
|
||||
log.Error("No project selected; aborting login.")
|
||||
return
|
||||
}
|
||||
|
||||
seenProjects := make(map[string]bool)
|
||||
for _, candidateID := range projectSelections {
|
||||
log.Infof("Activating project %s", candidateID)
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||
if _, ok := errors.AsType[*projectSelectionRequiredError](errSetup); ok {
|
||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||
showProjectSelectionHelp(storage.Email, projects)
|
||||
return
|
||||
}
|
||||
log.Errorf("Failed to complete user setup: %v", errSetup)
|
||||
return
|
||||
}
|
||||
finalID := strings.TrimSpace(storage.ProjectID)
|
||||
if finalID == "" {
|
||||
finalID = candidateID
|
||||
}
|
||||
|
||||
if seenProjects[finalID] {
|
||||
log.Infof("Project %s already activated, skipping", finalID)
|
||||
continue
|
||||
}
|
||||
seenProjects[finalID] = true
|
||||
activatedProjects = append(activatedProjects, finalID)
|
||||
}
|
||||
seenProjects[finalID] = true
|
||||
activatedProjects = append(activatedProjects, finalID)
|
||||
}
|
||||
|
||||
storage.Auto = false
|
||||
@@ -235,7 +260,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
}
|
||||
}
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
// Auto-discovery: try onboardUser without specifying a project
|
||||
// to let Google auto-provision one (matches Gemini CLI headless behavior
|
||||
// and Antigravity's FetchProjectID pattern).
|
||||
autoOnboardReq := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer autoCancel()
|
||||
for attempt := 1; ; attempt++ {
|
||||
var onboardResp map[string]any
|
||||
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
|
||||
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
|
||||
}
|
||||
|
||||
if done, okDone := onboardResp["done"].(bool); okDone && done {
|
||||
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
|
||||
switch v := resp["cloudaicompanionProject"].(type) {
|
||||
case string:
|
||||
projectID = strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if id, okID := v["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
|
||||
select {
|
||||
case <-autoCtx.Done():
|
||||
return &projectSelectionRequiredError{}
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
}
|
||||
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
|
||||
}
|
||||
|
||||
onboardReqBody := map[string]any{
|
||||
@@ -617,7 +683,7 @@ func updateAuthRecord(record *cliproxyauth.Auth, storage *gemini.GeminiTokenStor
|
||||
return
|
||||
}
|
||||
|
||||
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, false)
|
||||
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, true)
|
||||
|
||||
if record.Metadata == nil {
|
||||
record.Metadata = make(map[string]any)
|
||||
|
||||
@@ -54,8 +54,7 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
if err != nil {
|
||||
var authErr *codex.AuthenticationError
|
||||
if errors.As(err, &authErr) {
|
||||
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == codex.ErrPortInUse.Type {
|
||||
os.Exit(codex.ErrPortInUse.Code)
|
||||
|
||||
@@ -44,8 +44,7 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
||||
if err != nil {
|
||||
var emailErr *sdkAuth.EmailRequiredError
|
||||
if errors.As(err, &emailErr) {
|
||||
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
|
||||
log.Error(emailErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -55,6 +55,34 @@ func StartService(cfg *config.Config, configPath string, localPassword string) {
|
||||
}
|
||||
}
|
||||
|
||||
// StartServiceBackground starts the proxy service in a background goroutine
|
||||
// and returns a cancel function for shutdown and a done channel.
|
||||
func StartServiceBackground(cfg *config.Config, configPath string, localPassword string) (cancel func(), done <-chan struct{}) {
|
||||
builder := cliproxy.NewBuilder().
|
||||
WithConfig(cfg).
|
||||
WithConfigPath(configPath).
|
||||
WithLocalManagementPassword(localPassword)
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
doneCh := make(chan struct{})
|
||||
|
||||
service, err := builder.Build()
|
||||
if err != nil {
|
||||
log.Errorf("failed to build proxy service: %v", err)
|
||||
close(doneCh)
|
||||
return cancelFn, doneCh
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(doneCh)
|
||||
if err := service.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("proxy service exited with error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return cancelFn, doneCh
|
||||
}
|
||||
|
||||
// WaitForCloudDeploy waits indefinitely for shutdown signals in cloud deploy mode
|
||||
// when no configuration file is available.
|
||||
func WaitForCloudDeploy() {
|
||||
|
||||
@@ -97,6 +97,10 @@ type Config struct {
|
||||
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
||||
|
||||
// ClaudeHeaderDefaults configures default header values for Claude API requests.
|
||||
// These are used as fallbacks when the client does not send its own headers.
|
||||
ClaudeHeaderDefaults ClaudeHeaderDefaults `yaml:"claude-header-defaults" json:"claude-header-defaults"`
|
||||
|
||||
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
|
||||
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"`
|
||||
|
||||
@@ -130,6 +134,15 @@ type Config struct {
|
||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
|
||||
// when the client does not send them. Update these when Claude Code releases a new version.
|
||||
type ClaudeHeaderDefaults struct {
|
||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
// TLSConfig holds HTTPS server settings.
|
||||
type TLSConfig struct {
|
||||
// Enable toggles HTTPS server mode.
|
||||
@@ -368,6 +381,9 @@ type CodexKey struct {
|
||||
// If empty, the default Codex API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
|
||||
// Websockets enables the Responses API websocket transport for this credential.
|
||||
Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"`
|
||||
|
||||
// ProxyURL overrides the global proxy setting for this API key if provided.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
@@ -535,14 +551,15 @@ func LoadConfig(configFile string) (*Config, error) {
|
||||
// If optional is true and the file is missing, it returns an empty Config.
|
||||
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Perform oauth-model-alias migration before loading config.
|
||||
// This migrates oauth-model-mappings to oauth-model-alias if needed.
|
||||
if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
||||
// Log warning but don't fail - config loading should still work
|
||||
fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
||||
} else if migrated {
|
||||
fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
||||
}
|
||||
// NOTE: Startup oauth-model-alias migration is intentionally disabled.
|
||||
// Reason: avoid mutating config.yaml during server startup.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
||||
// // Log warning but don't fail - config loading should still work
|
||||
// fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
||||
// } else if migrated {
|
||||
// fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
||||
// }
|
||||
|
||||
// Read the entire configuration file into memory.
|
||||
data, err := os.ReadFile(configFile)
|
||||
@@ -583,18 +600,21 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
var legacy legacyConfigData
|
||||
if errLegacy := yaml.Unmarshal(data, &legacy); errLegacy == nil {
|
||||
if cfg.migrateLegacyGeminiKeys(legacy.LegacyGeminiKeys) {
|
||||
cfg.legacyMigrationPending = true
|
||||
}
|
||||
if cfg.migrateLegacyOpenAICompatibilityKeys(legacy.OpenAICompat) {
|
||||
cfg.legacyMigrationPending = true
|
||||
}
|
||||
if cfg.migrateLegacyAmpConfig(&legacy) {
|
||||
cfg.legacyMigrationPending = true
|
||||
}
|
||||
}
|
||||
// NOTE: Startup legacy key migration is intentionally disabled.
|
||||
// Reason: avoid mutating config.yaml during server startup.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// var legacy legacyConfigData
|
||||
// if errLegacy := yaml.Unmarshal(data, &legacy); errLegacy == nil {
|
||||
// if cfg.migrateLegacyGeminiKeys(legacy.LegacyGeminiKeys) {
|
||||
// cfg.legacyMigrationPending = true
|
||||
// }
|
||||
// if cfg.migrateLegacyOpenAICompatibilityKeys(legacy.OpenAICompat) {
|
||||
// cfg.legacyMigrationPending = true
|
||||
// }
|
||||
// if cfg.migrateLegacyAmpConfig(&legacy) {
|
||||
// cfg.legacyMigrationPending = true
|
||||
// }
|
||||
// }
|
||||
|
||||
// Hash remote management key if plaintext is detected (nested)
|
||||
// We consider a value to be already hashed if it looks like a bcrypt hash ($2a$, $2b$, or $2y$ prefix).
|
||||
@@ -628,9 +648,6 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
}
|
||||
|
||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
||||
syncInlineAccessProvider(&cfg)
|
||||
|
||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||
cfg.SanitizeGeminiKeys()
|
||||
|
||||
@@ -658,17 +675,20 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Validate raw payload rules and drop invalid entries.
|
||||
cfg.SanitizePayloadRules()
|
||||
|
||||
if cfg.legacyMigrationPending {
|
||||
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||
if !optional && configFile != "" {
|
||||
if err := SaveConfigPreserveComments(configFile, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist migrated legacy config: %w", err)
|
||||
}
|
||||
fmt.Println("Legacy configuration normalized and persisted.")
|
||||
} else {
|
||||
fmt.Println("Legacy configuration normalized in memory; persistence skipped.")
|
||||
}
|
||||
}
|
||||
// NOTE: Legacy migration persistence is intentionally disabled together with
|
||||
// startup legacy migration to keep startup read-only for config.yaml.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// if cfg.legacyMigrationPending {
|
||||
// fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||
// if !optional && configFile != "" {
|
||||
// if err := SaveConfigPreserveComments(configFile, &cfg); err != nil {
|
||||
// return nil, fmt.Errorf("failed to persist migrated legacy config: %w", err)
|
||||
// }
|
||||
// fmt.Println("Legacy configuration normalized and persisted.")
|
||||
// } else {
|
||||
// fmt.Println("Legacy configuration normalized in memory; persistence skipped.")
|
||||
// }
|
||||
// }
|
||||
|
||||
// Return the populated configuration struct.
|
||||
return &cfg, nil
|
||||
@@ -732,14 +752,44 @@ func payloadRawString(value any) ([]byte, bool) {
|
||||
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||
// It also injects default aliases for channels that have built-in defaults (e.g., kiro)
|
||||
// when no user-configured aliases exist for those channels.
|
||||
func (cfg *Config) SanitizeOAuthModelAlias() {
|
||||
if cfg == nil || len(cfg.OAuthModelAlias) == 0 {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Inject default Kiro aliases if no user-configured kiro aliases exist
|
||||
if cfg.OAuthModelAlias == nil {
|
||||
cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
|
||||
}
|
||||
if _, hasKiro := cfg.OAuthModelAlias["kiro"]; !hasKiro {
|
||||
// Check case-insensitive too
|
||||
found := false
|
||||
for k := range cfg.OAuthModelAlias {
|
||||
if strings.EqualFold(strings.TrimSpace(k), "kiro") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
|
||||
}
|
||||
}
|
||||
|
||||
if len(cfg.OAuthModelAlias) == 0 {
|
||||
return
|
||||
}
|
||||
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
|
||||
for rawChannel, aliases := range cfg.OAuthModelAlias {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(aliases) == 0 {
|
||||
if channel == "" {
|
||||
continue
|
||||
}
|
||||
// Preserve channels that were explicitly set to empty/nil – they act
|
||||
// as "disabled" markers so default injection won't re-add them (#222).
|
||||
if len(aliases) == 0 {
|
||||
out[channel] = nil
|
||||
continue
|
||||
}
|
||||
seenAlias := make(map[string]struct{}, len(aliases))
|
||||
@@ -881,18 +931,6 @@ func normalizeModelPrefix(prefix string) string {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func syncInlineAccessProvider(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
if len(cfg.APIKeys) == 0 {
|
||||
if provider := cfg.ConfigAPIKeyProvider(); provider != nil && len(provider.APIKeys) > 0 {
|
||||
cfg.APIKeys = append([]string(nil), provider.APIKeys...)
|
||||
}
|
||||
}
|
||||
cfg.Access.Providers = nil
|
||||
}
|
||||
|
||||
// looksLikeBcrypt returns true if the provided string appears to be a bcrypt hash.
|
||||
func looksLikeBcrypt(s string) bool {
|
||||
return len(s) > 4 && (s[:4] == "$2a$" || s[:4] == "$2b$" || s[:4] == "$2y$")
|
||||
@@ -980,7 +1018,7 @@ func hashSecret(secret string) (string, error) {
|
||||
// SaveConfigPreserveComments writes the config back to YAML while preserving existing comments
|
||||
// and key ordering by loading the original file into a yaml.Node tree and updating values in-place.
|
||||
func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||
persistCfg := sanitizeConfigForPersist(cfg)
|
||||
persistCfg := cfg
|
||||
// Load original YAML as a node tree to preserve comments and ordering.
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
@@ -1048,16 +1086,6 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func sanitizeConfigForPersist(cfg *Config) *Config {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *cfg
|
||||
clone.SDKConfig = cfg.SDKConfig
|
||||
clone.SDKConfig.Access = AccessConfig{}
|
||||
return &clone
|
||||
}
|
||||
|
||||
// SaveConfigPreserveCommentsUpdateNestedScalar updates a nested scalar key path like ["a","b"]
|
||||
// while preserving comments and positions.
|
||||
func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []string, value string) error {
|
||||
@@ -1154,8 +1182,13 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
|
||||
|
||||
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
|
||||
// key order and comments of existing keys in dst. New keys are only added if their
|
||||
// value is non-zero to avoid polluting the config with defaults.
|
||||
func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
// value is non-zero and not a known default to avoid polluting the config with defaults.
|
||||
func mergeMappingPreserve(dst, src *yaml.Node, path ...[]string) {
|
||||
var currentPath []string
|
||||
if len(path) > 0 {
|
||||
currentPath = path[0]
|
||||
}
|
||||
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
@@ -1169,16 +1202,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
sk := src.Content[i]
|
||||
sv := src.Content[i+1]
|
||||
idx := findMapKeyIndex(dst, sk.Value)
|
||||
childPath := appendPath(currentPath, sk.Value)
|
||||
if idx >= 0 {
|
||||
// Merge into existing value node (always update, even to zero values)
|
||||
dv := dst.Content[idx+1]
|
||||
mergeNodePreserve(dv, sv)
|
||||
mergeNodePreserve(dv, sv, childPath)
|
||||
} else {
|
||||
// New key: only add if value is non-zero to avoid polluting config with defaults
|
||||
if isZeroValueNode(sv) {
|
||||
// New key: only add if value is non-zero and not a known default
|
||||
candidate := deepCopyNode(sv)
|
||||
pruneKnownDefaultsInNewNode(childPath, candidate)
|
||||
if isKnownDefaultValue(childPath, candidate) {
|
||||
continue
|
||||
}
|
||||
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
|
||||
dst.Content = append(dst.Content, deepCopyNode(sk), candidate)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1186,7 +1222,12 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
// mergeNodePreserve merges src into dst for scalars, mappings and sequences while
|
||||
// reusing destination nodes to keep comments and anchors. For sequences, it updates
|
||||
// in-place by index.
|
||||
func mergeNodePreserve(dst, src *yaml.Node) {
|
||||
func mergeNodePreserve(dst, src *yaml.Node, path ...[]string) {
|
||||
var currentPath []string
|
||||
if len(path) > 0 {
|
||||
currentPath = path[0]
|
||||
}
|
||||
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
@@ -1195,7 +1236,7 @@ func mergeNodePreserve(dst, src *yaml.Node) {
|
||||
if dst.Kind != yaml.MappingNode {
|
||||
copyNodeShallow(dst, src)
|
||||
}
|
||||
mergeMappingPreserve(dst, src)
|
||||
mergeMappingPreserve(dst, src, currentPath)
|
||||
case yaml.SequenceNode:
|
||||
// Preserve explicit null style if dst was null and src is empty sequence
|
||||
if dst.Kind == yaml.ScalarNode && dst.Tag == "!!null" && len(src.Content) == 0 {
|
||||
@@ -1218,7 +1259,7 @@ func mergeNodePreserve(dst, src *yaml.Node) {
|
||||
dst.Content[i] = deepCopyNode(src.Content[i])
|
||||
continue
|
||||
}
|
||||
mergeNodePreserve(dst.Content[i], src.Content[i])
|
||||
mergeNodePreserve(dst.Content[i], src.Content[i], currentPath)
|
||||
if dst.Content[i] != nil && src.Content[i] != nil &&
|
||||
dst.Content[i].Kind == yaml.MappingNode && src.Content[i].Kind == yaml.MappingNode {
|
||||
pruneMissingMapKeys(dst.Content[i], src.Content[i])
|
||||
@@ -1260,6 +1301,94 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
|
||||
return -1
|
||||
}
|
||||
|
||||
// appendPath appends a key to the path, returning a new slice to avoid modifying the original.
|
||||
func appendPath(path []string, key string) []string {
|
||||
if len(path) == 0 {
|
||||
return []string{key}
|
||||
}
|
||||
newPath := make([]string, len(path)+1)
|
||||
copy(newPath, path)
|
||||
newPath[len(path)] = key
|
||||
return newPath
|
||||
}
|
||||
|
||||
// isKnownDefaultValue returns true if the given node at the specified path
|
||||
// represents a known default value that should not be written to the config file.
|
||||
// This prevents non-zero defaults from polluting the config.
|
||||
func isKnownDefaultValue(path []string, node *yaml.Node) bool {
|
||||
// First check if it's a zero value
|
||||
if isZeroValueNode(node) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Match known non-zero defaults by exact dotted path.
|
||||
if len(path) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
fullPath := strings.Join(path, ".")
|
||||
|
||||
// Check string defaults
|
||||
if node.Kind == yaml.ScalarNode && node.Tag == "!!str" {
|
||||
switch fullPath {
|
||||
case "pprof.addr":
|
||||
return node.Value == DefaultPprofAddr
|
||||
case "remote-management.panel-github-repository":
|
||||
return node.Value == DefaultPanelGitHubRepository
|
||||
case "routing.strategy":
|
||||
return node.Value == "round-robin"
|
||||
}
|
||||
}
|
||||
|
||||
// Check integer defaults
|
||||
if node.Kind == yaml.ScalarNode && node.Tag == "!!int" {
|
||||
switch fullPath {
|
||||
case "error-logs-max-files":
|
||||
return node.Value == "10"
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// pruneKnownDefaultsInNewNode removes default-valued descendants from a new node
|
||||
// before it is appended into the destination YAML tree.
|
||||
func pruneKnownDefaultsInNewNode(path []string, node *yaml.Node) {
|
||||
if node == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch node.Kind {
|
||||
case yaml.MappingNode:
|
||||
filtered := make([]*yaml.Node, 0, len(node.Content))
|
||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||
keyNode := node.Content[i]
|
||||
valueNode := node.Content[i+1]
|
||||
if keyNode == nil || valueNode == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
childPath := appendPath(path, keyNode.Value)
|
||||
if isKnownDefaultValue(childPath, valueNode) {
|
||||
continue
|
||||
}
|
||||
|
||||
pruneKnownDefaultsInNewNode(childPath, valueNode)
|
||||
if (valueNode.Kind == yaml.MappingNode || valueNode.Kind == yaml.SequenceNode) &&
|
||||
len(valueNode.Content) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
filtered = append(filtered, keyNode, valueNode)
|
||||
}
|
||||
node.Content = filtered
|
||||
case yaml.SequenceNode:
|
||||
for _, child := range node.Content {
|
||||
pruneKnownDefaultsInNewNode(path, child)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isZeroValueNode returns true if the YAML node represents a zero/default value
|
||||
// that should not be written as a new key to preserve config cleanliness.
|
||||
// For mappings and sequences, recursively checks if all children are zero values.
|
||||
|
||||
@@ -17,6 +17,29 @@ var antigravityModelConversionTable = map[string]string{
|
||||
"gemini-claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||
"gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||
"gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
|
||||
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||
}
|
||||
|
||||
// defaultKiroAliases returns the default oauth-model-alias configuration
|
||||
// for the kiro channel. Maps kiro-prefixed model names to standard Claude model
|
||||
// names so that clients like Claude Code can use standard names directly.
|
||||
func defaultKiroAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
// Sonnet 4.5
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||
// Sonnet 4
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
|
||||
// Opus 4.6
|
||||
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
|
||||
// Opus 4.5
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
|
||||
// Haiku 4.5
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
|
||||
}
|
||||
}
|
||||
|
||||
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
||||
@@ -30,6 +53,7 @@ func defaultAntigravityAliases() []OAuthModelAlias {
|
||||
{Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"},
|
||||
{Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"},
|
||||
{Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"},
|
||||
{Name: "claude-opus-4-6-thinking", Alias: "gemini-claude-opus-4-6-thinking"},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -131,6 +131,9 @@ func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) {
|
||||
if !strings.Contains(content, "claude-opus-4-5-thinking") {
|
||||
t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-opus-4-6-thinking") {
|
||||
t.Fatal("expected missing default alias claude-opus-4-6-thinking to be added")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) {
|
||||
|
||||
@@ -54,3 +54,132 @@ func TestSanitizeOAuthModelAlias_AllowsMultipleAliasesForSameName(t *testing.T)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroAliases(t *testing.T) {
|
||||
// When no kiro aliases are configured, defaults should be injected
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"codex": {
|
||||
{Name: "gpt-5", Alias: "g5"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) == 0 {
|
||||
t.Fatal("expected default kiro aliases to be injected")
|
||||
}
|
||||
|
||||
// Check that standard Claude model names are present
|
||||
aliasSet := make(map[string]bool)
|
||||
for _, a := range kiroAliases {
|
||||
aliasSet[a.Alias] = true
|
||||
}
|
||||
expectedAliases := []string{
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-sonnet-4",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-haiku-4-5",
|
||||
}
|
||||
for _, expected := range expectedAliases {
|
||||
if !aliasSet[expected] {
|
||||
t.Fatalf("expected default kiro alias %q to be present", expected)
|
||||
}
|
||||
}
|
||||
|
||||
// All should have fork=true
|
||||
for _, a := range kiroAliases {
|
||||
if !a.Fork {
|
||||
t.Fatalf("expected all default kiro aliases to have fork=true, got fork=false for %q", a.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
// Codex aliases should still be preserved
|
||||
if len(cfg.OAuthModelAlias["codex"]) != 1 {
|
||||
t.Fatal("expected codex aliases to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
|
||||
// When user has configured kiro aliases, defaults should NOT be injected
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"kiro": {
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "my-custom-sonnet", Fork: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) != 1 {
|
||||
t.Fatalf("expected 1 user-configured kiro alias, got %d", len(kiroAliases))
|
||||
}
|
||||
if kiroAliases[0].Alias != "my-custom-sonnet" {
|
||||
t.Fatalf("expected user alias to be preserved, got %q", kiroAliases[0].Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) {
|
||||
// When user explicitly deletes kiro aliases (key exists with nil value),
|
||||
// defaults should NOT be re-injected on subsequent sanitize calls (#222).
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"kiro": nil, // explicitly deleted
|
||||
"codex": {{Name: "gpt-5", Alias: "g5"}},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) != 0 {
|
||||
t.Fatalf("expected kiro aliases to remain empty after explicit deletion, got %d aliases", len(kiroAliases))
|
||||
}
|
||||
// The key itself must still be present to prevent re-injection on next reload
|
||||
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
|
||||
t.Fatal("expected kiro key to be preserved as nil marker after sanitization")
|
||||
}
|
||||
// Other channels should be unaffected
|
||||
if len(cfg.OAuthModelAlias["codex"]) != 1 {
|
||||
t.Fatal("expected codex aliases to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) {
|
||||
// Same as above but with empty slice instead of nil (PUT with empty body).
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"kiro": {}, // explicitly set to empty
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
if len(cfg.OAuthModelAlias["kiro"]) != 0 {
|
||||
t.Fatalf("expected kiro aliases to remain empty, got %d aliases", len(cfg.OAuthModelAlias["kiro"]))
|
||||
}
|
||||
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
|
||||
t.Fatal("expected kiro key to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroWhenEmpty(t *testing.T) {
|
||||
// When OAuthModelAlias is nil, kiro defaults should still be injected
|
||||
cfg := &Config{}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) == 0 {
|
||||
t.Fatal("expected default kiro aliases to be injected when OAuthModelAlias is nil")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,8 +20,9 @@ type SDKConfig struct {
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
|
||||
// Access holds request authentication provider configuration.
|
||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
|
||||
// Default is false (disabled).
|
||||
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
|
||||
|
||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||
@@ -42,65 +43,3 @@ type StreamingConfig struct {
|
||||
// <= 0 disables bootstrap retries. Default is 0.
|
||||
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
||||
}
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
type AccessConfig struct {
|
||||
// Providers lists configured authentication providers.
|
||||
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
||||
}
|
||||
|
||||
// AccessProvider describes a request authentication provider entry.
|
||||
type AccessProvider struct {
|
||||
// Name is the instance identifier for the provider.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Type selects the provider implementation registered via the SDK.
|
||||
Type string `yaml:"type" json:"type"`
|
||||
|
||||
// SDK optionally names a third-party SDK module providing this provider.
|
||||
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
||||
|
||||
// APIKeys lists inline keys for providers that require them.
|
||||
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
||||
|
||||
// Config passes provider-specific options to the implementation.
|
||||
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
||||
AccessProviderTypeConfigAPIKey = "config-api-key"
|
||||
|
||||
// DefaultAccessProviderName is applied when no provider name is supplied.
|
||||
DefaultAccessProviderName = "config-inline"
|
||||
)
|
||||
|
||||
// ConfigAPIKeyProvider returns the first inline API key provider if present.
|
||||
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range c.Access.Providers {
|
||||
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
|
||||
if c.Access.Providers[i].Name == "" {
|
||||
c.Access.Providers[i].Name = DefaultAccessProviderName
|
||||
}
|
||||
return &c.Access.Providers[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
||||
// It returns nil when no keys are supplied.
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
provider := &AccessProvider{
|
||||
Name: DefaultAccessProviderName,
|
||||
Type: AccessProviderTypeConfigAPIKey,
|
||||
APIKeys: append([]string(nil), keys...),
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
@@ -27,4 +27,7 @@ const (
|
||||
|
||||
// Kiro represents the AWS CodeWhisperer (Kiro) provider identifier.
|
||||
Kiro = "kiro"
|
||||
|
||||
// Kilo represents the Kilo AI provider identifier.
|
||||
Kilo = "kilo"
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,6 +29,7 @@ const (
|
||||
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
managementSyncMinInterval = 30 * time.Second
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
@@ -37,11 +39,10 @@ const ManagementFileName = managementAssetName
|
||||
var (
|
||||
lastUpdateCheckMu sync.Mutex
|
||||
lastUpdateCheckTime time.Time
|
||||
|
||||
currentConfigPtr atomic.Pointer[config.Config]
|
||||
disableControlPanel atomic.Bool
|
||||
schedulerOnce sync.Once
|
||||
schedulerConfigPath atomic.Value
|
||||
sfGroup singleflight.Group
|
||||
)
|
||||
|
||||
// SetCurrentConfig stores the latest configuration snapshot for management asset decisions.
|
||||
@@ -50,16 +51,7 @@ func SetCurrentConfig(cfg *config.Config) {
|
||||
currentConfigPtr.Store(nil)
|
||||
return
|
||||
}
|
||||
|
||||
prevDisabled := disableControlPanel.Load()
|
||||
currentConfigPtr.Store(cfg)
|
||||
disableControlPanel.Store(cfg.RemoteManagement.DisableControlPanel)
|
||||
|
||||
if prevDisabled && !cfg.RemoteManagement.DisableControlPanel {
|
||||
lastUpdateCheckMu.Lock()
|
||||
lastUpdateCheckTime = time.Time{}
|
||||
lastUpdateCheckMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// StartAutoUpdater launches a background goroutine that periodically ensures the management asset is up to date.
|
||||
@@ -92,7 +84,7 @@ func runAutoUpdater(ctx context.Context) {
|
||||
log.Debug("management asset auto-updater skipped: config not yet available")
|
||||
return
|
||||
}
|
||||
if disableControlPanel.Load() {
|
||||
if cfg.RemoteManagement.DisableControlPanel {
|
||||
log.Debug("management asset auto-updater skipped: control panel disabled")
|
||||
return
|
||||
}
|
||||
@@ -181,103 +173,106 @@ func FilePath(configFilePath string) string {
|
||||
}
|
||||
|
||||
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
|
||||
// The function is designed to run in a background goroutine and will never panic.
|
||||
// It enforces a 3-hour rate limit to avoid frequent checks on config/auth file changes.
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) {
|
||||
// It coalesces concurrent sync attempts and returns whether the asset exists after the sync attempt.
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) bool {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
if disableControlPanel.Load() {
|
||||
log.Debug("management asset sync skipped: control panel disabled by configuration")
|
||||
return
|
||||
}
|
||||
|
||||
staticDir = strings.TrimSpace(staticDir)
|
||||
if staticDir == "" {
|
||||
log.Debug("management asset sync skipped: empty static directory")
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localFileMissing := false
|
||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||
if errors.Is(errStat, os.ErrNotExist) {
|
||||
localFileMissing = true
|
||||
} else {
|
||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting: check only once every 3 hours
|
||||
lastUpdateCheckMu.Lock()
|
||||
now := time.Now()
|
||||
timeSinceLastCheck := now.Sub(lastUpdateCheckTime)
|
||||
if timeSinceLastCheck < updateCheckInterval {
|
||||
_, _, _ = sfGroup.Do(localPath, func() (interface{}, error) {
|
||||
lastUpdateCheckMu.Lock()
|
||||
now := time.Now()
|
||||
timeSinceLastAttempt := now.Sub(lastUpdateCheckTime)
|
||||
if !lastUpdateCheckTime.IsZero() && timeSinceLastAttempt < managementSyncMinInterval {
|
||||
lastUpdateCheckMu.Unlock()
|
||||
log.Debugf(
|
||||
"management asset sync skipped by throttle: last attempt %v ago (interval %v)",
|
||||
timeSinceLastAttempt.Round(time.Second),
|
||||
managementSyncMinInterval,
|
||||
)
|
||||
return nil, nil
|
||||
}
|
||||
lastUpdateCheckTime = now
|
||||
lastUpdateCheckMu.Unlock()
|
||||
log.Debugf("management asset update check skipped: last check was %v ago (interval: %v)", timeSinceLastCheck.Round(time.Second), updateCheckInterval)
|
||||
return
|
||||
}
|
||||
lastUpdateCheckTime = now
|
||||
lastUpdateCheckMu.Unlock()
|
||||
|
||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||
return
|
||||
}
|
||||
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localHash, err := fileSHA256(localPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.WithError(err).Debug("failed to read local management asset hash")
|
||||
}
|
||||
localHash = ""
|
||||
}
|
||||
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
localFileMissing := false
|
||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||
if errors.Is(errStat, os.ErrNotExist) {
|
||||
localFileMissing = true
|
||||
} else {
|
||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return
|
||||
}
|
||||
|
||||
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
|
||||
log.Debug("management asset is already up to date")
|
||||
return
|
||||
}
|
||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localHash, err := fileSHA256(localPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.WithError(err).Debug("failed to read local management asset hash")
|
||||
}
|
||||
return
|
||||
localHash = ""
|
||||
}
|
||||
log.WithError(err).Warn("failed to download management asset")
|
||||
return
|
||||
}
|
||||
|
||||
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
||||
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
||||
}
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to update management asset on disk")
|
||||
return
|
||||
}
|
||||
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
|
||||
log.Debug("management asset is already up to date")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
log.WithError(err).Warn("failed to download management asset")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
||||
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to update management asset on disk")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
_, err := os.Stat(localPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
|
||||
|
||||
21
internal/registry/kilo_models.go
Normal file
21
internal/registry/kilo_models.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Package registry provides model definitions for various AI service providers.
|
||||
package registry
|
||||
|
||||
// GetKiloModels returns the Kilo model definitions
|
||||
func GetKiloModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
// --- Base Models ---
|
||||
{
|
||||
ID: "kilo/auto",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "kilo",
|
||||
Type: "kilo",
|
||||
DisplayName: "Kilo Auto",
|
||||
Description: "Automatic model selection by Kilo",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,9 @@ import (
|
||||
// - codex
|
||||
// - qwen
|
||||
// - iflow
|
||||
// - kimi
|
||||
// - kiro
|
||||
// - kilo
|
||||
// - github-copilot
|
||||
// - kiro
|
||||
// - amazonq
|
||||
@@ -43,10 +45,14 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
return GetQwenModels()
|
||||
case "iflow":
|
||||
return GetIFlowModels()
|
||||
case "kimi":
|
||||
return GetKimiModels()
|
||||
case "github-copilot":
|
||||
return GetGitHubCopilotModels()
|
||||
case "kiro":
|
||||
return GetKiroModels()
|
||||
case "kilo":
|
||||
return GetKiloModels()
|
||||
case "amazonq":
|
||||
return GetAmazonQModels()
|
||||
case "antigravity":
|
||||
@@ -93,8 +99,10 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
GetOpenAIModels(),
|
||||
GetQwenModels(),
|
||||
GetIFlowModels(),
|
||||
GetKimiModels(),
|
||||
GetGitHubCopilotModels(),
|
||||
GetKiroModels(),
|
||||
GetKiloModels(),
|
||||
GetAmazonQModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
@@ -144,6 +152,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5-mini",
|
||||
@@ -156,6 +165,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5-codex",
|
||||
@@ -168,6 +178,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1",
|
||||
@@ -180,6 +191,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex",
|
||||
@@ -192,6 +204,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex-mini",
|
||||
@@ -204,6 +217,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex-max",
|
||||
@@ -216,6 +230,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2",
|
||||
@@ -228,6 +243,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2-codex",
|
||||
@@ -240,6 +256,20 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.3-codex",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5.3 Codex",
|
||||
Description: "OpenAI GPT-5.3 Codex via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-haiku-4.5",
|
||||
@@ -277,6 +307,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4.6",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "Claude Opus 4.6",
|
||||
Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4",
|
||||
Object: "model",
|
||||
@@ -301,6 +343,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4.6",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "Claude Sonnet 4.6",
|
||||
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "gemini-2.5-pro",
|
||||
Object: "model",
|
||||
@@ -376,6 +430,30 @@ func GetKiroModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-opus-4-6",
|
||||
Object: "model",
|
||||
Created: 1736899200, // 2025-01-15
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Claude Opus 4.6",
|
||||
Description: "Claude Opus 4.6 via Kiro (2.2x credit)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4-6",
|
||||
Object: "model",
|
||||
Created: 1739836800, // 2025-02-18
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Claude Sonnet 4.6",
|
||||
Description: "Claude Sonnet 4.6 via Kiro (1.3x credit)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-opus-4-5",
|
||||
Object: "model",
|
||||
@@ -424,7 +502,112 @@ func GetKiroModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
// --- 第三方模型 (通过 Kiro 接入) ---
|
||||
{
|
||||
ID: "kiro-deepseek-3-2",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro DeepSeek 3.2",
|
||||
Description: "DeepSeek 3.2 via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-minimax-m2-1",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro MiniMax M2.1",
|
||||
Description: "MiniMax M2.1 via Kiro",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-qwen3-coder-next",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Qwen3 Coder Next",
|
||||
Description: "Qwen3 Coder Next via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-4o",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-4o",
|
||||
Description: "OpenAI GPT-4o via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-4",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-4",
|
||||
Description: "OpenAI GPT-4 via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 8192,
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-4-turbo",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-4 Turbo",
|
||||
Description: "OpenAI GPT-4 Turbo via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-3-5-turbo",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-3.5 Turbo",
|
||||
Description: "OpenAI GPT-3.5 Turbo via Kiro",
|
||||
ContextLength: 16384,
|
||||
MaxCompletionTokens: 4096,
|
||||
},
|
||||
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
|
||||
{
|
||||
ID: "kiro-claude-opus-4-6-agentic",
|
||||
Object: "model",
|
||||
Created: 1736899200, // 2025-01-15
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Claude Opus 4.6 (Agentic)",
|
||||
Description: "Claude Opus 4.6 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4-6-agentic",
|
||||
Object: "model",
|
||||
Created: 1739836800, // 2025-02-18
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Claude Sonnet 4.6 (Agentic)",
|
||||
Description: "Claude Sonnet 4.6 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-opus-4-5-agentic",
|
||||
Object: "model",
|
||||
|
||||
@@ -15,7 +15,7 @@ func GetClaudeModels() []*ModelInfo {
|
||||
DisplayName: "Claude 4.5 Haiku",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
// Thinking: not supported for Haiku models
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-5-20250929",
|
||||
@@ -28,6 +28,41 @@ func GetClaudeModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-6",
|
||||
Object: "model",
|
||||
Created: 1771372800, // 2026-02-17
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.6 Sonnet",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-6",
|
||||
Object: "model",
|
||||
Created: 1770318000, // 2026-02-05
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.6 Opus",
|
||||
Description: "Premium model combining maximum intelligence with practical performance",
|
||||
ContextLength: 1000000,
|
||||
MaxCompletionTokens: 128000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-6",
|
||||
Object: "model",
|
||||
Created: 1771286400, // 2026-02-17
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.6 Sonnet",
|
||||
Description: "Best combination of speed and intelligence",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-5-20251101",
|
||||
Object: "model",
|
||||
@@ -716,6 +751,34 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.3-codex",
|
||||
Object: "model",
|
||||
Created: 1770307200,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.3",
|
||||
DisplayName: "GPT 5.3 Codex",
|
||||
Description: "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.",
|
||||
ContextLength: 400000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.3-codex-spark",
|
||||
Object: "model",
|
||||
Created: 1770912000,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.3",
|
||||
DisplayName: "GPT 5.3 Codex Spark",
|
||||
Description: "Ultra-fast coding model.",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -748,6 +811,19 @@ func GetQwenModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 2048,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
{
|
||||
ID: "coder-model",
|
||||
Object: "model",
|
||||
Created: 1771171200,
|
||||
OwnedBy: "qwen",
|
||||
Type: "qwen",
|
||||
Version: "3.5",
|
||||
DisplayName: "Qwen 3.5 Plus",
|
||||
Description: "efficient hybrid model with leading coding performance",
|
||||
ContextLength: 1048576,
|
||||
MaxCompletionTokens: 65536,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
{
|
||||
ID: "vision-model",
|
||||
Object: "model",
|
||||
@@ -788,6 +864,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-5", DisplayName: "GLM-5", Description: "Zhipu GLM 5 general model", Created: 1770768000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
@@ -802,6 +879,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.5", DisplayName: "MiniMax-M2.5", Description: "MiniMax M2.5", Created: 1770825600, Thinking: iFlowThinkingSupport},
|
||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
|
||||
}
|
||||
@@ -838,10 +916,54 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
"tab_flash_lite_preview": {},
|
||||
}
|
||||
}
|
||||
|
||||
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions
|
||||
func GetKimiModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
{
|
||||
ID: "kimi-k2",
|
||||
Object: "model",
|
||||
Created: 1752192000, // 2025-07-11
|
||||
OwnedBy: "moonshot",
|
||||
Type: "kimi",
|
||||
DisplayName: "Kimi K2",
|
||||
Description: "Kimi K2 - Moonshot AI's flagship coding model",
|
||||
ContextLength: 131072,
|
||||
MaxCompletionTokens: 32768,
|
||||
},
|
||||
{
|
||||
ID: "kimi-k2-thinking",
|
||||
Object: "model",
|
||||
Created: 1762387200, // 2025-11-06
|
||||
OwnedBy: "moonshot",
|
||||
Type: "kimi",
|
||||
DisplayName: "Kimi K2 Thinking",
|
||||
Description: "Kimi K2 Thinking - Extended reasoning model",
|
||||
ContextLength: 131072,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kimi-k2.5",
|
||||
Object: "model",
|
||||
Created: 1769472000, // 2026-01-26
|
||||
OwnedBy: "moonshot",
|
||||
Type: "kimi",
|
||||
DisplayName: "Kimi K2.5",
|
||||
Description: "Kimi K2.5 - Latest Moonshot AI coding model with improved capabilities",
|
||||
ContextLength: 131072,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -601,8 +601,7 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
now := time.Now()
|
||||
registration.QuotaExceededClients[clientID] = &now
|
||||
registration.QuotaExceededClients[clientID] = new(time.Now())
|
||||
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: bytes.Clone(body.payload),
|
||||
Body: body.payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -156,20 +156,20 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
|
||||
if len(wsResp.Body) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(wsResp.Body))
|
||||
appendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
|
||||
}
|
||||
if wsResp.Status < 200 || wsResp.Status >= 300 {
|
||||
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
|
||||
}
|
||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the AI Studio API.
|
||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -199,7 +199,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: bytes.Clone(body.payload),
|
||||
Body: body.payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -225,7 +225,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
var body bytes.Buffer
|
||||
if len(firstEvent.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(firstEvent.Payload))
|
||||
appendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
|
||||
body.Write(firstEvent.Payload)
|
||||
}
|
||||
if firstEvent.Type == wsrelay.MessageTypeStreamEnd {
|
||||
@@ -244,7 +244,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
metadataLogged = true
|
||||
}
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
body.Write(event.Payload)
|
||||
}
|
||||
if event.Type == wsrelay.MessageTypeStreamEnd {
|
||||
@@ -254,7 +254,6 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(first wsrelay.StreamEvent) {
|
||||
defer close(out)
|
||||
var param any
|
||||
@@ -274,12 +273,12 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
case wsrelay.MessageTypeStreamChunk:
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
filtered := FilterSSEUsageMetadata(event.Payload)
|
||||
if detail, ok := parseGeminiStreamUsage(filtered); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(filtered), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
}
|
||||
@@ -293,9 +292,9 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
metadataLogged = true
|
||||
}
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
}
|
||||
@@ -318,7 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
}
|
||||
}(firstEvent)
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request using the AI Studio API.
|
||||
@@ -350,7 +349,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: bytes.Clone(body.payload),
|
||||
Body: body.payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -364,7 +363,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
|
||||
if len(resp.Body) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body))
|
||||
appendAPIResponseChunk(ctx, e.cfg, resp.Body)
|
||||
}
|
||||
if resp.Status < 200 || resp.Status >= 300 {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
|
||||
@@ -373,7 +372,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
if totalTokens <= 0 {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
|
||||
}
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, bytes.Clone(resp.Body))
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
}
|
||||
|
||||
@@ -393,12 +392,13 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, translatedPayload{}, err
|
||||
|
||||
@@ -133,12 +133,13 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -230,8 +231,8 @@ attemptLoop:
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
@@ -274,12 +275,13 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -433,8 +435,8 @@ attemptLoop:
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
return resp, nil
|
||||
@@ -643,7 +645,7 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -665,12 +667,13 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -772,7 +775,6 @@ attemptLoop:
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -800,12 +802,12 @@ attemptLoop:
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), ¶m)
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m)
|
||||
for i := range tail {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
||||
}
|
||||
@@ -817,7 +819,7 @@ attemptLoop:
|
||||
reporter.ensurePublished(ctx)
|
||||
}
|
||||
}(httpResp)
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
@@ -872,7 +874,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
// Prepare payload once (doesn't depend on baseURL)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -965,7 +967,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
lastStatus = httpResp.StatusCode
|
||||
@@ -1004,7 +1006,12 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
||||
exec := &AntigravityExecutor{cfg: cfg}
|
||||
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil || token == "" {
|
||||
if errToken != nil {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: token error: %v", auth.ID, errToken)
|
||||
return nil
|
||||
}
|
||||
if token == "" {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: got empty token", auth.ID)
|
||||
return nil
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
@@ -1018,6 +1025,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
modelsURL := baseURL + antigravityModelsPath
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
|
||||
if errReq != nil {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: create request error: %v", auth.ID, errReq)
|
||||
return nil
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
@@ -1030,12 +1038,14 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: context canceled: %v", auth.ID, errDo)
|
||||
return nil
|
||||
}
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: request error: %v", auth.ID, errDo)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1048,6 +1058,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: read body error: %v", auth.ID, errRead)
|
||||
return nil
|
||||
}
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
@@ -1055,11 +1066,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: unexpected status %d, body: %s", auth.ID, httpResp.StatusCode, string(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: no models field in response, body: %s", auth.ID, string(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestAntigravityBuildRequest_SanitizesGeminiToolSchema(t *testing.T) {
|
||||
body := buildRequestBodyFromPayload(t, "gemini-2.5-pro")
|
||||
|
||||
decl := extractFirstFunctionDeclaration(t, body)
|
||||
if _, ok := decl["parametersJsonSchema"]; ok {
|
||||
t.Fatalf("parametersJsonSchema should be renamed to parameters")
|
||||
}
|
||||
|
||||
params, ok := decl["parameters"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parameters missing or invalid type")
|
||||
}
|
||||
assertSchemaSanitizedAndPropertyPreserved(t, params)
|
||||
}
|
||||
|
||||
func TestAntigravityBuildRequest_SanitizesAntigravityToolSchema(t *testing.T) {
|
||||
body := buildRequestBodyFromPayload(t, "claude-opus-4-6")
|
||||
|
||||
decl := extractFirstFunctionDeclaration(t, body)
|
||||
params, ok := decl["parameters"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parameters missing or invalid type")
|
||||
}
|
||||
assertSchemaSanitizedAndPropertyPreserved(t, params)
|
||||
}
|
||||
|
||||
func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
executor := &AntigravityExecutor{}
|
||||
auth := &cliproxyauth.Auth{}
|
||||
payload := []byte(`{
|
||||
"request": {
|
||||
"tools": [
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"name": "tool_1",
|
||||
"parametersJsonSchema": {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"$id": "root-schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"$id": {"type": "string"},
|
||||
"arg": {
|
||||
"type": "object",
|
||||
"prefill": "hello",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["a", "b"],
|
||||
"enumTitles": ["A", "B"]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"patternProperties": {
|
||||
"^x-": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
|
||||
req, err := executor.buildRequest(context.Background(), auth, "token", modelName, payload, false, "", "https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("buildRequest error: %v", err)
|
||||
}
|
||||
|
||||
raw, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body error: %v", err)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(raw, &body); err != nil {
|
||||
t.Fatalf("unmarshal request body error: %v, body=%s", err, string(raw))
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func extractFirstFunctionDeclaration(t *testing.T, body map[string]any) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
request, ok := body["request"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("request missing or invalid type")
|
||||
}
|
||||
tools, ok := request["tools"].([]any)
|
||||
if !ok || len(tools) == 0 {
|
||||
t.Fatalf("tools missing or empty")
|
||||
}
|
||||
tool, ok := tools[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("first tool invalid type")
|
||||
}
|
||||
decls, ok := tool["function_declarations"].([]any)
|
||||
if !ok || len(decls) == 0 {
|
||||
t.Fatalf("function_declarations missing or empty")
|
||||
}
|
||||
decl, ok := decls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("first function declaration invalid type")
|
||||
}
|
||||
return decl
|
||||
}
|
||||
|
||||
func assertSchemaSanitizedAndPropertyPreserved(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
|
||||
if _, ok := params["$id"]; ok {
|
||||
t.Fatalf("root $id should be removed from schema")
|
||||
}
|
||||
if _, ok := params["patternProperties"]; ok {
|
||||
t.Fatalf("patternProperties should be removed from schema")
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("properties missing or invalid type")
|
||||
}
|
||||
if _, ok := props["$id"]; !ok {
|
||||
t.Fatalf("property named $id should be preserved")
|
||||
}
|
||||
|
||||
arg, ok := props["arg"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("arg property missing or invalid type")
|
||||
}
|
||||
if _, ok := arg["prefill"]; ok {
|
||||
t.Fatalf("prefill should be removed from nested schema")
|
||||
}
|
||||
|
||||
argProps, ok := arg["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("arg.properties missing or invalid type")
|
||||
}
|
||||
mode, ok := argProps["mode"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("mode property missing or invalid type")
|
||||
}
|
||||
if _, ok := mode["enumTitles"]; ok {
|
||||
t.Fatalf("enumTitles should be removed from nested schema")
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -100,12 +101,13 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -133,7 +135,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
bodyForTranslation := body
|
||||
bodyForUpstream := body
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
@@ -142,7 +144,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -207,7 +209,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
} else {
|
||||
reporter.publish(ctx, parseClaudeUsage(data))
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
|
||||
}
|
||||
var param any
|
||||
@@ -216,16 +218,16 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
to,
|
||||
from,
|
||||
req.Model,
|
||||
bytes.Clone(opts.OriginalRequest),
|
||||
opts.OriginalRequest,
|
||||
bodyForTranslation,
|
||||
data,
|
||||
¶m,
|
||||
)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -240,12 +242,13 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("claude")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -273,7 +276,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
bodyForTranslation := body
|
||||
bodyForUpstream := body
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
@@ -282,7 +285,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -327,7 +330,6 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -346,7 +348,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
}
|
||||
// Forward the line as-is to preserve SSE format
|
||||
@@ -373,7 +375,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(
|
||||
@@ -381,7 +383,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
to,
|
||||
from,
|
||||
req.Model,
|
||||
bytes.Clone(opts.OriginalRequest),
|
||||
opts.OriginalRequest,
|
||||
bodyForTranslation,
|
||||
bytes.Clone(line),
|
||||
¶m,
|
||||
@@ -396,7 +398,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
@@ -411,7 +413,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
if !strings.HasPrefix(baseModel, "claude-3-5-haiku") {
|
||||
@@ -421,7 +423,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
// Extract betas from body and convert to header (for count_tokens too)
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
body = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
@@ -430,7 +432,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -485,7 +487,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "input_tokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
@@ -636,7 +638,49 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
|
||||
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
||||
func mapStainlessOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return "MacOS"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "freebsd":
|
||||
return "FreeBSD"
|
||||
default:
|
||||
return "Other::" + runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
||||
func mapStainlessArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "arm64":
|
||||
return "arm64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return "other::" + runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
|
||||
hdrDefault := func(cfgVal, fallback string) string {
|
||||
if cfgVal != "" {
|
||||
return cfgVal
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
var hd config.ClaudeHeaderDefaults
|
||||
if cfg != nil {
|
||||
hd = cfg.ClaudeHeaderDefaults
|
||||
}
|
||||
|
||||
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
|
||||
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
|
||||
if isAnthropicBase && useAPIKey {
|
||||
@@ -683,16 +727,17 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", "v24.3.0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", "0.55.1")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", "arm64")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", "MacOS")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", "60")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "claude-cli/1.0.83 (external, cli)")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
|
||||
r.Header.Set("Connection", "keep-alive")
|
||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||
if stream {
|
||||
@@ -700,6 +745,8 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
} else {
|
||||
r.Header.Set("Accept", "application/json")
|
||||
}
|
||||
// Keep OS/Arch mapping dynamic (not configurable).
|
||||
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
@@ -751,11 +798,21 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
// Collect built-in tool names (those with a non-empty "type" field) so we can
|
||||
// skip them consistently in both tools and message history.
|
||||
builtinTools := map[string]bool{}
|
||||
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
|
||||
builtinTools[name] = true
|
||||
}
|
||||
|
||||
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
||||
tools.ForEach(func(index, tool gjson.Result) bool {
|
||||
// Skip built-in tools (web_search, code_execution, etc.) which have
|
||||
// a "type" field and require their name to remain unchanged.
|
||||
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
|
||||
if n := tool.Get("name").String(); n != "" {
|
||||
builtinTools[n] = true
|
||||
}
|
||||
return true
|
||||
}
|
||||
name := tool.Get("name").String()
|
||||
@@ -770,7 +827,7 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
|
||||
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
|
||||
name := gjson.GetBytes(body, "tool_choice.name").String()
|
||||
if name != "" && !strings.HasPrefix(name, prefix) {
|
||||
if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] {
|
||||
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
|
||||
}
|
||||
}
|
||||
@@ -782,15 +839,38 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
return true
|
||||
}
|
||||
content.ForEach(func(contentIndex, part gjson.Result) bool {
|
||||
if part.Get("type").String() != "tool_use" {
|
||||
return true
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "tool_use":
|
||||
name := part.Get("name").String()
|
||||
if name == "" || strings.HasPrefix(name, prefix) || builtinTools[name] {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+name)
|
||||
case "tool_reference":
|
||||
toolName := part.Get("tool_name").String()
|
||||
if toolName == "" || strings.HasPrefix(toolName, prefix) || builtinTools[toolName] {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+toolName)
|
||||
case "tool_result":
|
||||
// Handle nested tool_reference blocks inside tool_result.content[]
|
||||
nestedContent := part.Get("content")
|
||||
if nestedContent.Exists() && nestedContent.IsArray() {
|
||||
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
|
||||
if nestedPart.Get("type").String() == "tool_reference" {
|
||||
nestedToolName := nestedPart.Get("tool_name").String()
|
||||
if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) && !builtinTools[nestedToolName] {
|
||||
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
name := part.Get("name").String()
|
||||
if name == "" || strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+name)
|
||||
return true
|
||||
})
|
||||
return true
|
||||
@@ -809,15 +889,38 @@ func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
|
||||
return body
|
||||
}
|
||||
content.ForEach(func(index, part gjson.Result) bool {
|
||||
if part.Get("type").String() != "tool_use" {
|
||||
return true
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "tool_use":
|
||||
name := part.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("content.%d.name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
|
||||
case "tool_reference":
|
||||
toolName := part.Get("tool_name").String()
|
||||
if !strings.HasPrefix(toolName, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("content.%d.tool_name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix))
|
||||
case "tool_result":
|
||||
// Handle nested tool_reference blocks inside tool_result.content[]
|
||||
nestedContent := part.Get("content")
|
||||
if nestedContent.Exists() && nestedContent.IsArray() {
|
||||
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
|
||||
if nestedPart.Get("type").String() == "tool_reference" {
|
||||
nestedToolName := nestedPart.Get("tool_name").String()
|
||||
if strings.HasPrefix(nestedToolName, prefix) {
|
||||
nestedPath := fmt.Sprintf("content.%d.content.%d.tool_name", index.Int(), nestedIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, nestedPath, strings.TrimPrefix(nestedToolName, prefix))
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
name := part.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("content.%d.name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
|
||||
return true
|
||||
})
|
||||
return body
|
||||
@@ -832,15 +935,34 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
|
||||
return line
|
||||
}
|
||||
contentBlock := gjson.GetBytes(payload, "content_block")
|
||||
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
|
||||
if !contentBlock.Exists() {
|
||||
return line
|
||||
}
|
||||
name := contentBlock.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return line
|
||||
}
|
||||
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
|
||||
if err != nil {
|
||||
|
||||
blockType := contentBlock.Get("type").String()
|
||||
var updated []byte
|
||||
var err error
|
||||
|
||||
switch blockType {
|
||||
case "tool_use":
|
||||
name := contentBlock.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return line
|
||||
}
|
||||
updated, err = sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
case "tool_reference":
|
||||
toolName := contentBlock.Get("tool_name").String()
|
||||
if !strings.HasPrefix(toolName, prefix) {
|
||||
return line
|
||||
}
|
||||
updated, err = sjson.SetBytes(payload, "content_block.tool_name", strings.TrimPrefix(toolName, prefix))
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
default:
|
||||
return line
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_WithToolReference(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"name":"alpha"}],"messages":[{"role":"user","content":[{"type":"tool_reference","tool_name":"beta"},{"type":"tool_reference","tool_name":"proxy_gamma"}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.tool_name").String(); got != "proxy_beta" {
|
||||
t.Fatalf("messages.0.content.0.tool_name = %q, want %q", got, "proxy_beta")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != "proxy_gamma" {
|
||||
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, "proxy_gamma")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
@@ -37,6 +49,97 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_BuiltinToolSkipped(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"type": "web_search_20250305", "name": "web_search", "max_uses": 5},
|
||||
{"name": "Read"}
|
||||
],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}},
|
||||
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_KnownBuiltinInHistoryOnly(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"name": "Read"}
|
||||
],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_CustomToolsPrefixed(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [{"name": "Read"}, {"name": "Write"}],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}},
|
||||
{"type": "tool_use", "name": "Write", "id": "w1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Write" {
|
||||
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Write")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Write" {
|
||||
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"type": "web_search_20250305", "name": "web_search"},
|
||||
{"name": "Read"}
|
||||
],
|
||||
"tool_choice": {"type": "tool", "name": "web_search"}
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
@@ -49,6 +152,18 @@ func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse_WithToolReference(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_reference","tool_name":"proxy_alpha"},{"type":"tool_reference","tool_name":"bravo"}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "content.0.tool_name").String(); got != "alpha" {
|
||||
t.Fatalf("content.0.tool_name = %q, want %q", got, "alpha")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "content.1.tool_name").String(); got != "bravo" {
|
||||
t.Fatalf("content.1.tool_name = %q, want %q", got, "bravo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
|
||||
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||
@@ -61,3 +176,53 @@ func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromStreamLine_WithToolReference(t *testing.T) {
|
||||
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_reference","tool_name":"proxy_beta"},"index":0}`)
|
||||
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||
|
||||
payload := bytes.TrimSpace(out)
|
||||
if bytes.HasPrefix(payload, []byte("data:")) {
|
||||
payload = bytes.TrimSpace(payload[len("data:"):])
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "content_block.tool_name").String(); got != "beta" {
|
||||
t.Fatalf("content_block.tool_name = %q, want %q", got, "beta")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"mcp__nia__manage_resource"}]}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
|
||||
if got != "proxy_mcp__nia__manage_resource" {
|
||||
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "proxy_mcp__nia__manage_resource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "content.0.content.0.tool_name").String()
|
||||
if got != "mcp__nia__manage_resource" {
|
||||
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "mcp__nia__manage_resource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T) {
|
||||
// tool_result.content can be a string - should not be processed
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"plain string result"}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content").String()
|
||||
if got != "plain string result" {
|
||||
t.Fatalf("string content should remain unchanged = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
|
||||
if got != "web_search" {
|
||||
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
codexClientVersion = "0.101.0"
|
||||
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
)
|
||||
|
||||
var dataTag = []byte("data:")
|
||||
|
||||
// CodexExecutor is a stateless executor for Codex (OpenAI Responses API entrypoint).
|
||||
@@ -88,12 +93,13 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -176,8 +182,8 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
}
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, line, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
|
||||
@@ -197,12 +203,13 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai-response")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -265,12 +272,12 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||
reporter.ensurePublished(ctx)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
|
||||
}
|
||||
@@ -286,12 +293,13 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -354,7 +362,6 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -378,7 +385,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -389,7 +396,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
@@ -397,7 +404,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -634,10 +641,9 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", "0.21.0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Openai-Beta", "responses=experimental")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "codex_cli_rs/0.50.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
1408
internal/runtime/executor/codex_websockets_executor.go
Normal file
1408
internal/runtime/executor/codex_websockets_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -119,12 +119,13 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -223,8 +224,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 {
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), payload, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -255,7 +256,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Gemini CLI API.
|
||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -272,12 +273,13 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -380,7 +382,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response, reqBody []byte, attemptModel string) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -399,14 +400,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
@@ -428,18 +429,18 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
||||
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
if len(lastBody) > 0 {
|
||||
@@ -485,7 +486,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
|
||||
// Gemini CLI endpoint when iterating fallback variants.
|
||||
for range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -544,7 +545,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
lastStatus = resp.StatusCode
|
||||
lastBody = append([]byte(nil), data...)
|
||||
@@ -897,8 +898,7 @@ func parseRetryDelay(errorBody []byte) (*time.Duration, error) {
|
||||
if matches := re.FindStringSubmatch(message); len(matches) > 1 {
|
||||
seconds, err := strconv.Atoi(matches[1])
|
||||
if err == nil {
|
||||
duration := time.Duration(seconds) * time.Second
|
||||
return &duration, nil
|
||||
return new(time.Duration(seconds) * time.Second), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,12 +116,13 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
// Official Gemini API via API key or OAuth bearer
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -203,13 +204,13 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Gemini API.
|
||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -222,12 +223,13 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -296,7 +298,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -318,12 +319,12 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseGeminiStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(payload), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
@@ -333,7 +334,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request using the Gemini API.
|
||||
@@ -344,7 +345,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -414,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
||||
|
||||
@@ -253,7 +253,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Vertex AI API.
|
||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -318,12 +318,13 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -417,8 +418,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -432,12 +433,13 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -521,13 +523,13 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -536,12 +538,13 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -615,7 +618,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -632,12 +634,12 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
@@ -647,11 +649,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -660,12 +662,13 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -739,7 +742,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -756,12 +758,12 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
@@ -771,7 +773,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// countTokensWithServiceAccount counts tokens using service account credentials.
|
||||
@@ -781,7 +783,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -855,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||
@@ -865,7 +867,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -939,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
333
internal/runtime/executor/github_copilot_executor_test.go
Normal file
333
internal/runtime/executor/github_copilot_executor_test.go
Normal file
@@ -0,0 +1,333 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "suffix stripped",
|
||||
model: "claude-opus-4.6(medium)",
|
||||
wantModel: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "no suffix unchanged",
|
||||
model: "claude-opus-4.6",
|
||||
wantModel: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "different suffix stripped",
|
||||
model: "gpt-4o(high)",
|
||||
wantModel: "gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "numeric suffix stripped",
|
||||
model: "gemini-2.5-pro(8192)",
|
||||
wantModel: "gemini-2.5-pro",
|
||||
},
|
||||
}
|
||||
|
||||
e := &GitHubCopilotExecutor{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"` + tt.model + `","messages":[]}`)
|
||||
got := e.normalizeModel(tt.model, body)
|
||||
|
||||
gotModel := gjson.GetBytes(got, "model").String()
|
||||
if gotModel != tt.wantModel {
|
||||
t.Fatalf("normalizeModel() model = %q, want %q", gotModel, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") {
|
||||
t.Fatal("expected openai-response source to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") {
|
||||
t.Fatal("expected codex model to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
|
||||
t.Fatal("expected default openai source with non-codex model to use /chat/completions")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`)
|
||||
got := normalizeGitHubCopilotChatTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("tools len = %d, want 1", len(tools))
|
||||
}
|
||||
if tools[0].Get("type").String() != "function" {
|
||||
t.Fatalf("tool type = %q, want function", tools[0].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`)
|
||||
got := normalizeGitHubCopilotChatTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
|
||||
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
in := gjson.GetBytes(got, "input")
|
||||
if !in.IsArray() {
|
||||
t.Fatalf("input type = %v, want array", in.Type)
|
||||
}
|
||||
raw := in.Raw
|
||||
if !strings.Contains(raw, "sys text") || !strings.Contains(raw, "user text") || !strings.Contains(raw, "assistant text") {
|
||||
t.Fatalf("input = %s, want structured array with all texts", raw)
|
||||
}
|
||||
if gjson.GetBytes(got, "messages").Exists() {
|
||||
t.Fatal("messages should be removed after conversion")
|
||||
}
|
||||
if gjson.GetBytes(got, "system").Exists() {
|
||||
t.Fatal("system should be removed after conversion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":{"foo":"bar"}}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
in := gjson.GetBytes(got, "input")
|
||||
if in.Type != gjson.String {
|
||||
t.Fatalf("input type = %v, want string", in.Type)
|
||||
}
|
||||
if !strings.Contains(in.String(), "foo") {
|
||||
t.Fatalf("input = %q, want stringified object", in.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("tools len = %d, want 1", len(tools))
|
||||
}
|
||||
if tools[0].Get("name").String() != "sum" {
|
||||
t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String())
|
||||
}
|
||||
if !tools[0].Get("parameters").Exists() {
|
||||
t.Fatal("expected parameters to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("tools len = %d, want 2", len(tools))
|
||||
}
|
||||
if tools[0].Get("type").String() != "function" {
|
||||
t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String())
|
||||
}
|
||||
if tools[0].Get("name").String() != "Bash" {
|
||||
t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String())
|
||||
}
|
||||
if tools[0].Get("description").String() != "Run commands" {
|
||||
t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String())
|
||||
}
|
||||
if !tools[0].Get("parameters").Exists() {
|
||||
t.Fatal("expected parameters to be set from input_schema")
|
||||
}
|
||||
if tools[0].Get("parameters.properties.command").Exists() != true {
|
||||
t.Fatal("expected parameters.properties.command to exist")
|
||||
}
|
||||
if tools[1].Get("name").String() != "Read" {
|
||||
t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice.type").String() != "function" {
|
||||
t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String())
|
||||
}
|
||||
if gjson.GetBytes(got, "tool_choice.name").String() != "sum" {
|
||||
t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tool_choice":{"type":"function"}}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
|
||||
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.Get(out, "type").String() != "message" {
|
||||
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.type").String() != "text" {
|
||||
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.text").String() != "hello" {
|
||||
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.Get(out, "content.0.type").String() != "tool_use" {
|
||||
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.name").String() != "sum" {
|
||||
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
|
||||
}
|
||||
if gjson.Get(out, "stop_reason").String() != "tool_use" {
|
||||
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
var param any
|
||||
|
||||
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), ¶m)
|
||||
if len(created) == 0 || !strings.Contains(created[0], "message_start") {
|
||||
t.Fatalf("created events = %#v, want message_start", created)
|
||||
}
|
||||
|
||||
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), ¶m)
|
||||
joinedDelta := strings.Join(delta, "")
|
||||
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
|
||||
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
|
||||
}
|
||||
|
||||
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), ¶m)
|
||||
joinedCompleted := strings.Join(completed, "")
|
||||
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
|
||||
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for X-Initiator detection logic (Problem L) ---
|
||||
|
||||
func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"messages":[{"role":"system","content":"sys"},{"role":"user","content":"hello"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_AgentWithAssistantAndUserToolResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// Claude Code typical flow: last message is user (tool result), but has assistant in history
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (assistant exists in messages)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (tool role exists)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for x-github-api-version header (Problem M) ---
|
||||
|
||||
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
e.applyHeaders(req, "token", nil)
|
||||
if got := req.Header.Get("X-Github-Api-Version"); got != "2025-04-01" {
|
||||
t.Fatalf("X-Github-Api-Version = %q, want 2025-04-01", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for vision detection (Problem P) ---
|
||||
|
||||
func TestDetectVisionContent_WithImageURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
|
||||
if !detectVisionContent(body) {
|
||||
t.Fatal("expected vision content to be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectVisionContent_WithImageType(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"image","source":{"data":"abc","media_type":"image/png"}}]}]}`)
|
||||
if !detectVisionContent(body) {
|
||||
t.Fatal("expected image type to be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectVisionContent_NoVision(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||
if detectVisionContent(body) {
|
||||
t.Fatal("expected no vision content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectVisionContent_NoMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
// After Responses API normalization, messages is removed — detection should return false
|
||||
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
|
||||
if detectVisionContent(body) {
|
||||
t.Fatal("expected no vision content when messages field is absent")
|
||||
}
|
||||
}
|
||||
@@ -4,12 +4,16 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
@@ -87,12 +91,13 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||
@@ -163,13 +168,13 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming chat completion request.
|
||||
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -189,12 +194,13 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||
@@ -256,7 +262,6 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -274,7 +279,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -288,7 +293,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
@@ -296,7 +301,7 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
enc, err := tokenizerForModel(baseModel)
|
||||
if err != nil {
|
||||
@@ -451,6 +456,20 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
r.Header.Set("User-Agent", iflowUserAgent)
|
||||
|
||||
// Generate session-id
|
||||
sessionID := "session-" + generateUUID()
|
||||
r.Header.Set("session-id", sessionID)
|
||||
|
||||
// Generate timestamp and signature
|
||||
timestamp := time.Now().UnixMilli()
|
||||
r.Header.Set("x-iflow-timestamp", fmt.Sprintf("%d", timestamp))
|
||||
|
||||
signature := createIFlowSignature(iflowUserAgent, sessionID, timestamp, apiKey)
|
||||
if signature != "" {
|
||||
r.Header.Set("x-iflow-signature", signature)
|
||||
}
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
@@ -458,6 +477,23 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// createIFlowSignature generates HMAC-SHA256 signature for iFlow API requests.
|
||||
// The signature payload format is: userAgent:sessionId:timestamp
|
||||
func createIFlowSignature(userAgent, sessionID string, timestamp int64, apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return ""
|
||||
}
|
||||
payload := fmt.Sprintf("%s:%s:%d", userAgent, sessionID, timestamp)
|
||||
h := hmac.New(sha256.New, []byte(apiKey))
|
||||
h.Write([]byte(payload))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// generateUUID generates a random UUID v4 string.
|
||||
func generateUUID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
func iflowCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
if a == nil {
|
||||
return "", ""
|
||||
|
||||
460
internal/runtime/executor/kilo_executor.go
Normal file
460
internal/runtime/executor/kilo_executor.go
Normal file
@@ -0,0 +1,460 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// KiloExecutor handles requests to Kilo API.
|
||||
type KiloExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewKiloExecutor creates a new Kilo executor instance.
|
||||
func NewKiloExecutor(cfg *config.Config) *KiloExecutor {
|
||||
return &KiloExecutor{cfg: cfg}
|
||||
}
|
||||
|
||||
// Identifier returns the unique identifier for this executor.
|
||||
func (e *KiloExecutor) Identifier() string { return "kilo" }
|
||||
|
||||
// PrepareRequest prepares the HTTP request before execution.
|
||||
func (e *KiloExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
accessToken, _ := kiloCredentials(auth)
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return fmt.Errorf("kilo: missing access token")
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest executes a raw HTTP request.
|
||||
func (e *KiloExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("kilo executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming request.
|
||||
func (e *KiloExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, orgID := kiloCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return resp, fmt.Errorf("kilo: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
endpoint := "/api/openrouter/chat/completions"
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
url := "https://api.kilo.ai" + endpoint
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if orgID != "" {
|
||||
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, body)
|
||||
reporter.publish(ctx, parseOpenAIUsage(body))
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request.
|
||||
func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, orgID := kiloCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("kilo: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
endpoint := "/api/openrouter/chat/completions"
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := "https://api.kilo.ai" + endpoint
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if orgID != "" {
|
||||
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
httpReq.Header.Set("Cache-Control", "no-cache")
|
||||
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
httpResp.Body.Close()
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 52_428_800)
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
|
||||
return &cliproxyexecutor.StreamResult{
|
||||
Headers: httpResp.Header.Clone(),
|
||||
Chunks: out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refresh validates the Kilo token.
|
||||
func (e *KiloExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("missing auth")
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// CountTokens returns the token count for the given request.
|
||||
func (e *KiloExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("kilo: count tokens not supported")
|
||||
}
|
||||
|
||||
// kiloCredentials extracts access token and other info from auth.
|
||||
func kiloCredentials(auth *cliproxyauth.Auth) (accessToken, orgID string) {
|
||||
if auth == nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Prefer kilocode specific keys, then fall back to generic keys.
|
||||
// Check metadata first, then attributes.
|
||||
if auth.Metadata != nil {
|
||||
if token, ok := auth.Metadata["kilocodeToken"].(string); ok && token != "" {
|
||||
accessToken = token
|
||||
} else if token, ok := auth.Metadata["access_token"].(string); ok && token != "" {
|
||||
accessToken = token
|
||||
}
|
||||
|
||||
if org, ok := auth.Metadata["kilocodeOrganizationId"].(string); ok && org != "" {
|
||||
orgID = org
|
||||
} else if org, ok := auth.Metadata["organization_id"].(string); ok && org != "" {
|
||||
orgID = org
|
||||
}
|
||||
}
|
||||
|
||||
if accessToken == "" && auth.Attributes != nil {
|
||||
if token := auth.Attributes["kilocodeToken"]; token != "" {
|
||||
accessToken = token
|
||||
} else if token := auth.Attributes["access_token"]; token != "" {
|
||||
accessToken = token
|
||||
}
|
||||
}
|
||||
|
||||
if orgID == "" && auth.Attributes != nil {
|
||||
if org := auth.Attributes["kilocodeOrganizationId"]; org != "" {
|
||||
orgID = org
|
||||
} else if org := auth.Attributes["organization_id"]; org != "" {
|
||||
orgID = org
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, orgID
|
||||
}
|
||||
|
||||
// FetchKiloModels fetches models from Kilo API.
|
||||
func FetchKiloModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
||||
accessToken, orgID := kiloCredentials(auth)
|
||||
if accessToken == "" {
|
||||
log.Infof("kilo: no access token found, skipping dynamic model fetch (using static kilo/auto)")
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
log.Debugf("kilo: fetching dynamic models (orgID: %s)", orgID)
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.kilo.ai/api/openrouter/models", nil)
|
||||
if err != nil {
|
||||
log.Warnf("kilo: failed to create model fetch request: %v", err)
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if orgID != "" {
|
||||
req.Header.Set("X-Kilocode-OrganizationID", orgID)
|
||||
}
|
||||
req.Header.Set("User-Agent", "cli-proxy-kilo")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Warnf("kilo: fetch models canceled: %v", err)
|
||||
} else {
|
||||
log.Warnf("kilo: using static models (API fetch failed: %v)", err)
|
||||
}
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Warnf("kilo: failed to read models response: %v", err)
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("kilo: fetch models failed: status %d, body: %s", resp.StatusCode, string(body))
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(body, "data")
|
||||
if !result.Exists() {
|
||||
// Try root if data field is missing
|
||||
result = gjson.ParseBytes(body)
|
||||
if !result.IsArray() {
|
||||
log.Debugf("kilo: response body: %s", string(body))
|
||||
log.Warn("kilo: invalid API response format (expected array or data field with array)")
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
}
|
||||
|
||||
var dynamicModels []*registry.ModelInfo
|
||||
now := time.Now().Unix()
|
||||
count := 0
|
||||
totalCount := 0
|
||||
|
||||
result.ForEach(func(key, value gjson.Result) bool {
|
||||
totalCount++
|
||||
id := value.Get("id").String()
|
||||
pIdxResult := value.Get("preferredIndex")
|
||||
preferredIndex := pIdxResult.Int()
|
||||
|
||||
// Filter models where preferredIndex > 0 (Kilo-curated models)
|
||||
if preferredIndex <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if it's free. We look for :free suffix, is_free flag, or zero pricing.
|
||||
isFree := strings.HasSuffix(id, ":free") || id == "giga-potato" || value.Get("is_free").Bool()
|
||||
if !isFree {
|
||||
// Check pricing as fallback
|
||||
promptPricing := value.Get("pricing.prompt").String()
|
||||
if promptPricing == "0" || promptPricing == "0.0" {
|
||||
isFree = true
|
||||
}
|
||||
}
|
||||
|
||||
if !isFree {
|
||||
log.Debugf("kilo: skipping curated paid model: %s", id)
|
||||
return true
|
||||
}
|
||||
|
||||
log.Debugf("kilo: found curated model: %s (preferredIndex: %d)", id, preferredIndex)
|
||||
|
||||
dynamicModels = append(dynamicModels, ®istry.ModelInfo{
|
||||
ID: id,
|
||||
DisplayName: value.Get("name").String(),
|
||||
ContextLength: int(value.Get("context_length").Int()),
|
||||
OwnedBy: "kilo",
|
||||
Type: "kilo",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
})
|
||||
count++
|
||||
return true
|
||||
})
|
||||
|
||||
log.Infof("kilo: fetched %d models from API, %d curated free (preferredIndex > 0)", totalCount, count)
|
||||
if count == 0 && totalCount > 0 {
|
||||
log.Warn("kilo: no curated free models found (check API response fields)")
|
||||
}
|
||||
|
||||
staticModels := registry.GetKiloModels()
|
||||
// Always include kilo/auto (first static model)
|
||||
allModels := append(staticModels[:1], dynamicModels...)
|
||||
|
||||
return allModels
|
||||
}
|
||||
617
internal/runtime/executor/kimi_executor.go
Normal file
617
internal/runtime/executor/kimi_executor.go
Normal file
@@ -0,0 +1,617 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
kimiauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// KimiExecutor is a stateless executor for Kimi API using OpenAI-compatible chat completions.
|
||||
type KimiExecutor struct {
|
||||
ClaudeExecutor
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewKimiExecutor creates a new Kimi executor.
|
||||
func NewKimiExecutor(cfg *config.Config) *KimiExecutor { return &KimiExecutor{cfg: cfg} }
|
||||
|
||||
// Identifier returns the executor identifier.
|
||||
func (e *KimiExecutor) Identifier() string { return "kimi" }
|
||||
|
||||
// PrepareRequest injects Kimi credentials into the outgoing HTTP request.
|
||||
func (e *KimiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
token := kimiCreds(auth)
|
||||
if strings.TrimSpace(token) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects Kimi credentials into the request and executes it.
|
||||
func (e *KimiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("kimi executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming chat completion request to Kimi.
|
||||
func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
from := opts.SourceFormat
|
||||
if from.String() == "claude" {
|
||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||
return e.ClaudeExecutor.Execute(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
token := kimiCreds(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := bytes.Clone(originalPayloadSource)
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
// Strip kimi- prefix for upstream API
|
||||
upstreamModel := stripKimiPrefix(baseModel)
|
||||
body, err = sjson.SetBytes(body, "model", upstreamModel)
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("kimi executor: failed to set model in payload: %w", err)
|
||||
}
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "kimi", e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, err = normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyKimiHeadersWithAuth(httpReq, token, false, auth)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming chat completion request to Kimi.
|
||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
from := opts.SourceFormat
|
||||
if from.String() == "claude" {
|
||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||
return e.ClaudeExecutor.ExecuteStream(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
token := kimiCreds(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := bytes.Clone(originalPayloadSource)
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
// Strip kimi- prefix for upstream API
|
||||
upstreamModel := stripKimiPrefix(baseModel)
|
||||
body, err = sjson.SetBytes(body, "model", upstreamModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi executor: failed to set model in payload: %w", err)
|
||||
}
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "kimi", e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
|
||||
}
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, err = normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyKimiHeadersWithAuth(httpReq, token, true, auth)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 1_048_576) // 1MB
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens estimates token count for Kimi requests.
|
||||
func (e *KimiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||
return e.ClaudeExecutor.CountTokens(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
out := body
|
||||
pending := make([]string, 0)
|
||||
patched := 0
|
||||
patchedReasoning := 0
|
||||
ambiguous := 0
|
||||
latestReasoning := ""
|
||||
hasLatestReasoning := false
|
||||
|
||||
removePending := func(id string) {
|
||||
for idx := range pending {
|
||||
if pending[idx] != id {
|
||||
continue
|
||||
}
|
||||
pending = append(pending[:idx], pending[idx+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
msgs := messages.Array()
|
||||
for msgIdx := range msgs {
|
||||
msg := msgs[msgIdx]
|
||||
role := strings.TrimSpace(msg.Get("role").String())
|
||||
switch role {
|
||||
case "assistant":
|
||||
reasoning := msg.Get("reasoning_content")
|
||||
if reasoning.Exists() {
|
||||
reasoningText := reasoning.String()
|
||||
if strings.TrimSpace(reasoningText) != "" {
|
||||
latestReasoning = reasoningText
|
||||
hasLatestReasoning = true
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := msg.Get("tool_calls")
|
||||
if !toolCalls.Exists() || !toolCalls.IsArray() || len(toolCalls.Array()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if !reasoning.Exists() || strings.TrimSpace(reasoning.String()) == "" {
|
||||
reasoningText := fallbackAssistantReasoning(msg, hasLatestReasoning, latestReasoning)
|
||||
path := fmt.Sprintf("messages.%d.reasoning_content", msgIdx)
|
||||
next, err := sjson.SetBytes(out, path, reasoningText)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi executor: failed to set assistant reasoning_content: %w", err)
|
||||
}
|
||||
out = next
|
||||
patchedReasoning++
|
||||
}
|
||||
|
||||
for _, tc := range toolCalls.Array() {
|
||||
id := strings.TrimSpace(tc.Get("id").String())
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
pending = append(pending, id)
|
||||
}
|
||||
case "tool":
|
||||
toolCallID := strings.TrimSpace(msg.Get("tool_call_id").String())
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(msg.Get("call_id").String())
|
||||
if toolCallID != "" {
|
||||
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
|
||||
next, err := sjson.SetBytes(out, path, toolCallID)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi executor: failed to set tool_call_id from call_id: %w", err)
|
||||
}
|
||||
out = next
|
||||
patched++
|
||||
}
|
||||
}
|
||||
if toolCallID == "" {
|
||||
if len(pending) == 1 {
|
||||
toolCallID = pending[0]
|
||||
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
|
||||
next, err := sjson.SetBytes(out, path, toolCallID)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi executor: failed to infer tool_call_id: %w", err)
|
||||
}
|
||||
out = next
|
||||
patched++
|
||||
} else if len(pending) > 1 {
|
||||
ambiguous++
|
||||
}
|
||||
}
|
||||
if toolCallID != "" {
|
||||
removePending(toolCallID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if patched > 0 || patchedReasoning > 0 {
|
||||
log.WithFields(log.Fields{
|
||||
"patched_tool_messages": patched,
|
||||
"patched_reasoning_messages": patchedReasoning,
|
||||
}).Debug("kimi executor: normalized tool message fields")
|
||||
}
|
||||
if ambiguous > 0 {
|
||||
log.WithFields(log.Fields{
|
||||
"ambiguous_tool_messages": ambiguous,
|
||||
"pending_tool_calls": len(pending),
|
||||
}).Warn("kimi executor: tool messages missing tool_call_id with ambiguous candidates")
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func fallbackAssistantReasoning(msg gjson.Result, hasLatest bool, latest string) string {
|
||||
if hasLatest && strings.TrimSpace(latest) != "" {
|
||||
return latest
|
||||
}
|
||||
|
||||
content := msg.Get("content")
|
||||
if content.Type == gjson.String {
|
||||
if text := strings.TrimSpace(content.String()); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
if content.IsArray() {
|
||||
parts := make([]string, 0, len(content.Array()))
|
||||
for _, item := range content.Array() {
|
||||
text := strings.TrimSpace(item.Get("text").String())
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, text)
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
return "[reasoning unavailable]"
|
||||
}
|
||||
|
||||
// Refresh refreshes the Kimi token using the refresh token.
|
||||
func (e *KimiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
log.Debugf("kimi executor: refresh called")
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("kimi executor: auth is nil")
|
||||
}
|
||||
// Expect refresh_token in metadata for OAuth-based accounts
|
||||
var refreshToken string
|
||||
if auth.Metadata != nil {
|
||||
if v, ok := auth.Metadata["refresh_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
refreshToken = v
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(refreshToken) == "" {
|
||||
// Nothing to refresh
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
client := kimiauth.NewDeviceFlowClientWithDeviceID(e.cfg, resolveKimiDeviceID(auth))
|
||||
td, err := client.RefreshToken(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
auth.Metadata["access_token"] = td.AccessToken
|
||||
if td.RefreshToken != "" {
|
||||
auth.Metadata["refresh_token"] = td.RefreshToken
|
||||
}
|
||||
if td.ExpiresAt > 0 {
|
||||
exp := time.Unix(td.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
auth.Metadata["expired"] = exp
|
||||
}
|
||||
auth.Metadata["type"] = "kimi"
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
auth.Metadata["last_refresh"] = now
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// applyKimiHeaders sets required headers for Kimi API requests.
|
||||
// Headers match kimi-cli client for compatibility.
|
||||
func applyKimiHeaders(r *http.Request, token string, stream bool) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
// Match kimi-cli headers exactly
|
||||
r.Header.Set("User-Agent", "KimiCLI/1.10.6")
|
||||
r.Header.Set("X-Msh-Platform", "kimi_cli")
|
||||
r.Header.Set("X-Msh-Version", "1.10.6")
|
||||
r.Header.Set("X-Msh-Device-Name", getKimiHostname())
|
||||
r.Header.Set("X-Msh-Device-Model", getKimiDeviceModel())
|
||||
r.Header.Set("X-Msh-Device-Id", getKimiDeviceID())
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
return
|
||||
}
|
||||
r.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
func resolveKimiDeviceIDFromAuth(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
deviceIDRaw, ok := auth.Metadata["device_id"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
deviceID, ok := deviceIDRaw.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(deviceID)
|
||||
}
|
||||
|
||||
func resolveKimiDeviceIDFromStorage(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
storage, ok := auth.Storage.(*kimiauth.KimiTokenStorage)
|
||||
if !ok || storage == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(storage.DeviceID)
|
||||
}
|
||||
|
||||
func resolveKimiDeviceID(auth *cliproxyauth.Auth) string {
|
||||
deviceID := resolveKimiDeviceIDFromAuth(auth)
|
||||
if deviceID != "" {
|
||||
return deviceID
|
||||
}
|
||||
return resolveKimiDeviceIDFromStorage(auth)
|
||||
}
|
||||
|
||||
func applyKimiHeadersWithAuth(r *http.Request, token string, stream bool, auth *cliproxyauth.Auth) {
|
||||
applyKimiHeaders(r, token, stream)
|
||||
|
||||
if deviceID := resolveKimiDeviceID(auth); deviceID != "" {
|
||||
r.Header.Set("X-Msh-Device-Id", deviceID)
|
||||
}
|
||||
}
|
||||
|
||||
// getKimiHostname returns the machine hostname.
|
||||
func getKimiHostname() string {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
// getKimiDeviceModel returns a device model string matching kimi-cli format.
|
||||
func getKimiDeviceModel() string {
|
||||
return fmt.Sprintf("%s %s", runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
// getKimiDeviceID returns a stable device ID, matching kimi-cli storage location.
|
||||
func getKimiDeviceID() string {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "cli-proxy-api-device"
|
||||
}
|
||||
// Check kimi-cli's device_id location first (platform-specific)
|
||||
var kimiShareDir string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
kimiShareDir = filepath.Join(homeDir, "Library", "Application Support", "kimi")
|
||||
case "windows":
|
||||
appData := os.Getenv("APPDATA")
|
||||
if appData == "" {
|
||||
appData = filepath.Join(homeDir, "AppData", "Roaming")
|
||||
}
|
||||
kimiShareDir = filepath.Join(appData, "kimi")
|
||||
default: // linux and other unix-like
|
||||
kimiShareDir = filepath.Join(homeDir, ".local", "share", "kimi")
|
||||
}
|
||||
deviceIDPath := filepath.Join(kimiShareDir, "device_id")
|
||||
if data, err := os.ReadFile(deviceIDPath); err == nil {
|
||||
return strings.TrimSpace(string(data))
|
||||
}
|
||||
return "cli-proxy-api-device"
|
||||
}
|
||||
|
||||
// kimiCreds extracts the access token from auth.
|
||||
func kimiCreds(a *cliproxyauth.Auth) (token string) {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
// Check metadata first (OAuth flow stores tokens here)
|
||||
if a.Metadata != nil {
|
||||
if v, ok := a.Metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
// Fallback to attributes (API key style)
|
||||
if a.Attributes != nil {
|
||||
if v := a.Attributes["access_token"]; v != "" {
|
||||
return v
|
||||
}
|
||||
if v := a.Attributes["api_key"]; v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// stripKimiPrefix removes the "kimi-" prefix from model names for the upstream API.
|
||||
func stripKimiPrefix(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if strings.HasPrefix(strings.ToLower(model), "kimi-") {
|
||||
return model[5:]
|
||||
}
|
||||
return model
|
||||
}
|
||||
205
internal/runtime/executor/kimi_executor_test.go
Normal file
205
internal/runtime/executor/kimi_executor_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_UsesCallIDFallback(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"list_directory:1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
|
||||
{"role":"tool","call_id":"list_directory:1","content":"[]"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||
if got != "list_directory:1" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "list_directory:1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_InferSinglePendingID(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_123","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
|
||||
{"role":"tool","content":"file-content"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||
if got != "call_123" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_AmbiguousMissingIDIsNotInferred(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[
|
||||
{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}},
|
||||
{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}
|
||||
]},
|
||||
{"role":"tool","content":"result-without-id"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
if gjson.GetBytes(out, "messages.1.tool_call_id").Exists() {
|
||||
t.Fatalf("messages.1.tool_call_id should be absent for ambiguous case, got %q", gjson.GetBytes(out, "messages.1.tool_call_id").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_PreservesExistingToolCallID(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
|
||||
{"role":"tool","tool_call_id":"call_1","call_id":"different-id","content":"result"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||
if got != "call_1" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_InheritsPreviousReasoningForAssistantToolCalls(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":"plan","reasoning_content":"previous reasoning"},
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.reasoning_content").String()
|
||||
if got != "previous reasoning" {
|
||||
t.Fatalf("messages.1.reasoning_content = %q, want %q", got, "previous reasoning")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_InsertsFallbackReasoningWhenMissing(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
reasoning := gjson.GetBytes(out, "messages.0.reasoning_content")
|
||||
if !reasoning.Exists() {
|
||||
t.Fatalf("messages.0.reasoning_content should exist")
|
||||
}
|
||||
if reasoning.String() != "[reasoning unavailable]" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", reasoning.String(), "[reasoning unavailable]")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_UsesContentAsReasoningFallback(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":[{"type":"text","text":"first line"},{"type":"text","text":"second line"}],"tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||
if got != "first line\nsecond line" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "first line\nsecond line")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_ReplacesEmptyReasoningContent(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":"assistant summary","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":""}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||
if got != "assistant summary" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "assistant summary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_PreservesExistingAssistantReasoning(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"keep me"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||
if got != "keep me" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "keep me")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_RepairsIDsAndReasoningTogether(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"r1"},
|
||||
{"role":"tool","call_id":"call_1","content":"[]"},
|
||||
{"role":"assistant","tool_calls":[{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
|
||||
{"role":"tool","call_id":"call_2","content":"file"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.1.tool_call_id").String(); got != "call_1" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.3.tool_call_id").String(); got != "call_2" {
|
||||
t.Fatalf("messages.3.tool_call_id = %q, want %q", got, "call_2")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.2.reasoning_content").String(); got != "r1" {
|
||||
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "r1")
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -385,6 +386,35 @@ func buildKiroEndpointConfigs(region string) []kiroEndpointConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// resolveKiroAPIRegion determines the AWS region for Kiro API calls.
|
||||
// Region priority:
|
||||
// 1. auth.Metadata["api_region"] - explicit API region override
|
||||
// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource
|
||||
// 3. kiroDefaultRegion (us-east-1) - fallback
|
||||
// Note: OIDC "region" is NOT used - it's for token refresh, not API calls
|
||||
func resolveKiroAPIRegion(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return kiroDefaultRegion
|
||||
}
|
||||
// Priority 1: Explicit api_region override
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
log.Debugf("kiro: using region %s (source: api_region)", r)
|
||||
return r
|
||||
}
|
||||
// Priority 2: Extract from ProfileARN
|
||||
if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||
if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" {
|
||||
log.Debugf("kiro: using region %s (source: profile_arn)", arnRegion)
|
||||
return arnRegion
|
||||
}
|
||||
}
|
||||
// Note: OIDC "region" field is NOT used for API endpoint
|
||||
// Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2)
|
||||
// Using OIDC region for API calls causes DNS failures
|
||||
log.Debugf("kiro: using region %s (source: default)", kiroDefaultRegion)
|
||||
return kiroDefaultRegion
|
||||
}
|
||||
|
||||
// kiroEndpointConfigs is kept for backward compatibility with default us-east-1 region.
|
||||
// Prefer using buildKiroEndpointConfigs(region) for dynamic region support.
|
||||
var kiroEndpointConfigs = buildKiroEndpointConfigs(kiroDefaultRegion)
|
||||
@@ -403,30 +433,8 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
|
||||
return kiroEndpointConfigs
|
||||
}
|
||||
|
||||
// Determine API region with priority: api_region > profile_arn > region > default
|
||||
region := kiroDefaultRegion
|
||||
regionSource := "default"
|
||||
|
||||
if auth.Metadata != nil {
|
||||
// Priority 1: Explicit api_region override
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
region = r
|
||||
regionSource = "api_region"
|
||||
} else {
|
||||
// Priority 2: Extract from ProfileARN
|
||||
if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||
if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" {
|
||||
region = arnRegion
|
||||
regionSource = "profile_arn"
|
||||
}
|
||||
}
|
||||
// Note: OIDC "region" field is NOT used for API endpoint
|
||||
// Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2)
|
||||
// Using OIDC region for API calls causes DNS failures
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro: using region %s (source: %s)", region, regionSource)
|
||||
// Determine API region using shared resolution logic
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
|
||||
// Build endpoint configs for the specified region
|
||||
endpointConfigs := buildKiroEndpointConfigs(region)
|
||||
@@ -519,8 +527,12 @@ func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string,
|
||||
case "openai":
|
||||
log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String())
|
||||
return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
|
||||
case "kiro":
|
||||
// Body is already in Kiro format — pass through directly
|
||||
log.Debugf("kiro: body already in Kiro format, passing through directly")
|
||||
return body, false
|
||||
default:
|
||||
// Default to Claude format (also handles "claude", "kiro", etc.)
|
||||
// Default to Claude format
|
||||
log.Debugf("kiro: using Claude payload builder for source format: %s", sourceFormat.String())
|
||||
return kiroclaude.BuildKiroPayload(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
|
||||
}
|
||||
@@ -636,10 +648,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
// Check if token is expired before making request
|
||||
// Check if token is expired before making request (covers both normal and web_search paths)
|
||||
if e.isTokenExpired(accessToken) {
|
||||
log.Infof("kiro: access token expired, attempting recovery")
|
||||
|
||||
@@ -668,6 +677,16 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
}
|
||||
}
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint")
|
||||
return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -1014,8 +1033,9 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
|
||||
// Build response in Claude format for Kiro translator
|
||||
// stopReason is extracted from upstream response by parseEventStream
|
||||
kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, requestedModel, usageInfo, stopReason)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, requestedModel, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -1033,7 +1053,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
|
||||
// ExecuteStream handles streaming requests to Kiro API.
|
||||
// Supports automatic token refresh on 401/403 errors and quota fallback on 429.
|
||||
func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
accessToken, profileArn := kiroCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("kiro: access token not found in auth")
|
||||
@@ -1057,10 +1077,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
// Check if token is expired before making request
|
||||
// Check if token is expired before making request (covers both normal and web_search paths)
|
||||
if e.isTokenExpired(accessToken) {
|
||||
log.Infof("kiro: access token expired, attempting recovery before stream request")
|
||||
|
||||
@@ -1089,6 +1106,20 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
}
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
|
||||
streamWebSearch, errWebSearch := e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
if errWebSearch != nil {
|
||||
return nil, errWebSearch
|
||||
}
|
||||
return &cliproxyexecutor.StreamResult{Chunks: streamWebSearch}, nil
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -1101,7 +1132,11 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
// Execute stream with retry on 401/403 and 429 (quota exhausted)
|
||||
// Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint
|
||||
return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
||||
streamKiro, errStreamKiro := e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
||||
if errStreamKiro != nil {
|
||||
return nil, errStreamKiro
|
||||
}
|
||||
return &cliproxyexecutor.StreamResult{Chunks: streamKiro}, nil
|
||||
}
|
||||
|
||||
// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors.
|
||||
@@ -1405,7 +1440,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
// So we always enable thinking parsing for Kiro responses
|
||||
log.Debugf("kiro: stream thinkingEnabled = %v (always true for Kiro)", thinkingEnabled)
|
||||
|
||||
e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter, thinkingEnabled)
|
||||
e.streamToChannel(ctx, resp.Body, out, from, payloadRequestedModel(opts, req.Model), opts.OriginalRequest, body, reporter, thinkingEnabled)
|
||||
}(httpResp, thinkingEnabled)
|
||||
|
||||
return out, nil
|
||||
@@ -1681,6 +1716,8 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
modelMap := map[string]string{
|
||||
// Amazon Q format (amazonq- prefix) - same API as Kiro
|
||||
"amazonq-auto": "auto",
|
||||
"amazonq-claude-opus-4-6": "claude-opus-4.6",
|
||||
"amazonq-claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"amazonq-claude-opus-4-5": "claude-opus-4.5",
|
||||
"amazonq-claude-sonnet-4-5": "claude-sonnet-4.5",
|
||||
"amazonq-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||
@@ -1688,6 +1725,8 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
"amazonq-claude-sonnet-4-20250514": "claude-sonnet-4",
|
||||
"amazonq-claude-haiku-4-5": "claude-haiku-4.5",
|
||||
// Kiro format (kiro- prefix) - valid model names that should be preserved
|
||||
"kiro-claude-opus-4-6": "claude-opus-4.6",
|
||||
"kiro-claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"kiro-claude-opus-4-5": "claude-opus-4.5",
|
||||
"kiro-claude-sonnet-4-5": "claude-sonnet-4.5",
|
||||
"kiro-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||
@@ -1696,6 +1735,10 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
"kiro-claude-haiku-4-5": "claude-haiku-4.5",
|
||||
"kiro-auto": "auto",
|
||||
// Native format (no prefix) - used by Kiro IDE directly
|
||||
"claude-opus-4-6": "claude-opus-4.6",
|
||||
"claude-opus-4.6": "claude-opus-4.6",
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"claude-sonnet-4.6": "claude-sonnet-4.6",
|
||||
"claude-opus-4-5": "claude-opus-4.5",
|
||||
"claude-opus-4.5": "claude-opus-4.5",
|
||||
"claude-haiku-4-5": "claude-haiku-4.5",
|
||||
@@ -1707,10 +1750,14 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
"claude-sonnet-4-20250514": "claude-sonnet-4",
|
||||
"auto": "auto",
|
||||
// Agentic variants (same backend model IDs, but with special system prompt)
|
||||
"claude-opus-4.6-agentic": "claude-opus-4.6",
|
||||
"claude-sonnet-4.6-agentic": "claude-sonnet-4.6",
|
||||
"claude-opus-4.5-agentic": "claude-opus-4.5",
|
||||
"claude-sonnet-4.5-agentic": "claude-sonnet-4.5",
|
||||
"claude-sonnet-4-agentic": "claude-sonnet-4",
|
||||
"claude-haiku-4.5-agentic": "claude-haiku-4.5",
|
||||
"kiro-claude-opus-4-6-agentic": "claude-opus-4.6",
|
||||
"kiro-claude-sonnet-4-6-agentic": "claude-sonnet-4.6",
|
||||
"kiro-claude-opus-4-5-agentic": "claude-opus-4.5",
|
||||
"kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5",
|
||||
"kiro-claude-sonnet-4-agentic": "claude-sonnet-4",
|
||||
@@ -1736,6 +1783,10 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
log.Debugf("kiro: unknown Sonnet 3.7 model '%s', mapping to claude-3-7-sonnet-20250219", model)
|
||||
return "claude-3-7-sonnet-20250219"
|
||||
}
|
||||
if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") {
|
||||
log.Debugf("kiro: unknown Sonnet 4.6 model '%s', mapping to claude-sonnet-4.6", model)
|
||||
return "claude-sonnet-4.6"
|
||||
}
|
||||
if strings.Contains(modelLower, "4-5") || strings.Contains(modelLower, "4.5") {
|
||||
log.Debugf("kiro: unknown Sonnet 4.5 model '%s', mapping to claude-sonnet-4.5", model)
|
||||
return "claude-sonnet-4.5"
|
||||
@@ -1747,6 +1798,10 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
|
||||
// Check for Opus variants
|
||||
if strings.Contains(modelLower, "opus") {
|
||||
if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") {
|
||||
log.Debugf("kiro: unknown Opus 4.6 model '%s', mapping to claude-opus-4.6", model)
|
||||
return "claude-opus-4.6"
|
||||
}
|
||||
log.Debugf("kiro: unknown Opus model '%s', mapping to claude-opus-4.5", model)
|
||||
return "claude-opus-4.5"
|
||||
}
|
||||
@@ -2096,6 +2151,22 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroclaude.Ki
|
||||
}
|
||||
}
|
||||
|
||||
case "contextUsageEvent":
|
||||
// Handle context usage events from Kiro API
|
||||
// Format: {"contextUsageEvent": {"contextUsagePercentage": 0.53}}
|
||||
if ctxUsage, ok := event["contextUsageEvent"].(map[string]interface{}); ok {
|
||||
if ctxPct, ok := ctxUsage["contextUsagePercentage"].(float64); ok {
|
||||
upstreamContextPercentage = ctxPct
|
||||
log.Debugf("kiro: parseEventStream received contextUsageEvent: %.2f%%", ctxPct*100)
|
||||
}
|
||||
} else {
|
||||
// Try direct field (fallback)
|
||||
if ctxPct, ok := event["contextUsagePercentage"].(float64); ok {
|
||||
upstreamContextPercentage = ctxPct
|
||||
log.Debugf("kiro: parseEventStream received contextUsagePercentage (direct): %.2f%%", ctxPct*100)
|
||||
}
|
||||
}
|
||||
|
||||
case "error", "exception", "internalServerException", "invalidStateEvent":
|
||||
// Handle error events from Kiro API stream
|
||||
errMsg := ""
|
||||
@@ -2699,6 +2770,22 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
}
|
||||
}
|
||||
|
||||
case "contextUsageEvent":
|
||||
// Handle context usage events from Kiro API
|
||||
// Format: {"contextUsageEvent": {"contextUsagePercentage": 0.53}}
|
||||
if ctxUsage, ok := event["contextUsageEvent"].(map[string]interface{}); ok {
|
||||
if ctxPct, ok := ctxUsage["contextUsagePercentage"].(float64); ok {
|
||||
upstreamContextPercentage = ctxPct
|
||||
log.Debugf("kiro: streamToChannel received contextUsageEvent: %.2f%%", ctxPct*100)
|
||||
}
|
||||
} else {
|
||||
// Try direct field (fallback)
|
||||
if ctxPct, ok := event["contextUsagePercentage"].(float64); ok {
|
||||
upstreamContextPercentage = ctxPct
|
||||
log.Debugf("kiro: streamToChannel received contextUsagePercentage (direct): %.2f%%", ctxPct*100)
|
||||
}
|
||||
}
|
||||
|
||||
case "error", "exception", "internalServerException":
|
||||
// Handle error events from Kiro API stream
|
||||
errMsg := ""
|
||||
@@ -4058,6 +4145,659 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool {
|
||||
return isExpired
|
||||
}
|
||||
|
||||
// NOTE: Message merging functions moved to internal/translator/kiro/common/message_merge.go
|
||||
// NOTE: Tool calling support functions moved to internal/translator/kiro/claude/kiro_claude_tools.go
|
||||
// The executor now uses kiroclaude.* and kirocommon.* functions instead
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// Web Search Handler (MCP API)
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
// fetchToolDescription caching:
|
||||
// Uses a mutex + fetched flag to ensure only one goroutine fetches at a time,
|
||||
// with automatic retry on failure:
|
||||
// - On failure, fetched stays false so subsequent calls will retry
|
||||
// - On success, fetched is set to true — subsequent calls skip immediately (mutex-free fast path)
|
||||
// The cached description is stored in the translator package via kiroclaude.SetWebSearchDescription(),
|
||||
// enabling the translator's convertClaudeToolsToKiro to read it when building Kiro requests.
|
||||
var (
|
||||
toolDescMu sync.Mutex
|
||||
toolDescFetched atomic.Bool
|
||||
)
|
||||
|
||||
// fetchToolDescription calls MCP tools/list to get the web_search tool description
|
||||
// and caches it. Safe to call concurrently — only one goroutine fetches at a time.
|
||||
// If the fetch fails, subsequent calls will retry. On success, no further fetches occur.
|
||||
// The httpClient parameter allows reusing a shared pooled HTTP client.
|
||||
func fetchToolDescription(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) {
|
||||
// Fast path: already fetched successfully, no lock needed
|
||||
if toolDescFetched.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
toolDescMu.Lock()
|
||||
defer toolDescMu.Unlock()
|
||||
|
||||
// Double-check after acquiring lock
|
||||
if toolDescFetched.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, authToken, httpClient, auth, authAttrs)
|
||||
reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`)
|
||||
log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", mcpEndpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to create tools/list request: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Reuse same headers as callMcpAPI
|
||||
handler.setMcpHeaders(req)
|
||||
|
||||
resp, err := handler.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: tools/list request failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil || resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body))
|
||||
|
||||
// Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}}
|
||||
var result struct {
|
||||
Result *struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
} `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
|
||||
log.Warnf("kiro/websearch: failed to parse tools/list response")
|
||||
return
|
||||
}
|
||||
|
||||
for _, tool := range result.Result.Tools {
|
||||
if tool.Name == "web_search" && tool.Description != "" {
|
||||
kiroclaude.SetWebSearchDescription(tool.Description)
|
||||
toolDescFetched.Store(true) // success — no more fetches
|
||||
log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// web_search tool not found in response
|
||||
log.Warnf("kiro/websearch: web_search tool not found in tools/list response")
|
||||
}
|
||||
|
||||
// webSearchHandler handles web search requests via Kiro MCP API
|
||||
type webSearchHandler struct {
|
||||
ctx context.Context
|
||||
mcpEndpoint string
|
||||
httpClient *http.Client
|
||||
authToken string
|
||||
auth *cliproxyauth.Auth // for applyDynamicFingerprint
|
||||
authAttrs map[string]string // optional, for custom headers from auth.Attributes
|
||||
}
|
||||
|
||||
// newWebSearchHandler creates a new webSearchHandler.
|
||||
// If httpClient is nil, a default client with 30s timeout is used.
|
||||
// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse.
|
||||
func newWebSearchHandler(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) *webSearchHandler {
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
return &webSearchHandler{
|
||||
ctx: ctx,
|
||||
mcpEndpoint: mcpEndpoint,
|
||||
httpClient: httpClient,
|
||||
authToken: authToken,
|
||||
auth: auth,
|
||||
authAttrs: authAttrs,
|
||||
}
|
||||
}
|
||||
|
||||
// setMcpHeaders sets standard MCP API headers on the request,
|
||||
// aligned with the GAR request pattern.
|
||||
func (h *webSearchHandler) setMcpHeaders(req *http.Request) {
|
||||
// 1. Content-Type & Accept (aligned with GAR)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
|
||||
// 2. Kiro-specific headers (aligned with GAR)
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
|
||||
// 3. User-Agent: Reuse applyDynamicFingerprint for consistency
|
||||
applyDynamicFingerprint(req, h.auth)
|
||||
|
||||
// 4. AWS SDK identifiers
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
|
||||
// 5. Authentication
|
||||
req.Header.Set("Authorization", "Bearer "+h.authToken)
|
||||
|
||||
// 6. Custom headers from auth attributes
|
||||
util.ApplyCustomHeadersFromAttrs(req, h.authAttrs)
|
||||
}
|
||||
|
||||
// mcpMaxRetries is the maximum number of retries for MCP API calls.
|
||||
const mcpMaxRetries = 2
|
||||
|
||||
// callMcpAPI calls the Kiro MCP API with the given request.
|
||||
// Includes retry logic with exponential backoff for retryable errors.
|
||||
func (h *webSearchHandler) callMcpAPI(request *kiroclaude.McpRequest) (*kiroclaude.McpResponse, error) {
|
||||
requestBody, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal MCP request: %w", err)
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.mcpEndpoint, len(requestBody))
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= mcpMaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(1<<attempt) * time.Second
|
||||
if backoff > 10*time.Second {
|
||||
backoff = 10 * time.Second
|
||||
}
|
||||
log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr)
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
return nil, h.ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(h.ctx, "POST", h.mcpEndpoint, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
h.setMcpHeaders(req)
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("MCP API request failed: %w", err)
|
||||
continue // network error → retry
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read MCP response: %w", err)
|
||||
continue // read error → retry
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body))
|
||||
|
||||
// Retryable HTTP status codes (aligned with GAR: 502, 503, 504)
|
||||
if resp.StatusCode >= 502 && resp.StatusCode <= 504 {
|
||||
lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body))
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var mcpResponse kiroclaude.McpResponse
|
||||
if err := json.Unmarshal(body, &mcpResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse MCP response: %w", err)
|
||||
}
|
||||
|
||||
if mcpResponse.Error != nil {
|
||||
code := -1
|
||||
if mcpResponse.Error.Code != nil {
|
||||
code = *mcpResponse.Error.Code
|
||||
}
|
||||
msg := "Unknown error"
|
||||
if mcpResponse.Error.Message != nil {
|
||||
msg = *mcpResponse.Error.Message
|
||||
}
|
||||
return nil, fmt.Errorf("MCP error %d: %s", code, msg)
|
||||
}
|
||||
|
||||
return &mcpResponse, nil
|
||||
}
|
||||
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// webSearchAuthAttrs extracts auth attributes for MCP calls.
|
||||
// Used by handleWebSearch and handleWebSearchStream to pass custom headers.
|
||||
func webSearchAuthAttrs(auth *cliproxyauth.Auth) map[string]string {
|
||||
if auth != nil {
|
||||
return auth.Attributes
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxWebSearchIterations = 5
|
||||
|
||||
// handleWebSearchStream handles web_search requests:
|
||||
// Step 1: tools/list (sync) → fetch/cache tool description
|
||||
// Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop
|
||||
// Note: We skip the "model decides to search" step because Claude Code already
|
||||
// decided to use web_search. The Kiro tool description restricts non-coding
|
||||
// topics, so asking the model again would cause it to refuse valid searches.
|
||||
func (e *KiroExecutor) handleWebSearchStream(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
// Extract search query from Claude Code's web_search tool_use
|
||||
query := kiroclaude.ExtractSearchQuery(req.Payload)
|
||||
if query == "" {
|
||||
log.Warnf("kiro/websearch: failed to extract search query, falling back to normal flow")
|
||||
return e.callKiroDirectStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
|
||||
|
||||
// ── Step 1: tools/list (SYNC) — cache tool description ──
|
||||
{
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
}
|
||||
|
||||
// Create output channel
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
|
||||
// Usage reporting: track web search requests like normal streaming requests
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
|
||||
go func() {
|
||||
var wsErr error
|
||||
defer reporter.trackFailure(ctx, &wsErr)
|
||||
defer close(out)
|
||||
|
||||
// Estimate input tokens using tokenizer (matching streamToChannel pattern)
|
||||
var totalUsage usage.Detail
|
||||
if enc, tokErr := getTokenizer(req.Model); tokErr == nil {
|
||||
if inp, e := countClaudeChatTokens(enc, req.Payload); e == nil && inp > 0 {
|
||||
totalUsage.InputTokens = inp
|
||||
} else {
|
||||
totalUsage.InputTokens = int64(len(req.Payload) / 4)
|
||||
}
|
||||
} else {
|
||||
totalUsage.InputTokens = int64(len(req.Payload) / 4)
|
||||
}
|
||||
if totalUsage.InputTokens == 0 && len(req.Payload) > 0 {
|
||||
totalUsage.InputTokens = 1
|
||||
}
|
||||
var accumulatedOutputLen int
|
||||
defer func() {
|
||||
if wsErr != nil {
|
||||
return // let trackFailure handle failure reporting
|
||||
}
|
||||
totalUsage.OutputTokens = int64(accumulatedOutputLen / 4)
|
||||
if accumulatedOutputLen > 0 && totalUsage.OutputTokens == 0 {
|
||||
totalUsage.OutputTokens = 1
|
||||
}
|
||||
reporter.publish(ctx, totalUsage)
|
||||
}()
|
||||
|
||||
// Send message_start event to client (aligned with streamToChannel pattern)
|
||||
// Use payloadRequestedModel to return user's original model alias
|
||||
msgStart := kiroclaude.BuildClaudeMessageStartEvent(
|
||||
payloadRequestedModel(opts, req.Model),
|
||||
totalUsage.InputTokens,
|
||||
)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: append(msgStart, '\n', '\n')}:
|
||||
}
|
||||
|
||||
// ── Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop ──
|
||||
contentBlockIndex := 0
|
||||
currentQuery := query
|
||||
|
||||
// Replace web_search tool description with a minimal one that allows re-search.
|
||||
// The original tools/list description from Kiro restricts non-coding topics,
|
||||
// but we've already decided to search. We keep the tool so the model can
|
||||
// request additional searches when results are insufficient.
|
||||
simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload))
|
||||
if simplifyErr != nil {
|
||||
log.Warnf("kiro/websearch: failed to simplify web_search tool: %v, using original payload", simplifyErr)
|
||||
simplifiedPayload = bytes.Clone(req.Payload)
|
||||
}
|
||||
|
||||
currentClaudePayload := simplifiedPayload
|
||||
totalSearches := 0
|
||||
|
||||
// Generate toolUseId for the first iteration (Claude Code already decided to search)
|
||||
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
|
||||
|
||||
for iteration := 0; iteration < maxWebSearchIterations; iteration++ {
|
||||
log.Infof("kiro/websearch: search iteration %d/%d",
|
||||
iteration+1, maxWebSearchIterations)
|
||||
|
||||
// MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery)
|
||||
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
log.Warnf("kiro/websearch: MCP API call failed: %v, continuing with empty results", mcpErr)
|
||||
} else {
|
||||
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
|
||||
}
|
||||
|
||||
resultCount := 0
|
||||
if searchResults != nil {
|
||||
resultCount = len(searchResults.Results)
|
||||
}
|
||||
totalSearches++
|
||||
log.Infof("kiro/websearch: iteration %d — got %d search results", iteration+1, resultCount)
|
||||
|
||||
// Send search indicator events to client
|
||||
searchEvents := kiroclaude.GenerateSearchIndicatorEvents(currentQuery, currentToolUseId, searchResults, contentBlockIndex)
|
||||
for _, event := range searchEvents {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: event}:
|
||||
}
|
||||
}
|
||||
contentBlockIndex += 2
|
||||
|
||||
// Inject tool_use + tool_result into Claude payload, then call GAR
|
||||
var err error
|
||||
currentClaudePayload, err = kiroclaude.InjectToolResultsClaude(currentClaudePayload, currentToolUseId, currentQuery, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to inject tool results: %v", err)
|
||||
wsErr = fmt.Errorf("failed to inject tool results: %w", err)
|
||||
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
|
||||
return
|
||||
}
|
||||
|
||||
// Call GAR with modified Claude payload (full translation pipeline)
|
||||
modifiedReq := req
|
||||
modifiedReq.Payload = currentClaudePayload
|
||||
kiroChunks, kiroErr := e.callKiroAndBuffer(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if kiroErr != nil {
|
||||
log.Warnf("kiro/websearch: Kiro API failed at iteration %d: %v", iteration+1, kiroErr)
|
||||
wsErr = fmt.Errorf("Kiro API failed at iteration %d: %w", iteration+1, kiroErr)
|
||||
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
|
||||
return
|
||||
}
|
||||
|
||||
// Analyze response
|
||||
analysis := kiroclaude.AnalyzeBufferedStream(kiroChunks)
|
||||
log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v",
|
||||
iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse)
|
||||
|
||||
if analysis.HasWebSearchToolUse && analysis.WebSearchQuery != "" && iteration+1 < maxWebSearchIterations {
|
||||
// Model wants another search
|
||||
filteredChunks := kiroclaude.FilterChunksForClient(kiroChunks, analysis.WebSearchToolUseIndex, contentBlockIndex)
|
||||
for _, chunk := range filteredChunks {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
|
||||
}
|
||||
}
|
||||
|
||||
currentQuery = analysis.WebSearchQuery
|
||||
currentToolUseId = analysis.WebSearchToolUseId
|
||||
continue
|
||||
}
|
||||
|
||||
// Model returned final response — stream to client
|
||||
for _, chunk := range kiroChunks {
|
||||
if contentBlockIndex > 0 && len(chunk) > 0 {
|
||||
adjusted, shouldForward := kiroclaude.AdjustSSEChunk(chunk, contentBlockIndex)
|
||||
if !shouldForward {
|
||||
continue
|
||||
}
|
||||
accumulatedOutputLen += len(adjusted)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: adjusted}:
|
||||
}
|
||||
} else {
|
||||
accumulatedOutputLen += len(chunk)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Infof("kiro/websearch: completed after %d search iteration(s), total searches: %d", iteration+1, totalSearches)
|
||||
return
|
||||
}
|
||||
|
||||
log.Warnf("kiro/websearch: reached max iterations (%d), stopping search loop", maxWebSearchIterations)
|
||||
}()
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// handleWebSearch handles web_search requests for non-streaming Execute path.
|
||||
// Performs MCP search synchronously, injects results into the request payload,
|
||||
// then calls the normal non-streaming Kiro API path which returns a proper
|
||||
// Claude JSON response (not SSE chunks).
|
||||
func (e *KiroExecutor) handleWebSearch(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (cliproxyexecutor.Response, error) {
|
||||
// Extract search query from Claude Code's web_search tool_use
|
||||
query := kiroclaude.ExtractSearchQuery(req.Payload)
|
||||
if query == "" {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute")
|
||||
// Fall through to normal non-streaming path
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
|
||||
|
||||
// Step 1: Fetch/cache tool description (sync)
|
||||
{
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
}
|
||||
|
||||
// Step 2: Perform MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(query)
|
||||
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr)
|
||||
} else {
|
||||
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
|
||||
}
|
||||
|
||||
resultCount := 0
|
||||
if searchResults != nil {
|
||||
resultCount = len(searchResults.Results)
|
||||
}
|
||||
log.Infof("kiro/websearch: non-stream: got %d search results", resultCount)
|
||||
|
||||
// Step 3: Replace restrictive web_search tool description (align with streaming path)
|
||||
simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload))
|
||||
if simplifyErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to simplify web_search tool: %v, using original payload", simplifyErr)
|
||||
simplifiedPayload = bytes.Clone(req.Payload)
|
||||
}
|
||||
|
||||
// Step 4: Inject search tool_use + tool_result into Claude payload
|
||||
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
|
||||
modifiedPayload, err := kiroclaude.InjectToolResultsClaude(simplifiedPayload, currentToolUseId, query, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err)
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Step 5: Call Kiro API via the normal non-streaming path (executeWithRetry)
|
||||
// This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream
|
||||
// to produce a proper Claude JSON response
|
||||
modifiedReq := req
|
||||
modifiedReq.Payload = modifiedPayload
|
||||
|
||||
resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Step 6: Inject server_tool_use + web_search_tool_result into response
|
||||
// so Claude Code can display "Did X searches in Ys"
|
||||
indicators := []kiroclaude.SearchIndicator{
|
||||
{
|
||||
ToolUseID: currentToolUseId,
|
||||
Query: query,
|
||||
Results: searchResults,
|
||||
},
|
||||
}
|
||||
injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators)
|
||||
if injErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr)
|
||||
} else {
|
||||
resp.Payload = injectedPayload
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// callKiroAndBuffer calls the Kiro API and buffers all response chunks.
|
||||
// Returns the buffered chunks for analysis before forwarding to client.
|
||||
// Usage reporting is NOT done here — the caller (handleWebSearchStream) manages its own reporter.
|
||||
func (e *KiroExecutor) callKiroAndBuffer(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) ([][]byte, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
log.Debugf("kiro/websearch GAR request: %d bytes", len(body))
|
||||
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
kiroStream, err := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
nil, body, from, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Buffer all chunks
|
||||
var chunks [][]byte
|
||||
for chunk := range kiroStream {
|
||||
if chunk.Err != nil {
|
||||
return chunks, chunk.Err
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
chunks = append(chunks, bytes.Clone(chunk.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro/websearch GAR response: %d chunks buffered", len(chunks))
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// callKiroDirectStream creates a direct streaming channel to Kiro API without search.
|
||||
func (e *KiroExecutor) callKiroDirectStream(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
var streamErr error
|
||||
defer reporter.trackFailure(ctx, &streamErr)
|
||||
|
||||
stream, streamErr := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
)
|
||||
return stream, streamErr
|
||||
}
|
||||
|
||||
// sendFallbackText sends a simple text response when the Kiro API fails during the search loop.
|
||||
// Delegates SSE event construction to kiroclaude.BuildFallbackTextEvents() for alignment
|
||||
// with how streamToChannel() uses BuildClaude*Event() functions.
|
||||
func (e *KiroExecutor) sendFallbackText(
|
||||
ctx context.Context,
|
||||
out chan<- cliproxyexecutor.StreamChunk,
|
||||
contentBlockIndex int,
|
||||
query string,
|
||||
searchResults *kiroclaude.WebSearchResults,
|
||||
) {
|
||||
events := kiroclaude.BuildFallbackTextEvents(contentBlockIndex, query, searchResults)
|
||||
for _, event := range events {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: append(event, '\n', '\n')}:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// executeNonStreamFallback runs the standard non-streaming Execute path for a request.
|
||||
// Used by handleWebSearch after injecting search results, or as a fallback.
|
||||
func (e *KiroExecutor) executeNonStreamFallback(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (cliproxyexecutor.Response, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
var err error
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
resp, err := e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequ
|
||||
writeHeaders(builder, info.Headers)
|
||||
builder.WriteString("\nBody:\n")
|
||||
if len(info.Body) > 0 {
|
||||
builder.WriteString(string(bytes.Clone(info.Body)))
|
||||
builder.WriteString(string(info.Body))
|
||||
} else {
|
||||
builder.WriteString("<empty>")
|
||||
}
|
||||
@@ -152,7 +152,7 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
data := bytes.TrimSpace(bytes.Clone(chunk))
|
||||
data := bytes.TrimSpace(chunk)
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -88,12 +88,13 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
to = sdktranslator.FromString("openai-response")
|
||||
endpoint = "/responses/compact"
|
||||
}
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
if opts.Alt == "responses/compact" {
|
||||
@@ -170,12 +171,12 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.ensurePublished(ctx)
|
||||
// Translate response back to source format when needed
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -189,12 +190,13 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
@@ -256,7 +258,6 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -283,7 +284,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
|
||||
// OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
|
||||
// Pass through translator; it yields one or more chunks for the target schema.
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -296,7 +297,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
// Ensure we record the request if no usage chunk was ever seen
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
@@ -304,7 +305,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
modelForCounting := baseModel
|
||||
|
||||
|
||||
@@ -22,9 +22,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
qwenUserAgent = "google-api-nodejs-client/9.15.1"
|
||||
qwenXGoogAPIClient = "gl-node/22.17.0"
|
||||
qwenClientMetadataValue = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||
)
|
||||
|
||||
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
||||
@@ -81,12 +79,13 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -150,12 +149,12 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -171,12 +170,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -236,7 +236,6 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -253,12 +252,12 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
||||
}
|
||||
@@ -268,7 +267,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
@@ -276,7 +275,7 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if strings.TrimSpace(modelName) == "" {
|
||||
@@ -342,8 +341,18 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
r.Header.Set("User-Agent", qwenUserAgent)
|
||||
r.Header.Set("X-Goog-Api-Client", qwenXGoogAPIClient)
|
||||
r.Header.Set("Client-Metadata", qwenClientMetadataValue)
|
||||
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
|
||||
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
|
||||
r.Header.Set("Sec-Fetch-Mode", "cors")
|
||||
r.Header.Set("X-Stainless-Lang", "js")
|
||||
r.Header.Set("X-Stainless-Arch", "arm64")
|
||||
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
|
||||
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
|
||||
r.Header.Set("X-Stainless-Retry-Count", "0")
|
||||
r.Header.Set("X-Stainless-Os", "MacOS")
|
||||
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
|
||||
r.Header.Set("X-Stainless-Runtime", "node")
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
return
|
||||
|
||||
@@ -7,5 +7,6 @@ import (
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/gemini"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/geminicli"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/iflow"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/kimi"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/openai"
|
||||
)
|
||||
|
||||
@@ -21,6 +21,9 @@ import (
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// gcInterval defines minimum time between garbage collection runs.
|
||||
const gcInterval = 5 * time.Minute
|
||||
|
||||
// GitTokenStore persists token records and auth metadata using git as the backing storage.
|
||||
type GitTokenStore struct {
|
||||
mu sync.Mutex
|
||||
@@ -31,6 +34,7 @@ type GitTokenStore struct {
|
||||
remote string
|
||||
username string
|
||||
password string
|
||||
lastGC time.Time
|
||||
}
|
||||
|
||||
// NewGitTokenStore creates a token store that saves credentials to disk through the
|
||||
@@ -613,6 +617,7 @@ func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string)
|
||||
} else if errRewrite := s.rewriteHeadAsSingleCommit(repo, headRef.Name(), commitHash, message, signature); errRewrite != nil {
|
||||
return errRewrite
|
||||
}
|
||||
s.maybeRunGC(repo)
|
||||
if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil {
|
||||
if errors.Is(err, git.NoErrAlreadyUpToDate) {
|
||||
return nil
|
||||
@@ -652,6 +657,23 @@ func (s *GitTokenStore) rewriteHeadAsSingleCommit(repo *git.Repository, branch p
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GitTokenStore) maybeRunGC(repo *git.Repository) {
|
||||
now := time.Now()
|
||||
if now.Sub(s.lastGC) < gcInterval {
|
||||
return
|
||||
}
|
||||
s.lastGC = now
|
||||
|
||||
pruneOpts := git.PruneOptions{
|
||||
OnlyObjectsOlderThan: now,
|
||||
Handler: repo.DeleteObject,
|
||||
}
|
||||
if err := repo.Prune(pruneOpts); err != nil && !errors.Is(err, git.ErrLooseObjectsNotSupported) {
|
||||
return
|
||||
}
|
||||
_ = repo.RepackObjects(&git.RepackConfig{})
|
||||
}
|
||||
|
||||
// PersistConfig commits and pushes configuration changes to git.
|
||||
func (s *GitTokenStore) PersistConfig(_ context.Context) error {
|
||||
if err := s.EnsureRepository(); err != nil {
|
||||
|
||||
@@ -18,6 +18,7 @@ var providerAppliers = map[string]ProviderApplier{
|
||||
"codex": nil,
|
||||
"iflow": nil,
|
||||
"antigravity": nil,
|
||||
"kimi": nil,
|
||||
}
|
||||
|
||||
// GetProviderApplier returns the ProviderApplier for the given provider name.
|
||||
@@ -326,6 +327,9 @@ func extractThinkingConfig(body []byte, provider string) ThinkingConfig {
|
||||
return config
|
||||
}
|
||||
return extractOpenAIConfig(body)
|
||||
case "kimi":
|
||||
// Kimi uses OpenAI-compatible reasoning_effort format
|
||||
return extractOpenAIConfig(body)
|
||||
default:
|
||||
return ThinkingConfig{}
|
||||
}
|
||||
|
||||
126
internal/thinking/provider/kimi/apply.go
Normal file
126
internal/thinking/provider/kimi/apply.go
Normal file
@@ -0,0 +1,126 @@
|
||||
// Package kimi implements thinking configuration for Kimi (Moonshot AI) models.
|
||||
//
|
||||
// Kimi models use the OpenAI-compatible reasoning_effort format with discrete levels
|
||||
// (low/medium/high). The provider strips any existing thinking config and applies
|
||||
// the unified ThinkingConfig in OpenAI format.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// Applier implements thinking.ProviderApplier for Kimi models.
|
||||
//
|
||||
// Kimi-specific behavior:
|
||||
// - Output format: reasoning_effort (string: low/medium/high)
|
||||
// - Uses OpenAI-compatible format
|
||||
// - Supports budget-to-level conversion
|
||||
type Applier struct{}
|
||||
|
||||
var _ thinking.ProviderApplier = (*Applier)(nil)
|
||||
|
||||
// NewApplier creates a new Kimi thinking applier.
|
||||
func NewApplier() *Applier {
|
||||
return &Applier{}
|
||||
}
|
||||
|
||||
func init() {
|
||||
thinking.RegisterProvider("kimi", NewApplier())
|
||||
}
|
||||
|
||||
// Apply applies thinking configuration to Kimi request body.
|
||||
//
|
||||
// Expected output format:
|
||||
//
|
||||
// {
|
||||
// "reasoning_effort": "high"
|
||||
// }
|
||||
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||
if thinking.IsUserDefinedModel(modelInfo) {
|
||||
return applyCompatibleKimi(body, config)
|
||||
}
|
||||
if modelInfo.Thinking == nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
body = []byte(`{}`)
|
||||
}
|
||||
|
||||
var effort string
|
||||
switch config.Mode {
|
||||
case thinking.ModeLevel:
|
||||
if config.Level == "" {
|
||||
return body, nil
|
||||
}
|
||||
effort = string(config.Level)
|
||||
case thinking.ModeNone:
|
||||
// Kimi uses "none" to disable thinking
|
||||
effort = string(thinking.LevelNone)
|
||||
case thinking.ModeBudget:
|
||||
// Convert budget to level using threshold mapping
|
||||
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
effort = level
|
||||
case thinking.ModeAuto:
|
||||
// Auto mode maps to "auto" effort
|
||||
effort = string(thinking.LevelAuto)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
|
||||
if effort == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// applyCompatibleKimi applies thinking config for user-defined Kimi models.
|
||||
func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
body = []byte(`{}`)
|
||||
}
|
||||
|
||||
var effort string
|
||||
switch config.Mode {
|
||||
case thinking.ModeLevel:
|
||||
if config.Level == "" {
|
||||
return body, nil
|
||||
}
|
||||
effort = string(config.Level)
|
||||
case thinking.ModeNone:
|
||||
effort = string(thinking.LevelNone)
|
||||
if config.Level != "" {
|
||||
effort = string(config.Level)
|
||||
}
|
||||
case thinking.ModeAuto:
|
||||
effort = string(thinking.LevelAuto)
|
||||
case thinking.ModeBudget:
|
||||
// Convert budget to level
|
||||
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
effort = level
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -10,10 +10,53 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// validReasoningEffortLevels contains the standard values accepted by the
|
||||
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
|
||||
// auto) are NOT in this set and must be clamped before use.
|
||||
var validReasoningEffortLevels = map[string]struct{}{
|
||||
"none": {},
|
||||
"low": {},
|
||||
"medium": {},
|
||||
"high": {},
|
||||
}
|
||||
|
||||
// clampReasoningEffort maps any thinking level string to a value that is safe
|
||||
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
|
||||
// mapped to the nearest standard equivalent.
|
||||
//
|
||||
// Mapping rules:
|
||||
// - none / low / medium / high → returned as-is (already valid)
|
||||
// - xhigh → "high" (nearest lower standard level)
|
||||
// - minimal → "low" (nearest higher standard level)
|
||||
// - auto → "medium" (reasonable default)
|
||||
// - anything else → "medium" (safe default)
|
||||
func clampReasoningEffort(level string) string {
|
||||
if _, ok := validReasoningEffortLevels[level]; ok {
|
||||
return level
|
||||
}
|
||||
var clamped string
|
||||
switch level {
|
||||
case string(thinking.LevelXHigh):
|
||||
clamped = string(thinking.LevelHigh)
|
||||
case string(thinking.LevelMinimal):
|
||||
clamped = string(thinking.LevelLow)
|
||||
case string(thinking.LevelAuto):
|
||||
clamped = string(thinking.LevelMedium)
|
||||
default:
|
||||
clamped = string(thinking.LevelMedium)
|
||||
}
|
||||
log.WithFields(log.Fields{
|
||||
"original": level,
|
||||
"clamped": clamped,
|
||||
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
|
||||
return clamped
|
||||
}
|
||||
|
||||
// Applier implements thinking.ProviderApplier for OpenAI models.
|
||||
//
|
||||
// OpenAI-specific behavior:
|
||||
@@ -58,7 +101,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
}
|
||||
|
||||
if config.Mode == thinking.ModeLevel {
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -79,7 +122,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -114,7 +157,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
@@ -37,7 +36,7 @@ import (
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
enableThoughtTranslate := true
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
// system instruction
|
||||
systemInstructionJSON := ""
|
||||
@@ -232,8 +231,12 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
} else if functionResponseResult.IsObject() {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
} else {
|
||||
} else if functionResponseResult.Raw != "" {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
} else {
|
||||
// Content field is missing entirely — .Raw is empty which
|
||||
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
@@ -345,7 +348,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// Inject interleaved thinking hint when both tools and thinking are active
|
||||
hasTools := toolDeclCount > 0
|
||||
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && thinkingResult.Get("type").String() == "enabled"
|
||||
thinkingType := thinkingResult.Get("type").String()
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive")
|
||||
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
|
||||
|
||||
if hasTools && hasThinking && isClaudeThinking {
|
||||
@@ -378,12 +382,18 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
switch t.Get("type").String() {
|
||||
case "enabled":
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||
|
||||
@@ -661,6 +661,85 @@ func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultNoContent(t *testing.T) {
|
||||
// Bug repro: tool_result with no content field produces invalid JSON
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-opus-4-6-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "MyTool-123-456",
|
||||
"name": "MyTool",
|
||||
"input": {"key": "value"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "MyTool-123-456"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Errorf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
// Verify the functionResponse has a valid result value
|
||||
fr := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse.response.result")
|
||||
if !fr.Exists() {
|
||||
t.Error("functionResponse.response.result should exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
|
||||
// Bug repro: tool_result with null content produces invalid JSON
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-opus-4-6-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "MyTool-123-456",
|
||||
"name": "MyTool",
|
||||
"input": {"key": "value"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "MyTool-123-456",
|
||||
"content": null
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Errorf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
||||
// When tools + thinking but no system instruction, should create one with hint
|
||||
inputJSON := []byte(`{
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -34,7 +33,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
template := ""
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
|
||||
)
|
||||
|
||||
func ConvertOpenAIResponsesRequestToAntigravity(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = ConvertOpenAIResponsesRequestToGemini(modelName, rawJSON, stream)
|
||||
return ConvertGeminiRequestToAntigravity(modelName, rawJSON, stream)
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -30,7 +28,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Claude Code API format
|
||||
func ConvertGeminiCLIRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
// Extract the inner request object and promote it to the top level
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -46,7 +45,7 @@ var (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Claude Code API format
|
||||
func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
if account == "" {
|
||||
u, _ := uuid.NewRandom()
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -44,7 +43,7 @@ var (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Claude Code API format
|
||||
func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
if account == "" {
|
||||
u, _ := uuid.NewRandom()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -32,7 +31,7 @@ var (
|
||||
// - max_output_tokens -> max_tokens
|
||||
// - stream passthrough via parameter
|
||||
func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
if account == "" {
|
||||
u, _ := uuid.NewRandom()
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -35,7 +34,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in internal client format
|
||||
func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
template := `{"model":"","instructions":"","input":[]}`
|
||||
|
||||
@@ -223,6 +222,10 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
case "adaptive":
|
||||
// Claude adaptive means "enable with max capacity"; keep it as highest level
|
||||
// and let ApplyThinking normalize per target model capability.
|
||||
reasoningEffort = string(thinking.LevelXHigh)
|
||||
case "disabled":
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
|
||||
@@ -113,10 +113,10 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
|
||||
stopReason := rootResult.Get("response.stop_reason").String()
|
||||
if stopReason != "" {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", stopReason)
|
||||
} else if p {
|
||||
if p {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
|
||||
} else if stopReason == "max_tokens" || stopReason == "stop" {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", stopReason)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -30,7 +28,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Codex API format
|
||||
func ConvertGeminiCLIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
@@ -37,7 +36,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Codex API format
|
||||
func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base template
|
||||
out := `{"model":"","instructions":"","input":[]}`
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -29,7 +27,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in OpenAI Responses API format
|
||||
func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Start with empty JSON object
|
||||
out := `{"instructions":""}`
|
||||
|
||||
|
||||
@@ -20,10 +20,12 @@ var (
|
||||
|
||||
// ConvertCliToOpenAIParams holds parameters for response conversion.
|
||||
type ConvertCliToOpenAIParams struct {
|
||||
ResponseID string
|
||||
CreatedAt int64
|
||||
Model string
|
||||
FunctionCallIndex int
|
||||
ResponseID string
|
||||
CreatedAt int64
|
||||
Model string
|
||||
FunctionCallIndex int
|
||||
HasReceivedArgumentsDelta bool
|
||||
HasToolCallAnnounced bool
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToOpenAI translates a single chunk of a streaming response from the
|
||||
@@ -43,10 +45,12 @@ type ConvertCliToOpenAIParams struct {
|
||||
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
if *param == nil {
|
||||
*param = &ConvertCliToOpenAIParams{
|
||||
Model: modelName,
|
||||
CreatedAt: 0,
|
||||
ResponseID: "",
|
||||
FunctionCallIndex: -1,
|
||||
Model: modelName,
|
||||
CreatedAt: 0,
|
||||
ResponseID: "",
|
||||
FunctionCallIndex: -1,
|
||||
HasReceivedArgumentsDelta: false,
|
||||
HasToolCallAnnounced: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +94,9 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
}
|
||||
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
}
|
||||
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
}
|
||||
@@ -115,35 +122,93 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||
} else if dataType == "response.output_item.done" {
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
} else if dataType == "response.output_item.added" {
|
||||
itemResult := rootResult.Get("item")
|
||||
if itemResult.Exists() {
|
||||
if itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// set the index
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened
|
||||
name := itemResult.Get("name").String()
|
||||
// Build reverse map on demand from original request tools
|
||||
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Increment index for this new function call item.
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
|
||||
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
|
||||
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened.
|
||||
name := itemResult.Get("name").String()
|
||||
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
|
||||
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.function_call_arguments.delta" {
|
||||
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
|
||||
|
||||
deltaValue := rootResult.Get("delta").String()
|
||||
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.function_call_arguments.done" {
|
||||
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
|
||||
// Arguments were already streamed via delta events; nothing to emit.
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Fallback: no delta events were received, emit the full arguments as a single chunk.
|
||||
fullArgs := rootResult.Get("arguments").String()
|
||||
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.output_item.done" {
|
||||
itemResult := rootResult.Get("item")
|
||||
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
|
||||
// Tool call was already announced via output_item.added; skip emission.
|
||||
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Fallback path: model skipped output_item.added, so emit complete tool call now.
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened.
|
||||
name := itemResult.Get("name").String()
|
||||
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else {
|
||||
return []string{}
|
||||
}
|
||||
@@ -205,6 +270,9 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
}
|
||||
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
}
|
||||
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -9,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
inputResult := gjson.GetBytes(rawJSON, "input")
|
||||
if inputResult.Type == gjson.String {
|
||||
@@ -28,6 +27,9 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||
|
||||
// Delete the user field as it is not supported by the Codex upstream.
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
||||
|
||||
// Convert role "system" to "developer" in input array to comply with Codex API requirements.
|
||||
rawJSON = convertSystemRoleToDeveloper(rawJSON)
|
||||
|
||||
|
||||
@@ -263,3 +263,20 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) {
|
||||
t.Errorf("Expected third role 'assistant', got '%s'", thirdRole.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserFieldDeletion(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
"user": "test-user",
|
||||
"input": [{"role": "user", "content": "Hello"}]
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Verify user field is deleted
|
||||
userField := gjson.Get(outputStr, "user")
|
||||
if userField.Exists() {
|
||||
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
@@ -173,12 +173,18 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
switch t.Get("type").String() {
|
||||
case "enabled":
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -33,7 +32,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
template := ""
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -28,7 +27,7 @@ const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base envelope (no default thinkingConfig)
|
||||
out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -77,14 +78,20 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the finish reason.
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
finishReason := ""
|
||||
if stopReasonResult := gjson.GetBytes(rawJSON, "response.stop_reason"); stopReasonResult.Exists() {
|
||||
finishReason = stopReasonResult.String()
|
||||
}
|
||||
if finishReason == "" {
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
finishReason = finishReasonResult.String()
|
||||
}
|
||||
}
|
||||
finishReason = strings.ToLower(finishReason)
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
@@ -97,6 +104,14 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
// Include cached token count if present (indicates prompt caching is working)
|
||||
if cachedTokenCount > 0 {
|
||||
var err error
|
||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("gemini-cli openai response: failed to set cached_tokens: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
@@ -187,6 +202,12 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
if hasFunctionCall {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
||||
} else if finishReason != "" && (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex == 0 {
|
||||
// Only pass through specific finish reasons
|
||||
if finishReason == "max_tokens" || finishReason == "stop" {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||
}
|
||||
}
|
||||
|
||||
return []string{template}
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini-cli/gemini"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
|
||||
)
|
||||
|
||||
func ConvertOpenAIResponsesRequestToGeminiCLI(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = ConvertOpenAIResponsesRequestToGemini(modelName, rawJSON, stream)
|
||||
return ConvertGeminiRequestToGeminiCLI(modelName, rawJSON, stream)
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ const geminiClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request in Gemini CLI format.
|
||||
func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
@@ -154,12 +154,18 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled
|
||||
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
switch t.Get("type").String() {
|
||||
case "enabled":
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -19,7 +18,7 @@ import (
|
||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the internal client.
|
||||
func ConvertGeminiCLIRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String())
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -19,7 +18,7 @@ import (
|
||||
//
|
||||
// It keeps the payload otherwise unchanged.
|
||||
func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Fast path: if no contents field, only attach safety settings
|
||||
contents := gjson.GetBytes(rawJSON, "contents")
|
||||
if !contents.Exists() {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user