mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-21 16:40:22 +00:00
Compare commits
133 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6bcac3a55a | ||
|
|
d24ea4ce2a | ||
|
|
2c30c981ae | ||
|
|
aa1da8a858 | ||
|
|
f1e9a787d7 | ||
|
|
4eeec297de | ||
|
|
77cc4ce3a0 | ||
|
|
37dfea1d3f | ||
|
|
e6626c672a | ||
|
|
c66cb0afd2 | ||
|
|
fb48eee973 | ||
|
|
bb44e5ec44 | ||
|
|
c785c1a3ca | ||
|
|
0659ffab75 | ||
|
|
7cb398d167 | ||
|
|
c3e12c5e58 | ||
|
|
1825fc7503 | ||
|
|
48732ba05e | ||
|
|
acf483c9e6 | ||
|
|
3b3e0d1141 | ||
|
|
7acd428507 | ||
|
|
450d1227bd | ||
|
|
492b9c46f0 | ||
|
|
6e634fe3f9 | ||
|
|
eb7571936c | ||
|
|
5382764d8a | ||
|
|
49c8ec69d0 | ||
|
|
3b421c8181 | ||
|
|
21d2329947 | ||
|
|
0993413bab | ||
|
|
713388dd7b | ||
|
|
e6c7af0fa9 | ||
|
|
837aa6e3aa | ||
|
|
d210be06c2 | ||
|
|
afc8a0f9be | ||
|
|
af8e9ef458 | ||
|
|
cec6f993ad | ||
|
|
950de29f48 | ||
|
|
d6ec33e8e1 | ||
|
|
081cfe806e | ||
|
|
c1c62a6c04 | ||
|
|
d693d7993b | ||
|
|
5936f9895c | ||
|
|
2fdf5d2793 | ||
|
|
b3da00d2ed | ||
|
|
740277a9f2 | ||
|
|
f91807b6b9 | ||
|
|
57d18bb226 | ||
|
|
10b9c6cb8a | ||
|
|
b24786f8a7 | ||
|
|
7b0eb41ebc | ||
|
|
70949929db | ||
|
|
7c9c89dace | ||
|
|
ef5901c81b | ||
|
|
d4829c82f7 | ||
|
|
a5f4166a9b | ||
|
|
0cbfe7f457 | ||
|
|
f2b1ec4f9e | ||
|
|
1cc21cc45b | ||
|
|
07cf616e2b | ||
|
|
2b8c466e88 | ||
|
|
ca2174ea48 | ||
|
|
c09fb2a79d | ||
|
|
4445a165e9 | ||
|
|
e92e2af71a | ||
|
|
a6bdd9a652 | ||
|
|
349a6349b3 | ||
|
|
00822770ec | ||
|
|
1a0ceda0fc | ||
|
|
b9ae4ab803 | ||
|
|
72add453d2 | ||
|
|
2789396435 | ||
|
|
61da7bd981 | ||
|
|
ae4c502792 | ||
|
|
ec6068060b | ||
|
|
ecb01d3dcd | ||
|
|
22c0c00bd4 | ||
|
|
9eb3e7a6c4 | ||
|
|
357c191510 | ||
|
|
5db244af76 | ||
|
|
dc375d1b74 | ||
|
|
9c040445af | ||
|
|
fff866424e | ||
|
|
2d12becfd6 | ||
|
|
252f7e0751 | ||
|
|
b2b17528cb | ||
|
|
55f938164b | ||
|
|
76294f0c59 | ||
|
|
2bcee78c6e | ||
|
|
7fe8246a9f | ||
|
|
93fe58e31e | ||
|
|
e5b5dc870f | ||
|
|
a54877c023 | ||
|
|
bb86a0c0c4 | ||
|
|
5fa23c7f41 | ||
|
|
f9a09b7f23 | ||
|
|
b0cde626fe | ||
|
|
e42ef9a95d | ||
|
|
abf1629ec7 | ||
|
|
73dc0b10b8 | ||
|
|
2ea95266e3 | ||
|
|
922d4141c0 | ||
|
|
1f8f198c45 | ||
|
|
c55275342c | ||
|
|
9261b0c20b | ||
|
|
7cc725496e | ||
|
|
5726a99c80 | ||
|
|
b5756bf729 | ||
|
|
709d999f9f | ||
|
|
24c18614f0 | ||
|
|
603f06a762 | ||
|
|
98f0a3e3bd | ||
|
|
2c8821891c | ||
|
|
0a2555b0f3 | ||
|
|
020df41efe | ||
|
|
f31f7f701a | ||
|
|
54ad7c1b6b | ||
|
|
a45c6defa7 | ||
|
|
40bee3e8d9 | ||
|
|
93147dddeb | ||
|
|
c0f9b15a58 | ||
|
|
6f2fbdcbae | ||
|
|
65debb874f | ||
|
|
3caadac003 | ||
|
|
6a9e3a6b84 | ||
|
|
269972440a | ||
|
|
cce13e6ad2 | ||
|
|
8a565dcad8 | ||
|
|
d536110404 | ||
|
|
48e957ddff | ||
|
|
94563d622c | ||
|
|
ce0c6aa82b | ||
|
|
3c85d2a4d7 |
@@ -8,6 +8,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -26,6 +27,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
@@ -70,6 +72,7 @@ func main() {
|
|||||||
// Command-line flags to control the application's behavior.
|
// Command-line flags to control the application's behavior.
|
||||||
var login bool
|
var login bool
|
||||||
var codexLogin bool
|
var codexLogin bool
|
||||||
|
var codexDeviceLogin bool
|
||||||
var claudeLogin bool
|
var claudeLogin bool
|
||||||
var qwenLogin bool
|
var qwenLogin bool
|
||||||
var kiloLogin bool
|
var kiloLogin bool
|
||||||
@@ -89,12 +92,15 @@ func main() {
|
|||||||
var vertexImport string
|
var vertexImport string
|
||||||
var configPath string
|
var configPath string
|
||||||
var password string
|
var password string
|
||||||
|
var tuiMode bool
|
||||||
|
var standalone bool
|
||||||
var noIncognito bool
|
var noIncognito bool
|
||||||
var useIncognito bool
|
var useIncognito bool
|
||||||
|
|
||||||
// Define command-line flags for different operation modes.
|
// Define command-line flags for different operation modes.
|
||||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||||
|
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
||||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
||||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||||
@@ -116,6 +122,8 @@ func main() {
|
|||||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||||
flag.StringVar(&password, "password", "", "")
|
flag.StringVar(&password, "password", "", "")
|
||||||
|
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||||
|
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||||
|
|
||||||
flag.CommandLine.Usage = func() {
|
flag.CommandLine.Usage = func() {
|
||||||
out := flag.CommandLine.Output()
|
out := flag.CommandLine.Output()
|
||||||
@@ -496,6 +504,9 @@ func main() {
|
|||||||
} else if codexLogin {
|
} else if codexLogin {
|
||||||
// Handle Codex login
|
// Handle Codex login
|
||||||
cmd.DoCodexLogin(cfg, options)
|
cmd.DoCodexLogin(cfg, options)
|
||||||
|
} else if codexDeviceLogin {
|
||||||
|
// Handle Codex device-code login
|
||||||
|
cmd.DoCodexDeviceLogin(cfg, options)
|
||||||
} else if claudeLogin {
|
} else if claudeLogin {
|
||||||
// Handle Claude login
|
// Handle Claude login
|
||||||
cmd.DoClaudeLogin(cfg, options)
|
cmd.DoClaudeLogin(cfg, options)
|
||||||
@@ -540,15 +551,89 @@ func main() {
|
|||||||
cmd.WaitForCloudDeploy()
|
cmd.WaitForCloudDeploy()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Start the main proxy service
|
if tuiMode {
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
if standalone {
|
||||||
|
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||||
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
hook := tui.NewLogHook(2000)
|
||||||
|
hook.SetFormatter(&logging.LogFormatter{})
|
||||||
|
log.AddHook(hook)
|
||||||
|
|
||||||
// 初始化并启动 Kiro token 后台刷新
|
origStdout := os.Stdout
|
||||||
if cfg.AuthDir != "" {
|
origStderr := os.Stderr
|
||||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
origLogOutput := log.StandardLogger().Out
|
||||||
defer kiro.StopGlobalRefreshManager()
|
log.SetOutput(io.Discard)
|
||||||
|
|
||||||
|
devNull, errOpenDevNull := os.Open(os.DevNull)
|
||||||
|
if errOpenDevNull == nil {
|
||||||
|
os.Stdout = devNull
|
||||||
|
os.Stderr = devNull
|
||||||
|
}
|
||||||
|
|
||||||
|
restoreIO := func() {
|
||||||
|
os.Stdout = origStdout
|
||||||
|
os.Stderr = origStderr
|
||||||
|
log.SetOutput(origLogOutput)
|
||||||
|
if devNull != nil {
|
||||||
|
_ = devNull.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
localMgmtPassword := fmt.Sprintf("tui-%d-%d", os.Getpid(), time.Now().UnixNano())
|
||||||
|
if password == "" {
|
||||||
|
password = localMgmtPassword
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel, done := cmd.StartServiceBackground(cfg, configFilePath, password)
|
||||||
|
|
||||||
|
client := tui.NewClient(cfg.Port, password)
|
||||||
|
ready := false
|
||||||
|
backoff := 100 * time.Millisecond
|
||||||
|
for i := 0; i < 30; i++ {
|
||||||
|
if _, errGetConfig := client.GetConfig(); errGetConfig == nil {
|
||||||
|
ready = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(backoff)
|
||||||
|
if backoff < time.Second {
|
||||||
|
backoff = time.Duration(float64(backoff) * 1.5)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ready {
|
||||||
|
restoreIO()
|
||||||
|
cancel()
|
||||||
|
<-done
|
||||||
|
fmt.Fprintf(os.Stderr, "TUI error: embedded server is not ready\n")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if errRun := tui.Run(cfg.Port, password, hook, origStdout); errRun != nil {
|
||||||
|
restoreIO()
|
||||||
|
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
|
||||||
|
} else {
|
||||||
|
restoreIO()
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
<-done
|
||||||
|
} else {
|
||||||
|
// Default TUI mode: pure management client.
|
||||||
|
// The proxy server must already be running.
|
||||||
|
if errRun := tui.Run(cfg.Port, password, nil, os.Stdout); errRun != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Start the main proxy service
|
||||||
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
|
||||||
|
if cfg.AuthDir != "" {
|
||||||
|
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||||
|
defer kiro.StopGlobalRefreshManager()
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.StartService(cfg, configFilePath, password)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.StartService(cfg, configFilePath, password)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,6 +73,10 @@ proxy-url: ''
|
|||||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||||
force-model-prefix: false
|
force-model-prefix: false
|
||||||
|
|
||||||
|
# When true, forward filtered upstream response headers to downstream clients.
|
||||||
|
# Default is false (disabled).
|
||||||
|
passthrough-headers: false
|
||||||
|
|
||||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||||
request-retry: 3
|
request-retry: 3
|
||||||
|
|
||||||
@@ -160,6 +164,15 @@ nonstream-keepalive-interval: 0
|
|||||||
# sensitive-words: # optional: words to obfuscate with zero-width characters
|
# sensitive-words: # optional: words to obfuscate with zero-width characters
|
||||||
# - "API"
|
# - "API"
|
||||||
# - "proxy"
|
# - "proxy"
|
||||||
|
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||||
|
|
||||||
|
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||||
|
# These are used as fallbacks when the client does not send its own headers.
|
||||||
|
# claude-header-defaults:
|
||||||
|
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
||||||
|
# package-version: "0.74.0"
|
||||||
|
# runtime-version: "v24.3.0"
|
||||||
|
# timeout: "600"
|
||||||
|
|
||||||
# Kiro (AWS CodeWhisperer) configuration
|
# Kiro (AWS CodeWhisperer) configuration
|
||||||
# Note: Kiro API currently only operates in us-east-1 region
|
# Note: Kiro API currently only operates in us-east-1 region
|
||||||
|
|||||||
@@ -159,13 +159,13 @@ func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request,
|
|||||||
return clipexec.Response{}, errors.New("count tokens not implemented")
|
return clipexec.Response{}, errors.New("count tokens not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) {
|
||||||
ch := make(chan clipexec.StreamChunk, 1)
|
ch := make(chan clipexec.StreamChunk, 1)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
|
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
|
||||||
}()
|
}()
|
||||||
return ch, nil
|
return &clipexec.StreamResult{Chunks: ch}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, c
|
|||||||
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
|
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) {
|
||||||
return nil, errors.New("echo executor: ExecuteStream not implemented")
|
return nil, errors.New("echo executor: ExecuteStream not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
21
go.mod
21
go.mod
@@ -4,6 +4,10 @@ go 1.26.0
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/andybalholm/brotli v1.0.6
|
github.com/andybalholm/brotli v1.0.6
|
||||||
|
github.com/atotto/clipboard v0.1.4
|
||||||
|
github.com/charmbracelet/bubbles v1.0.0
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0
|
||||||
github.com/fsnotify/fsnotify v1.9.0
|
github.com/fsnotify/fsnotify v1.9.0
|
||||||
github.com/fxamacker/cbor/v2 v2.9.0
|
github.com/fxamacker/cbor/v2 v2.9.0
|
||||||
github.com/gin-gonic/gin v1.10.1
|
github.com/gin-gonic/gin v1.10.1
|
||||||
@@ -33,8 +37,16 @@ require (
|
|||||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||||
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||||
github.com/bytedance/sonic v1.11.6 // indirect
|
github.com/bytedance/sonic v1.11.6 // indirect
|
||||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||||
|
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||||
|
github.com/charmbracelet/x/ansi v0.11.6 // indirect
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||||
|
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||||
|
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
||||||
|
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||||
|
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
||||||
github.com/cloudflare/circl v1.6.1 // indirect
|
github.com/cloudflare/circl v1.6.1 // indirect
|
||||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||||
@@ -42,6 +54,7 @@ require (
|
|||||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/emirpasic/gods v1.18.1 // indirect
|
github.com/emirpasic/gods v1.18.1 // indirect
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
github.com/go-git/gcfg/v2 v2.0.2 // indirect
|
github.com/go-git/gcfg/v2 v2.0.2 // indirect
|
||||||
@@ -58,19 +71,27 @@ require (
|
|||||||
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||||
github.com/leodido/go-urn v1.4.0 // indirect
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||||
|
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||||
github.com/minio/md5-simd v1.1.2 // indirect
|
github.com/minio/md5-simd v1.1.2 // indirect
|
||||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||||
|
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||||
|
github.com/muesli/termenv v0.16.0 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
||||||
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/rs/xid v1.5.0 // indirect
|
github.com/rs/xid v1.5.0 // indirect
|
||||||
github.com/sergi/go-diff v1.4.0 // indirect
|
github.com/sergi/go-diff v1.4.0 // indirect
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
github.com/tidwall/pretty v1.2.0 // indirect
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||||
github.com/x448/float16 v0.8.4 // indirect
|
github.com/x448/float16 v0.8.4 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/sys v0.38.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
|
|||||||
45
go.sum
45
go.sum
@@ -10,10 +10,34 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
|
|||||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
||||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
||||||
|
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||||
|
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||||
|
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
|
||||||
|
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||||
|
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
|
||||||
|
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||||
|
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||||
|
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||||
|
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||||
|
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||||
|
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
|
||||||
|
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
|
||||||
|
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||||
|
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||||
|
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||||
|
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||||
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
||||||
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
||||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||||
@@ -33,6 +57,8 @@ github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o
|
|||||||
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
|
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
|
||||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
|
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
|
||||||
@@ -101,8 +127,14 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
|||||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||||
|
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||||
|
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||||
|
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||||
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
|
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
|
||||||
@@ -114,6 +146,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
|||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||||
|
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||||
|
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||||
|
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||||
|
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||||
@@ -124,6 +162,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
||||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||||
|
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||||
|
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||||
@@ -161,6 +201,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
@@ -168,12 +210,15 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
|||||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||||
|
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||||
|
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
|
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
|||||||
@@ -814,6 +814,87 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
|
||||||
|
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||||
|
if h.authManager == nil {
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Prefix *string `json:"prefix"`
|
||||||
|
ProxyURL *string `json:"proxy_url"`
|
||||||
|
Priority *int `json:"priority"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
name := strings.TrimSpace(req.Name)
|
||||||
|
if name == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
// Find auth by name or ID
|
||||||
|
var targetAuth *coreauth.Auth
|
||||||
|
if auth, ok := h.authManager.GetByID(name); ok {
|
||||||
|
targetAuth = auth
|
||||||
|
} else {
|
||||||
|
auths := h.authManager.List()
|
||||||
|
for _, auth := range auths {
|
||||||
|
if auth.FileName == name {
|
||||||
|
targetAuth = auth
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetAuth == nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
changed := false
|
||||||
|
if req.Prefix != nil {
|
||||||
|
targetAuth.Prefix = *req.Prefix
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
if req.ProxyURL != nil {
|
||||||
|
targetAuth.ProxyURL = *req.ProxyURL
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
if req.Priority != nil {
|
||||||
|
if targetAuth.Metadata == nil {
|
||||||
|
targetAuth.Metadata = make(map[string]any)
|
||||||
|
}
|
||||||
|
if *req.Priority == 0 {
|
||||||
|
delete(targetAuth.Metadata, "priority")
|
||||||
|
} else {
|
||||||
|
targetAuth.Metadata["priority"] = *req.Priority
|
||||||
|
}
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !changed {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
targetAuth.UpdatedAt = time.Now()
|
||||||
|
|
||||||
|
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||||
if h == nil || h.authManager == nil {
|
if h == nil || h.authManager == nil {
|
||||||
return
|
return
|
||||||
@@ -870,11 +951,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
|||||||
if store == nil {
|
if store == nil {
|
||||||
return "", fmt.Errorf("token store unavailable")
|
return "", fmt.Errorf("token store unavailable")
|
||||||
}
|
}
|
||||||
|
if h.postAuthHook != nil {
|
||||||
|
if err := h.postAuthHook(ctx, record); err != nil {
|
||||||
|
return "", fmt.Errorf("post-auth hook failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return store.Save(ctx, record)
|
return store.Save(ctx, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Claude authentication...")
|
fmt.Println("Initializing Claude authentication...")
|
||||||
|
|
||||||
@@ -1019,6 +1106,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
||||||
|
|
||||||
@@ -1277,6 +1365,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Codex authentication...")
|
fmt.Println("Initializing Codex authentication...")
|
||||||
|
|
||||||
@@ -1422,6 +1511,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Antigravity authentication...")
|
fmt.Println("Initializing Antigravity authentication...")
|
||||||
|
|
||||||
@@ -1586,6 +1676,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Qwen authentication...")
|
fmt.Println("Initializing Qwen authentication...")
|
||||||
|
|
||||||
@@ -1641,6 +1732,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Kimi authentication...")
|
fmt.Println("Initializing Kimi authentication...")
|
||||||
|
|
||||||
@@ -1717,6 +1809,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing iFlow authentication...")
|
fmt.Println("Initializing iFlow authentication...")
|
||||||
|
|
||||||
@@ -2440,6 +2533,14 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PopulateAuthContext extracts request info and adds it to the context
|
||||||
|
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
|
||||||
|
info := &coreauth.RequestInfo{
|
||||||
|
Query: c.Request.URL.Query(),
|
||||||
|
Headers: c.Request.Header,
|
||||||
|
}
|
||||||
|
return coreauth.WithRequestInfo(ctx, info)
|
||||||
|
}
|
||||||
const kiroCallbackPort = 9876
|
const kiroCallbackPort = 9876
|
||||||
|
|
||||||
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ type Handler struct {
|
|||||||
allowRemoteOverride bool
|
allowRemoteOverride bool
|
||||||
envSecret string
|
envSecret string
|
||||||
logDir string
|
logDir string
|
||||||
|
postAuthHook coreauth.PostAuthHook
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new management handler instance.
|
// NewHandler creates a new management handler instance.
|
||||||
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
|
|||||||
h.logDir = dir
|
h.logDir = dir
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
|
||||||
|
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
|
||||||
|
h.postAuthHook = hook
|
||||||
|
}
|
||||||
|
|
||||||
// Middleware enforces access control for management endpoints.
|
// Middleware enforces access control for management endpoints.
|
||||||
// All requests (local and remote) require a valid management key.
|
// All requests (local and remote) require a valid management key.
|
||||||
// Additionally, remote access requires allow-remote-management=true.
|
// Additionally, remote access requires allow-remote-management=true.
|
||||||
|
|||||||
@@ -15,10 +15,12 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
|
||||||
|
|
||||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||||
// It captures detailed information about the request and response, including headers and body,
|
// It captures detailed information about the request and response, including headers and body,
|
||||||
// and uses the provided RequestLogger to record this data. When logging is disabled in the
|
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
|
||||||
// logger, it still captures data so that upstream errors can be persisted.
|
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
|
||||||
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
if logger == nil {
|
if logger == nil {
|
||||||
@@ -26,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Request.Method == http.MethodGet {
|
if shouldSkipMethodForRequestLogging(c.Request) {
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -37,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
loggerEnabled := logger.IsEnabled()
|
||||||
|
|
||||||
// Capture request information
|
// Capture request information
|
||||||
requestInfo, err := captureRequestInfo(c)
|
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log error but continue processing
|
// Log error but continue processing
|
||||||
// In a real implementation, you might want to use a proper logger here
|
// In a real implementation, you might want to use a proper logger here
|
||||||
@@ -48,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
|||||||
|
|
||||||
// Create response writer wrapper
|
// Create response writer wrapper
|
||||||
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
||||||
if !logger.IsEnabled() {
|
if !loggerEnabled {
|
||||||
wrapper.logOnErrorOnly = true
|
wrapper.logOnErrorOnly = true
|
||||||
}
|
}
|
||||||
c.Writer = wrapper
|
c.Writer = wrapper
|
||||||
@@ -64,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
|
||||||
|
if req == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if req.Method != http.MethodGet {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !isResponsesWebsocketUpgrade(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isResponsesWebsocketUpgrade(req *http.Request) bool {
|
||||||
|
if req == nil || req.URL == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if req.URL.Path != "/v1/responses" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
|
||||||
|
if loggerEnabled {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if req == nil || req.Body == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
|
||||||
|
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if req.ContentLength <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
|
||||||
|
}
|
||||||
|
|
||||||
// captureRequestInfo extracts relevant information from the incoming HTTP request.
|
// captureRequestInfo extracts relevant information from the incoming HTTP request.
|
||||||
// It captures the URL, method, headers, and body. The request body is read and then
|
// It captures the URL, method, headers, and body. The request body is read and then
|
||||||
// restored so that it can be processed by subsequent handlers.
|
// restored so that it can be processed by subsequent handlers.
|
||||||
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
|
||||||
// Capture URL with sensitive query parameters masked
|
// Capture URL with sensitive query parameters masked
|
||||||
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||||
url := c.Request.URL.Path
|
url := c.Request.URL.Path
|
||||||
@@ -86,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
|||||||
|
|
||||||
// Capture request body
|
// Capture request body
|
||||||
var body []byte
|
var body []byte
|
||||||
if c.Request.Body != nil {
|
if captureBody && c.Request.Body != nil {
|
||||||
// Read the body
|
// Read the body
|
||||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
138
internal/api/middleware/request_logging_test.go
Normal file
138
internal/api/middleware/request_logging_test.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req *http.Request
|
||||||
|
skip bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil request",
|
||||||
|
req: nil,
|
||||||
|
skip: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "post request should not skip",
|
||||||
|
req: &http.Request{
|
||||||
|
Method: http.MethodPost,
|
||||||
|
URL: &url.URL{Path: "/v1/responses"},
|
||||||
|
},
|
||||||
|
skip: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "plain get should skip",
|
||||||
|
req: &http.Request{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
URL: &url.URL{Path: "/v1/models"},
|
||||||
|
Header: http.Header{},
|
||||||
|
},
|
||||||
|
skip: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "responses websocket upgrade should not skip",
|
||||||
|
req: &http.Request{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
URL: &url.URL{Path: "/v1/responses"},
|
||||||
|
Header: http.Header{"Upgrade": []string{"websocket"}},
|
||||||
|
},
|
||||||
|
skip: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "responses get without upgrade should skip",
|
||||||
|
req: &http.Request{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
URL: &url.URL{Path: "/v1/responses"},
|
||||||
|
Header: http.Header{},
|
||||||
|
},
|
||||||
|
skip: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range tests {
|
||||||
|
got := shouldSkipMethodForRequestLogging(tests[i].req)
|
||||||
|
if got != tests[i].skip {
|
||||||
|
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldCaptureRequestBody(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
loggerEnabled bool
|
||||||
|
req *http.Request
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "logger enabled always captures",
|
||||||
|
loggerEnabled: true,
|
||||||
|
req: &http.Request{
|
||||||
|
Body: io.NopCloser(strings.NewReader("{}")),
|
||||||
|
ContentLength: -1,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil request",
|
||||||
|
loggerEnabled: false,
|
||||||
|
req: nil,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "small known size json in error-only mode",
|
||||||
|
loggerEnabled: false,
|
||||||
|
req: &http.Request{
|
||||||
|
Body: io.NopCloser(strings.NewReader("{}")),
|
||||||
|
ContentLength: 2,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large known size skipped in error-only mode",
|
||||||
|
loggerEnabled: false,
|
||||||
|
req: &http.Request{
|
||||||
|
Body: io.NopCloser(strings.NewReader("x")),
|
||||||
|
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown size skipped in error-only mode",
|
||||||
|
loggerEnabled: false,
|
||||||
|
req: &http.Request{
|
||||||
|
Body: io.NopCloser(strings.NewReader("x")),
|
||||||
|
ContentLength: -1,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multipart skipped in error-only mode",
|
||||||
|
loggerEnabled: false,
|
||||||
|
req: &http.Request{
|
||||||
|
Body: io.NopCloser(strings.NewReader("x")),
|
||||||
|
ContentLength: 1,
|
||||||
|
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range tests {
|
||||||
|
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
|
||||||
|
if got != tests[i].want {
|
||||||
|
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,6 +14,8 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
||||||
|
|
||||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||||
type RequestInfo struct {
|
type RequestInfo struct {
|
||||||
URL string // URL is the request URL.
|
URL string // URL is the request URL.
|
||||||
@@ -223,8 +225,8 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
|||||||
|
|
||||||
// Only fall back to request payload hints when Content-Type is not set yet.
|
// Only fall back to request payload hints when Content-Type is not set yet.
|
||||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||||
bodyStr := string(w.requestInfo.Body)
|
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
|
||||||
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
|
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
@@ -310,7 +312,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||||
@@ -361,16 +363,32 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
|
|||||||
return time.Time{}
|
return time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||||
|
if c != nil {
|
||||||
|
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
|
||||||
|
switch value := bodyOverride.(type) {
|
||||||
|
case []byte:
|
||||||
|
if len(value) > 0 {
|
||||||
|
return bytes.Clone(value)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(value) != "" {
|
||||||
|
return []byte(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||||
|
return w.requestInfo.Body
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||||
if w.requestInfo == nil {
|
if w.requestInfo == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestBody []byte
|
|
||||||
if len(w.requestInfo.Body) > 0 {
|
|
||||||
requestBody = w.requestInfo.Body
|
|
||||||
}
|
|
||||||
|
|
||||||
if loggerWithOptions, ok := w.logger.(interface {
|
if loggerWithOptions, ok := w.logger.(interface {
|
||||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||||
}); ok {
|
}); ok {
|
||||||
|
|||||||
43
internal/api/middleware/response_writer_test.go
Normal file
43
internal/api/middleware/response_writer_test.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
wrapper := &ResponseWriterWrapper{
|
||||||
|
requestInfo: &RequestInfo{Body: []byte("original-body")},
|
||||||
|
}
|
||||||
|
|
||||||
|
body := wrapper.extractRequestBody(c)
|
||||||
|
if string(body) != "original-body" {
|
||||||
|
t.Fatalf("request body = %q, want %q", string(body), "original-body")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
|
||||||
|
body = wrapper.extractRequestBody(c)
|
||||||
|
if string(body) != "override-body" {
|
||||||
|
t.Fatalf("request body = %q, want %q", string(body), "override-body")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
wrapper := &ResponseWriterWrapper{}
|
||||||
|
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
||||||
|
|
||||||
|
body := wrapper.extractRequestBody(c)
|
||||||
|
if string(body) != "override-as-string" {
|
||||||
|
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -215,7 +215,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
|||||||
|
|
||||||
// Don't log as error for context canceled - it's usually client closing connection
|
// Don't log as error for context canceled - it's usually client closing connection
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
|
return
|
||||||
} else {
|
} else {
|
||||||
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
|
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -493,6 +493,30 @@ func TestReverseProxy_ErrorHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxy_ErrorHandler_ContextCanceled(t *testing.T) {
|
||||||
|
// Test that context.Canceled errors return 499 without generic error response
|
||||||
|
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource(""))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a canceled context to trigger the cancellation path
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // Cancel immediately
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil).WithContext(ctx)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Directly invoke the ErrorHandler with context.Canceled
|
||||||
|
proxy.ErrorHandler(rr, req, context.Canceled)
|
||||||
|
|
||||||
|
// Body should be empty for canceled requests (no JSON error response)
|
||||||
|
body := rr.Body.Bytes()
|
||||||
|
if len(body) > 0 {
|
||||||
|
t.Fatalf("expected empty body for canceled context, got: %s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
|
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
|
||||||
// Upstream returns gzipped JSON without Content-Encoding header
|
// Upstream returns gzipped JSON without Content-Encoding header
|
||||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ type serverOptionConfig struct {
|
|||||||
keepAliveEnabled bool
|
keepAliveEnabled bool
|
||||||
keepAliveTimeout time.Duration
|
keepAliveTimeout time.Duration
|
||||||
keepAliveOnTimeout func()
|
keepAliveOnTimeout func()
|
||||||
|
postAuthHook auth.PostAuthHook
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerOption customises HTTP server construction.
|
// ServerOption customises HTTP server construction.
|
||||||
@@ -112,6 +113,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithPostAuthHook registers a hook to be called after auth record creation.
|
||||||
|
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
|
||||||
|
return func(cfg *serverOptionConfig) {
|
||||||
|
cfg.postAuthHook = hook
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Server represents the main API server.
|
// Server represents the main API server.
|
||||||
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@@ -263,6 +271,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
}
|
}
|
||||||
logDir := logging.ResolveLogDirectory(cfg)
|
logDir := logging.ResolveLogDirectory(cfg)
|
||||||
s.mgmt.SetLogDirectory(logDir)
|
s.mgmt.SetLogDirectory(logDir)
|
||||||
|
if optionState.postAuthHook != nil {
|
||||||
|
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
|
||||||
|
}
|
||||||
s.localPassword = optionState.localPassword
|
s.localPassword = optionState.localPassword
|
||||||
|
|
||||||
// Setup routes
|
// Setup routes
|
||||||
@@ -285,8 +296,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
optionState.routerConfigurator(engine, s.handlers, cfg)
|
optionState.routerConfigurator(engine, s.handlers, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register management routes when configuration or environment secrets are available.
|
// Register management routes when configuration or environment secrets are available,
|
||||||
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
|
// or when a local management password is provided (e.g. TUI mode).
|
||||||
|
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
|
||||||
s.managementRoutesEnabled.Store(hasManagementSecret)
|
s.managementRoutesEnabled.Store(hasManagementSecret)
|
||||||
if hasManagementSecret {
|
if hasManagementSecret {
|
||||||
s.registerManagementRoutes()
|
s.registerManagementRoutes()
|
||||||
@@ -329,6 +341,7 @@ func (s *Server) setupRoutes() {
|
|||||||
v1.POST("/completions", openaiHandlers.Completions)
|
v1.POST("/completions", openaiHandlers.Completions)
|
||||||
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||||
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
||||||
|
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
|
||||||
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
||||||
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
|
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
|
||||||
}
|
}
|
||||||
@@ -642,6 +655,7 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||||
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
|
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
|
||||||
|
mgmt.PATCH("/auth-files/fields", s.mgmt.PatchAuthFileFields)
|
||||||
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
||||||
|
|
||||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
// OAuth configuration constants for Claude/Anthropic
|
// OAuth configuration constants for Claude/Anthropic
|
||||||
const (
|
const (
|
||||||
AuthURL = "https://claude.ai/oauth/authorize"
|
AuthURL = "https://claude.ai/oauth/authorize"
|
||||||
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
TokenURL = "https://api.anthropic.com/v1/oauth/token"
|
||||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||||
RedirectURI = "http://localhost:54545/callback"
|
RedirectURI = "http://localhost:54545/callback"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
|
|||||||
|
|
||||||
// Expire is the timestamp when the current access token expires.
|
// Expire is the timestamp when the current access token expires.
|
||||||
Expire string `json:"expired"`
|
Expire string `json:"expired"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
// Encode and write the token data as JSON
|
// Encode and write the token data as JSON
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
|||||||
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
||||||
// authorization code and PKCE verifier.
|
// authorization code and PKCE verifier.
|
||||||
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||||
|
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
|
||||||
|
// a caller-provided redirect URI. This supports alternate auth flows such as device
|
||||||
|
// login while preserving the existing token parsing and storage behavior.
|
||||||
|
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||||
if pkceCodes == nil {
|
if pkceCodes == nil {
|
||||||
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(redirectURI) == "" {
|
||||||
|
return nil, fmt.Errorf("redirect URI is required for token exchange")
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare token exchange request
|
// Prepare token exchange request
|
||||||
data := url.Values{
|
data := url.Values{
|
||||||
"grant_type": {"authorization_code"},
|
"grant_type": {"authorization_code"},
|
||||||
"client_id": {ClientID},
|
"client_id": {ClientID},
|
||||||
"code": {code},
|
"code": {code},
|
||||||
"redirect_uri": {RedirectURI},
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||||
"code_verifier": {pkceCodes.CodeVerifier},
|
"code_verifier": {pkceCodes.CodeVerifier},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return tokenData, nil
|
return tokenData, nil
|
||||||
}
|
}
|
||||||
|
if isNonRetryableRefreshErr(err) {
|
||||||
|
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
lastErr = err
|
lastErr = err
|
||||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||||
@@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
|||||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isNonRetryableRefreshErr(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
raw := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(raw, "refresh_token_reused")
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
||||||
// This is typically called after a successful token refresh to persist the new credentials.
|
// This is typically called after a successful token refresh to persist the new credentials.
|
||||||
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
||||||
|
|||||||
44
internal/auth/codex/openai_auth_test.go
Normal file
44
internal/auth/codex/openai_auth_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return f(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
auth := &CodexAuth{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
|
||||||
|
Header: make(http.Header),
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for non-retryable refresh failure")
|
||||||
|
}
|
||||||
|
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
|
||||||
|
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
|
||||||
|
}
|
||||||
|
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||||
|
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
// Expire is the timestamp when the current access token expires.
|
// Expire is the timestamp when the current access token expires.
|
||||||
Expire string `json:"expired"`
|
Expire string `json:"expired"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
|
|||||||
|
|
||||||
// Type indicates the authentication provider type, always "gemini" for this storage.
|
// Type indicates the authentication provider type, always "gemini" for this storage.
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
|
|||||||
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
misc.LogSavingCredentials(authFilePath)
|
misc.LogSavingCredentials(authFilePath)
|
||||||
ts.Type = "gemini"
|
ts.Type = "gemini"
|
||||||
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||||
return fmt.Errorf("failed to create directory: %v", err)
|
return fmt.Errorf("failed to create directory: %v", err)
|
||||||
}
|
}
|
||||||
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
enc := json.NewEncoder(f)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
if err := enc.Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
|
|||||||
Scope string `json:"scope"`
|
Scope string `json:"scope"`
|
||||||
Cookie string `json:"cookie"`
|
Cookie string `json:"cookie"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serialises the token storage to disk.
|
// SaveTokenToFile serialises the token storage to disk.
|
||||||
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
}
|
}
|
||||||
defer func() { _ = f.Close() }()
|
defer func() { _ = f.Close() }()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
|
|||||||
Expired string `json:"expired,omitempty"`
|
Expired string `json:"expired,omitempty"`
|
||||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||||
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
encoder := json.NewEncoder(f)
|
encoder := json.NewEncoder(f)
|
||||||
encoder.SetIndent("", " ")
|
encoder.SetIndent("", " ")
|
||||||
if err = encoder.Encode(ts); err != nil {
|
if err = encoder.Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
// Expire is the timestamp when the current access token expires.
|
// Expire is the timestamp when the current access token expires.
|
||||||
Expire string `json:"expired"`
|
Expire string `json:"expired"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
60
internal/cmd/openai_device_login.go
Normal file
60
internal/cmd/openai_device_login.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
codexLoginModeMetadataKey = "codex_login_mode"
|
||||||
|
codexLoginModeDevice = "device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
|
||||||
|
// existing codex-login OAuth callback flow intact.
|
||||||
|
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptFn := options.Prompt
|
||||||
|
if promptFn == nil {
|
||||||
|
promptFn = defaultProjectPrompt()
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
CallbackPort: options.CallbackPort,
|
||||||
|
Metadata: map[string]string{
|
||||||
|
codexLoginModeMetadataKey: codexLoginModeDevice,
|
||||||
|
},
|
||||||
|
Prompt: promptFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||||
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||||
|
if authErr.Type == codex.ErrPortInUse.Type {
|
||||||
|
os.Exit(codex.ErrPortInUse.Code)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("Codex device authentication failed: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
fmt.Println("Codex device authentication successful!")
|
||||||
|
}
|
||||||
@@ -55,6 +55,34 @@ func StartService(cfg *config.Config, configPath string, localPassword string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartServiceBackground starts the proxy service in a background goroutine
|
||||||
|
// and returns a cancel function for shutdown and a done channel.
|
||||||
|
func StartServiceBackground(cfg *config.Config, configPath string, localPassword string) (cancel func(), done <-chan struct{}) {
|
||||||
|
builder := cliproxy.NewBuilder().
|
||||||
|
WithConfig(cfg).
|
||||||
|
WithConfigPath(configPath).
|
||||||
|
WithLocalManagementPassword(localPassword)
|
||||||
|
|
||||||
|
ctx, cancelFn := context.WithCancel(context.Background())
|
||||||
|
doneCh := make(chan struct{})
|
||||||
|
|
||||||
|
service, err := builder.Build()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to build proxy service: %v", err)
|
||||||
|
close(doneCh)
|
||||||
|
return cancelFn, doneCh
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(doneCh)
|
||||||
|
if err := service.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
log.Errorf("proxy service exited with error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return cancelFn, doneCh
|
||||||
|
}
|
||||||
|
|
||||||
// WaitForCloudDeploy waits indefinitely for shutdown signals in cloud deploy mode
|
// WaitForCloudDeploy waits indefinitely for shutdown signals in cloud deploy mode
|
||||||
// when no configuration file is available.
|
// when no configuration file is available.
|
||||||
func WaitForCloudDeploy() {
|
func WaitForCloudDeploy() {
|
||||||
|
|||||||
@@ -97,6 +97,10 @@ type Config struct {
|
|||||||
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
||||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
||||||
|
|
||||||
|
// ClaudeHeaderDefaults configures default header values for Claude API requests.
|
||||||
|
// These are used as fallbacks when the client does not send its own headers.
|
||||||
|
ClaudeHeaderDefaults ClaudeHeaderDefaults `yaml:"claude-header-defaults" json:"claude-header-defaults"`
|
||||||
|
|
||||||
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
|
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
|
||||||
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"`
|
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"`
|
||||||
|
|
||||||
@@ -130,6 +134,15 @@ type Config struct {
|
|||||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
|
||||||
|
// when the client does not send them. Update these when Claude Code releases a new version.
|
||||||
|
type ClaudeHeaderDefaults struct {
|
||||||
|
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||||
|
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||||
|
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||||
|
Timeout string `yaml:"timeout" json:"timeout"`
|
||||||
|
}
|
||||||
|
|
||||||
// TLSConfig holds HTTPS server settings.
|
// TLSConfig holds HTTPS server settings.
|
||||||
type TLSConfig struct {
|
type TLSConfig struct {
|
||||||
// Enable toggles HTTPS server mode.
|
// Enable toggles HTTPS server mode.
|
||||||
@@ -301,6 +314,10 @@ type CloakConfig struct {
|
|||||||
// SensitiveWords is a list of words to obfuscate with zero-width characters.
|
// SensitiveWords is a list of words to obfuscate with zero-width characters.
|
||||||
// This can help bypass certain content filters.
|
// This can help bypass certain content filters.
|
||||||
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
|
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
|
||||||
|
|
||||||
|
// CacheUserID controls whether Claude user_id values are cached per API key.
|
||||||
|
// When false, a fresh random user_id is generated for every request.
|
||||||
|
CacheUserID *bool `yaml:"cache-user-id,omitempty" json:"cache-user-id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaudeKey represents the configuration for a Claude API key,
|
// ClaudeKey represents the configuration for a Claude API key,
|
||||||
@@ -368,6 +385,9 @@ type CodexKey struct {
|
|||||||
// If empty, the default Codex API URL will be used.
|
// If empty, the default Codex API URL will be used.
|
||||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||||
|
|
||||||
|
// Websockets enables the Responses API websocket transport for this credential.
|
||||||
|
Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"`
|
||||||
|
|
||||||
// ProxyURL overrides the global proxy setting for this API key if provided.
|
// ProxyURL overrides the global proxy setting for this API key if provided.
|
||||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||||
|
|
||||||
@@ -743,22 +763,24 @@ func (cfg *Config) SanitizeOAuthModelAlias() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Inject default Kiro aliases if no user-configured kiro aliases exist
|
// Inject channel defaults when the channel is absent in user config.
|
||||||
|
// Presence is checked case-insensitively and includes explicit nil/empty markers.
|
||||||
if cfg.OAuthModelAlias == nil {
|
if cfg.OAuthModelAlias == nil {
|
||||||
cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
|
cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
|
||||||
}
|
}
|
||||||
if _, hasKiro := cfg.OAuthModelAlias["kiro"]; !hasKiro {
|
hasChannel := func(channel string) bool {
|
||||||
// Check case-insensitive too
|
|
||||||
found := false
|
|
||||||
for k := range cfg.OAuthModelAlias {
|
for k := range cfg.OAuthModelAlias {
|
||||||
if strings.EqualFold(strings.TrimSpace(k), "kiro") {
|
if strings.EqualFold(strings.TrimSpace(k), channel) {
|
||||||
found = true
|
return true
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
return false
|
||||||
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
|
}
|
||||||
}
|
if !hasChannel("kiro") {
|
||||||
|
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
|
||||||
|
}
|
||||||
|
if !hasChannel("github-copilot") {
|
||||||
|
cfg.OAuthModelAlias["github-copilot"] = defaultGitHubCopilotAliases()
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(cfg.OAuthModelAlias) == 0 {
|
if len(cfg.OAuthModelAlias) == 0 {
|
||||||
|
|||||||
@@ -42,6 +42,21 @@ func defaultKiroAliases() []OAuthModelAlias {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// defaultGitHubCopilotAliases returns default oauth-model-alias entries that
|
||||||
|
// expose Claude hyphen-style IDs for GitHub Copilot Claude models.
|
||||||
|
// This keeps compatibility with clients (e.g. Claude Code) that use
|
||||||
|
// Anthropic-style model IDs like "claude-opus-4-6".
|
||||||
|
func defaultGitHubCopilotAliases() []OAuthModelAlias {
|
||||||
|
return []OAuthModelAlias{
|
||||||
|
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
|
||||||
|
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
|
||||||
|
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
|
||||||
|
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
|
||||||
|
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||||
|
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
||||||
// for the antigravity channel when neither field exists.
|
// for the antigravity channel when neither field exists.
|
||||||
func defaultAntigravityAliases() []OAuthModelAlias {
|
func defaultAntigravityAliases() []OAuthModelAlias {
|
||||||
|
|||||||
@@ -107,6 +107,44 @@ func TestSanitizeOAuthModelAlias_InjectsDefaultKiroAliases(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSanitizeOAuthModelAlias_InjectsDefaultGitHubCopilotAliases(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||||
|
"codex": {
|
||||||
|
{Name: "gpt-5", Alias: "g5"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.SanitizeOAuthModelAlias()
|
||||||
|
|
||||||
|
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
|
||||||
|
if len(copilotAliases) == 0 {
|
||||||
|
t.Fatal("expected default github-copilot aliases to be injected")
|
||||||
|
}
|
||||||
|
|
||||||
|
aliasSet := make(map[string]bool, len(copilotAliases))
|
||||||
|
for _, a := range copilotAliases {
|
||||||
|
aliasSet[a.Alias] = true
|
||||||
|
if !a.Fork {
|
||||||
|
t.Fatalf("expected all default github-copilot aliases to have fork=true, got fork=false for %q", a.Alias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expectedAliases := []string{
|
||||||
|
"claude-haiku-4-5",
|
||||||
|
"claude-opus-4-1",
|
||||||
|
"claude-opus-4-5",
|
||||||
|
"claude-opus-4-6",
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
"claude-sonnet-4-6",
|
||||||
|
}
|
||||||
|
for _, expected := range expectedAliases {
|
||||||
|
if !aliasSet[expected] {
|
||||||
|
t.Fatalf("expected default github-copilot alias %q to be present", expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
|
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
|
||||||
// When user has configured kiro aliases, defaults should NOT be injected
|
// When user has configured kiro aliases, defaults should NOT be injected
|
||||||
cfg := &Config{
|
cfg := &Config{
|
||||||
@@ -128,6 +166,26 @@ func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserGitHubCopilotAliases(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||||
|
"github-copilot": {
|
||||||
|
{Name: "claude-opus-4.6", Alias: "my-opus", Fork: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.SanitizeOAuthModelAlias()
|
||||||
|
|
||||||
|
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
|
||||||
|
if len(copilotAliases) != 1 {
|
||||||
|
t.Fatalf("expected 1 user-configured github-copilot alias, got %d", len(copilotAliases))
|
||||||
|
}
|
||||||
|
if copilotAliases[0].Alias != "my-opus" {
|
||||||
|
t.Fatalf("expected user alias to be preserved, got %q", copilotAliases[0].Alias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) {
|
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) {
|
||||||
// When user explicitly deletes kiro aliases (key exists with nil value),
|
// When user explicitly deletes kiro aliases (key exists with nil value),
|
||||||
// defaults should NOT be re-injected on subsequent sanitize calls (#222).
|
// defaults should NOT be re-injected on subsequent sanitize calls (#222).
|
||||||
@@ -154,6 +212,24 @@ func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSanitizeOAuthModelAlias_GitHubCopilotDoesNotReinjectAfterExplicitDeletion(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||||
|
"github-copilot": nil, // explicitly deleted
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.SanitizeOAuthModelAlias()
|
||||||
|
|
||||||
|
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
|
||||||
|
if len(copilotAliases) != 0 {
|
||||||
|
t.Fatalf("expected github-copilot aliases to remain empty after explicit deletion, got %d aliases", len(copilotAliases))
|
||||||
|
}
|
||||||
|
if _, exists := cfg.OAuthModelAlias["github-copilot"]; !exists {
|
||||||
|
t.Fatal("expected github-copilot key to be preserved as nil marker after sanitization")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) {
|
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) {
|
||||||
// Same as above but with empty slice instead of nil (PUT with empty body).
|
// Same as above but with empty slice instead of nil (PUT with empty body).
|
||||||
cfg := &Config{
|
cfg := &Config{
|
||||||
|
|||||||
@@ -20,6 +20,10 @@ type SDKConfig struct {
|
|||||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||||
|
|
||||||
|
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
|
||||||
|
// Default is false (disabled).
|
||||||
|
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
|
||||||
|
|
||||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package misc
|
package misc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
|
|||||||
func LogCredentialSeparator() {
|
func LogCredentialSeparator() {
|
||||||
log.Debug(credentialSeparator)
|
log.Debug(credentialSeparator)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
|
||||||
|
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
|
||||||
|
var data map[string]any
|
||||||
|
|
||||||
|
// Fast path: if source is already a map, just copy it to avoid mutation of original
|
||||||
|
if srcMap, ok := source.(map[string]any); ok {
|
||||||
|
data = make(map[string]any, len(srcMap)+len(metadata))
|
||||||
|
for k, v := range srcMap {
|
||||||
|
data[k] = v
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Slow path: marshal to JSON and back to map to respect JSON tags
|
||||||
|
temp, err := json.Marshal(source)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal source: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(temp, &data); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge extra metadata
|
||||||
|
if metadata != nil {
|
||||||
|
if data == nil {
|
||||||
|
data = make(map[string]any)
|
||||||
|
}
|
||||||
|
for k, v := range metadata {
|
||||||
|
data[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -129,7 +129,19 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
||||||
func GetGitHubCopilotModels() []*ModelInfo {
|
func GetGitHubCopilotModels() []*ModelInfo {
|
||||||
now := int64(1732752000) // 2024-11-27
|
now := int64(1732752000) // 2024-11-27
|
||||||
return []*ModelInfo{
|
gpt4oEntries := []struct {
|
||||||
|
ID string
|
||||||
|
DisplayName string
|
||||||
|
Description string
|
||||||
|
}{
|
||||||
|
{ID: "gpt-4o-2024-11-20", DisplayName: "GPT-4o (2024-11-20)", Description: "OpenAI GPT-4o 2024-11-20 via GitHub Copilot"},
|
||||||
|
{ID: "gpt-4o-2024-08-06", DisplayName: "GPT-4o (2024-08-06)", Description: "OpenAI GPT-4o 2024-08-06 via GitHub Copilot"},
|
||||||
|
{ID: "gpt-4o-2024-05-13", DisplayName: "GPT-4o (2024-05-13)", Description: "OpenAI GPT-4o 2024-05-13 via GitHub Copilot"},
|
||||||
|
{ID: "gpt-4o", DisplayName: "GPT-4o", Description: "OpenAI GPT-4o via GitHub Copilot"},
|
||||||
|
{ID: "gpt-4-o-preview", DisplayName: "GPT-4-o Preview", Description: "OpenAI GPT-4-o Preview via GitHub Copilot"},
|
||||||
|
}
|
||||||
|
|
||||||
|
models := []*ModelInfo{
|
||||||
{
|
{
|
||||||
ID: "gpt-4.1",
|
ID: "gpt-4.1",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -141,6 +153,23 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
ContextLength: 128000,
|
ContextLength: 128000,
|
||||||
MaxCompletionTokens: 16384,
|
MaxCompletionTokens: 16384,
|
||||||
},
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range gpt4oEntries {
|
||||||
|
models = append(models, &ModelInfo{
|
||||||
|
ID: entry.ID,
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: entry.DisplayName,
|
||||||
|
Description: entry.Description,
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 16384,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return append(models, []*ModelInfo{
|
||||||
{
|
{
|
||||||
ID: "gpt-5",
|
ID: "gpt-5",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -258,6 +287,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
SupportedEndpoints: []string{"/responses"},
|
SupportedEndpoints: []string{"/responses"},
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.3-codex",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-5.3 Codex",
|
||||||
|
Description: "OpenAI GPT-5.3 Codex via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/responses"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-haiku-4.5",
|
ID: "claude-haiku-4.5",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -330,6 +372,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-sonnet-4.6",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "Claude Sonnet 4.6",
|
||||||
|
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-2.5-pro",
|
ID: "gemini-2.5-pro",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -352,6 +406,17 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
ContextLength: 1048576,
|
ContextLength: 1048576,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3.1-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "Gemini 3.1 Pro (Preview)",
|
||||||
|
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
|
||||||
|
ContextLength: 1048576,
|
||||||
|
MaxCompletionTokens: 65536,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-flash-preview",
|
ID: "gemini-3-flash-preview",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -386,7 +451,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
MaxCompletionTokens: 16384,
|
MaxCompletionTokens: 16384,
|
||||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||||
},
|
},
|
||||||
}
|
}...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
|
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
|
||||||
@@ -417,6 +482,18 @@ func GetKiroModels() []*ModelInfo {
|
|||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-sonnet-4-6",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1739836800, // 2025-02-18
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Sonnet 4.6",
|
||||||
|
Description: "Claude Sonnet 4.6 via Kiro (1.3x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "kiro-claude-opus-4-5",
|
ID: "kiro-claude-opus-4-5",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -559,6 +636,18 @@ func GetKiroModels() []*ModelInfo {
|
|||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-sonnet-4-6-agentic",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1739836800, // 2025-02-18
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Sonnet 4.6 (Agentic)",
|
||||||
|
Description: "Claude Sonnet 4.6 optimized for coding agents (chunked writes)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "kiro-claude-opus-4-5-agentic",
|
ID: "kiro-claude-opus-4-5-agentic",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
|
|||||||
@@ -28,6 +28,17 @@ func GetClaudeModels() []*ModelInfo {
|
|||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-sonnet-4-6",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1771372800, // 2026-02-17
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 4.6 Sonnet",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-opus-4-6",
|
ID: "claude-opus-4-6",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -40,6 +51,18 @@ func GetClaudeModels() []*ModelInfo {
|
|||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-sonnet-4-6",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1771286400, // 2026-02-17
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 4.6 Sonnet",
|
||||||
|
Description: "Best combination of speed and intelligence",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-opus-4-5-20251101",
|
ID: "claude-opus-4-5-20251101",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -173,6 +196,21 @@ func GetGeminiModels() []*ModelInfo {
|
|||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3.1-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1771459200,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3.1-pro-preview",
|
||||||
|
Version: "3.1",
|
||||||
|
DisplayName: "Gemini 3.1 Pro Preview",
|
||||||
|
Description: "Gemini 3.1 Pro Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-flash-preview",
|
ID: "gemini-3-flash-preview",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -283,6 +321,21 @@ func GetGeminiVertexModels() []*ModelInfo {
|
|||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3.1-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1771459200,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3.1-pro-preview",
|
||||||
|
Version: "3.1",
|
||||||
|
DisplayName: "Gemini 3.1 Pro Preview",
|
||||||
|
Description: "Gemini 3.1 Pro Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-pro-image-preview",
|
ID: "gemini-3-pro-image-preview",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -425,6 +478,21 @@ func GetGeminiCLIModels() []*ModelInfo {
|
|||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3.1-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1771459200,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3.1-pro-preview",
|
||||||
|
Version: "3.1",
|
||||||
|
DisplayName: "Gemini 3.1 Pro Preview",
|
||||||
|
Description: "Gemini 3.1 Pro Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-flash-preview",
|
ID: "gemini-3-flash-preview",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -506,6 +574,21 @@ func GetAIStudioModels() []*ModelInfo {
|
|||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3.1-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1771459200,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3.1-pro-preview",
|
||||||
|
Version: "3.1",
|
||||||
|
DisplayName: "Gemini 3.1 Pro Preview",
|
||||||
|
Description: "Gemini 3.1 Pro Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-flash-preview",
|
ID: "gemini-3-flash-preview",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -892,11 +975,14 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
|||||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
|
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
|
||||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||||
|
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
|
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
|
||||||
|
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"gpt-oss-120b-medium": {},
|
"gpt-oss-120b-medium": {},
|
||||||
"tab_flash_lite_preview": {},
|
"tab_flash_lite_preview": {},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -164,12 +164,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
|
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming request to the AI Studio API.
|
// ExecuteStream performs a streaming request to the AI Studio API.
|
||||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -254,7 +254,6 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
|
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func(first wsrelay.StreamEvent) {
|
go func(first wsrelay.StreamEvent) {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
var param any
|
var param any
|
||||||
@@ -318,7 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(firstEvent)
|
}(firstEvent)
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokens counts tokens for the given request using the AI Studio API.
|
// CountTokens counts tokens for the given request using the AI Studio API.
|
||||||
|
|||||||
@@ -232,7 +232,7 @@ attemptLoop:
|
|||||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||||
var param any
|
var param any
|
||||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
@@ -436,7 +436,7 @@ attemptLoop:
|
|||||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||||
var param any
|
var param any
|
||||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
@@ -645,7 +645,7 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming request to the Antigravity API.
|
// ExecuteStream performs a streaming request to the Antigravity API.
|
||||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -775,7 +775,6 @@ attemptLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func(resp *http.Response) {
|
go func(resp *http.Response) {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -820,7 +819,7 @@ attemptLoop:
|
|||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
}
|
}
|
||||||
}(httpResp)
|
}(httpResp)
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
@@ -968,7 +967,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -116,7 +117,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
|
|
||||||
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
||||||
// based on client type and configuration.
|
// based on client type and configuration.
|
||||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
@@ -134,7 +135,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
extraBetas, body = extractAndRemoveBetas(body)
|
extraBetas, body = extractAndRemoveBetas(body)
|
||||||
bodyForTranslation := body
|
bodyForTranslation := body
|
||||||
bodyForUpstream := body
|
bodyForUpstream := body
|
||||||
if isClaudeOAuthToken(apiKey) {
|
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,7 +144,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
@@ -208,7 +209,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
} else {
|
} else {
|
||||||
reporter.publish(ctx, parseClaudeUsage(data))
|
reporter.publish(ctx, parseClaudeUsage(data))
|
||||||
}
|
}
|
||||||
if isClaudeOAuthToken(apiKey) {
|
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||||
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
|
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
|
||||||
}
|
}
|
||||||
var param any
|
var param any
|
||||||
@@ -222,11 +223,11 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
data,
|
data,
|
||||||
¶m,
|
¶m,
|
||||||
)
|
)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -257,7 +258,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
|
|
||||||
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
||||||
// based on client type and configuration.
|
// based on client type and configuration.
|
||||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
@@ -275,7 +276,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
extraBetas, body = extractAndRemoveBetas(body)
|
extraBetas, body = extractAndRemoveBetas(body)
|
||||||
bodyForTranslation := body
|
bodyForTranslation := body
|
||||||
bodyForUpstream := body
|
bodyForUpstream := body
|
||||||
if isClaudeOAuthToken(apiKey) {
|
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -284,7 +285,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas)
|
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas, e.cfg)
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
@@ -329,7 +330,6 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -348,7 +348,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.publish(ctx, detail)
|
||||||
}
|
}
|
||||||
if isClaudeOAuthToken(apiKey) {
|
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||||
}
|
}
|
||||||
// Forward the line as-is to preserve SSE format
|
// Forward the line as-is to preserve SSE format
|
||||||
@@ -375,7 +375,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.publish(ctx, detail)
|
||||||
}
|
}
|
||||||
if isClaudeOAuthToken(apiKey) {
|
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||||
}
|
}
|
||||||
chunks := sdktranslator.TranslateStream(
|
chunks := sdktranslator.TranslateStream(
|
||||||
@@ -398,7 +398,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
@@ -423,7 +423,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
// Extract betas from body and convert to header (for count_tokens too)
|
// Extract betas from body and convert to header (for count_tokens too)
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
extraBetas, body = extractAndRemoveBetas(body)
|
extraBetas, body = extractAndRemoveBetas(body)
|
||||||
if isClaudeOAuthToken(apiKey) {
|
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||||
body = applyClaudeToolPrefix(body, claudeToolPrefix)
|
body = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -432,7 +432,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
@@ -487,7 +487,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
count := gjson.GetBytes(data, "input_tokens").Int()
|
count := gjson.GetBytes(data, "input_tokens").Int()
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
@@ -638,7 +638,49 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
|
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
||||||
|
func mapStainlessOS() string {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
return "MacOS"
|
||||||
|
case "windows":
|
||||||
|
return "Windows"
|
||||||
|
case "linux":
|
||||||
|
return "Linux"
|
||||||
|
case "freebsd":
|
||||||
|
return "FreeBSD"
|
||||||
|
default:
|
||||||
|
return "Other::" + runtime.GOOS
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
||||||
|
func mapStainlessArch() string {
|
||||||
|
switch runtime.GOARCH {
|
||||||
|
case "amd64":
|
||||||
|
return "x64"
|
||||||
|
case "arm64":
|
||||||
|
return "arm64"
|
||||||
|
case "386":
|
||||||
|
return "x86"
|
||||||
|
default:
|
||||||
|
return "other::" + runtime.GOARCH
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
|
||||||
|
hdrDefault := func(cfgVal, fallback string) string {
|
||||||
|
if cfgVal != "" {
|
||||||
|
return cfgVal
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
var hd config.ClaudeHeaderDefaults
|
||||||
|
if cfg != nil {
|
||||||
|
hd = cfg.ClaudeHeaderDefaults
|
||||||
|
}
|
||||||
|
|
||||||
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
|
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
|
||||||
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
|
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
|
||||||
if isAnthropicBase && useAPIKey {
|
if isAnthropicBase && useAPIKey {
|
||||||
@@ -685,16 +727,17 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
|||||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||||
|
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", "v24.3.0")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", "0.55.1")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", "arm64")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", "MacOS")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", "60")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "claude-cli/1.0.83 (external, cli)")
|
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
|
||||||
r.Header.Set("Connection", "keep-alive")
|
r.Header.Set("Connection", "keep-alive")
|
||||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||||
if stream {
|
if stream {
|
||||||
@@ -702,6 +745,8 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
|||||||
} else {
|
} else {
|
||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
}
|
}
|
||||||
|
// Keep OS/Arch mapping dynamic (not configurable).
|
||||||
|
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
|
||||||
var attrs map[string]string
|
var attrs map[string]string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
attrs = auth.Attributes
|
attrs = auth.Attributes
|
||||||
@@ -753,11 +798,21 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Collect built-in tool names (those with a non-empty "type" field) so we can
|
||||||
|
// skip them consistently in both tools and message history.
|
||||||
|
builtinTools := map[string]bool{}
|
||||||
|
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
|
||||||
|
builtinTools[name] = true
|
||||||
|
}
|
||||||
|
|
||||||
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
||||||
tools.ForEach(func(index, tool gjson.Result) bool {
|
tools.ForEach(func(index, tool gjson.Result) bool {
|
||||||
// Skip built-in tools (web_search, code_execution, etc.) which have
|
// Skip built-in tools (web_search, code_execution, etc.) which have
|
||||||
// a "type" field and require their name to remain unchanged.
|
// a "type" field and require their name to remain unchanged.
|
||||||
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
|
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
|
||||||
|
if n := tool.Get("name").String(); n != "" {
|
||||||
|
builtinTools[n] = true
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
name := tool.Get("name").String()
|
name := tool.Get("name").String()
|
||||||
@@ -772,7 +827,7 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
|||||||
|
|
||||||
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
|
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
|
||||||
name := gjson.GetBytes(body, "tool_choice.name").String()
|
name := gjson.GetBytes(body, "tool_choice.name").String()
|
||||||
if name != "" && !strings.HasPrefix(name, prefix) {
|
if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] {
|
||||||
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
|
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -784,15 +839,38 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
content.ForEach(func(contentIndex, part gjson.Result) bool {
|
content.ForEach(func(contentIndex, part gjson.Result) bool {
|
||||||
if part.Get("type").String() != "tool_use" {
|
partType := part.Get("type").String()
|
||||||
return true
|
switch partType {
|
||||||
|
case "tool_use":
|
||||||
|
name := part.Get("name").String()
|
||||||
|
if name == "" || strings.HasPrefix(name, prefix) || builtinTools[name] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
||||||
|
body, _ = sjson.SetBytes(body, path, prefix+name)
|
||||||
|
case "tool_reference":
|
||||||
|
toolName := part.Get("tool_name").String()
|
||||||
|
if toolName == "" || strings.HasPrefix(toolName, prefix) || builtinTools[toolName] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
|
||||||
|
body, _ = sjson.SetBytes(body, path, prefix+toolName)
|
||||||
|
case "tool_result":
|
||||||
|
// Handle nested tool_reference blocks inside tool_result.content[]
|
||||||
|
nestedContent := part.Get("content")
|
||||||
|
if nestedContent.Exists() && nestedContent.IsArray() {
|
||||||
|
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
|
||||||
|
if nestedPart.Get("type").String() == "tool_reference" {
|
||||||
|
nestedToolName := nestedPart.Get("tool_name").String()
|
||||||
|
if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) && !builtinTools[nestedToolName] {
|
||||||
|
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
|
||||||
|
body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
name := part.Get("name").String()
|
|
||||||
if name == "" || strings.HasPrefix(name, prefix) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
|
||||||
body, _ = sjson.SetBytes(body, path, prefix+name)
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return true
|
return true
|
||||||
@@ -811,15 +889,38 @@ func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
content.ForEach(func(index, part gjson.Result) bool {
|
content.ForEach(func(index, part gjson.Result) bool {
|
||||||
if part.Get("type").String() != "tool_use" {
|
partType := part.Get("type").String()
|
||||||
return true
|
switch partType {
|
||||||
|
case "tool_use":
|
||||||
|
name := part.Get("name").String()
|
||||||
|
if !strings.HasPrefix(name, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
path := fmt.Sprintf("content.%d.name", index.Int())
|
||||||
|
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
|
||||||
|
case "tool_reference":
|
||||||
|
toolName := part.Get("tool_name").String()
|
||||||
|
if !strings.HasPrefix(toolName, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
path := fmt.Sprintf("content.%d.tool_name", index.Int())
|
||||||
|
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix))
|
||||||
|
case "tool_result":
|
||||||
|
// Handle nested tool_reference blocks inside tool_result.content[]
|
||||||
|
nestedContent := part.Get("content")
|
||||||
|
if nestedContent.Exists() && nestedContent.IsArray() {
|
||||||
|
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
|
||||||
|
if nestedPart.Get("type").String() == "tool_reference" {
|
||||||
|
nestedToolName := nestedPart.Get("tool_name").String()
|
||||||
|
if strings.HasPrefix(nestedToolName, prefix) {
|
||||||
|
nestedPath := fmt.Sprintf("content.%d.content.%d.tool_name", index.Int(), nestedIndex.Int())
|
||||||
|
body, _ = sjson.SetBytes(body, nestedPath, strings.TrimPrefix(nestedToolName, prefix))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
name := part.Get("name").String()
|
|
||||||
if !strings.HasPrefix(name, prefix) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
path := fmt.Sprintf("content.%d.name", index.Int())
|
|
||||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return body
|
return body
|
||||||
@@ -834,15 +935,34 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
|
|||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
contentBlock := gjson.GetBytes(payload, "content_block")
|
contentBlock := gjson.GetBytes(payload, "content_block")
|
||||||
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
|
if !contentBlock.Exists() {
|
||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
name := contentBlock.Get("name").String()
|
|
||||||
if !strings.HasPrefix(name, prefix) {
|
blockType := contentBlock.Get("type").String()
|
||||||
return line
|
var updated []byte
|
||||||
}
|
var err error
|
||||||
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
|
|
||||||
if err != nil {
|
switch blockType {
|
||||||
|
case "tool_use":
|
||||||
|
name := contentBlock.Get("name").String()
|
||||||
|
if !strings.HasPrefix(name, prefix) {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
updated, err = sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
|
||||||
|
if err != nil {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
case "tool_reference":
|
||||||
|
toolName := contentBlock.Get("tool_name").String()
|
||||||
|
if !strings.HasPrefix(toolName, prefix) {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
updated, err = sjson.SetBytes(payload, "content_block.tool_name", strings.TrimPrefix(toolName, prefix))
|
||||||
|
if err != nil {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
default:
|
||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -862,10 +982,10 @@ func getClientUserAgent(ctx context.Context) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
|
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
|
||||||
// Returns (cloakMode, strictMode, sensitiveWords).
|
// Returns (cloakMode, strictMode, sensitiveWords, cacheUserID).
|
||||||
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
|
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bool) {
|
||||||
if auth == nil || auth.Attributes == nil {
|
if auth == nil || auth.Attributes == nil {
|
||||||
return "auto", false, nil
|
return "auto", false, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
cloakMode := auth.Attributes["cloak_mode"]
|
cloakMode := auth.Attributes["cloak_mode"]
|
||||||
@@ -883,7 +1003,9 @@ func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return cloakMode, strictMode, sensitiveWords
|
cacheUserID := strings.EqualFold(strings.TrimSpace(auth.Attributes["cloak_cache_user_id"]), "true")
|
||||||
|
|
||||||
|
return cloakMode, strictMode, sensitiveWords, cacheUserID
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
|
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
|
||||||
@@ -916,16 +1038,24 @@ func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *c
|
|||||||
}
|
}
|
||||||
|
|
||||||
// injectFakeUserID generates and injects a fake user ID into the request metadata.
|
// injectFakeUserID generates and injects a fake user ID into the request metadata.
|
||||||
func injectFakeUserID(payload []byte) []byte {
|
// When useCache is false, a new user ID is generated for every call.
|
||||||
|
func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
|
||||||
|
generateID := func() string {
|
||||||
|
if useCache {
|
||||||
|
return cachedUserID(apiKey)
|
||||||
|
}
|
||||||
|
return generateFakeUserID()
|
||||||
|
}
|
||||||
|
|
||||||
metadata := gjson.GetBytes(payload, "metadata")
|
metadata := gjson.GetBytes(payload, "metadata")
|
||||||
if !metadata.Exists() {
|
if !metadata.Exists() {
|
||||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
|
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
|
|
||||||
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
|
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||||
if existingUserID == "" || !isValidUserID(existingUserID) {
|
if existingUserID == "" || !isValidUserID(existingUserID) {
|
||||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
|
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
|
||||||
}
|
}
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
@@ -962,7 +1092,7 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
|||||||
|
|
||||||
// applyCloaking applies cloaking transformations to the payload based on config and client.
|
// applyCloaking applies cloaking transformations to the payload based on config and client.
|
||||||
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
|
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
|
||||||
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string) []byte {
|
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string, apiKey string) []byte {
|
||||||
clientUserAgent := getClientUserAgent(ctx)
|
clientUserAgent := getClientUserAgent(ctx)
|
||||||
|
|
||||||
// Get cloak config from ClaudeKey configuration
|
// Get cloak config from ClaudeKey configuration
|
||||||
@@ -972,16 +1102,20 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
|||||||
var cloakMode string
|
var cloakMode string
|
||||||
var strictMode bool
|
var strictMode bool
|
||||||
var sensitiveWords []string
|
var sensitiveWords []string
|
||||||
|
var cacheUserID bool
|
||||||
|
|
||||||
if cloakCfg != nil {
|
if cloakCfg != nil {
|
||||||
cloakMode = cloakCfg.Mode
|
cloakMode = cloakCfg.Mode
|
||||||
strictMode = cloakCfg.StrictMode
|
strictMode = cloakCfg.StrictMode
|
||||||
sensitiveWords = cloakCfg.SensitiveWords
|
sensitiveWords = cloakCfg.SensitiveWords
|
||||||
|
if cloakCfg.CacheUserID != nil {
|
||||||
|
cacheUserID = *cloakCfg.CacheUserID
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to auth attributes if no config found
|
// Fallback to auth attributes if no config found
|
||||||
if cloakMode == "" {
|
if cloakMode == "" {
|
||||||
attrMode, attrStrict, attrWords := getCloakConfigFromAuth(auth)
|
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
|
||||||
cloakMode = attrMode
|
cloakMode = attrMode
|
||||||
if !strictMode {
|
if !strictMode {
|
||||||
strictMode = attrStrict
|
strictMode = attrStrict
|
||||||
@@ -989,6 +1123,12 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
|||||||
if len(sensitiveWords) == 0 {
|
if len(sensitiveWords) == 0 {
|
||||||
sensitiveWords = attrWords
|
sensitiveWords = attrWords
|
||||||
}
|
}
|
||||||
|
if cloakCfg == nil || cloakCfg.CacheUserID == nil {
|
||||||
|
cacheUserID = attrCache
|
||||||
|
}
|
||||||
|
} else if cloakCfg == nil || cloakCfg.CacheUserID == nil {
|
||||||
|
_, _, _, attrCache := getCloakConfigFromAuth(auth)
|
||||||
|
cacheUserID = attrCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if cloaking should be applied
|
// Determine if cloaking should be applied
|
||||||
@@ -1002,7 +1142,7 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Inject fake user ID
|
// Inject fake user ID
|
||||||
payload = injectFakeUserID(payload)
|
payload = injectFakeUserID(payload, apiKey, cacheUserID)
|
||||||
|
|
||||||
// Apply sensitive word obfuscation
|
// Apply sensitive word obfuscation
|
||||||
if len(sensitiveWords) > 0 {
|
if len(sensitiveWords) > 0 {
|
||||||
|
|||||||
@@ -2,9 +2,18 @@ package executor
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestApplyClaudeToolPrefix(t *testing.T) {
|
func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||||
@@ -25,6 +34,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix_WithToolReference(t *testing.T) {
|
||||||
|
input := []byte(`{"tools":[{"name":"alpha"}],"messages":[{"role":"user","content":[{"type":"tool_reference","tool_name":"beta"},{"type":"tool_reference","tool_name":"proxy_gamma"}]}]}`)
|
||||||
|
out := applyClaudeToolPrefix(input, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.content.0.tool_name").String(); got != "proxy_beta" {
|
||||||
|
t.Fatalf("messages.0.content.0.tool_name = %q, want %q", got, "proxy_beta")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != "proxy_gamma" {
|
||||||
|
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, "proxy_gamma")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||||
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
|
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
|
||||||
out := applyClaudeToolPrefix(input, "proxy_")
|
out := applyClaudeToolPrefix(input, "proxy_")
|
||||||
@@ -37,6 +58,97 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix_BuiltinToolSkipped(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"tools": [
|
||||||
|
{"type": "web_search_20250305", "name": "web_search", "max_uses": 5},
|
||||||
|
{"name": "Read"}
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}},
|
||||||
|
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}}
|
||||||
|
]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
out := applyClaudeToolPrefix(body, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
|
||||||
|
t.Fatalf("tools.0.name = %q, want %q", got, "web_search")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
|
||||||
|
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Read" {
|
||||||
|
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Read")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Read" {
|
||||||
|
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Read")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix_KnownBuiltinInHistoryOnly(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"tools": [
|
||||||
|
{"name": "Read"}
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}}
|
||||||
|
]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
out := applyClaudeToolPrefix(body, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
|
||||||
|
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||||
|
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix_CustomToolsPrefixed(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"tools": [{"name": "Read"}, {"name": "Write"}],
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}},
|
||||||
|
{"type": "tool_use", "name": "Write", "id": "w1", "input": {}}
|
||||||
|
]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
out := applyClaudeToolPrefix(body, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||||
|
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Write" {
|
||||||
|
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Write")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Read" {
|
||||||
|
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Read")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Write" {
|
||||||
|
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Write")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"tools": [
|
||||||
|
{"type": "web_search_20250305", "name": "web_search"},
|
||||||
|
{"name": "Read"}
|
||||||
|
],
|
||||||
|
"tool_choice": {"type": "tool", "name": "web_search"}
|
||||||
|
}`)
|
||||||
|
out := applyClaudeToolPrefix(body, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "web_search" {
|
||||||
|
t.Fatalf("tool_choice.name = %q, want %q", got, "web_search")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||||
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||||
@@ -49,6 +161,18 @@ func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStripClaudeToolPrefixFromResponse_WithToolReference(t *testing.T) {
|
||||||
|
input := []byte(`{"content":[{"type":"tool_reference","tool_name":"proxy_alpha"},{"type":"tool_reference","tool_name":"bravo"}]}`)
|
||||||
|
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "content.0.tool_name").String(); got != "alpha" {
|
||||||
|
t.Fatalf("content.0.tool_name = %q, want %q", got, "alpha")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "content.1.tool_name").String(); got != "bravo" {
|
||||||
|
t.Fatalf("content.1.tool_name = %q, want %q", got, "bravo")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||||
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
|
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
|
||||||
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||||
@@ -61,3 +185,166 @@ func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
|||||||
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
|
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStripClaudeToolPrefixFromStreamLine_WithToolReference(t *testing.T) {
|
||||||
|
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_reference","tool_name":"proxy_beta"},"index":0}`)
|
||||||
|
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||||
|
|
||||||
|
payload := bytes.TrimSpace(out)
|
||||||
|
if bytes.HasPrefix(payload, []byte("data:")) {
|
||||||
|
payload = bytes.TrimSpace(payload[len("data:"):])
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(payload, "content_block.tool_name").String(); got != "beta" {
|
||||||
|
t.Fatalf("content_block.tool_name = %q, want %q", got, "beta")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
|
||||||
|
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"mcp__nia__manage_resource"}]}]}]}`)
|
||||||
|
out := applyClaudeToolPrefix(input, "proxy_")
|
||||||
|
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
|
||||||
|
if got != "proxy_mcp__nia__manage_resource" {
|
||||||
|
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "proxy_mcp__nia__manage_resource")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
|
||||||
|
resetUserIDCache()
|
||||||
|
|
||||||
|
var userIDs []string
|
||||||
|
var requestModels []string
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
userID := gjson.GetBytes(body, "metadata.user_id").String()
|
||||||
|
model := gjson.GetBytes(body, "model").String()
|
||||||
|
userIDs = append(userIDs, userID)
|
||||||
|
requestModels = append(requestModels, model)
|
||||||
|
t.Logf("HTTP Server received request: model=%s, user_id=%s, url=%s", model, userID, r.URL.String())
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Logf("End-to-end test: Fake HTTP server started at %s", server.URL)
|
||||||
|
|
||||||
|
cacheEnabled := true
|
||||||
|
executor := NewClaudeExecutor(&config.Config{
|
||||||
|
ClaudeKey: []config.ClaudeKey{
|
||||||
|
{
|
||||||
|
APIKey: "key-123",
|
||||||
|
BaseURL: server.URL,
|
||||||
|
Cloak: &config.CloakConfig{
|
||||||
|
CacheUserID: &cacheEnabled,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
models := []string{"claude-3-5-sonnet", "claude-3-5-haiku"}
|
||||||
|
for _, model := range models {
|
||||||
|
t.Logf("Sending request for model: %s", model)
|
||||||
|
modelPayload, _ := sjson.SetBytes(payload, "model", model)
|
||||||
|
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: model,
|
||||||
|
Payload: modelPayload,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("Execute(%s) error: %v", model, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(userIDs) != 2 {
|
||||||
|
t.Fatalf("expected 2 requests, got %d", len(userIDs))
|
||||||
|
}
|
||||||
|
if userIDs[0] == "" || userIDs[1] == "" {
|
||||||
|
t.Fatal("expected user_id to be populated")
|
||||||
|
}
|
||||||
|
t.Logf("user_id[0] (model=%s): %s", requestModels[0], userIDs[0])
|
||||||
|
t.Logf("user_id[1] (model=%s): %s", requestModels[1], userIDs[1])
|
||||||
|
if userIDs[0] != userIDs[1] {
|
||||||
|
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
|
||||||
|
}
|
||||||
|
if !isValidUserID(userIDs[0]) {
|
||||||
|
t.Fatalf("user_id %q is not valid", userIDs[0])
|
||||||
|
}
|
||||||
|
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
|
||||||
|
resetUserIDCache()
|
||||||
|
|
||||||
|
var userIDs []string
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
userIDs = append(userIDs, gjson.GetBytes(body, "metadata.user_id").String())
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("Execute call %d error: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(userIDs) != 2 {
|
||||||
|
t.Fatalf("expected 2 requests, got %d", len(userIDs))
|
||||||
|
}
|
||||||
|
if userIDs[0] == "" || userIDs[1] == "" {
|
||||||
|
t.Fatal("expected user_id to be populated")
|
||||||
|
}
|
||||||
|
if userIDs[0] == userIDs[1] {
|
||||||
|
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
|
||||||
|
}
|
||||||
|
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
|
||||||
|
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
|
||||||
|
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
|
||||||
|
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||||
|
got := gjson.GetBytes(out, "content.0.content.0.tool_name").String()
|
||||||
|
if got != "mcp__nia__manage_resource" {
|
||||||
|
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "mcp__nia__manage_resource")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T) {
|
||||||
|
// tool_result.content can be a string - should not be processed
|
||||||
|
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"plain string result"}]}]}`)
|
||||||
|
out := applyClaudeToolPrefix(input, "proxy_")
|
||||||
|
got := gjson.GetBytes(out, "messages.0.content.0.content").String()
|
||||||
|
if got != "plain string result" {
|
||||||
|
t.Fatalf("string content should remain unchanged = %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
|
||||||
|
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`)
|
||||||
|
out := applyClaudeToolPrefix(input, "proxy_")
|
||||||
|
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
|
||||||
|
if got != "web_search" {
|
||||||
|
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
|
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
|
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
|
||||||
@@ -273,11 +273,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
|||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
|
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
|
||||||
}
|
}
|
||||||
@@ -362,7 +362,6 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -397,7 +396,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
@@ -643,7 +642,6 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "Openai-Beta", "responses=experimental")
|
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
|
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
|
||||||
|
|
||||||
|
|||||||
1408
internal/runtime/executor/codex_websockets_executor.go
Normal file
1408
internal/runtime/executor/codex_websockets_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -225,7 +225,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,7 +256,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming request to the Gemini CLI API.
|
// ExecuteStream performs a streaming request to the Gemini CLI API.
|
||||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -382,7 +382,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func(resp *http.Response, reqBody []byte, attemptModel string) {
|
go func(resp *http.Response, reqBody []byte, attemptModel string) {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -441,7 +440,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
||||||
|
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(lastBody) > 0 {
|
if len(lastBody) > 0 {
|
||||||
@@ -546,7 +545,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
|||||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
lastStatus = resp.StatusCode
|
lastStatus = resp.StatusCode
|
||||||
lastBody = append([]byte(nil), data...)
|
lastBody = append([]byte(nil), data...)
|
||||||
|
|||||||
@@ -205,12 +205,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
reporter.publish(ctx, parseGeminiUsage(data))
|
reporter.publish(ctx, parseGeminiUsage(data))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming request to the Gemini API.
|
// ExecuteStream performs a streaming request to the Gemini API.
|
||||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -298,7 +298,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -335,7 +334,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokens counts tokens for the given request using the Gemini API.
|
// CountTokens counts tokens for the given request using the Gemini API.
|
||||||
@@ -416,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
|
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
||||||
|
|||||||
@@ -253,7 +253,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming request to the Vertex AI API.
|
// ExecuteStream performs a streaming request to the Vertex AI API.
|
||||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -524,12 +524,12 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
reporter.publish(ctx, parseGeminiUsage(data))
|
reporter.publish(ctx, parseGeminiUsage(data))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
||||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
@@ -618,7 +618,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -650,11 +649,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
||||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
@@ -743,7 +742,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -775,7 +773,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// countTokensWithServiceAccount counts tokens using service account credentials.
|
// countTokensWithServiceAccount counts tokens using service account credentials.
|
||||||
@@ -859,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||||
@@ -943,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||||
|
|||||||
@@ -39,7 +39,8 @@ const (
|
|||||||
copilotEditorVersion = "vscode/1.107.0"
|
copilotEditorVersion = "vscode/1.107.0"
|
||||||
copilotPluginVersion = "copilot-chat/0.35.0"
|
copilotPluginVersion = "copilot-chat/0.35.0"
|
||||||
copilotIntegrationID = "vscode-chat"
|
copilotIntegrationID = "vscode-chat"
|
||||||
copilotOpenAIIntent = "conversation-edits"
|
copilotOpenAIIntent = "conversation-panel"
|
||||||
|
copilotGitHubAPIVer = "2025-04-01"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
|
// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
|
||||||
@@ -51,8 +52,9 @@ type GitHubCopilotExecutor struct {
|
|||||||
|
|
||||||
// cachedAPIToken stores a cached Copilot API token with its expiry.
|
// cachedAPIToken stores a cached Copilot API token with its expiry.
|
||||||
type cachedAPIToken struct {
|
type cachedAPIToken struct {
|
||||||
token string
|
token string
|
||||||
expiresAt time.Time
|
apiEndpoint string
|
||||||
|
expiresAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGitHubCopilotExecutor constructs a new executor instance.
|
// NewGitHubCopilotExecutor constructs a new executor instance.
|
||||||
@@ -75,7 +77,7 @@ func (e *GitHubCopilotExecutor) PrepareRequest(req *http.Request, auth *cliproxy
|
|||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
}
|
}
|
||||||
apiToken, errToken := e.ensureAPIToken(ctx, auth)
|
apiToken, _, errToken := e.ensureAPIToken(ctx, auth)
|
||||||
if errToken != nil {
|
if errToken != nil {
|
||||||
return errToken
|
return errToken
|
||||||
}
|
}
|
||||||
@@ -101,7 +103,7 @@ func (e *GitHubCopilotExecutor) HttpRequest(ctx context.Context, auth *cliproxya
|
|||||||
|
|
||||||
// Execute handles non-streaming requests to GitHub Copilot.
|
// Execute handles non-streaming requests to GitHub Copilot.
|
||||||
func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
apiToken, errToken := e.ensureAPIToken(ctx, auth)
|
apiToken, baseURL, errToken := e.ensureAPIToken(ctx, auth)
|
||||||
if errToken != nil {
|
if errToken != nil {
|
||||||
return resp, errToken
|
return resp, errToken
|
||||||
}
|
}
|
||||||
@@ -124,6 +126,9 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
|||||||
body = e.normalizeModel(req.Model, body)
|
body = e.normalizeModel(req.Model, body)
|
||||||
body = flattenAssistantContent(body)
|
body = flattenAssistantContent(body)
|
||||||
|
|
||||||
|
// Detect vision content before input normalization removes messages
|
||||||
|
hasVision := detectVisionContent(body)
|
||||||
|
|
||||||
thinkingProvider := "openai"
|
thinkingProvider := "openai"
|
||||||
if useResponses {
|
if useResponses {
|
||||||
thinkingProvider = "codex"
|
thinkingProvider = "codex"
|
||||||
@@ -147,7 +152,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
|||||||
if useResponses {
|
if useResponses {
|
||||||
path = githubCopilotResponsesPath
|
path = githubCopilotResponsesPath
|
||||||
}
|
}
|
||||||
url := githubCopilotBaseURL + path
|
url := baseURL + path
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
@@ -155,7 +160,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
|||||||
e.applyHeaders(httpReq, apiToken, body)
|
e.applyHeaders(httpReq, apiToken, body)
|
||||||
|
|
||||||
// Add Copilot-Vision-Request header if the request contains vision content
|
// Add Copilot-Vision-Request header if the request contains vision content
|
||||||
if detectVisionContent(body) {
|
if hasVision {
|
||||||
httpReq.Header.Set("Copilot-Vision-Request", "true")
|
httpReq.Header.Set("Copilot-Vision-Request", "true")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,8 +232,8 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream handles streaming requests to GitHub Copilot.
|
// ExecuteStream handles streaming requests to GitHub Copilot.
|
||||||
func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
apiToken, errToken := e.ensureAPIToken(ctx, auth)
|
apiToken, baseURL, errToken := e.ensureAPIToken(ctx, auth)
|
||||||
if errToken != nil {
|
if errToken != nil {
|
||||||
return nil, errToken
|
return nil, errToken
|
||||||
}
|
}
|
||||||
@@ -251,6 +256,9 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
body = e.normalizeModel(req.Model, body)
|
body = e.normalizeModel(req.Model, body)
|
||||||
body = flattenAssistantContent(body)
|
body = flattenAssistantContent(body)
|
||||||
|
|
||||||
|
// Detect vision content before input normalization removes messages
|
||||||
|
hasVision := detectVisionContent(body)
|
||||||
|
|
||||||
thinkingProvider := "openai"
|
thinkingProvider := "openai"
|
||||||
if useResponses {
|
if useResponses {
|
||||||
thinkingProvider = "codex"
|
thinkingProvider = "codex"
|
||||||
@@ -278,7 +286,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
if useResponses {
|
if useResponses {
|
||||||
path = githubCopilotResponsesPath
|
path = githubCopilotResponsesPath
|
||||||
}
|
}
|
||||||
url := githubCopilotBaseURL + path
|
url := baseURL + path
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -286,7 +294,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
e.applyHeaders(httpReq, apiToken, body)
|
e.applyHeaders(httpReq, apiToken, body)
|
||||||
|
|
||||||
// Add Copilot-Vision-Request header if the request contains vision content
|
// Add Copilot-Vision-Request header if the request contains vision content
|
||||||
if detectVisionContent(body) {
|
if hasVision {
|
||||||
httpReq.Header.Set("Copilot-Vision-Request", "true")
|
httpReq.Header.Set("Copilot-Vision-Request", "true")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,7 +341,6 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
@@ -386,7 +393,10 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{
|
||||||
|
Headers: httpResp.Header.Clone(),
|
||||||
|
Chunks: out,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokens is not supported for GitHub Copilot.
|
// CountTokens is not supported for GitHub Copilot.
|
||||||
@@ -418,22 +428,22 @@ func (e *GitHubCopilotExecutor) Refresh(ctx context.Context, auth *cliproxyauth.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ensureAPIToken gets or refreshes the Copilot API token.
|
// ensureAPIToken gets or refreshes the Copilot API token.
|
||||||
func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, error) {
|
func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, string, error) {
|
||||||
if auth == nil {
|
if auth == nil {
|
||||||
return "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
|
return "", "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the GitHub access token
|
// Get the GitHub access token
|
||||||
accessToken := metaStringValue(auth.Metadata, "access_token")
|
accessToken := metaStringValue(auth.Metadata, "access_token")
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"}
|
return "", "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for cached API token using thread-safe access
|
// Check for cached API token using thread-safe access
|
||||||
e.mu.RLock()
|
e.mu.RLock()
|
||||||
if cached, ok := e.cache[accessToken]; ok && cached.expiresAt.After(time.Now().Add(tokenExpiryBuffer)) {
|
if cached, ok := e.cache[accessToken]; ok && cached.expiresAt.After(time.Now().Add(tokenExpiryBuffer)) {
|
||||||
e.mu.RUnlock()
|
e.mu.RUnlock()
|
||||||
return cached.token, nil
|
return cached.token, cached.apiEndpoint, nil
|
||||||
}
|
}
|
||||||
e.mu.RUnlock()
|
e.mu.RUnlock()
|
||||||
|
|
||||||
@@ -441,7 +451,13 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro
|
|||||||
copilotAuth := copilotauth.NewCopilotAuth(e.cfg)
|
copilotAuth := copilotauth.NewCopilotAuth(e.cfg)
|
||||||
apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken)
|
apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)}
|
return "", "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use endpoint from token response, fall back to default
|
||||||
|
apiEndpoint := githubCopilotBaseURL
|
||||||
|
if apiToken.Endpoints.API != "" {
|
||||||
|
apiEndpoint = strings.TrimRight(apiToken.Endpoints.API, "/")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache the token with thread-safe access
|
// Cache the token with thread-safe access
|
||||||
@@ -451,12 +467,13 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro
|
|||||||
}
|
}
|
||||||
e.mu.Lock()
|
e.mu.Lock()
|
||||||
e.cache[accessToken] = &cachedAPIToken{
|
e.cache[accessToken] = &cachedAPIToken{
|
||||||
token: apiToken.Token,
|
token: apiToken.Token,
|
||||||
expiresAt: expiresAt,
|
apiEndpoint: apiEndpoint,
|
||||||
|
expiresAt: expiresAt,
|
||||||
}
|
}
|
||||||
e.mu.Unlock()
|
e.mu.Unlock()
|
||||||
|
|
||||||
return apiToken.Token, nil
|
return apiToken.Token, apiEndpoint, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyHeaders sets the required headers for GitHub Copilot API requests.
|
// applyHeaders sets the required headers for GitHub Copilot API requests.
|
||||||
@@ -469,16 +486,17 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, b
|
|||||||
r.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
|
r.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
|
||||||
r.Header.Set("Openai-Intent", copilotOpenAIIntent)
|
r.Header.Set("Openai-Intent", copilotOpenAIIntent)
|
||||||
r.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
|
r.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
|
||||||
|
r.Header.Set("X-Github-Api-Version", copilotGitHubAPIVer)
|
||||||
r.Header.Set("X-Request-Id", uuid.NewString())
|
r.Header.Set("X-Request-Id", uuid.NewString())
|
||||||
|
|
||||||
initiator := "user"
|
initiator := "user"
|
||||||
if len(body) > 0 {
|
if len(body) > 0 {
|
||||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||||
arr := messages.Array()
|
for _, msg := range messages.Array() {
|
||||||
if len(arr) > 0 {
|
role := msg.Get("role").String()
|
||||||
lastRole := arr[len(arr)-1].Get("role").String()
|
if role == "assistant" || role == "tool" {
|
||||||
if lastRole != "" && lastRole != "user" {
|
|
||||||
initiator = "agent"
|
initiator = "agent"
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package executor
|
package executor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -247,3 +248,86 @@ func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.
|
|||||||
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
|
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Tests for X-Initiator detection logic (Problem L) ---
|
||||||
|
|
||||||
|
func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
e := &GitHubCopilotExecutor{}
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
|
body := []byte(`{"messages":[{"role":"system","content":"sys"},{"role":"user","content":"hello"}]}`)
|
||||||
|
e.applyHeaders(req, "token", body)
|
||||||
|
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||||
|
t.Fatalf("X-Initiator = %q, want user", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyHeaders_XInitiator_AgentWithAssistantAndUserToolResult(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
e := &GitHubCopilotExecutor{}
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
|
// Claude Code typical flow: last message is user (tool result), but has assistant in history
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
|
||||||
|
e.applyHeaders(req, "token", body)
|
||||||
|
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||||
|
t.Fatalf("X-Initiator = %q, want agent (assistant exists in messages)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
e := &GitHubCopilotExecutor{}
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
|
||||||
|
e.applyHeaders(req, "token", body)
|
||||||
|
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||||
|
t.Fatalf("X-Initiator = %q, want agent (tool role exists)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Tests for x-github-api-version header (Problem M) ---
|
||||||
|
|
||||||
|
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
e := &GitHubCopilotExecutor{}
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
|
e.applyHeaders(req, "token", nil)
|
||||||
|
if got := req.Header.Get("X-Github-Api-Version"); got != "2025-04-01" {
|
||||||
|
t.Fatalf("X-Github-Api-Version = %q, want 2025-04-01", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Tests for vision detection (Problem P) ---
|
||||||
|
|
||||||
|
func TestDetectVisionContent_WithImageURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
|
||||||
|
if !detectVisionContent(body) {
|
||||||
|
t.Fatal("expected vision content to be detected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectVisionContent_WithImageType(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":[{"type":"image","source":{"data":"abc","media_type":"image/png"}}]}]}`)
|
||||||
|
if !detectVisionContent(body) {
|
||||||
|
t.Fatal("expected image type to be detected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectVisionContent_NoVision(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||||
|
if detectVisionContent(body) {
|
||||||
|
t.Fatal("expected no vision content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectVisionContent_NoMessages(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// After Responses API normalization, messages is removed — detection should return false
|
||||||
|
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
|
||||||
|
if detectVisionContent(body) {
|
||||||
|
t.Fatal("expected no vision content when messages field is absent")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -169,12 +169,12 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
// the original model name in the response for client compatibility.
|
// the original model name in the response for client compatibility.
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming chat completion request.
|
// ExecuteStream performs a streaming chat completion request.
|
||||||
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -262,7 +262,6 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -294,7 +293,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ func (e *KiloExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming request.
|
// ExecuteStream performs a streaming request.
|
||||||
func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
@@ -253,7 +253,6 @@ func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer httpResp.Body.Close()
|
defer httpResp.Body.Close()
|
||||||
@@ -286,7 +285,10 @@ func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{
|
||||||
|
Headers: httpResp.Header.Clone(),
|
||||||
|
Chunks: out,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh validates the Kilo token.
|
// Refresh validates the Kilo token.
|
||||||
@@ -456,4 +458,3 @@ func FetchKiloModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.C
|
|||||||
|
|
||||||
return allModels
|
return allModels
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -161,12 +161,12 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
// the original model name in the response for client compatibility.
|
// the original model name in the response for client compatibility.
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming chat completion request to Kimi.
|
// ExecuteStream performs a streaming chat completion request to Kimi.
|
||||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
if from.String() == "claude" {
|
if from.String() == "claude" {
|
||||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||||
@@ -253,7 +253,6 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -285,7 +284,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokens estimates token count for Kimi requests.
|
// CountTokens estimates token count for Kimi requests.
|
||||||
|
|||||||
@@ -1053,7 +1053,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
|||||||
|
|
||||||
// ExecuteStream handles streaming requests to Kiro API.
|
// ExecuteStream handles streaming requests to Kiro API.
|
||||||
// Supports automatic token refresh on 401/403 errors and quota fallback on 429.
|
// Supports automatic token refresh on 401/403 errors and quota fallback on 429.
|
||||||
func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
accessToken, profileArn := kiroCredentials(auth)
|
accessToken, profileArn := kiroCredentials(auth)
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, fmt.Errorf("kiro: access token not found in auth")
|
return nil, fmt.Errorf("kiro: access token not found in auth")
|
||||||
@@ -1110,7 +1110,11 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
// Route to MCP endpoint instead of normal Kiro API
|
// Route to MCP endpoint instead of normal Kiro API
|
||||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||||
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
|
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
|
||||||
return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
|
streamWebSearch, errWebSearch := e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||||
|
if errWebSearch != nil {
|
||||||
|
return nil, errWebSearch
|
||||||
|
}
|
||||||
|
return &cliproxyexecutor.StreamResult{Chunks: streamWebSearch}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
@@ -1128,7 +1132,11 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
|
|
||||||
// Execute stream with retry on 401/403 and 429 (quota exhausted)
|
// Execute stream with retry on 401/403 and 429 (quota exhausted)
|
||||||
// Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint
|
// Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint
|
||||||
return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
streamKiro, errStreamKiro := e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
||||||
|
if errStreamKiro != nil {
|
||||||
|
return nil, errStreamKiro
|
||||||
|
}
|
||||||
|
return &cliproxyexecutor.StreamResult{Chunks: streamKiro}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors.
|
// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors.
|
||||||
@@ -1709,6 +1717,7 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
|||||||
// Amazon Q format (amazonq- prefix) - same API as Kiro
|
// Amazon Q format (amazonq- prefix) - same API as Kiro
|
||||||
"amazonq-auto": "auto",
|
"amazonq-auto": "auto",
|
||||||
"amazonq-claude-opus-4-6": "claude-opus-4.6",
|
"amazonq-claude-opus-4-6": "claude-opus-4.6",
|
||||||
|
"amazonq-claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||||
"amazonq-claude-opus-4-5": "claude-opus-4.5",
|
"amazonq-claude-opus-4-5": "claude-opus-4.5",
|
||||||
"amazonq-claude-sonnet-4-5": "claude-sonnet-4.5",
|
"amazonq-claude-sonnet-4-5": "claude-sonnet-4.5",
|
||||||
"amazonq-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
"amazonq-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||||
@@ -1717,6 +1726,7 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
|||||||
"amazonq-claude-haiku-4-5": "claude-haiku-4.5",
|
"amazonq-claude-haiku-4-5": "claude-haiku-4.5",
|
||||||
// Kiro format (kiro- prefix) - valid model names that should be preserved
|
// Kiro format (kiro- prefix) - valid model names that should be preserved
|
||||||
"kiro-claude-opus-4-6": "claude-opus-4.6",
|
"kiro-claude-opus-4-6": "claude-opus-4.6",
|
||||||
|
"kiro-claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||||
"kiro-claude-opus-4-5": "claude-opus-4.5",
|
"kiro-claude-opus-4-5": "claude-opus-4.5",
|
||||||
"kiro-claude-sonnet-4-5": "claude-sonnet-4.5",
|
"kiro-claude-sonnet-4-5": "claude-sonnet-4.5",
|
||||||
"kiro-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
"kiro-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||||
@@ -1727,6 +1737,8 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
|||||||
// Native format (no prefix) - used by Kiro IDE directly
|
// Native format (no prefix) - used by Kiro IDE directly
|
||||||
"claude-opus-4-6": "claude-opus-4.6",
|
"claude-opus-4-6": "claude-opus-4.6",
|
||||||
"claude-opus-4.6": "claude-opus-4.6",
|
"claude-opus-4.6": "claude-opus-4.6",
|
||||||
|
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||||
|
"claude-sonnet-4.6": "claude-sonnet-4.6",
|
||||||
"claude-opus-4-5": "claude-opus-4.5",
|
"claude-opus-4-5": "claude-opus-4.5",
|
||||||
"claude-opus-4.5": "claude-opus-4.5",
|
"claude-opus-4.5": "claude-opus-4.5",
|
||||||
"claude-haiku-4-5": "claude-haiku-4.5",
|
"claude-haiku-4-5": "claude-haiku-4.5",
|
||||||
@@ -1739,11 +1751,13 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
|||||||
"auto": "auto",
|
"auto": "auto",
|
||||||
// Agentic variants (same backend model IDs, but with special system prompt)
|
// Agentic variants (same backend model IDs, but with special system prompt)
|
||||||
"claude-opus-4.6-agentic": "claude-opus-4.6",
|
"claude-opus-4.6-agentic": "claude-opus-4.6",
|
||||||
|
"claude-sonnet-4.6-agentic": "claude-sonnet-4.6",
|
||||||
"claude-opus-4.5-agentic": "claude-opus-4.5",
|
"claude-opus-4.5-agentic": "claude-opus-4.5",
|
||||||
"claude-sonnet-4.5-agentic": "claude-sonnet-4.5",
|
"claude-sonnet-4.5-agentic": "claude-sonnet-4.5",
|
||||||
"claude-sonnet-4-agentic": "claude-sonnet-4",
|
"claude-sonnet-4-agentic": "claude-sonnet-4",
|
||||||
"claude-haiku-4.5-agentic": "claude-haiku-4.5",
|
"claude-haiku-4.5-agentic": "claude-haiku-4.5",
|
||||||
"kiro-claude-opus-4-6-agentic": "claude-opus-4.6",
|
"kiro-claude-opus-4-6-agentic": "claude-opus-4.6",
|
||||||
|
"kiro-claude-sonnet-4-6-agentic": "claude-sonnet-4.6",
|
||||||
"kiro-claude-opus-4-5-agentic": "claude-opus-4.5",
|
"kiro-claude-opus-4-5-agentic": "claude-opus-4.5",
|
||||||
"kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5",
|
"kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5",
|
||||||
"kiro-claude-sonnet-4-agentic": "claude-sonnet-4",
|
"kiro-claude-sonnet-4-agentic": "claude-sonnet-4",
|
||||||
@@ -1769,6 +1783,10 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
|||||||
log.Debugf("kiro: unknown Sonnet 3.7 model '%s', mapping to claude-3-7-sonnet-20250219", model)
|
log.Debugf("kiro: unknown Sonnet 3.7 model '%s', mapping to claude-3-7-sonnet-20250219", model)
|
||||||
return "claude-3-7-sonnet-20250219"
|
return "claude-3-7-sonnet-20250219"
|
||||||
}
|
}
|
||||||
|
if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") {
|
||||||
|
log.Debugf("kiro: unknown Sonnet 4.6 model '%s', mapping to claude-sonnet-4.6", model)
|
||||||
|
return "claude-sonnet-4.6"
|
||||||
|
}
|
||||||
if strings.Contains(modelLower, "4-5") || strings.Contains(modelLower, "4.5") {
|
if strings.Contains(modelLower, "4-5") || strings.Contains(modelLower, "4.5") {
|
||||||
log.Debugf("kiro: unknown Sonnet 4.5 model '%s', mapping to claude-sonnet-4.5", model)
|
log.Debugf("kiro: unknown Sonnet 4.5 model '%s', mapping to claude-sonnet-4.5", model)
|
||||||
return "claude-sonnet-4.5"
|
return "claude-sonnet-4.5"
|
||||||
@@ -1780,6 +1798,10 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
|||||||
|
|
||||||
// Check for Opus variants
|
// Check for Opus variants
|
||||||
if strings.Contains(modelLower, "opus") {
|
if strings.Contains(modelLower, "opus") {
|
||||||
|
if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") {
|
||||||
|
log.Debugf("kiro: unknown Opus 4.6 model '%s', mapping to claude-opus-4.6", model)
|
||||||
|
return "claude-opus-4.6"
|
||||||
|
}
|
||||||
log.Debugf("kiro: unknown Opus model '%s', mapping to claude-opus-4.5", model)
|
log.Debugf("kiro: unknown Opus model '%s', mapping to claude-opus-4.5", model)
|
||||||
return "claude-opus-4.5"
|
return "claude-opus-4.5"
|
||||||
}
|
}
|
||||||
@@ -2529,6 +2551,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
isThinkingBlockOpen := false // Track if thinking content block SSE event is open
|
isThinkingBlockOpen := false // Track if thinking content block SSE event is open
|
||||||
thinkingBlockIndex := -1 // Index of the thinking content block
|
thinkingBlockIndex := -1 // Index of the thinking content block
|
||||||
var accumulatedThinkingContent strings.Builder // Accumulate thinking content for token counting
|
var accumulatedThinkingContent strings.Builder // Accumulate thinking content for token counting
|
||||||
|
hasOfficialReasoningEvent := false // Disable tag parsing after official reasoning events appear
|
||||||
|
|
||||||
// Buffer for handling partial tag matches at chunk boundaries
|
// Buffer for handling partial tag matches at chunk boundaries
|
||||||
var pendingContent strings.Builder // Buffer content that might be part of a tag
|
var pendingContent strings.Builder // Buffer content that might be part of a tag
|
||||||
@@ -2964,6 +2987,31 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
lastUsageUpdateTime = time.Now()
|
lastUsageUpdateTime = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hasOfficialReasoningEvent {
|
||||||
|
processText := strings.TrimSpace(strings.ReplaceAll(strings.ReplaceAll(contentDelta, kirocommon.ThinkingStartTag, ""), kirocommon.ThinkingEndTag, ""))
|
||||||
|
if processText != "" {
|
||||||
|
if !isTextBlockOpen {
|
||||||
|
contentBlockIndex++
|
||||||
|
isTextBlockOpen = true
|
||||||
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
||||||
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||||
|
for _, chunk := range sseData {
|
||||||
|
if chunk != "" {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processText, contentBlockIndex)
|
||||||
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
||||||
|
for _, chunk := range sseData {
|
||||||
|
if chunk != "" {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// TAG-BASED THINKING PARSING: Parse <thinking> tags from content
|
// TAG-BASED THINKING PARSING: Parse <thinking> tags from content
|
||||||
// Combine pending content with new content for processing
|
// Combine pending content with new content for processing
|
||||||
pendingContent.WriteString(contentDelta)
|
pendingContent.WriteString(contentDelta)
|
||||||
@@ -3242,6 +3290,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
}
|
}
|
||||||
|
|
||||||
if thinkingText != "" {
|
if thinkingText != "" {
|
||||||
|
hasOfficialReasoningEvent = true
|
||||||
// Close text block if open before starting thinking block
|
// Close text block if open before starting thinking block
|
||||||
if isTextBlockOpen && contentBlockIndex >= 0 {
|
if isTextBlockOpen && contentBlockIndex >= 0 {
|
||||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||||
|
|||||||
@@ -172,11 +172,11 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
// Translate response back to source format when needed
|
// Translate response back to source format when needed
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
@@ -258,7 +258,6 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -298,7 +297,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
// Ensure we record the request if no usage chunk was ever seen
|
// Ensure we record the request if no usage chunk was ever seen
|
||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||||
@@ -22,9 +23,151 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||||
|
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
|
||||||
|
qwenRateLimitWindow = time.Minute // sliding window duration
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
|
||||||
|
var qwenBeijingLoc = func() *time.Location {
|
||||||
|
loc, err := time.LoadLocation("Asia/Shanghai")
|
||||||
|
if err != nil || loc == nil {
|
||||||
|
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
|
||||||
|
return time.FixedZone("CST", 8*3600)
|
||||||
|
}
|
||||||
|
return loc
|
||||||
|
}()
|
||||||
|
|
||||||
|
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
|
||||||
|
var qwenQuotaCodes = map[string]struct{}{
|
||||||
|
"insufficient_quota": {},
|
||||||
|
"quota_exceeded": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
|
||||||
|
// Qwen has a limit of 60 requests per minute per account.
|
||||||
|
var qwenRateLimiter = struct {
|
||||||
|
sync.Mutex
|
||||||
|
requests map[string][]time.Time // authID -> request timestamps
|
||||||
|
}{
|
||||||
|
requests: make(map[string][]time.Time),
|
||||||
|
}
|
||||||
|
|
||||||
|
// redactAuthID returns a redacted version of the auth ID for safe logging.
|
||||||
|
// Keeps a small prefix/suffix to allow correlation across events.
|
||||||
|
func redactAuthID(id string) string {
|
||||||
|
if id == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(id) <= 8 {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return id[:4] + "..." + id[len(id)-4:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
|
||||||
|
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
|
||||||
|
func checkQwenRateLimit(authID string) error {
|
||||||
|
if authID == "" {
|
||||||
|
// Empty authID should not bypass rate limiting in production
|
||||||
|
// Use debug level to avoid log spam for certain auth flows
|
||||||
|
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
windowStart := now.Add(-qwenRateLimitWindow)
|
||||||
|
|
||||||
|
qwenRateLimiter.Lock()
|
||||||
|
defer qwenRateLimiter.Unlock()
|
||||||
|
|
||||||
|
// Get and filter timestamps within the window
|
||||||
|
timestamps := qwenRateLimiter.requests[authID]
|
||||||
|
var validTimestamps []time.Time
|
||||||
|
for _, ts := range timestamps {
|
||||||
|
if ts.After(windowStart) {
|
||||||
|
validTimestamps = append(validTimestamps, ts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always prune expired entries to prevent memory leak
|
||||||
|
// Delete empty entries, otherwise update with pruned slice
|
||||||
|
if len(validTimestamps) == 0 {
|
||||||
|
delete(qwenRateLimiter.requests, authID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if rate limit exceeded
|
||||||
|
if len(validTimestamps) >= qwenRateLimitPerMin {
|
||||||
|
// Calculate when the oldest request will expire
|
||||||
|
oldestInWindow := validTimestamps[0]
|
||||||
|
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
|
||||||
|
if retryAfter < time.Second {
|
||||||
|
retryAfter = time.Second
|
||||||
|
}
|
||||||
|
retryAfterSec := int(retryAfter.Seconds())
|
||||||
|
return statusErr{
|
||||||
|
code: http.StatusTooManyRequests,
|
||||||
|
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
|
||||||
|
retryAfter: &retryAfter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record this request and update the map with pruned timestamps
|
||||||
|
validTimestamps = append(validTimestamps, now)
|
||||||
|
qwenRateLimiter.requests[authID] = validTimestamps
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
|
||||||
|
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
|
||||||
|
func isQwenQuotaError(body []byte) bool {
|
||||||
|
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
|
||||||
|
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
|
||||||
|
|
||||||
|
// Primary check: exact match on error.code or error.type (most reliable)
|
||||||
|
if _, ok := qwenQuotaCodes[code]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, ok := qwenQuotaCodes[errType]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: check message only if code/type don't match (less reliable)
|
||||||
|
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
|
||||||
|
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
|
||||||
|
strings.Contains(msg, "free allocated quota exceeded") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
|
||||||
|
// Returns the appropriate status code and retryAfter duration for statusErr.
|
||||||
|
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
|
||||||
|
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
|
||||||
|
errCode = httpCode
|
||||||
|
// Only check quota errors for expected status codes to avoid false positives
|
||||||
|
// Qwen returns 403 for quota errors, 429 for rate limits
|
||||||
|
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
|
||||||
|
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
|
||||||
|
cooldown := timeUntilNextDay()
|
||||||
|
retryAfter = &cooldown
|
||||||
|
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
|
||||||
|
}
|
||||||
|
return errCode, retryAfter
|
||||||
|
}
|
||||||
|
|
||||||
|
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
|
||||||
|
// Qwen's daily quota resets at 00:00 Beijing time.
|
||||||
|
func timeUntilNextDay() time.Duration {
|
||||||
|
now := time.Now()
|
||||||
|
nowLocal := now.In(qwenBeijingLoc)
|
||||||
|
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
|
||||||
|
return tomorrow.Sub(now)
|
||||||
|
}
|
||||||
|
|
||||||
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
||||||
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
||||||
type QwenExecutor struct {
|
type QwenExecutor struct {
|
||||||
@@ -67,6 +210,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check rate limit before proceeding
|
||||||
|
var authID string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
}
|
||||||
|
if err := checkQwenRateLimit(authID); err != nil {
|
||||||
|
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
token, baseURL := qwenCreds(auth)
|
token, baseURL := qwenCreds(auth)
|
||||||
@@ -102,9 +256,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
applyQwenHeaders(httpReq, token, false)
|
applyQwenHeaders(httpReq, token, false)
|
||||||
var authID, authLabel, authType, authValue string
|
var authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
@@ -135,8 +288,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||||
|
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
@@ -150,14 +305,25 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
// the original model name in the response for client compatibility.
|
// the original model name in the response for client compatibility.
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check rate limit before proceeding
|
||||||
|
var authID string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
}
|
||||||
|
if err := checkQwenRateLimit(authID); err != nil {
|
||||||
|
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
token, baseURL := qwenCreds(auth)
|
token, baseURL := qwenCreds(auth)
|
||||||
@@ -200,9 +366,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
applyQwenHeaders(httpReq, token, true)
|
applyQwenHeaders(httpReq, token, true)
|
||||||
var authID, authLabel, authType, authValue string
|
var authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
@@ -228,15 +393,16 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
||||||
|
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||||
|
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -268,7 +434,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
|||||||
89
internal/runtime/executor/user_id_cache.go
Normal file
89
internal/runtime/executor/user_id_cache.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type userIDCacheEntry struct {
|
||||||
|
value string
|
||||||
|
expire time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
userIDCache = make(map[string]userIDCacheEntry)
|
||||||
|
userIDCacheMu sync.RWMutex
|
||||||
|
userIDCacheCleanupOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
userIDTTL = time.Hour
|
||||||
|
userIDCacheCleanupPeriod = 15 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
func startUserIDCacheCleanup() {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(userIDCacheCleanupPeriod)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
purgeExpiredUserIDs()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func purgeExpiredUserIDs() {
|
||||||
|
now := time.Now()
|
||||||
|
userIDCacheMu.Lock()
|
||||||
|
for key, entry := range userIDCache {
|
||||||
|
if !entry.expire.After(now) {
|
||||||
|
delete(userIDCache, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func userIDCacheKey(apiKey string) string {
|
||||||
|
sum := sha256.Sum256([]byte(apiKey))
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func cachedUserID(apiKey string) string {
|
||||||
|
if apiKey == "" {
|
||||||
|
return generateFakeUserID()
|
||||||
|
}
|
||||||
|
|
||||||
|
userIDCacheCleanupOnce.Do(startUserIDCacheCleanup)
|
||||||
|
|
||||||
|
key := userIDCacheKey(apiKey)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
userIDCacheMu.RLock()
|
||||||
|
entry, ok := userIDCache[key]
|
||||||
|
valid := ok && entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value)
|
||||||
|
userIDCacheMu.RUnlock()
|
||||||
|
if valid {
|
||||||
|
userIDCacheMu.Lock()
|
||||||
|
entry = userIDCache[key]
|
||||||
|
if entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value) {
|
||||||
|
entry.expire = now.Add(userIDTTL)
|
||||||
|
userIDCache[key] = entry
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
return entry.value
|
||||||
|
}
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
newID := generateFakeUserID()
|
||||||
|
|
||||||
|
userIDCacheMu.Lock()
|
||||||
|
entry, ok = userIDCache[key]
|
||||||
|
if !ok || entry.value == "" || !entry.expire.After(now) || !isValidUserID(entry.value) {
|
||||||
|
entry.value = newID
|
||||||
|
}
|
||||||
|
entry.expire = now.Add(userIDTTL)
|
||||||
|
userIDCache[key] = entry
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
return entry.value
|
||||||
|
}
|
||||||
86
internal/runtime/executor/user_id_cache_test.go
Normal file
86
internal/runtime/executor/user_id_cache_test.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func resetUserIDCache() {
|
||||||
|
userIDCacheMu.Lock()
|
||||||
|
userIDCache = make(map[string]userIDCacheEntry)
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
|
||||||
|
resetUserIDCache()
|
||||||
|
|
||||||
|
first := cachedUserID("api-key-1")
|
||||||
|
second := cachedUserID("api-key-1")
|
||||||
|
|
||||||
|
if first == "" {
|
||||||
|
t.Fatal("expected generated user_id to be non-empty")
|
||||||
|
}
|
||||||
|
if first != second {
|
||||||
|
t.Fatalf("expected cached user_id to be reused, got %q and %q", first, second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
|
||||||
|
resetUserIDCache()
|
||||||
|
|
||||||
|
expiredID := cachedUserID("api-key-expired")
|
||||||
|
cacheKey := userIDCacheKey("api-key-expired")
|
||||||
|
userIDCacheMu.Lock()
|
||||||
|
userIDCache[cacheKey] = userIDCacheEntry{
|
||||||
|
value: expiredID,
|
||||||
|
expire: time.Now().Add(-time.Minute),
|
||||||
|
}
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
|
||||||
|
newID := cachedUserID("api-key-expired")
|
||||||
|
if newID == expiredID {
|
||||||
|
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
|
||||||
|
}
|
||||||
|
if newID == "" {
|
||||||
|
t.Fatal("expected regenerated user_id to be non-empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
|
||||||
|
resetUserIDCache()
|
||||||
|
|
||||||
|
first := cachedUserID("api-key-1")
|
||||||
|
second := cachedUserID("api-key-2")
|
||||||
|
|
||||||
|
if first == second {
|
||||||
|
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
|
||||||
|
resetUserIDCache()
|
||||||
|
|
||||||
|
key := "api-key-renew"
|
||||||
|
id := cachedUserID(key)
|
||||||
|
cacheKey := userIDCacheKey(key)
|
||||||
|
|
||||||
|
soon := time.Now()
|
||||||
|
userIDCacheMu.Lock()
|
||||||
|
userIDCache[cacheKey] = userIDCacheEntry{
|
||||||
|
value: id,
|
||||||
|
expire: soon.Add(2 * time.Second),
|
||||||
|
}
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
|
||||||
|
if refreshed := cachedUserID(key); refreshed != id {
|
||||||
|
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
|
||||||
|
}
|
||||||
|
|
||||||
|
userIDCacheMu.RLock()
|
||||||
|
entry := userIDCache[cacheKey]
|
||||||
|
userIDCacheMu.RUnlock()
|
||||||
|
|
||||||
|
if entry.expire.Sub(soon) < 30*time.Minute {
|
||||||
|
t.Fatalf("expected TTL to renew, got %v remaining", entry.expire.Sub(soon))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -231,8 +231,12 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
|
|
||||||
} else if functionResponseResult.IsObject() {
|
} else if functionResponseResult.IsObject() {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||||
} else {
|
} else if functionResponseResult.Raw != "" {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||||
|
} else {
|
||||||
|
// Content field is missing entirely — .Raw is empty which
|
||||||
|
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
|
||||||
|
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
partJSON := `{}`
|
partJSON := `{}`
|
||||||
|
|||||||
@@ -661,6 +661,85 @@ func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultNoContent(t *testing.T) {
|
||||||
|
// Bug repro: tool_result with no content field produces invalid JSON
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-opus-4-6-thinking",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "MyTool-123-456",
|
||||||
|
"name": "MyTool",
|
||||||
|
"input": {"key": "value"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "MyTool-123-456"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Errorf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the functionResponse has a valid result value
|
||||||
|
fr := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse.response.result")
|
||||||
|
if !fr.Exists() {
|
||||||
|
t.Error("functionResponse.response.result should exist")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
|
||||||
|
// Bug repro: tool_result with null content produces invalid JSON
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-opus-4-6-thinking",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "MyTool-123-456",
|
||||||
|
"name": "MyTool",
|
||||||
|
"input": {"key": "value"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "MyTool-123-456",
|
||||||
|
"content": null
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Errorf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
||||||
// When tools + thinking but no system instruction, should create one with hint
|
// When tools + thinking but no system instruction, should create one with hint
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
|
|||||||
@@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
|||||||
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "file":
|
||||||
|
fileData := part.Get("file.file_data").String()
|
||||||
|
if strings.HasPrefix(fileData, "data:") {
|
||||||
|
semicolonIdx := strings.Index(fileData, ";")
|
||||||
|
commaIdx := strings.Index(fileData, ",")
|
||||||
|
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||||
|
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||||
|
data := fileData[commaIdx+1:]
|
||||||
|
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||||
|
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||||
|
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
var textAggregate strings.Builder
|
var textAggregate strings.Builder
|
||||||
var partsJSON []string
|
var partsJSON []string
|
||||||
hasImage := false
|
hasImage := false
|
||||||
|
hasFile := false
|
||||||
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
||||||
parts.ForEach(func(_, part gjson.Result) bool {
|
parts.ForEach(func(_, part gjson.Result) bool {
|
||||||
ptype := part.Get("type").String()
|
ptype := part.Get("type").String()
|
||||||
@@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
hasImage = true
|
hasImage = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case "input_file":
|
||||||
|
fileData := part.Get("file_data").String()
|
||||||
|
if fileData != "" {
|
||||||
|
mediaType := "application/octet-stream"
|
||||||
|
data := fileData
|
||||||
|
if strings.HasPrefix(fileData, "data:") {
|
||||||
|
trimmed := strings.TrimPrefix(fileData, "data:")
|
||||||
|
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
|
||||||
|
if len(mediaAndData) == 2 {
|
||||||
|
if mediaAndData[0] != "" {
|
||||||
|
mediaType = mediaAndData[0]
|
||||||
|
}
|
||||||
|
data = mediaAndData[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||||
|
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
|
||||||
|
contentPart, _ = sjson.Set(contentPart, "source.data", data)
|
||||||
|
partsJSON = append(partsJSON, contentPart)
|
||||||
|
if role == "" {
|
||||||
|
role = "user"
|
||||||
|
}
|
||||||
|
hasFile = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
@@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
if len(partsJSON) > 0 {
|
if len(partsJSON) > 0 {
|
||||||
msg := `{"role":"","content":[]}`
|
msg := `{"role":"","content":[]}`
|
||||||
msg, _ = sjson.Set(msg, "role", role)
|
msg, _ = sjson.Set(msg, "role", role)
|
||||||
if len(partsJSON) == 1 && !hasImage {
|
if len(partsJSON) == 1 && !hasImage && !hasFile {
|
||||||
// Preserve legacy behavior for single text content
|
// Preserve legacy behavior for single text content
|
||||||
msg, _ = sjson.Delete(msg, "content")
|
msg, _ = sjson.Delete(msg, "content")
|
||||||
textPart := gjson.Parse(partsJSON[0])
|
textPart := gjson.Parse(partsJSON[0])
|
||||||
|
|||||||
@@ -22,8 +22,9 @@ var (
|
|||||||
|
|
||||||
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
|
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
|
||||||
type ConvertCodexResponseToClaudeParams struct {
|
type ConvertCodexResponseToClaudeParams struct {
|
||||||
HasToolCall bool
|
HasToolCall bool
|
||||||
BlockIndex int
|
BlockIndex int
|
||||||
|
HasReceivedArgumentsDelta bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
|
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
|
||||||
@@ -137,6 +138,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
itemType := itemResult.Get("type").String()
|
itemType := itemResult.Get("type").String()
|
||||||
if itemType == "function_call" {
|
if itemType == "function_call" {
|
||||||
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
|
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
|
||||||
|
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
|
||||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||||
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
|
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
|
||||||
@@ -171,12 +173,29 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
output += fmt.Sprintf("data: %s\n\n", template)
|
output += fmt.Sprintf("data: %s\n\n", template)
|
||||||
}
|
}
|
||||||
} else if typeStr == "response.function_call_arguments.delta" {
|
} else if typeStr == "response.function_call_arguments.delta" {
|
||||||
|
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
|
||||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||||
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
|
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
|
||||||
|
|
||||||
output += "event: content_block_delta\n"
|
output += "event: content_block_delta\n"
|
||||||
output += fmt.Sprintf("data: %s\n\n", template)
|
output += fmt.Sprintf("data: %s\n\n", template)
|
||||||
|
} else if typeStr == "response.function_call_arguments.done" {
|
||||||
|
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
|
||||||
|
// in a single "done" event without preceding "delta" events.
|
||||||
|
// Emit the full arguments as a single input_json_delta so the
|
||||||
|
// downstream Claude client receives the complete tool input.
|
||||||
|
// When delta events were already received, skip to avoid duplicating arguments.
|
||||||
|
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
|
||||||
|
if args := rootResult.Get("arguments").String(); args != "" {
|
||||||
|
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||||
|
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||||
|
template, _ = sjson.Set(template, "delta.partial_json", args)
|
||||||
|
|
||||||
|
output += "event: content_block_delta\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n\n", template)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{output}
|
return []string{output}
|
||||||
|
|||||||
@@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
|||||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
}
|
}
|
||||||
case "file":
|
case "file":
|
||||||
// Files are not specified in examples; skip for now
|
if role == "user" {
|
||||||
|
fileData := it.Get("file.file_data").String()
|
||||||
|
filename := it.Get("file.filename").String()
|
||||||
|
if fileData != "" {
|
||||||
|
part := `{}`
|
||||||
|
part, _ = sjson.Set(part, "type", "input_file")
|
||||||
|
part, _ = sjson.Set(part, "file_data", fileData)
|
||||||
|
if filename != "" {
|
||||||
|
part, _ = sjson.Set(part, "filename", filename)
|
||||||
|
}
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,10 +20,12 @@ var (
|
|||||||
|
|
||||||
// ConvertCliToOpenAIParams holds parameters for response conversion.
|
// ConvertCliToOpenAIParams holds parameters for response conversion.
|
||||||
type ConvertCliToOpenAIParams struct {
|
type ConvertCliToOpenAIParams struct {
|
||||||
ResponseID string
|
ResponseID string
|
||||||
CreatedAt int64
|
CreatedAt int64
|
||||||
Model string
|
Model string
|
||||||
FunctionCallIndex int
|
FunctionCallIndex int
|
||||||
|
HasReceivedArgumentsDelta bool
|
||||||
|
HasToolCallAnnounced bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertCodexResponseToOpenAI translates a single chunk of a streaming response from the
|
// ConvertCodexResponseToOpenAI translates a single chunk of a streaming response from the
|
||||||
@@ -43,10 +45,12 @@ type ConvertCliToOpenAIParams struct {
|
|||||||
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||||
if *param == nil {
|
if *param == nil {
|
||||||
*param = &ConvertCliToOpenAIParams{
|
*param = &ConvertCliToOpenAIParams{
|
||||||
Model: modelName,
|
Model: modelName,
|
||||||
CreatedAt: 0,
|
CreatedAt: 0,
|
||||||
ResponseID: "",
|
ResponseID: "",
|
||||||
FunctionCallIndex: -1,
|
FunctionCallIndex: -1,
|
||||||
|
HasReceivedArgumentsDelta: false,
|
||||||
|
HasToolCallAnnounced: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,35 +122,93 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
|||||||
}
|
}
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||||
} else if dataType == "response.output_item.done" {
|
} else if dataType == "response.output_item.added" {
|
||||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
|
||||||
itemResult := rootResult.Get("item")
|
itemResult := rootResult.Get("item")
|
||||||
if itemResult.Exists() {
|
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||||
if itemResult.Get("type").String() != "function_call" {
|
return []string{}
|
||||||
return []string{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// set the index
|
|
||||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
|
||||||
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
|
||||||
|
|
||||||
// Restore original tool name if it was shortened
|
|
||||||
name := itemResult.Get("name").String()
|
|
||||||
// Build reverse map on demand from original request tools
|
|
||||||
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
|
||||||
if orig, ok := rev[name]; ok {
|
|
||||||
name = orig
|
|
||||||
}
|
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
|
||||||
|
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Increment index for this new function call item.
|
||||||
|
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||||
|
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
|
||||||
|
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
|
||||||
|
|
||||||
|
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||||
|
|
||||||
|
// Restore original tool name if it was shortened.
|
||||||
|
name := itemResult.Get("name").String()
|
||||||
|
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||||
|
if orig, ok := rev[name]; ok {
|
||||||
|
name = orig
|
||||||
|
}
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
|
||||||
|
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||||
|
|
||||||
|
} else if dataType == "response.function_call_arguments.delta" {
|
||||||
|
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
|
||||||
|
|
||||||
|
deltaValue := rootResult.Get("delta").String()
|
||||||
|
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
|
||||||
|
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||||
|
|
||||||
|
} else if dataType == "response.function_call_arguments.done" {
|
||||||
|
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
|
||||||
|
// Arguments were already streamed via delta events; nothing to emit.
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: no delta events were received, emit the full arguments as a single chunk.
|
||||||
|
fullArgs := rootResult.Get("arguments").String()
|
||||||
|
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
|
||||||
|
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||||
|
|
||||||
|
} else if dataType == "response.output_item.done" {
|
||||||
|
itemResult := rootResult.Get("item")
|
||||||
|
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
|
||||||
|
// Tool call was already announced via output_item.added; skip emission.
|
||||||
|
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback path: model skipped output_item.added, so emit complete tool call now.
|
||||||
|
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||||
|
|
||||||
|
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||||
|
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||||
|
|
||||||
|
// Restore original tool name if it was shortened.
|
||||||
|
name := itemResult.Get("name").String()
|
||||||
|
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||||
|
if orig, ok := rev[name]; ok {
|
||||||
|
name = orig
|
||||||
|
}
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||||
|
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
return []string{}
|
return []string{}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
|||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||||
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
|
||||||
|
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
|
||||||
|
|
||||||
// Delete the user field as it is not supported by the Codex upstream.
|
// Delete the user field as it is not supported by the Codex upstream.
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
||||||
@@ -36,6 +38,23 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
|||||||
return rawJSON
|
return rawJSON
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction
|
||||||
|
// for Codex upstream compatibility.
|
||||||
|
//
|
||||||
|
// Codex /responses currently rejects context_management with:
|
||||||
|
// {"detail":"Unsupported parameter: context_management"}.
|
||||||
|
//
|
||||||
|
// Compatibility strategy:
|
||||||
|
// 1) Remove context_management before forwarding to Codex upstream.
|
||||||
|
func applyResponsesCompactionCompatibility(rawJSON []byte) []byte {
|
||||||
|
if !gjson.GetBytes(rawJSON, "context_management").Exists() {
|
||||||
|
return rawJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management")
|
||||||
|
return rawJSON
|
||||||
|
}
|
||||||
|
|
||||||
// convertSystemRoleToDeveloper traverses the input array and converts any message items
|
// convertSystemRoleToDeveloper traverses the input array and converts any message items
|
||||||
// with role "system" to role "developer". This is necessary because Codex API does not
|
// with role "system" to role "developer". This is necessary because Codex API does not
|
||||||
// accept "system" role in the input array.
|
// accept "system" role in the input array.
|
||||||
|
|||||||
@@ -280,3 +280,41 @@ func TestUserFieldDeletion(t *testing.T) {
|
|||||||
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContextManagementCompactionCompatibility(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"context_management": [
|
||||||
|
{
|
||||||
|
"type": "compaction",
|
||||||
|
"compact_threshold": 12000
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"input": [{"role":"user","content":"hello"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if gjson.Get(outputStr, "context_management").Exists() {
|
||||||
|
t.Fatalf("context_management should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
if gjson.Get(outputStr, "truncation").Exists() {
|
||||||
|
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncationRemovedForCodexCompatibility(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"truncation": "disabled",
|
||||||
|
"input": [{"role":"user","content":"hello"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if gjson.Get(outputStr, "truncation").Exists() {
|
||||||
|
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
|||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -100,9 +100,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
|||||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
|
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
@@ -297,7 +297,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -531,8 +531,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
|
|
||||||
// usage mapping
|
// usage mapping
|
||||||
if um := root.Get("usageMetadata"); um.Exists() {
|
if um := root.Get("usageMetadata"); um.Exists() {
|
||||||
// input tokens = prompt + thoughts
|
// input tokens = prompt only (thoughts go to output)
|
||||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
input := um.Get("promptTokenCount").Int()
|
||||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
|
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
|
||||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||||
@@ -737,8 +737,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
|||||||
|
|
||||||
// usage mapping
|
// usage mapping
|
||||||
if um := root.Get("usageMetadata"); um.Exists() {
|
if um := root.Get("usageMetadata"); um.Exists() {
|
||||||
// input tokens = prompt + thoughts
|
// input tokens = prompt only (thoughts go to output)
|
||||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
input := um.Get("promptTokenCount").Int()
|
||||||
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
|
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
|
||||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||||
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||||
|
|||||||
@@ -243,13 +243,11 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA
|
|||||||
// Process messages and build history
|
// Process messages and build history
|
||||||
history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin)
|
history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin)
|
||||||
|
|
||||||
// Build content with system prompt (only on first turn to avoid re-injection)
|
// Build content with system prompt.
|
||||||
|
// Keep thinking tags on subsequent turns so multi-turn Claude sessions
|
||||||
|
// continue to emit reasoning events.
|
||||||
if currentUserMsg != nil {
|
if currentUserMsg != nil {
|
||||||
effectiveSystemPrompt := systemPrompt
|
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults)
|
||||||
if len(history) > 0 {
|
|
||||||
effectiveSystemPrompt = "" // Don't re-inject on subsequent turns
|
|
||||||
}
|
|
||||||
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, effectiveSystemPrompt, currentToolResults)
|
|
||||||
|
|
||||||
// Deduplicate currentToolResults
|
// Deduplicate currentToolResults
|
||||||
currentToolResults = deduplicateToolResults(currentToolResults)
|
currentToolResults = deduplicateToolResults(currentToolResults)
|
||||||
@@ -475,6 +473,15 @@ func IsThinkingEnabledWithHeaders(body []byte, headers http.Header) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check model name directly for thinking hints.
|
||||||
|
// This enables thinking variants even when clients don't send explicit thinking fields.
|
||||||
|
model := strings.TrimSpace(gjson.GetBytes(body, "model").String())
|
||||||
|
modelLower := strings.ToLower(model)
|
||||||
|
if strings.Contains(modelLower, "thinking") || strings.Contains(modelLower, "-reason") {
|
||||||
|
log.Debugf("kiro: thinking mode enabled via model name hint: %s", model)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("kiro: IsThinkingEnabled returning false (no thinking mode detected)")
|
log.Debugf("kiro: IsThinkingEnabled returning false (no thinking mode detected)")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,19 +53,25 @@ var KnownCommandTools = map[string]bool{
|
|||||||
"execute_python": true,
|
"execute_python": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequiredFieldsByTool maps tool names to their required fields.
|
// RequiredFieldsByTool maps tool names to their required field groups.
|
||||||
// If any of these fields are missing, the tool input is considered truncated.
|
// Each outer element is a required group; each inner slice lists alternative field names (OR logic).
|
||||||
var RequiredFieldsByTool = map[string][]string{
|
// A group is satisfied when ANY one of its alternatives exists in the parsed input.
|
||||||
"Write": {"file_path", "content"},
|
// All groups must be satisfied for the tool input to be considered valid.
|
||||||
"write_to_file": {"path", "content"},
|
//
|
||||||
"fsWrite": {"path", "content"},
|
// Example:
|
||||||
"create_file": {"path", "content"},
|
// {{"cmd", "command"}} means the tool needs EITHER "cmd" OR "command".
|
||||||
"edit_file": {"path"},
|
// {{"file_path"}, {"content"}} means the tool needs BOTH "file_path" AND "content".
|
||||||
"apply_diff": {"path", "diff"},
|
var RequiredFieldsByTool = map[string][][]string{
|
||||||
"str_replace_editor": {"path", "old_str", "new_str"},
|
"Write": {{"file_path"}, {"content"}},
|
||||||
"Bash": {"command"},
|
"write_to_file": {{"path"}, {"content"}},
|
||||||
"execute": {"command"},
|
"fsWrite": {{"path"}, {"content"}},
|
||||||
"run_command": {"command"},
|
"create_file": {{"path"}, {"content"}},
|
||||||
|
"edit_file": {{"path"}},
|
||||||
|
"apply_diff": {{"path"}, {"diff"}},
|
||||||
|
"str_replace_editor": {{"path"}, {"old_str"}, {"new_str"}},
|
||||||
|
"Bash": {{"cmd", "command"}},
|
||||||
|
"execute": {{"command"}},
|
||||||
|
"run_command": {{"command"}},
|
||||||
}
|
}
|
||||||
|
|
||||||
// DetectTruncation checks if the tool use input appears to be truncated.
|
// DetectTruncation checks if the tool use input appears to be truncated.
|
||||||
@@ -104,9 +110,9 @@ func DetectTruncation(toolName, toolUseID, rawInput string, parsedInput map[stri
|
|||||||
|
|
||||||
// Scenario 3: JSON parsed but critical fields are missing
|
// Scenario 3: JSON parsed but critical fields are missing
|
||||||
if parsedInput != nil {
|
if parsedInput != nil {
|
||||||
requiredFields, hasRequirements := RequiredFieldsByTool[toolName]
|
requiredGroups, hasRequirements := RequiredFieldsByTool[toolName]
|
||||||
if hasRequirements {
|
if hasRequirements {
|
||||||
missingFields := findMissingRequiredFields(parsedInput, requiredFields)
|
missingFields := findMissingRequiredFields(parsedInput, requiredGroups)
|
||||||
if len(missingFields) > 0 {
|
if len(missingFields) > 0 {
|
||||||
info.IsTruncated = true
|
info.IsTruncated = true
|
||||||
info.TruncationType = TruncationTypeMissingFields
|
info.TruncationType = TruncationTypeMissingFields
|
||||||
@@ -253,12 +259,21 @@ func extractParsedFieldNames(parsed map[string]interface{}) map[string]string {
|
|||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
// findMissingRequiredFields checks which required fields are missing from the parsed input.
|
// findMissingRequiredFields checks which required field groups are unsatisfied.
|
||||||
func findMissingRequiredFields(parsed map[string]interface{}, required []string) []string {
|
// Each group is a slice of alternative field names; the group is satisfied when ANY alternative exists.
|
||||||
|
// Returns the list of unsatisfied groups (represented by their alternatives joined with "/").
|
||||||
|
func findMissingRequiredFields(parsed map[string]interface{}, requiredGroups [][]string) []string {
|
||||||
var missing []string
|
var missing []string
|
||||||
for _, field := range required {
|
for _, group := range requiredGroups {
|
||||||
if _, exists := parsed[field]; !exists {
|
satisfied := false
|
||||||
missing = append(missing, field)
|
for _, field := range group {
|
||||||
|
if _, exists := parsed[field]; exists {
|
||||||
|
satisfied = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !satisfied {
|
||||||
|
missing = append(missing, strings.Join(group, "/"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return missing
|
return missing
|
||||||
|
|||||||
@@ -234,16 +234,16 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s
|
|||||||
// Kiro API supports official thinking/reasoning mode via <thinking_mode> tag.
|
// Kiro API supports official thinking/reasoning mode via <thinking_mode> tag.
|
||||||
// When set to "enabled", Kiro returns reasoning content as official reasoningContentEvent
|
// When set to "enabled", Kiro returns reasoning content as official reasoningContentEvent
|
||||||
// rather than inline <thinking> tags in assistantResponseEvent.
|
// rather than inline <thinking> tags in assistantResponseEvent.
|
||||||
// We use a high max_thinking_length to allow extensive reasoning.
|
// Use a conservative thinking budget to reduce latency/cost spikes in long sessions.
|
||||||
if thinkingEnabled {
|
if thinkingEnabled {
|
||||||
thinkingHint := `<thinking_mode>enabled</thinking_mode>
|
thinkingHint := `<thinking_mode>enabled</thinking_mode>
|
||||||
<max_thinking_length>200000</max_thinking_length>`
|
<max_thinking_length>16000</max_thinking_length>`
|
||||||
if systemPrompt != "" {
|
if systemPrompt != "" {
|
||||||
systemPrompt = thinkingHint + "\n\n" + systemPrompt
|
systemPrompt = thinkingHint + "\n\n" + systemPrompt
|
||||||
} else {
|
} else {
|
||||||
systemPrompt = thinkingHint
|
systemPrompt = thinkingHint
|
||||||
}
|
}
|
||||||
log.Debugf("kiro-openai: injected thinking prompt (official mode)")
|
log.Infof("kiro-openai: injected thinking prompt (official mode), has_tools: %v", len(kiroTools) > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process messages and build history
|
// Process messages and build history
|
||||||
@@ -578,6 +578,7 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
|
|||||||
|
|
||||||
// Truncate history if too long to prevent Kiro API errors
|
// Truncate history if too long to prevent Kiro API errors
|
||||||
history = truncateHistoryIfNeeded(history)
|
history = truncateHistoryIfNeeded(history)
|
||||||
|
history, currentToolResults = filterOrphanedToolResults(history, currentToolResults)
|
||||||
|
|
||||||
return history, currentUserMsg, currentToolResults
|
return history, currentUserMsg, currentToolResults
|
||||||
}
|
}
|
||||||
@@ -593,6 +594,61 @@ func truncateHistoryIfNeeded(history []KiroHistoryMessage) []KiroHistoryMessage
|
|||||||
return history[len(history)-kiroMaxHistoryMessages:]
|
return history[len(history)-kiroMaxHistoryMessages:]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func filterOrphanedToolResults(history []KiroHistoryMessage, currentToolResults []KiroToolResult) ([]KiroHistoryMessage, []KiroToolResult) {
|
||||||
|
// Remove tool results with no matching tool_use in retained history.
|
||||||
|
// This happens after truncation when the assistant turn that produced tool_use
|
||||||
|
// is dropped but a later user/tool_result survives.
|
||||||
|
validToolUseIDs := make(map[string]bool)
|
||||||
|
for _, h := range history {
|
||||||
|
if h.AssistantResponseMessage == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, tu := range h.AssistantResponseMessage.ToolUses {
|
||||||
|
validToolUseIDs[tu.ToolUseID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, h := range history {
|
||||||
|
if h.UserInputMessage == nil || h.UserInputMessage.UserInputMessageContext == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ctx := h.UserInputMessage.UserInputMessageContext
|
||||||
|
if len(ctx.ToolResults) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := make([]KiroToolResult, 0, len(ctx.ToolResults))
|
||||||
|
for _, tr := range ctx.ToolResults {
|
||||||
|
if validToolUseIDs[tr.ToolUseID] {
|
||||||
|
filtered = append(filtered, tr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Debugf("kiro-openai: dropping orphaned tool_result in history[%d]: toolUseId=%s (no matching tool_use)", i, tr.ToolUseID)
|
||||||
|
}
|
||||||
|
ctx.ToolResults = filtered
|
||||||
|
if len(ctx.ToolResults) == 0 && len(ctx.Tools) == 0 {
|
||||||
|
h.UserInputMessage.UserInputMessageContext = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(currentToolResults) > 0 {
|
||||||
|
filtered := make([]KiroToolResult, 0, len(currentToolResults))
|
||||||
|
for _, tr := range currentToolResults {
|
||||||
|
if validToolUseIDs[tr.ToolUseID] {
|
||||||
|
filtered = append(filtered, tr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Debugf("kiro-openai: dropping orphaned tool_result in currentMessage: toolUseId=%s (no matching tool_use)", tr.ToolUseID)
|
||||||
|
}
|
||||||
|
if len(filtered) != len(currentToolResults) {
|
||||||
|
log.Infof("kiro-openai: dropped %d orphaned tool_result(s) from currentMessage", len(currentToolResults)-len(filtered))
|
||||||
|
}
|
||||||
|
currentToolResults = filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
return history, currentToolResults
|
||||||
|
}
|
||||||
|
|
||||||
// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results
|
// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results
|
||||||
func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
|
func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
|
||||||
content := msg.Get("content")
|
content := msg.Get("content")
|
||||||
@@ -831,7 +887,6 @@ func hasThinkingTagInBody(body []byte) bool {
|
|||||||
return strings.Contains(bodyStr, "<thinking_mode>") || strings.Contains(bodyStr, "<max_thinking_length>")
|
return strings.Contains(bodyStr, "<thinking_mode>") || strings.Contains(bodyStr, "<max_thinking_length>")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// extractToolChoiceHint extracts tool_choice from OpenAI request and returns a system prompt hint.
|
// extractToolChoiceHint extracts tool_choice from OpenAI request and returns a system prompt hint.
|
||||||
// OpenAI tool_choice values:
|
// OpenAI tool_choice values:
|
||||||
// - "none": Don't use any tools
|
// - "none": Don't use any tools
|
||||||
|
|||||||
@@ -384,3 +384,57 @@ func TestAssistantEndsConversation(t *testing.T) {
|
|||||||
t.Error("Expected a 'Continue' message to be created when assistant is last")
|
t.Error("Expected a 'Continue' message to be created when assistant is last")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFilterOrphanedToolResults_RemovesHistoryAndCurrentOrphans(t *testing.T) {
|
||||||
|
history := []KiroHistoryMessage{
|
||||||
|
{
|
||||||
|
AssistantResponseMessage: &KiroAssistantResponseMessage{
|
||||||
|
Content: "assistant",
|
||||||
|
ToolUses: []KiroToolUse{
|
||||||
|
{ToolUseID: "keep-1", Name: "Read", Input: map[string]interface{}{}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UserInputMessage: &KiroUserInputMessage{
|
||||||
|
Content: "user-with-mixed-results",
|
||||||
|
UserInputMessageContext: &KiroUserInputMessageContext{
|
||||||
|
ToolResults: []KiroToolResult{
|
||||||
|
{ToolUseID: "keep-1", Status: "success", Content: []KiroTextContent{{Text: "ok"}}},
|
||||||
|
{ToolUseID: "orphan-1", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UserInputMessage: &KiroUserInputMessage{
|
||||||
|
Content: "user-only-orphans",
|
||||||
|
UserInputMessageContext: &KiroUserInputMessageContext{
|
||||||
|
ToolResults: []KiroToolResult{
|
||||||
|
{ToolUseID: "orphan-2", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
currentToolResults := []KiroToolResult{
|
||||||
|
{ToolUseID: "keep-1", Status: "success", Content: []KiroTextContent{{Text: "ok"}}},
|
||||||
|
{ToolUseID: "orphan-3", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredHistory, filteredCurrent := filterOrphanedToolResults(history, currentToolResults)
|
||||||
|
|
||||||
|
ctx1 := filteredHistory[1].UserInputMessage.UserInputMessageContext
|
||||||
|
if ctx1 == nil || len(ctx1.ToolResults) != 1 || ctx1.ToolResults[0].ToolUseID != "keep-1" {
|
||||||
|
t.Fatalf("expected mixed history message to keep only keep-1, got: %+v", ctx1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filteredHistory[2].UserInputMessage.UserInputMessageContext != nil {
|
||||||
|
t.Fatalf("expected orphan-only history context to be removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(filteredCurrent) != 1 || filteredCurrent[0].ToolUseID != "keep-1" {
|
||||||
|
t.Fatalf("expected current tool results to keep only keep-1, got: %+v", filteredCurrent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
542
internal/tui/app.go
Normal file
542
internal/tui/app.go
Normal file
@@ -0,0 +1,542 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/bubbles/textinput"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tab identifiers
|
||||||
|
const (
|
||||||
|
tabDashboard = iota
|
||||||
|
tabConfig
|
||||||
|
tabAuthFiles
|
||||||
|
tabAPIKeys
|
||||||
|
tabOAuth
|
||||||
|
tabUsage
|
||||||
|
tabLogs
|
||||||
|
)
|
||||||
|
|
||||||
|
// App is the root bubbletea model that contains all tab sub-models.
|
||||||
|
type App struct {
|
||||||
|
activeTab int
|
||||||
|
tabs []string
|
||||||
|
|
||||||
|
standalone bool
|
||||||
|
logsEnabled bool
|
||||||
|
|
||||||
|
authenticated bool
|
||||||
|
authInput textinput.Model
|
||||||
|
authError string
|
||||||
|
authConnecting bool
|
||||||
|
|
||||||
|
dashboard dashboardModel
|
||||||
|
config configTabModel
|
||||||
|
auth authTabModel
|
||||||
|
keys keysTabModel
|
||||||
|
oauth oauthTabModel
|
||||||
|
usage usageTabModel
|
||||||
|
logs logsTabModel
|
||||||
|
|
||||||
|
client *Client
|
||||||
|
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
ready bool
|
||||||
|
|
||||||
|
// Track which tabs have been initialized (fetched data)
|
||||||
|
initialized [7]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type authConnectMsg struct {
|
||||||
|
cfg map[string]any
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewApp creates the root TUI application model.
|
||||||
|
func NewApp(port int, secretKey string, hook *LogHook) App {
|
||||||
|
standalone := hook != nil
|
||||||
|
authRequired := !standalone
|
||||||
|
ti := textinput.New()
|
||||||
|
ti.CharLimit = 512
|
||||||
|
ti.EchoMode = textinput.EchoPassword
|
||||||
|
ti.EchoCharacter = '*'
|
||||||
|
ti.SetValue(strings.TrimSpace(secretKey))
|
||||||
|
ti.Focus()
|
||||||
|
|
||||||
|
client := NewClient(port, secretKey)
|
||||||
|
app := App{
|
||||||
|
activeTab: tabDashboard,
|
||||||
|
standalone: standalone,
|
||||||
|
logsEnabled: true,
|
||||||
|
authenticated: !authRequired,
|
||||||
|
authInput: ti,
|
||||||
|
dashboard: newDashboardModel(client),
|
||||||
|
config: newConfigTabModel(client),
|
||||||
|
auth: newAuthTabModel(client),
|
||||||
|
keys: newKeysTabModel(client),
|
||||||
|
oauth: newOAuthTabModel(client),
|
||||||
|
usage: newUsageTabModel(client),
|
||||||
|
logs: newLogsTabModel(client, hook),
|
||||||
|
client: client,
|
||||||
|
initialized: [7]bool{
|
||||||
|
tabDashboard: true,
|
||||||
|
tabLogs: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app.refreshTabs()
|
||||||
|
if authRequired {
|
||||||
|
app.initialized = [7]bool{}
|
||||||
|
}
|
||||||
|
app.setAuthInputPrompt()
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a App) Init() tea.Cmd {
|
||||||
|
if !a.authenticated {
|
||||||
|
return textinput.Blink
|
||||||
|
}
|
||||||
|
cmds := []tea.Cmd{a.dashboard.Init()}
|
||||||
|
if a.logsEnabled {
|
||||||
|
cmds = append(cmds, a.logs.Init())
|
||||||
|
}
|
||||||
|
return tea.Batch(cmds...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case tea.WindowSizeMsg:
|
||||||
|
a.width = msg.Width
|
||||||
|
a.height = msg.Height
|
||||||
|
a.ready = true
|
||||||
|
if a.width > 0 {
|
||||||
|
a.authInput.Width = a.width - 6
|
||||||
|
}
|
||||||
|
contentH := a.height - 4 // tab bar + status bar
|
||||||
|
if contentH < 1 {
|
||||||
|
contentH = 1
|
||||||
|
}
|
||||||
|
contentW := a.width
|
||||||
|
a.dashboard.SetSize(contentW, contentH)
|
||||||
|
a.config.SetSize(contentW, contentH)
|
||||||
|
a.auth.SetSize(contentW, contentH)
|
||||||
|
a.keys.SetSize(contentW, contentH)
|
||||||
|
a.oauth.SetSize(contentW, contentH)
|
||||||
|
a.usage.SetSize(contentW, contentH)
|
||||||
|
a.logs.SetSize(contentW, contentH)
|
||||||
|
return a, nil
|
||||||
|
|
||||||
|
case authConnectMsg:
|
||||||
|
a.authConnecting = false
|
||||||
|
if msg.err != nil {
|
||||||
|
a.authError = fmt.Sprintf(T("auth_gate_connect_fail"), msg.err.Error())
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
a.authError = ""
|
||||||
|
a.authenticated = true
|
||||||
|
a.logsEnabled = a.standalone || isLogsEnabledFromConfig(msg.cfg)
|
||||||
|
a.refreshTabs()
|
||||||
|
a.initialized = [7]bool{}
|
||||||
|
a.initialized[tabDashboard] = true
|
||||||
|
cmds := []tea.Cmd{a.dashboard.Init()}
|
||||||
|
if a.logsEnabled {
|
||||||
|
a.initialized[tabLogs] = true
|
||||||
|
cmds = append(cmds, a.logs.Init())
|
||||||
|
}
|
||||||
|
return a, tea.Batch(cmds...)
|
||||||
|
|
||||||
|
case configUpdateMsg:
|
||||||
|
var cmdLogs tea.Cmd
|
||||||
|
if !a.standalone && msg.err == nil && msg.path == "logging-to-file" {
|
||||||
|
logsEnabledConfig, okConfig := msg.value.(bool)
|
||||||
|
if okConfig {
|
||||||
|
logsEnabledBefore := a.logsEnabled
|
||||||
|
a.logsEnabled = logsEnabledConfig
|
||||||
|
if logsEnabledBefore != a.logsEnabled {
|
||||||
|
a.refreshTabs()
|
||||||
|
}
|
||||||
|
if !a.logsEnabled {
|
||||||
|
a.initialized[tabLogs] = false
|
||||||
|
}
|
||||||
|
if !logsEnabledBefore && a.logsEnabled {
|
||||||
|
a.initialized[tabLogs] = true
|
||||||
|
cmdLogs = a.logs.Init()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmdConfig tea.Cmd
|
||||||
|
a.config, cmdConfig = a.config.Update(msg)
|
||||||
|
if cmdConfig != nil && cmdLogs != nil {
|
||||||
|
return a, tea.Batch(cmdConfig, cmdLogs)
|
||||||
|
}
|
||||||
|
if cmdConfig != nil {
|
||||||
|
return a, cmdConfig
|
||||||
|
}
|
||||||
|
return a, cmdLogs
|
||||||
|
|
||||||
|
case tea.KeyMsg:
|
||||||
|
if !a.authenticated {
|
||||||
|
switch msg.String() {
|
||||||
|
case "ctrl+c", "q":
|
||||||
|
return a, tea.Quit
|
||||||
|
case "L":
|
||||||
|
ToggleLocale()
|
||||||
|
a.refreshTabs()
|
||||||
|
a.setAuthInputPrompt()
|
||||||
|
return a, nil
|
||||||
|
case "enter":
|
||||||
|
if a.authConnecting {
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
password := strings.TrimSpace(a.authInput.Value())
|
||||||
|
if password == "" {
|
||||||
|
a.authError = T("auth_gate_password_required")
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
a.authError = ""
|
||||||
|
a.authConnecting = true
|
||||||
|
return a, a.connectWithPassword(password)
|
||||||
|
default:
|
||||||
|
var cmd tea.Cmd
|
||||||
|
a.authInput, cmd = a.authInput.Update(msg)
|
||||||
|
return a, cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg.String() {
|
||||||
|
case "ctrl+c":
|
||||||
|
return a, tea.Quit
|
||||||
|
case "q":
|
||||||
|
// Only quit if not in logs tab (where 'q' might be useful)
|
||||||
|
if !a.logsEnabled || a.activeTab != tabLogs {
|
||||||
|
return a, tea.Quit
|
||||||
|
}
|
||||||
|
case "L":
|
||||||
|
ToggleLocale()
|
||||||
|
a.refreshTabs()
|
||||||
|
return a.broadcastToAllTabs(localeChangedMsg{})
|
||||||
|
case "tab":
|
||||||
|
if len(a.tabs) == 0 {
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
prevTab := a.activeTab
|
||||||
|
a.activeTab = (a.activeTab + 1) % len(a.tabs)
|
||||||
|
return a, a.initTabIfNeeded(prevTab)
|
||||||
|
case "shift+tab":
|
||||||
|
if len(a.tabs) == 0 {
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
prevTab := a.activeTab
|
||||||
|
a.activeTab = (a.activeTab - 1 + len(a.tabs)) % len(a.tabs)
|
||||||
|
return a, a.initTabIfNeeded(prevTab)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !a.authenticated {
|
||||||
|
var cmd tea.Cmd
|
||||||
|
a.authInput, cmd = a.authInput.Update(msg)
|
||||||
|
return a, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route msg to active tab
|
||||||
|
var cmd tea.Cmd
|
||||||
|
switch a.activeTab {
|
||||||
|
case tabDashboard:
|
||||||
|
a.dashboard, cmd = a.dashboard.Update(msg)
|
||||||
|
case tabConfig:
|
||||||
|
a.config, cmd = a.config.Update(msg)
|
||||||
|
case tabAuthFiles:
|
||||||
|
a.auth, cmd = a.auth.Update(msg)
|
||||||
|
case tabAPIKeys:
|
||||||
|
a.keys, cmd = a.keys.Update(msg)
|
||||||
|
case tabOAuth:
|
||||||
|
a.oauth, cmd = a.oauth.Update(msg)
|
||||||
|
case tabUsage:
|
||||||
|
a.usage, cmd = a.usage.Update(msg)
|
||||||
|
case tabLogs:
|
||||||
|
a.logs, cmd = a.logs.Update(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep logs polling alive even when logs tab is not active.
|
||||||
|
if a.logsEnabled && a.activeTab != tabLogs {
|
||||||
|
switch msg.(type) {
|
||||||
|
case logsPollMsg, logsTickMsg, logLineMsg:
|
||||||
|
var logCmd tea.Cmd
|
||||||
|
a.logs, logCmd = a.logs.Update(msg)
|
||||||
|
if logCmd != nil {
|
||||||
|
cmd = logCmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return a, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// localeChangedMsg is broadcast to all tabs when the user toggles locale.
|
||||||
|
type localeChangedMsg struct{}
|
||||||
|
|
||||||
|
func (a *App) refreshTabs() {
|
||||||
|
names := TabNames()
|
||||||
|
if a.logsEnabled {
|
||||||
|
a.tabs = names
|
||||||
|
} else {
|
||||||
|
filtered := make([]string, 0, len(names)-1)
|
||||||
|
for idx, name := range names {
|
||||||
|
if idx == tabLogs {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, name)
|
||||||
|
}
|
||||||
|
a.tabs = filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(a.tabs) == 0 {
|
||||||
|
a.activeTab = tabDashboard
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if a.activeTab >= len(a.tabs) {
|
||||||
|
a.activeTab = len(a.tabs) - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *App) initTabIfNeeded(_ int) tea.Cmd {
|
||||||
|
if a.initialized[a.activeTab] {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
a.initialized[a.activeTab] = true
|
||||||
|
switch a.activeTab {
|
||||||
|
case tabDashboard:
|
||||||
|
return a.dashboard.Init()
|
||||||
|
case tabConfig:
|
||||||
|
return a.config.Init()
|
||||||
|
case tabAuthFiles:
|
||||||
|
return a.auth.Init()
|
||||||
|
case tabAPIKeys:
|
||||||
|
return a.keys.Init()
|
||||||
|
case tabOAuth:
|
||||||
|
return a.oauth.Init()
|
||||||
|
case tabUsage:
|
||||||
|
return a.usage.Init()
|
||||||
|
case tabLogs:
|
||||||
|
if !a.logsEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return a.logs.Init()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a App) View() string {
|
||||||
|
if !a.authenticated {
|
||||||
|
return a.renderAuthView()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !a.ready {
|
||||||
|
return T("initializing_tui")
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
// Tab bar
|
||||||
|
sb.WriteString(a.renderTabBar())
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
// Content
|
||||||
|
switch a.activeTab {
|
||||||
|
case tabDashboard:
|
||||||
|
sb.WriteString(a.dashboard.View())
|
||||||
|
case tabConfig:
|
||||||
|
sb.WriteString(a.config.View())
|
||||||
|
case tabAuthFiles:
|
||||||
|
sb.WriteString(a.auth.View())
|
||||||
|
case tabAPIKeys:
|
||||||
|
sb.WriteString(a.keys.View())
|
||||||
|
case tabOAuth:
|
||||||
|
sb.WriteString(a.oauth.View())
|
||||||
|
case tabUsage:
|
||||||
|
sb.WriteString(a.usage.View())
|
||||||
|
case tabLogs:
|
||||||
|
if a.logsEnabled {
|
||||||
|
sb.WriteString(a.logs.View())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status bar
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(a.renderStatusBar())
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a App) renderAuthView() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
sb.WriteString(titleStyle.Render(T("auth_gate_title")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(T("auth_gate_help")))
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
if a.authConnecting {
|
||||||
|
sb.WriteString(warningStyle.Render(T("auth_gate_connecting")))
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(a.authError) != "" {
|
||||||
|
sb.WriteString(errorStyle.Render(a.authError))
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
sb.WriteString(a.authInput.View())
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(T("auth_gate_enter")))
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a App) renderTabBar() string {
|
||||||
|
var tabs []string
|
||||||
|
for i, name := range a.tabs {
|
||||||
|
if i == a.activeTab {
|
||||||
|
tabs = append(tabs, tabActiveStyle.Render(name))
|
||||||
|
} else {
|
||||||
|
tabs = append(tabs, tabInactiveStyle.Render(name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tabBar := lipgloss.JoinHorizontal(lipgloss.Top, tabs...)
|
||||||
|
return tabBarStyle.Width(a.width).Render(tabBar)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a App) renderStatusBar() string {
|
||||||
|
left := strings.TrimRight(T("status_left"), " ")
|
||||||
|
right := strings.TrimRight(T("status_right"), " ")
|
||||||
|
|
||||||
|
width := a.width
|
||||||
|
if width < 1 {
|
||||||
|
width = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// statusBarStyle has left/right padding(1), so content area is width-2.
|
||||||
|
contentWidth := width - 2
|
||||||
|
if contentWidth < 0 {
|
||||||
|
contentWidth = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if lipgloss.Width(left) > contentWidth {
|
||||||
|
left = fitStringWidth(left, contentWidth)
|
||||||
|
right = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining := contentWidth - lipgloss.Width(left)
|
||||||
|
if remaining < 0 {
|
||||||
|
remaining = 0
|
||||||
|
}
|
||||||
|
if lipgloss.Width(right) > remaining {
|
||||||
|
right = fitStringWidth(right, remaining)
|
||||||
|
}
|
||||||
|
|
||||||
|
gap := contentWidth - lipgloss.Width(left) - lipgloss.Width(right)
|
||||||
|
if gap < 0 {
|
||||||
|
gap = 0
|
||||||
|
}
|
||||||
|
return statusBarStyle.Width(width).Render(left + strings.Repeat(" ", gap) + right)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fitStringWidth(text string, maxWidth int) string {
|
||||||
|
if maxWidth <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if lipgloss.Width(text) <= maxWidth {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
out := ""
|
||||||
|
for _, r := range text {
|
||||||
|
next := out + string(r)
|
||||||
|
if lipgloss.Width(next) > maxWidth {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
out = next
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLogsEnabledFromConfig(cfg map[string]any) bool {
|
||||||
|
if cfg == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
value, ok := cfg["logging-to-file"]
|
||||||
|
if !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
enabled, ok := value.(bool)
|
||||||
|
if !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *App) setAuthInputPrompt() {
|
||||||
|
if a == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
a.authInput.Prompt = fmt.Sprintf(" %s: ", T("auth_gate_password"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a App) connectWithPassword(password string) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
a.client.SetSecretKey(password)
|
||||||
|
cfg, errGetConfig := a.client.GetConfig()
|
||||||
|
return authConnectMsg{cfg: cfg, err: errGetConfig}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run starts the TUI application.
|
||||||
|
// output specifies where bubbletea renders. If nil, defaults to os.Stdout.
|
||||||
|
func Run(port int, secretKey string, hook *LogHook, output io.Writer) error {
|
||||||
|
if output == nil {
|
||||||
|
output = os.Stdout
|
||||||
|
}
|
||||||
|
app := NewApp(port, secretKey, hook)
|
||||||
|
p := tea.NewProgram(app, tea.WithAltScreen(), tea.WithOutput(output))
|
||||||
|
_, err := p.Run()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a App) broadcastToAllTabs(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
var cmds []tea.Cmd
|
||||||
|
var cmd tea.Cmd
|
||||||
|
|
||||||
|
a.dashboard, cmd = a.dashboard.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
a.config, cmd = a.config.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
a.auth, cmd = a.auth.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
a.keys, cmd = a.keys.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
a.oauth, cmd = a.oauth.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
a.usage, cmd = a.usage.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
a.logs, cmd = a.logs.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
return a, tea.Batch(cmds...)
|
||||||
|
}
|
||||||
456
internal/tui/auth_tab.go
Normal file
456
internal/tui/auth_tab.go
Normal file
@@ -0,0 +1,456 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/bubbles/textinput"
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
// editableField represents an editable field on an auth file.
|
||||||
|
type editableField struct {
|
||||||
|
label string
|
||||||
|
key string // API field key: "prefix", "proxy_url", "priority"
|
||||||
|
}
|
||||||
|
|
||||||
|
var authEditableFields = []editableField{
|
||||||
|
{label: "Prefix", key: "prefix"},
|
||||||
|
{label: "Proxy URL", key: "proxy_url"},
|
||||||
|
{label: "Priority", key: "priority"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// authTabModel displays auth credential files with interactive management.
|
||||||
|
type authTabModel struct {
|
||||||
|
client *Client
|
||||||
|
viewport viewport.Model
|
||||||
|
files []map[string]any
|
||||||
|
err error
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
ready bool
|
||||||
|
cursor int
|
||||||
|
expanded int // -1 = none expanded, >=0 = expanded index
|
||||||
|
confirm int // -1 = no confirmation, >=0 = confirm delete for index
|
||||||
|
status string
|
||||||
|
|
||||||
|
// Editing state
|
||||||
|
editing bool // true when editing a field
|
||||||
|
editField int // index into authEditableFields
|
||||||
|
editInput textinput.Model // text input for editing
|
||||||
|
editFileName string // name of file being edited
|
||||||
|
}
|
||||||
|
|
||||||
|
type authFilesMsg struct {
|
||||||
|
files []map[string]any
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type authActionMsg struct {
|
||||||
|
action string // "deleted", "toggled", "updated"
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthTabModel(client *Client) authTabModel {
|
||||||
|
ti := textinput.New()
|
||||||
|
ti.CharLimit = 256
|
||||||
|
return authTabModel{
|
||||||
|
client: client,
|
||||||
|
expanded: -1,
|
||||||
|
confirm: -1,
|
||||||
|
editInput: ti,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m authTabModel) Init() tea.Cmd {
|
||||||
|
return m.fetchFiles
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m authTabModel) fetchFiles() tea.Msg {
|
||||||
|
files, err := m.client.GetAuthFiles()
|
||||||
|
return authFilesMsg{files: files, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m authTabModel) Update(msg tea.Msg) (authTabModel, tea.Cmd) {
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case localeChangedMsg:
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
case authFilesMsg:
|
||||||
|
if msg.err != nil {
|
||||||
|
m.err = msg.err
|
||||||
|
} else {
|
||||||
|
m.err = nil
|
||||||
|
m.files = msg.files
|
||||||
|
if m.cursor >= len(m.files) {
|
||||||
|
m.cursor = max(0, len(m.files)-1)
|
||||||
|
}
|
||||||
|
m.status = ""
|
||||||
|
}
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case authActionMsg:
|
||||||
|
if msg.err != nil {
|
||||||
|
m.status = errorStyle.Render("✗ " + msg.err.Error())
|
||||||
|
} else {
|
||||||
|
m.status = successStyle.Render("✓ " + msg.action)
|
||||||
|
}
|
||||||
|
m.confirm = -1
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, m.fetchFiles
|
||||||
|
|
||||||
|
case tea.KeyMsg:
|
||||||
|
// ---- Editing mode ----
|
||||||
|
if m.editing {
|
||||||
|
return m.handleEditInput(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Delete confirmation mode ----
|
||||||
|
if m.confirm >= 0 {
|
||||||
|
return m.handleConfirmInput(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Normal mode ----
|
||||||
|
return m.handleNormalInput(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// startEdit activates inline editing for a field on the currently selected auth file.
|
||||||
|
func (m *authTabModel) startEdit(fieldIdx int) tea.Cmd {
|
||||||
|
if m.cursor >= len(m.files) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
f := m.files[m.cursor]
|
||||||
|
m.editFileName = getString(f, "name")
|
||||||
|
m.editField = fieldIdx
|
||||||
|
m.editing = true
|
||||||
|
|
||||||
|
// Pre-populate with current value
|
||||||
|
key := authEditableFields[fieldIdx].key
|
||||||
|
currentVal := getAnyString(f, key)
|
||||||
|
m.editInput.SetValue(currentVal)
|
||||||
|
m.editInput.Focus()
|
||||||
|
m.editInput.Prompt = fmt.Sprintf(" %s: ", authEditableFields[fieldIdx].label)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return textinput.Blink
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *authTabModel) SetSize(w, h int) {
|
||||||
|
m.width = w
|
||||||
|
m.height = h
|
||||||
|
m.editInput.Width = w - 20
|
||||||
|
if !m.ready {
|
||||||
|
m.viewport = viewport.New(w, h)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
m.ready = true
|
||||||
|
} else {
|
||||||
|
m.viewport.Width = w
|
||||||
|
m.viewport.Height = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m authTabModel) View() string {
|
||||||
|
if !m.ready {
|
||||||
|
return T("loading")
|
||||||
|
}
|
||||||
|
return m.viewport.View()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m authTabModel) renderContent() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
sb.WriteString(titleStyle.Render(T("auth_title")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(T("auth_help1")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(T("auth_help2")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(strings.Repeat("─", m.width))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
if m.err != nil {
|
||||||
|
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.files) == 0 {
|
||||||
|
sb.WriteString(subtitleStyle.Render(T("no_auth_files")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, f := range m.files {
|
||||||
|
name := getString(f, "name")
|
||||||
|
channel := getString(f, "channel")
|
||||||
|
email := getString(f, "email")
|
||||||
|
disabled := getBool(f, "disabled")
|
||||||
|
|
||||||
|
statusIcon := successStyle.Render("●")
|
||||||
|
statusText := T("status_active")
|
||||||
|
if disabled {
|
||||||
|
statusIcon = lipgloss.NewStyle().Foreground(colorMuted).Render("○")
|
||||||
|
statusText = T("status_disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
cursor := " "
|
||||||
|
rowStyle := lipgloss.NewStyle()
|
||||||
|
if i == m.cursor {
|
||||||
|
cursor = "▸ "
|
||||||
|
rowStyle = lipgloss.NewStyle().Bold(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
displayName := name
|
||||||
|
if len(displayName) > 24 {
|
||||||
|
displayName = displayName[:21] + "..."
|
||||||
|
}
|
||||||
|
displayEmail := email
|
||||||
|
if len(displayEmail) > 28 {
|
||||||
|
displayEmail = displayEmail[:25] + "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
row := fmt.Sprintf("%s%s %-24s %-12s %-28s %s",
|
||||||
|
cursor, statusIcon, displayName, channel, displayEmail, statusText)
|
||||||
|
sb.WriteString(rowStyle.Render(row))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
// Delete confirmation
|
||||||
|
if m.confirm == i {
|
||||||
|
sb.WriteString(warningStyle.Render(fmt.Sprintf(" "+T("confirm_delete"), name)))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inline edit input
|
||||||
|
if m.editing && i == m.cursor {
|
||||||
|
sb.WriteString(m.editInput.View())
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(" " + T("enter_save") + " • " + T("esc_cancel")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expanded detail view
|
||||||
|
if m.expanded == i {
|
||||||
|
sb.WriteString(m.renderDetail(f))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.status != "" {
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(m.status)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m authTabModel) renderDetail(f map[string]any) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
labelStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("111")).
|
||||||
|
Bold(true)
|
||||||
|
valueStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("252"))
|
||||||
|
editableMarker := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("214")).
|
||||||
|
Render(" ✎")
|
||||||
|
|
||||||
|
sb.WriteString(" ┌─────────────────────────────────────────────\n")
|
||||||
|
|
||||||
|
fields := []struct {
|
||||||
|
label string
|
||||||
|
key string
|
||||||
|
editable bool
|
||||||
|
}{
|
||||||
|
{"Name", "name", false},
|
||||||
|
{"Channel", "channel", false},
|
||||||
|
{"Email", "email", false},
|
||||||
|
{"Status", "status", false},
|
||||||
|
{"Status Msg", "status_message", false},
|
||||||
|
{"File Name", "file_name", false},
|
||||||
|
{"Auth Type", "auth_type", false},
|
||||||
|
{"Prefix", "prefix", true},
|
||||||
|
{"Proxy URL", "proxy_url", true},
|
||||||
|
{"Priority", "priority", true},
|
||||||
|
{"Project ID", "project_id", false},
|
||||||
|
{"Disabled", "disabled", false},
|
||||||
|
{"Created", "created_at", false},
|
||||||
|
{"Updated", "updated_at", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, field := range fields {
|
||||||
|
val := getAnyString(f, field.key)
|
||||||
|
if val == "" || val == "<nil>" {
|
||||||
|
if field.editable {
|
||||||
|
val = T("not_set")
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
editMark := ""
|
||||||
|
if field.editable {
|
||||||
|
editMark = editableMarker
|
||||||
|
}
|
||||||
|
line := fmt.Sprintf(" │ %s %s%s",
|
||||||
|
labelStyle.Render(fmt.Sprintf("%-12s:", field.label)),
|
||||||
|
valueStyle.Render(val),
|
||||||
|
editMark)
|
||||||
|
sb.WriteString(line)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(" └─────────────────────────────────────────────\n")
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAnyString converts any value to its string representation.
|
||||||
|
func getAnyString(m map[string]any, key string) string {
|
||||||
|
v, ok := m[key]
|
||||||
|
if !ok || v == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func max(a, b int) int {
|
||||||
|
if a > b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m authTabModel) handleEditInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
|
||||||
|
switch msg.String() {
|
||||||
|
case "enter":
|
||||||
|
value := m.editInput.Value()
|
||||||
|
fieldKey := authEditableFields[m.editField].key
|
||||||
|
fileName := m.editFileName
|
||||||
|
m.editing = false
|
||||||
|
m.editInput.Blur()
|
||||||
|
fields := map[string]any{}
|
||||||
|
if fieldKey == "priority" {
|
||||||
|
p, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return m, func() tea.Msg {
|
||||||
|
return authActionMsg{err: fmt.Errorf("%s: %s", T("invalid_int"), value)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fields[fieldKey] = p
|
||||||
|
} else {
|
||||||
|
fields[fieldKey] = value
|
||||||
|
}
|
||||||
|
return m, func() tea.Msg {
|
||||||
|
err := m.client.PatchAuthFileFields(fileName, fields)
|
||||||
|
if err != nil {
|
||||||
|
return authActionMsg{err: err}
|
||||||
|
}
|
||||||
|
return authActionMsg{action: fmt.Sprintf(T("updated_field"), fieldKey, fileName)}
|
||||||
|
}
|
||||||
|
case "esc":
|
||||||
|
m.editing = false
|
||||||
|
m.editInput.Blur()
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
default:
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.editInput, cmd = m.editInput.Update(msg)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m authTabModel) handleConfirmInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
|
||||||
|
switch msg.String() {
|
||||||
|
case "y", "Y":
|
||||||
|
idx := m.confirm
|
||||||
|
m.confirm = -1
|
||||||
|
if idx < len(m.files) {
|
||||||
|
name := getString(m.files[idx], "name")
|
||||||
|
return m, func() tea.Msg {
|
||||||
|
err := m.client.DeleteAuthFile(name)
|
||||||
|
if err != nil {
|
||||||
|
return authActionMsg{err: err}
|
||||||
|
}
|
||||||
|
return authActionMsg{action: fmt.Sprintf(T("deleted"), name)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
case "n", "N", "esc":
|
||||||
|
m.confirm = -1
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m authTabModel) handleNormalInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
|
||||||
|
switch msg.String() {
|
||||||
|
case "j", "down":
|
||||||
|
if len(m.files) > 0 {
|
||||||
|
m.cursor = (m.cursor + 1) % len(m.files)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "k", "up":
|
||||||
|
if len(m.files) > 0 {
|
||||||
|
m.cursor = (m.cursor - 1 + len(m.files)) % len(m.files)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "enter", " ":
|
||||||
|
if m.expanded == m.cursor {
|
||||||
|
m.expanded = -1
|
||||||
|
} else {
|
||||||
|
m.expanded = m.cursor
|
||||||
|
}
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
case "d", "D":
|
||||||
|
if m.cursor < len(m.files) {
|
||||||
|
m.confirm = m.cursor
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "e", "E":
|
||||||
|
if m.cursor < len(m.files) {
|
||||||
|
f := m.files[m.cursor]
|
||||||
|
name := getString(f, "name")
|
||||||
|
disabled := getBool(f, "disabled")
|
||||||
|
newDisabled := !disabled
|
||||||
|
return m, func() tea.Msg {
|
||||||
|
err := m.client.ToggleAuthFile(name, newDisabled)
|
||||||
|
if err != nil {
|
||||||
|
return authActionMsg{err: err}
|
||||||
|
}
|
||||||
|
action := T("enabled")
|
||||||
|
if newDisabled {
|
||||||
|
action = T("disabled")
|
||||||
|
}
|
||||||
|
return authActionMsg{action: fmt.Sprintf("%s %s", action, name)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "1":
|
||||||
|
return m, m.startEdit(0) // prefix
|
||||||
|
case "2":
|
||||||
|
return m, m.startEdit(1) // proxy_url
|
||||||
|
case "3":
|
||||||
|
return m, m.startEdit(2) // priority
|
||||||
|
case "r":
|
||||||
|
m.status = ""
|
||||||
|
return m, m.fetchFiles
|
||||||
|
default:
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
20
internal/tui/browser.go
Normal file
20
internal/tui/browser.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
// openBrowser opens the specified URL in the user's default browser.
|
||||||
|
func openBrowser(url string) error {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
return exec.Command("open", url).Start()
|
||||||
|
case "linux":
|
||||||
|
return exec.Command("xdg-open", url).Start()
|
||||||
|
case "windows":
|
||||||
|
return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||||
|
default:
|
||||||
|
return exec.Command("xdg-open", url).Start()
|
||||||
|
}
|
||||||
|
}
|
||||||
400
internal/tui/client.go
Normal file
400
internal/tui/client.go
Normal file
@@ -0,0 +1,400 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client wraps HTTP calls to the management API.
|
||||||
|
type Client struct {
|
||||||
|
baseURL string
|
||||||
|
secretKey string
|
||||||
|
http *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new management API client.
|
||||||
|
func NewClient(port int, secretKey string) *Client {
|
||||||
|
return &Client{
|
||||||
|
baseURL: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
secretKey: strings.TrimSpace(secretKey),
|
||||||
|
http: &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSecretKey updates management API bearer token used by this client.
|
||||||
|
func (c *Client) SetSecretKey(secretKey string) {
|
||||||
|
c.secretKey = strings.TrimSpace(secretKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) doRequest(method, path string, body io.Reader) ([]byte, int, error) {
|
||||||
|
url := c.baseURL + path
|
||||||
|
req, err := http.NewRequest(method, url, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
if c.secretKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.secretKey)
|
||||||
|
}
|
||||||
|
if body != nil {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
|
resp, err := c.http.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, resp.StatusCode, err
|
||||||
|
}
|
||||||
|
return data, resp.StatusCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) get(path string) ([]byte, error) {
|
||||||
|
data, code, err := c.doRequest("GET", path, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if code >= 400 {
|
||||||
|
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) put(path string, body io.Reader) ([]byte, error) {
|
||||||
|
data, code, err := c.doRequest("PUT", path, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if code >= 400 {
|
||||||
|
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) patch(path string, body io.Reader) ([]byte, error) {
|
||||||
|
data, code, err := c.doRequest("PATCH", path, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if code >= 400 {
|
||||||
|
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getJSON fetches a path and unmarshals JSON into a generic map.
|
||||||
|
func (c *Client) getJSON(path string) (map[string]any, error) {
|
||||||
|
data, err := c.get(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var result map[string]any
|
||||||
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// postJSON sends a JSON body via POST and checks for errors.
|
||||||
|
func (c *Client) postJSON(path string, body any) error {
|
||||||
|
jsonBody, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, code, err := c.doRequest("POST", path, strings.NewReader(string(jsonBody)))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if code >= 400 {
|
||||||
|
return fmt.Errorf("HTTP %d", code)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig fetches the parsed config.
|
||||||
|
func (c *Client) GetConfig() (map[string]any, error) {
|
||||||
|
return c.getJSON("/v0/management/config")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfigYAML fetches the raw config.yaml content.
|
||||||
|
func (c *Client) GetConfigYAML() (string, error) {
|
||||||
|
data, err := c.get("/v0/management/config.yaml")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutConfigYAML uploads new config.yaml content.
|
||||||
|
func (c *Client) PutConfigYAML(yamlContent string) error {
|
||||||
|
_, err := c.put("/v0/management/config.yaml", strings.NewReader(yamlContent))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUsage fetches usage statistics.
|
||||||
|
func (c *Client) GetUsage() (map[string]any, error) {
|
||||||
|
return c.getJSON("/v0/management/usage")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthFiles lists auth credential files.
|
||||||
|
// API returns {"files": [...]}.
|
||||||
|
func (c *Client) GetAuthFiles() ([]map[string]any, error) {
|
||||||
|
wrapper, err := c.getJSON("/v0/management/auth-files")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return extractList(wrapper, "files")
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAuthFile deletes a single auth file by name.
|
||||||
|
func (c *Client) DeleteAuthFile(name string) error {
|
||||||
|
query := url.Values{}
|
||||||
|
query.Set("name", name)
|
||||||
|
path := "/v0/management/auth-files?" + query.Encode()
|
||||||
|
_, code, err := c.doRequest("DELETE", path, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if code >= 400 {
|
||||||
|
return fmt.Errorf("delete failed (HTTP %d)", code)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToggleAuthFile enables or disables an auth file.
|
||||||
|
func (c *Client) ToggleAuthFile(name string, disabled bool) error {
|
||||||
|
body, _ := json.Marshal(map[string]any{"name": name, "disabled": disabled})
|
||||||
|
_, err := c.patch("/v0/management/auth-files/status", strings.NewReader(string(body)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// PatchAuthFileFields updates editable fields on an auth file.
|
||||||
|
func (c *Client) PatchAuthFileFields(name string, fields map[string]any) error {
|
||||||
|
fields["name"] = name
|
||||||
|
body, _ := json.Marshal(fields)
|
||||||
|
_, err := c.patch("/v0/management/auth-files/fields", strings.NewReader(string(body)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLogs fetches log lines from the server.
|
||||||
|
func (c *Client) GetLogs(after int64, limit int) ([]string, int64, error) {
|
||||||
|
query := url.Values{}
|
||||||
|
if limit > 0 {
|
||||||
|
query.Set("limit", strconv.Itoa(limit))
|
||||||
|
}
|
||||||
|
if after > 0 {
|
||||||
|
query.Set("after", strconv.FormatInt(after, 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
path := "/v0/management/logs"
|
||||||
|
encodedQuery := query.Encode()
|
||||||
|
if encodedQuery != "" {
|
||||||
|
path += "?" + encodedQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
wrapper, err := c.getJSON(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, after, err
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := []string{}
|
||||||
|
if rawLines, ok := wrapper["lines"]; ok && rawLines != nil {
|
||||||
|
rawJSON, errMarshal := json.Marshal(rawLines)
|
||||||
|
if errMarshal != nil {
|
||||||
|
return nil, after, errMarshal
|
||||||
|
}
|
||||||
|
if errUnmarshal := json.Unmarshal(rawJSON, &lines); errUnmarshal != nil {
|
||||||
|
return nil, after, errUnmarshal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
latest := after
|
||||||
|
if rawLatest, ok := wrapper["latest-timestamp"]; ok {
|
||||||
|
switch value := rawLatest.(type) {
|
||||||
|
case float64:
|
||||||
|
latest = int64(value)
|
||||||
|
case json.Number:
|
||||||
|
if parsed, errParse := value.Int64(); errParse == nil {
|
||||||
|
latest = parsed
|
||||||
|
}
|
||||||
|
case int64:
|
||||||
|
latest = value
|
||||||
|
case int:
|
||||||
|
latest = int64(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if latest < after {
|
||||||
|
latest = after
|
||||||
|
}
|
||||||
|
|
||||||
|
return lines, latest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAPIKeys fetches the list of API keys.
|
||||||
|
// API returns {"api-keys": [...]}.
|
||||||
|
func (c *Client) GetAPIKeys() ([]string, error) {
|
||||||
|
wrapper, err := c.getJSON("/v0/management/api-keys")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
arr, ok := wrapper["api-keys"]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
raw, err := json.Marshal(arr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var result []string
|
||||||
|
if err := json.Unmarshal(raw, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAPIKey adds a new API key by sending old=nil, new=key which appends.
|
||||||
|
func (c *Client) AddAPIKey(key string) error {
|
||||||
|
body := map[string]any{"old": nil, "new": key}
|
||||||
|
jsonBody, _ := json.Marshal(body)
|
||||||
|
_, err := c.patch("/v0/management/api-keys", strings.NewReader(string(jsonBody)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// EditAPIKey replaces an API key at the given index.
|
||||||
|
func (c *Client) EditAPIKey(index int, newValue string) error {
|
||||||
|
body := map[string]any{"index": index, "value": newValue}
|
||||||
|
jsonBody, _ := json.Marshal(body)
|
||||||
|
_, err := c.patch("/v0/management/api-keys", strings.NewReader(string(jsonBody)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAPIKey deletes an API key by index.
|
||||||
|
func (c *Client) DeleteAPIKey(index int) error {
|
||||||
|
_, code, err := c.doRequest("DELETE", fmt.Sprintf("/v0/management/api-keys?index=%d", index), nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if code >= 400 {
|
||||||
|
return fmt.Errorf("delete failed (HTTP %d)", code)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiKeys fetches Gemini API keys.
|
||||||
|
// API returns {"gemini-api-key": [...]}.
|
||||||
|
func (c *Client) GetGeminiKeys() ([]map[string]any, error) {
|
||||||
|
return c.getWrappedKeyList("/v0/management/gemini-api-key", "gemini-api-key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClaudeKeys fetches Claude API keys.
|
||||||
|
func (c *Client) GetClaudeKeys() ([]map[string]any, error) {
|
||||||
|
return c.getWrappedKeyList("/v0/management/claude-api-key", "claude-api-key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexKeys fetches Codex API keys.
|
||||||
|
func (c *Client) GetCodexKeys() ([]map[string]any, error) {
|
||||||
|
return c.getWrappedKeyList("/v0/management/codex-api-key", "codex-api-key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVertexKeys fetches Vertex API keys.
|
||||||
|
func (c *Client) GetVertexKeys() ([]map[string]any, error) {
|
||||||
|
return c.getWrappedKeyList("/v0/management/vertex-api-key", "vertex-api-key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAICompat fetches OpenAI compatibility entries.
|
||||||
|
func (c *Client) GetOpenAICompat() ([]map[string]any, error) {
|
||||||
|
return c.getWrappedKeyList("/v0/management/openai-compatibility", "openai-compatibility")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getWrappedKeyList fetches a wrapped list from the API.
|
||||||
|
func (c *Client) getWrappedKeyList(path, key string) ([]map[string]any, error) {
|
||||||
|
wrapper, err := c.getJSON(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return extractList(wrapper, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractList pulls an array of maps from a wrapper object by key.
|
||||||
|
func extractList(wrapper map[string]any, key string) ([]map[string]any, error) {
|
||||||
|
arr, ok := wrapper[key]
|
||||||
|
if !ok || arr == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
raw, err := json.Marshal(arr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var result []map[string]any
|
||||||
|
if err := json.Unmarshal(raw, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDebug fetches the current debug setting.
|
||||||
|
func (c *Client) GetDebug() (bool, error) {
|
||||||
|
wrapper, err := c.getJSON("/v0/management/debug")
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if v, ok := wrapper["debug"]; ok {
|
||||||
|
if b, ok := v.(bool); ok {
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthStatus polls the OAuth session status.
|
||||||
|
// Returns status ("wait", "ok", "error") and optional error message.
|
||||||
|
func (c *Client) GetAuthStatus(state string) (string, string, error) {
|
||||||
|
query := url.Values{}
|
||||||
|
query.Set("state", state)
|
||||||
|
path := "/v0/management/get-auth-status?" + query.Encode()
|
||||||
|
wrapper, err := c.getJSON(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
status := getString(wrapper, "status")
|
||||||
|
errMsg := getString(wrapper, "error")
|
||||||
|
return status, errMsg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- Config field update methods -----
|
||||||
|
|
||||||
|
// PutBoolField updates a boolean config field.
|
||||||
|
func (c *Client) PutBoolField(path string, value bool) error {
|
||||||
|
body, _ := json.Marshal(map[string]any{"value": value})
|
||||||
|
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutIntField updates an integer config field.
|
||||||
|
func (c *Client) PutIntField(path string, value int) error {
|
||||||
|
body, _ := json.Marshal(map[string]any{"value": value})
|
||||||
|
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutStringField updates a string config field.
|
||||||
|
func (c *Client) PutStringField(path string, value string) error {
|
||||||
|
body, _ := json.Marshal(map[string]any{"value": value})
|
||||||
|
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteField sends a DELETE request for a config field.
|
||||||
|
func (c *Client) DeleteField(path string) error {
|
||||||
|
_, _, err := c.doRequest("DELETE", "/v0/management/"+path, nil)
|
||||||
|
return err
|
||||||
|
}
|
||||||
413
internal/tui/config_tab.go
Normal file
413
internal/tui/config_tab.go
Normal file
@@ -0,0 +1,413 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/bubbles/textinput"
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
// configField represents a single editable config field.
|
||||||
|
type configField struct {
|
||||||
|
label string
|
||||||
|
apiPath string // management API path (e.g. "debug", "proxy-url")
|
||||||
|
kind string // "bool", "int", "string", "readonly"
|
||||||
|
value string // current display value
|
||||||
|
rawValue any // raw value from API
|
||||||
|
}
|
||||||
|
|
||||||
|
// configTabModel displays parsed config with interactive editing.
|
||||||
|
type configTabModel struct {
|
||||||
|
client *Client
|
||||||
|
viewport viewport.Model
|
||||||
|
fields []configField
|
||||||
|
cursor int
|
||||||
|
editing bool
|
||||||
|
textInput textinput.Model
|
||||||
|
err error
|
||||||
|
message string // status message (success/error)
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
ready bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type configDataMsg struct {
|
||||||
|
config map[string]any
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type configUpdateMsg struct {
|
||||||
|
path string
|
||||||
|
value any
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConfigTabModel(client *Client) configTabModel {
|
||||||
|
ti := textinput.New()
|
||||||
|
ti.CharLimit = 256
|
||||||
|
return configTabModel{
|
||||||
|
client: client,
|
||||||
|
textInput: ti,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m configTabModel) Init() tea.Cmd {
|
||||||
|
return m.fetchConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m configTabModel) fetchConfig() tea.Msg {
|
||||||
|
cfg, err := m.client.GetConfig()
|
||||||
|
return configDataMsg{config: cfg, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m configTabModel) Update(msg tea.Msg) (configTabModel, tea.Cmd) {
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case localeChangedMsg:
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
case configDataMsg:
|
||||||
|
if msg.err != nil {
|
||||||
|
m.err = msg.err
|
||||||
|
m.fields = nil
|
||||||
|
} else {
|
||||||
|
m.err = nil
|
||||||
|
m.fields = m.parseConfig(msg.config)
|
||||||
|
}
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case configUpdateMsg:
|
||||||
|
if msg.err != nil {
|
||||||
|
m.message = errorStyle.Render("✗ " + msg.err.Error())
|
||||||
|
} else {
|
||||||
|
m.message = successStyle.Render(T("updated_ok"))
|
||||||
|
}
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
// Refresh config from server
|
||||||
|
return m, m.fetchConfig
|
||||||
|
|
||||||
|
case tea.KeyMsg:
|
||||||
|
if m.editing {
|
||||||
|
return m.handleEditingKey(msg)
|
||||||
|
}
|
||||||
|
return m.handleNormalKey(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m configTabModel) handleNormalKey(msg tea.KeyMsg) (configTabModel, tea.Cmd) {
|
||||||
|
switch msg.String() {
|
||||||
|
case "r":
|
||||||
|
m.message = ""
|
||||||
|
return m, m.fetchConfig
|
||||||
|
case "up", "k":
|
||||||
|
if m.cursor > 0 {
|
||||||
|
m.cursor--
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
// Ensure cursor is visible
|
||||||
|
m.ensureCursorVisible()
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "down", "j":
|
||||||
|
if m.cursor < len(m.fields)-1 {
|
||||||
|
m.cursor++
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
m.ensureCursorVisible()
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "enter", " ":
|
||||||
|
if m.cursor >= 0 && m.cursor < len(m.fields) {
|
||||||
|
f := m.fields[m.cursor]
|
||||||
|
if f.kind == "readonly" {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
if f.kind == "bool" {
|
||||||
|
// Toggle directly
|
||||||
|
return m, m.toggleBool(m.cursor)
|
||||||
|
}
|
||||||
|
// Start editing for int/string
|
||||||
|
m.editing = true
|
||||||
|
m.textInput.SetValue(configFieldEditValue(f))
|
||||||
|
m.textInput.Focus()
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, textinput.Blink
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m configTabModel) handleEditingKey(msg tea.KeyMsg) (configTabModel, tea.Cmd) {
|
||||||
|
switch msg.String() {
|
||||||
|
case "enter":
|
||||||
|
m.editing = false
|
||||||
|
m.textInput.Blur()
|
||||||
|
return m, m.submitEdit(m.cursor, m.textInput.Value())
|
||||||
|
case "esc":
|
||||||
|
m.editing = false
|
||||||
|
m.textInput.Blur()
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
default:
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.textInput, cmd = m.textInput.Update(msg)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m configTabModel) toggleBool(idx int) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
f := m.fields[idx]
|
||||||
|
current := f.value == "true"
|
||||||
|
newValue := !current
|
||||||
|
errPutBool := m.client.PutBoolField(f.apiPath, newValue)
|
||||||
|
return configUpdateMsg{
|
||||||
|
path: f.apiPath,
|
||||||
|
value: newValue,
|
||||||
|
err: errPutBool,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m configTabModel) submitEdit(idx int, newValue string) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
f := m.fields[idx]
|
||||||
|
var err error
|
||||||
|
var value any
|
||||||
|
switch f.kind {
|
||||||
|
case "int":
|
||||||
|
valueInt, errAtoi := strconv.Atoi(newValue)
|
||||||
|
if errAtoi != nil {
|
||||||
|
return configUpdateMsg{
|
||||||
|
path: f.apiPath,
|
||||||
|
err: fmt.Errorf("%s: %s", T("invalid_int"), newValue),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
value = valueInt
|
||||||
|
err = m.client.PutIntField(f.apiPath, valueInt)
|
||||||
|
case "string":
|
||||||
|
value = newValue
|
||||||
|
err = m.client.PutStringField(f.apiPath, newValue)
|
||||||
|
}
|
||||||
|
return configUpdateMsg{
|
||||||
|
path: f.apiPath,
|
||||||
|
value: value,
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func configFieldEditValue(f configField) string {
|
||||||
|
if rawString, ok := f.rawValue.(string); ok {
|
||||||
|
return rawString
|
||||||
|
}
|
||||||
|
return f.value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *configTabModel) SetSize(w, h int) {
|
||||||
|
m.width = w
|
||||||
|
m.height = h
|
||||||
|
if !m.ready {
|
||||||
|
m.viewport = viewport.New(w, h)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
m.ready = true
|
||||||
|
} else {
|
||||||
|
m.viewport.Width = w
|
||||||
|
m.viewport.Height = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *configTabModel) ensureCursorVisible() {
|
||||||
|
// Each field takes ~1 line, header takes ~4 lines
|
||||||
|
targetLine := m.cursor + 5
|
||||||
|
if targetLine < m.viewport.YOffset {
|
||||||
|
m.viewport.SetYOffset(targetLine)
|
||||||
|
}
|
||||||
|
if targetLine >= m.viewport.YOffset+m.viewport.Height {
|
||||||
|
m.viewport.SetYOffset(targetLine - m.viewport.Height + 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m configTabModel) View() string {
|
||||||
|
if !m.ready {
|
||||||
|
return T("loading")
|
||||||
|
}
|
||||||
|
return m.viewport.View()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m configTabModel) renderContent() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
sb.WriteString(titleStyle.Render(T("config_title")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
if m.message != "" {
|
||||||
|
sb.WriteString(" " + m.message)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(helpStyle.Render(T("config_help1")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(T("config_help2")))
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
|
||||||
|
if m.err != nil {
|
||||||
|
sb.WriteString(errorStyle.Render(" ⚠ Error: " + m.err.Error()))
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.fields) == 0 {
|
||||||
|
sb.WriteString(subtitleStyle.Render(T("no_config")))
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
currentSection := ""
|
||||||
|
for i, f := range m.fields {
|
||||||
|
// Section headers
|
||||||
|
section := fieldSection(f.apiPath)
|
||||||
|
if section != currentSection {
|
||||||
|
currentSection = section
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(" ── " + section + " "))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
isSelected := i == m.cursor
|
||||||
|
prefix := " "
|
||||||
|
if isSelected {
|
||||||
|
prefix = "▸ "
|
||||||
|
}
|
||||||
|
|
||||||
|
labelStr := lipgloss.NewStyle().
|
||||||
|
Foreground(colorInfo).
|
||||||
|
Bold(isSelected).
|
||||||
|
Width(32).
|
||||||
|
Render(f.label)
|
||||||
|
|
||||||
|
var valueStr string
|
||||||
|
if m.editing && isSelected {
|
||||||
|
valueStr = m.textInput.View()
|
||||||
|
} else {
|
||||||
|
switch f.kind {
|
||||||
|
case "bool":
|
||||||
|
if f.value == "true" {
|
||||||
|
valueStr = successStyle.Render("● ON")
|
||||||
|
} else {
|
||||||
|
valueStr = lipgloss.NewStyle().Foreground(colorMuted).Render("○ OFF")
|
||||||
|
}
|
||||||
|
case "readonly":
|
||||||
|
valueStr = lipgloss.NewStyle().Foreground(colorSubtext).Render(f.value)
|
||||||
|
default:
|
||||||
|
valueStr = valueStyle.Render(f.value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
line := prefix + labelStr + " " + valueStr
|
||||||
|
if isSelected && !m.editing {
|
||||||
|
line = lipgloss.NewStyle().Background(colorSurface).Render(line)
|
||||||
|
}
|
||||||
|
sb.WriteString(line + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m configTabModel) parseConfig(cfg map[string]any) []configField {
|
||||||
|
var fields []configField
|
||||||
|
|
||||||
|
// Server settings
|
||||||
|
fields = append(fields, configField{"Port", "port", "readonly", fmt.Sprintf("%.0f", getFloat(cfg, "port")), nil})
|
||||||
|
fields = append(fields, configField{"Host", "host", "readonly", getString(cfg, "host"), nil})
|
||||||
|
fields = append(fields, configField{"Debug", "debug", "bool", fmt.Sprintf("%v", getBool(cfg, "debug")), nil})
|
||||||
|
fields = append(fields, configField{"Proxy URL", "proxy-url", "string", getString(cfg, "proxy-url"), nil})
|
||||||
|
fields = append(fields, configField{"Request Retry", "request-retry", "int", fmt.Sprintf("%.0f", getFloat(cfg, "request-retry")), nil})
|
||||||
|
fields = append(fields, configField{"Max Retry Interval (s)", "max-retry-interval", "int", fmt.Sprintf("%.0f", getFloat(cfg, "max-retry-interval")), nil})
|
||||||
|
fields = append(fields, configField{"Force Model Prefix", "force-model-prefix", "string", getString(cfg, "force-model-prefix"), nil})
|
||||||
|
|
||||||
|
// Logging
|
||||||
|
fields = append(fields, configField{"Logging to File", "logging-to-file", "bool", fmt.Sprintf("%v", getBool(cfg, "logging-to-file")), nil})
|
||||||
|
fields = append(fields, configField{"Logs Max Total Size (MB)", "logs-max-total-size-mb", "int", fmt.Sprintf("%.0f", getFloat(cfg, "logs-max-total-size-mb")), nil})
|
||||||
|
fields = append(fields, configField{"Error Logs Max Files", "error-logs-max-files", "int", fmt.Sprintf("%.0f", getFloat(cfg, "error-logs-max-files")), nil})
|
||||||
|
fields = append(fields, configField{"Usage Stats Enabled", "usage-statistics-enabled", "bool", fmt.Sprintf("%v", getBool(cfg, "usage-statistics-enabled")), nil})
|
||||||
|
fields = append(fields, configField{"Request Log", "request-log", "bool", fmt.Sprintf("%v", getBool(cfg, "request-log")), nil})
|
||||||
|
|
||||||
|
// Quota exceeded
|
||||||
|
fields = append(fields, configField{"Switch Project on Quota", "quota-exceeded/switch-project", "bool", fmt.Sprintf("%v", getBoolNested(cfg, "quota-exceeded", "switch-project")), nil})
|
||||||
|
fields = append(fields, configField{"Switch Preview Model", "quota-exceeded/switch-preview-model", "bool", fmt.Sprintf("%v", getBoolNested(cfg, "quota-exceeded", "switch-preview-model")), nil})
|
||||||
|
|
||||||
|
// Routing
|
||||||
|
if routing, ok := cfg["routing"].(map[string]any); ok {
|
||||||
|
fields = append(fields, configField{"Routing Strategy", "routing/strategy", "string", getString(routing, "strategy"), nil})
|
||||||
|
} else {
|
||||||
|
fields = append(fields, configField{"Routing Strategy", "routing/strategy", "string", "", nil})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocket auth
|
||||||
|
fields = append(fields, configField{"WebSocket Auth", "ws-auth", "bool", fmt.Sprintf("%v", getBool(cfg, "ws-auth")), nil})
|
||||||
|
|
||||||
|
// AMP settings
|
||||||
|
if amp, ok := cfg["ampcode"].(map[string]any); ok {
|
||||||
|
upstreamURL := getString(amp, "upstream-url")
|
||||||
|
upstreamAPIKey := getString(amp, "upstream-api-key")
|
||||||
|
fields = append(fields, configField{"AMP Upstream URL", "ampcode/upstream-url", "string", upstreamURL, upstreamURL})
|
||||||
|
fields = append(fields, configField{"AMP Upstream API Key", "ampcode/upstream-api-key", "string", maskIfNotEmpty(upstreamAPIKey), upstreamAPIKey})
|
||||||
|
fields = append(fields, configField{"AMP Restrict Mgmt Localhost", "ampcode/restrict-management-to-localhost", "bool", fmt.Sprintf("%v", getBool(amp, "restrict-management-to-localhost")), nil})
|
||||||
|
}
|
||||||
|
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
func fieldSection(apiPath string) string {
|
||||||
|
if strings.HasPrefix(apiPath, "ampcode/") {
|
||||||
|
return T("section_ampcode")
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(apiPath, "quota-exceeded/") {
|
||||||
|
return T("section_quota")
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(apiPath, "routing/") {
|
||||||
|
return T("section_routing")
|
||||||
|
}
|
||||||
|
switch apiPath {
|
||||||
|
case "port", "host", "debug", "proxy-url", "request-retry", "max-retry-interval", "force-model-prefix":
|
||||||
|
return T("section_server")
|
||||||
|
case "logging-to-file", "logs-max-total-size-mb", "error-logs-max-files", "usage-statistics-enabled", "request-log":
|
||||||
|
return T("section_logging")
|
||||||
|
case "ws-auth":
|
||||||
|
return T("section_websocket")
|
||||||
|
default:
|
||||||
|
return T("section_other")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBoolNested(m map[string]any, keys ...string) bool {
|
||||||
|
current := m
|
||||||
|
for i, key := range keys {
|
||||||
|
if i == len(keys)-1 {
|
||||||
|
return getBool(current, key)
|
||||||
|
}
|
||||||
|
if nested, ok := current[key].(map[string]any); ok {
|
||||||
|
current = nested
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func maskIfNotEmpty(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return T("not_set")
|
||||||
|
}
|
||||||
|
return maskKey(s)
|
||||||
|
}
|
||||||
360
internal/tui/dashboard.go
Normal file
360
internal/tui/dashboard.go
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dashboardModel displays server info, stats cards, and config overview.
|
||||||
|
type dashboardModel struct {
|
||||||
|
client *Client
|
||||||
|
viewport viewport.Model
|
||||||
|
content string
|
||||||
|
err error
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
ready bool
|
||||||
|
|
||||||
|
// Cached data for re-rendering on locale change
|
||||||
|
lastConfig map[string]any
|
||||||
|
lastUsage map[string]any
|
||||||
|
lastAuthFiles []map[string]any
|
||||||
|
lastAPIKeys []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardDataMsg struct {
|
||||||
|
config map[string]any
|
||||||
|
usage map[string]any
|
||||||
|
authFiles []map[string]any
|
||||||
|
apiKeys []string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDashboardModel(client *Client) dashboardModel {
|
||||||
|
return dashboardModel{
|
||||||
|
client: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m dashboardModel) Init() tea.Cmd {
|
||||||
|
return m.fetchData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m dashboardModel) fetchData() tea.Msg {
|
||||||
|
cfg, cfgErr := m.client.GetConfig()
|
||||||
|
usage, usageErr := m.client.GetUsage()
|
||||||
|
authFiles, authErr := m.client.GetAuthFiles()
|
||||||
|
apiKeys, keysErr := m.client.GetAPIKeys()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for _, e := range []error{cfgErr, usageErr, authErr, keysErr} {
|
||||||
|
if e != nil {
|
||||||
|
err = e
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dashboardDataMsg{config: cfg, usage: usage, authFiles: authFiles, apiKeys: apiKeys, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case localeChangedMsg:
|
||||||
|
// Re-render immediately with cached data using new locale
|
||||||
|
m.content = m.renderDashboard(m.lastConfig, m.lastUsage, m.lastAuthFiles, m.lastAPIKeys)
|
||||||
|
m.viewport.SetContent(m.content)
|
||||||
|
// Also fetch fresh data in background
|
||||||
|
return m, m.fetchData
|
||||||
|
|
||||||
|
case dashboardDataMsg:
|
||||||
|
if msg.err != nil {
|
||||||
|
m.err = msg.err
|
||||||
|
m.content = errorStyle.Render("⚠ Error: " + msg.err.Error())
|
||||||
|
} else {
|
||||||
|
m.err = nil
|
||||||
|
// Cache data for locale switching
|
||||||
|
m.lastConfig = msg.config
|
||||||
|
m.lastUsage = msg.usage
|
||||||
|
m.lastAuthFiles = msg.authFiles
|
||||||
|
m.lastAPIKeys = msg.apiKeys
|
||||||
|
|
||||||
|
m.content = m.renderDashboard(msg.config, msg.usage, msg.authFiles, msg.apiKeys)
|
||||||
|
}
|
||||||
|
m.viewport.SetContent(m.content)
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case tea.KeyMsg:
|
||||||
|
if msg.String() == "r" {
|
||||||
|
return m, m.fetchData
|
||||||
|
}
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *dashboardModel) SetSize(w, h int) {
|
||||||
|
m.width = w
|
||||||
|
m.height = h
|
||||||
|
if !m.ready {
|
||||||
|
m.viewport = viewport.New(w, h)
|
||||||
|
m.viewport.SetContent(m.content)
|
||||||
|
m.ready = true
|
||||||
|
} else {
|
||||||
|
m.viewport.Width = w
|
||||||
|
m.viewport.Height = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m dashboardModel) View() string {
|
||||||
|
if !m.ready {
|
||||||
|
return T("loading")
|
||||||
|
}
|
||||||
|
return m.viewport.View()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []map[string]any, apiKeys []string) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
sb.WriteString(titleStyle.Render(T("dashboard_title")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(T("dashboard_help")))
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
|
||||||
|
// ━━━ Connection Status ━━━
|
||||||
|
connStyle := lipgloss.NewStyle().Bold(true).Foreground(colorSuccess)
|
||||||
|
sb.WriteString(connStyle.Render(T("connected")))
|
||||||
|
sb.WriteString(fmt.Sprintf(" %s", m.client.baseURL))
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
|
||||||
|
// ━━━ Stats Cards ━━━
|
||||||
|
cardWidth := 25
|
||||||
|
if m.width > 0 {
|
||||||
|
cardWidth = (m.width - 6) / 4
|
||||||
|
if cardWidth < 18 {
|
||||||
|
cardWidth = 18
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cardStyle := lipgloss.NewStyle().
|
||||||
|
Border(lipgloss.RoundedBorder()).
|
||||||
|
BorderForeground(lipgloss.Color("240")).
|
||||||
|
Padding(0, 1).
|
||||||
|
Width(cardWidth).
|
||||||
|
Height(2)
|
||||||
|
|
||||||
|
// Card 1: API Keys
|
||||||
|
keyCount := len(apiKeys)
|
||||||
|
card1 := cardStyle.Render(fmt.Sprintf(
|
||||||
|
"%s\n%s",
|
||||||
|
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("🔑 %d", keyCount)),
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(T("mgmt_keys")),
|
||||||
|
))
|
||||||
|
|
||||||
|
// Card 2: Auth Files
|
||||||
|
authCount := len(authFiles)
|
||||||
|
activeAuth := 0
|
||||||
|
for _, f := range authFiles {
|
||||||
|
if !getBool(f, "disabled") {
|
||||||
|
activeAuth++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
card2 := cardStyle.Render(fmt.Sprintf(
|
||||||
|
"%s\n%s",
|
||||||
|
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("📄 %d", authCount)),
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (%d %s)", T("auth_files_label"), activeAuth, T("active_suffix"))),
|
||||||
|
))
|
||||||
|
|
||||||
|
// Card 3: Total Requests
|
||||||
|
totalReqs := int64(0)
|
||||||
|
successReqs := int64(0)
|
||||||
|
failedReqs := int64(0)
|
||||||
|
totalTokens := int64(0)
|
||||||
|
if usage != nil {
|
||||||
|
if usageMap, ok := usage["usage"].(map[string]any); ok {
|
||||||
|
totalReqs = int64(getFloat(usageMap, "total_requests"))
|
||||||
|
successReqs = int64(getFloat(usageMap, "success_count"))
|
||||||
|
failedReqs = int64(getFloat(usageMap, "failure_count"))
|
||||||
|
totalTokens = int64(getFloat(usageMap, "total_tokens"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
card3 := cardStyle.Render(fmt.Sprintf(
|
||||||
|
"%s\n%s",
|
||||||
|
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(fmt.Sprintf("📈 %d", totalReqs)),
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (✓%d ✗%d)", T("total_requests"), successReqs, failedReqs)),
|
||||||
|
))
|
||||||
|
|
||||||
|
// Card 4: Total Tokens
|
||||||
|
tokenStr := formatLargeNumber(totalTokens)
|
||||||
|
card4 := cardStyle.Render(fmt.Sprintf(
|
||||||
|
"%s\n%s",
|
||||||
|
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("🔤 %s", tokenStr)),
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(T("total_tokens")),
|
||||||
|
))
|
||||||
|
|
||||||
|
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
|
||||||
|
// ━━━ Current Config ━━━
|
||||||
|
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("current_config")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
if cfg != nil {
|
||||||
|
debug := getBool(cfg, "debug")
|
||||||
|
retry := getFloat(cfg, "request-retry")
|
||||||
|
proxyURL := getString(cfg, "proxy-url")
|
||||||
|
loggingToFile := getBool(cfg, "logging-to-file")
|
||||||
|
usageEnabled := true
|
||||||
|
if v, ok := cfg["usage-statistics-enabled"]; ok {
|
||||||
|
if b, ok2 := v.(bool); ok2 {
|
||||||
|
usageEnabled = b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
configItems := []struct {
|
||||||
|
label string
|
||||||
|
value string
|
||||||
|
}{
|
||||||
|
{T("debug_mode"), boolEmoji(debug)},
|
||||||
|
{T("usage_stats"), boolEmoji(usageEnabled)},
|
||||||
|
{T("log_to_file"), boolEmoji(loggingToFile)},
|
||||||
|
{T("retry_count"), fmt.Sprintf("%.0f", retry)},
|
||||||
|
}
|
||||||
|
if proxyURL != "" {
|
||||||
|
configItems = append(configItems, struct {
|
||||||
|
label string
|
||||||
|
value string
|
||||||
|
}{T("proxy_url"), proxyURL})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render config items as a compact row
|
||||||
|
for _, item := range configItems {
|
||||||
|
sb.WriteString(fmt.Sprintf(" %s %s\n",
|
||||||
|
labelStyle.Render(item.label+":"),
|
||||||
|
valueStyle.Render(item.value)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Routing strategy
|
||||||
|
strategy := "round-robin"
|
||||||
|
if routing, ok := cfg["routing"].(map[string]any); ok {
|
||||||
|
if s := getString(routing, "strategy"); s != "" {
|
||||||
|
strategy = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf(" %s %s\n",
|
||||||
|
labelStyle.Render(T("routing_strategy")+":"),
|
||||||
|
valueStyle.Render(strategy)))
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
// ━━━ Per-Model Usage ━━━
|
||||||
|
if usage != nil {
|
||||||
|
if usageMap, ok := usage["usage"].(map[string]any); ok {
|
||||||
|
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
|
||||||
|
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("model_stats")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
header := fmt.Sprintf(" %-40s %10s %12s", T("model"), T("requests"), T("tokens"))
|
||||||
|
sb.WriteString(tableHeaderStyle.Render(header))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
for _, apiSnap := range apis {
|
||||||
|
if apiMap, ok := apiSnap.(map[string]any); ok {
|
||||||
|
if models, ok := apiMap["models"].(map[string]any); ok {
|
||||||
|
for model, v := range models {
|
||||||
|
if stats, ok := v.(map[string]any); ok {
|
||||||
|
reqs := int64(getFloat(stats, "total_requests"))
|
||||||
|
toks := int64(getFloat(stats, "total_tokens"))
|
||||||
|
row := fmt.Sprintf(" %-40s %10d %12s", truncate(model, 40), reqs, formatLargeNumber(toks))
|
||||||
|
sb.WriteString(tableCellStyle.Render(row))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatKV(key, value string) string {
|
||||||
|
return fmt.Sprintf(" %s %s\n", labelStyle.Render(key+":"), valueStyle.Render(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
func getString(m map[string]any, key string) string {
|
||||||
|
if v, ok := m[key]; ok {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFloat(m map[string]any, key string) float64 {
|
||||||
|
if v, ok := m[key]; ok {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case float64:
|
||||||
|
return n
|
||||||
|
case json.Number:
|
||||||
|
f, _ := n.Float64()
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBool(m map[string]any, key string) bool {
|
||||||
|
if v, ok := m[key]; ok {
|
||||||
|
if b, ok := v.(bool); ok {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolEmoji(b bool) string {
|
||||||
|
if b {
|
||||||
|
return T("bool_yes")
|
||||||
|
}
|
||||||
|
return T("bool_no")
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatLargeNumber(n int64) string {
|
||||||
|
if n >= 1_000_000 {
|
||||||
|
return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
|
||||||
|
}
|
||||||
|
if n >= 1_000 {
|
||||||
|
return fmt.Sprintf("%.1fK", float64(n)/1_000)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncate(s string, maxLen int) string {
|
||||||
|
if len(s) > maxLen {
|
||||||
|
return s[:maxLen-3] + "..."
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func minInt(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
364
internal/tui/i18n.go
Normal file
364
internal/tui/i18n.go
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
// i18n provides a simple internationalization system for the TUI.
|
||||||
|
// Supported locales: "zh" (Chinese, default), "en" (English).
|
||||||
|
|
||||||
|
var currentLocale = "en"
|
||||||
|
|
||||||
|
// SetLocale changes the active locale.
|
||||||
|
func SetLocale(locale string) {
|
||||||
|
if _, ok := locales[locale]; ok {
|
||||||
|
currentLocale = locale
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CurrentLocale returns the active locale code.
|
||||||
|
func CurrentLocale() string {
|
||||||
|
return currentLocale
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToggleLocale switches between zh and en.
|
||||||
|
func ToggleLocale() {
|
||||||
|
if currentLocale == "zh" {
|
||||||
|
currentLocale = "en"
|
||||||
|
} else {
|
||||||
|
currentLocale = "zh"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// T returns the translated string for the given key.
|
||||||
|
func T(key string) string {
|
||||||
|
if m, ok := locales[currentLocale]; ok {
|
||||||
|
if v, ok := m[key]; ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback to English
|
||||||
|
if m, ok := locales["en"]; ok {
|
||||||
|
if v, ok := m[key]; ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
var locales = map[string]map[string]string{
|
||||||
|
"zh": zhStrings,
|
||||||
|
"en": enStrings,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────
|
||||||
|
// Tab names
|
||||||
|
// ──────────────────────────────────────────
|
||||||
|
var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "使用统计", "日志"}
|
||||||
|
var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Usage", "Logs"}
|
||||||
|
|
||||||
|
// TabNames returns tab names in the current locale.
|
||||||
|
func TabNames() []string {
|
||||||
|
if currentLocale == "zh" {
|
||||||
|
return zhTabNames
|
||||||
|
}
|
||||||
|
return enTabNames
|
||||||
|
}
|
||||||
|
|
||||||
|
var zhStrings = map[string]string{
|
||||||
|
// ── Common ──
|
||||||
|
"loading": "加载中...",
|
||||||
|
"refresh": "刷新",
|
||||||
|
"save": "保存",
|
||||||
|
"cancel": "取消",
|
||||||
|
"confirm": "确认",
|
||||||
|
"yes": "是",
|
||||||
|
"no": "否",
|
||||||
|
"error": "错误",
|
||||||
|
"success": "成功",
|
||||||
|
"navigate": "导航",
|
||||||
|
"scroll": "滚动",
|
||||||
|
"enter_save": "Enter: 保存",
|
||||||
|
"esc_cancel": "Esc: 取消",
|
||||||
|
"enter_submit": "Enter: 提交",
|
||||||
|
"press_r": "[r] 刷新",
|
||||||
|
"press_scroll": "[↑↓] 滚动",
|
||||||
|
"not_set": "(未设置)",
|
||||||
|
"error_prefix": "⚠ 错误: ",
|
||||||
|
|
||||||
|
// ── Status bar ──
|
||||||
|
"status_left": " CLIProxyAPI 管理终端",
|
||||||
|
"status_right": "Tab/Shift+Tab: 切换 • L: 语言 • q/Ctrl+C: 退出 ",
|
||||||
|
"initializing_tui": "正在初始化...",
|
||||||
|
"auth_gate_title": "🔐 连接管理 API",
|
||||||
|
"auth_gate_help": " 请输入管理密码并按 Enter 连接",
|
||||||
|
"auth_gate_password": "密码",
|
||||||
|
"auth_gate_enter": " Enter: 连接 • q/Ctrl+C: 退出 • L: 语言",
|
||||||
|
"auth_gate_connecting": "正在连接...",
|
||||||
|
"auth_gate_connect_fail": "连接失败:%s",
|
||||||
|
"auth_gate_password_required": "请输入密码",
|
||||||
|
|
||||||
|
// ── Dashboard ──
|
||||||
|
"dashboard_title": "📊 仪表盘",
|
||||||
|
"dashboard_help": " [r] 刷新 • [↑↓] 滚动",
|
||||||
|
"connected": "● 已连接",
|
||||||
|
"mgmt_keys": "管理密钥",
|
||||||
|
"auth_files_label": "认证文件",
|
||||||
|
"active_suffix": "活跃",
|
||||||
|
"total_requests": "请求",
|
||||||
|
"success_label": "成功",
|
||||||
|
"failure_label": "失败",
|
||||||
|
"total_tokens": "总 Tokens",
|
||||||
|
"current_config": "当前配置",
|
||||||
|
"debug_mode": "启用调试模式",
|
||||||
|
"usage_stats": "启用使用统计",
|
||||||
|
"log_to_file": "启用日志记录到文件",
|
||||||
|
"retry_count": "重试次数",
|
||||||
|
"proxy_url": "代理 URL",
|
||||||
|
"routing_strategy": "路由策略",
|
||||||
|
"model_stats": "模型统计",
|
||||||
|
"model": "模型",
|
||||||
|
"requests": "请求数",
|
||||||
|
"tokens": "Tokens",
|
||||||
|
"bool_yes": "是 ✓",
|
||||||
|
"bool_no": "否",
|
||||||
|
|
||||||
|
// ── Config ──
|
||||||
|
"config_title": "⚙ 配置",
|
||||||
|
"config_help1": " [↑↓/jk] 导航 • [Enter/Space] 编辑 • [r] 刷新",
|
||||||
|
"config_help2": " 布尔: Enter 切换 • 文本/数字: Enter 输入, Enter 确认, Esc 取消",
|
||||||
|
"updated_ok": "✓ 更新成功",
|
||||||
|
"no_config": " 未加载配置",
|
||||||
|
"invalid_int": "无效整数",
|
||||||
|
"section_server": "服务器",
|
||||||
|
"section_logging": "日志与统计",
|
||||||
|
"section_quota": "配额超限处理",
|
||||||
|
"section_routing": "路由",
|
||||||
|
"section_websocket": "WebSocket",
|
||||||
|
"section_ampcode": "AMP Code",
|
||||||
|
"section_other": "其他",
|
||||||
|
|
||||||
|
// ── Auth Files ──
|
||||||
|
"auth_title": "🔑 认证文件",
|
||||||
|
"auth_help1": " [↑↓/jk] 导航 • [Enter] 展开 • [e] 启用/停用 • [d] 删除 • [r] 刷新",
|
||||||
|
"auth_help2": " [1] 编辑 prefix • [2] 编辑 proxy_url • [3] 编辑 priority",
|
||||||
|
"no_auth_files": " 无认证文件",
|
||||||
|
"confirm_delete": "⚠ 删除 %s? [y/n]",
|
||||||
|
"deleted": "已删除 %s",
|
||||||
|
"enabled": "已启用",
|
||||||
|
"disabled": "已停用",
|
||||||
|
"updated_field": "已更新 %s 的 %s",
|
||||||
|
"status_active": "活跃",
|
||||||
|
"status_disabled": "已停用",
|
||||||
|
|
||||||
|
// ── API Keys ──
|
||||||
|
"keys_title": "🔐 API 密钥",
|
||||||
|
"keys_help": " [↑↓/jk] 导航 • [a] 添加 • [e] 编辑 • [d] 删除 • [c] 复制 • [r] 刷新",
|
||||||
|
"no_keys": " 无 API Key,按 [a] 添加",
|
||||||
|
"access_keys": "Access API Keys",
|
||||||
|
"confirm_delete_key": "⚠ 确认删除 %s? [y/n]",
|
||||||
|
"key_added": "已添加 API Key",
|
||||||
|
"key_updated": "已更新 API Key",
|
||||||
|
"key_deleted": "已删除 API Key",
|
||||||
|
"copied": "✓ 已复制到剪贴板",
|
||||||
|
"copy_failed": "✗ 复制失败",
|
||||||
|
"new_key_prompt": " New Key: ",
|
||||||
|
"edit_key_prompt": " Edit Key: ",
|
||||||
|
"enter_add": " Enter: 添加 • Esc: 取消",
|
||||||
|
"enter_save_esc": " Enter: 保存 • Esc: 取消",
|
||||||
|
|
||||||
|
// ── OAuth ──
|
||||||
|
"oauth_title": "🔐 OAuth 登录",
|
||||||
|
"oauth_select": " 选择提供商并按 [Enter] 开始 OAuth 登录:",
|
||||||
|
"oauth_help": " [↑↓/jk] 导航 • [Enter] 登录 • [Esc] 清除状态",
|
||||||
|
"oauth_initiating": "⏳ 正在初始化 %s 登录...",
|
||||||
|
"oauth_success": "认证成功! 请刷新 Auth Files 标签查看新凭证。",
|
||||||
|
"oauth_completed": "认证流程已完成。",
|
||||||
|
"oauth_failed": "认证失败",
|
||||||
|
"oauth_timeout": "OAuth 流程超时 (5 分钟)",
|
||||||
|
"oauth_press_esc": " 按 [Esc] 取消",
|
||||||
|
"oauth_auth_url": " 授权链接:",
|
||||||
|
"oauth_remote_hint": " 远程浏览器模式:在浏览器中打开上述链接完成授权后,将回调 URL 粘贴到下方。",
|
||||||
|
"oauth_callback_url": " 回调 URL:",
|
||||||
|
"oauth_press_c": " 按 [c] 输入回调 URL • [Esc] 返回",
|
||||||
|
"oauth_submitting": "⏳ 提交回调中...",
|
||||||
|
"oauth_submit_ok": "✓ 回调已提交,等待处理...",
|
||||||
|
"oauth_submit_fail": "✗ 提交回调失败",
|
||||||
|
"oauth_waiting": " 等待认证中...",
|
||||||
|
|
||||||
|
// ── Usage ──
|
||||||
|
"usage_title": "📈 使用统计",
|
||||||
|
"usage_help": " [r] 刷新 • [↑↓] 滚动",
|
||||||
|
"usage_no_data": " 使用数据不可用",
|
||||||
|
"usage_total_reqs": "总请求数",
|
||||||
|
"usage_total_tokens": "总 Token 数",
|
||||||
|
"usage_success": "成功",
|
||||||
|
"usage_failure": "失败",
|
||||||
|
"usage_total_token_l": "总Token",
|
||||||
|
"usage_rpm": "RPM",
|
||||||
|
"usage_tpm": "TPM",
|
||||||
|
"usage_req_by_hour": "请求趋势 (按小时)",
|
||||||
|
"usage_tok_by_hour": "Token 使用趋势 (按小时)",
|
||||||
|
"usage_req_by_day": "请求趋势 (按天)",
|
||||||
|
"usage_api_detail": "API 详细统计",
|
||||||
|
"usage_input": "输入",
|
||||||
|
"usage_output": "输出",
|
||||||
|
"usage_cached": "缓存",
|
||||||
|
"usage_reasoning": "思考",
|
||||||
|
|
||||||
|
// ── Logs ──
|
||||||
|
"logs_title": "📋 日志",
|
||||||
|
"logs_auto_scroll": "● 自动滚动",
|
||||||
|
"logs_paused": "○ 已暂停",
|
||||||
|
"logs_filter": "过滤",
|
||||||
|
"logs_lines": "行数",
|
||||||
|
"logs_help": " [a] 自动滚动 • [c] 清除 • [1] 全部 [2] info+ [3] warn+ [4] error • [↑↓] 滚动",
|
||||||
|
"logs_waiting": " 等待日志输出...",
|
||||||
|
}
|
||||||
|
|
||||||
|
var enStrings = map[string]string{
|
||||||
|
// ── Common ──
|
||||||
|
"loading": "Loading...",
|
||||||
|
"refresh": "Refresh",
|
||||||
|
"save": "Save",
|
||||||
|
"cancel": "Cancel",
|
||||||
|
"confirm": "Confirm",
|
||||||
|
"yes": "Yes",
|
||||||
|
"no": "No",
|
||||||
|
"error": "Error",
|
||||||
|
"success": "Success",
|
||||||
|
"navigate": "Navigate",
|
||||||
|
"scroll": "Scroll",
|
||||||
|
"enter_save": "Enter: Save",
|
||||||
|
"esc_cancel": "Esc: Cancel",
|
||||||
|
"enter_submit": "Enter: Submit",
|
||||||
|
"press_r": "[r] Refresh",
|
||||||
|
"press_scroll": "[↑↓] Scroll",
|
||||||
|
"not_set": "(not set)",
|
||||||
|
"error_prefix": "⚠ Error: ",
|
||||||
|
|
||||||
|
// ── Status bar ──
|
||||||
|
"status_left": " CLIProxyAPI Management TUI",
|
||||||
|
"status_right": "Tab/Shift+Tab: switch • L: lang • q/Ctrl+C: quit ",
|
||||||
|
"initializing_tui": "Initializing...",
|
||||||
|
"auth_gate_title": "🔐 Connect Management API",
|
||||||
|
"auth_gate_help": " Enter management password and press Enter to connect",
|
||||||
|
"auth_gate_password": "Password",
|
||||||
|
"auth_gate_enter": " Enter: connect • q/Ctrl+C: quit • L: lang",
|
||||||
|
"auth_gate_connecting": "Connecting...",
|
||||||
|
"auth_gate_connect_fail": "Connection failed: %s",
|
||||||
|
"auth_gate_password_required": "password is required",
|
||||||
|
|
||||||
|
// ── Dashboard ──
|
||||||
|
"dashboard_title": "📊 Dashboard",
|
||||||
|
"dashboard_help": " [r] Refresh • [↑↓] Scroll",
|
||||||
|
"connected": "● Connected",
|
||||||
|
"mgmt_keys": "Mgmt Keys",
|
||||||
|
"auth_files_label": "Auth Files",
|
||||||
|
"active_suffix": "active",
|
||||||
|
"total_requests": "Requests",
|
||||||
|
"success_label": "Success",
|
||||||
|
"failure_label": "Failed",
|
||||||
|
"total_tokens": "Total Tokens",
|
||||||
|
"current_config": "Current Config",
|
||||||
|
"debug_mode": "Debug Mode",
|
||||||
|
"usage_stats": "Usage Statistics",
|
||||||
|
"log_to_file": "Log to File",
|
||||||
|
"retry_count": "Retry Count",
|
||||||
|
"proxy_url": "Proxy URL",
|
||||||
|
"routing_strategy": "Routing Strategy",
|
||||||
|
"model_stats": "Model Stats",
|
||||||
|
"model": "Model",
|
||||||
|
"requests": "Requests",
|
||||||
|
"tokens": "Tokens",
|
||||||
|
"bool_yes": "Yes ✓",
|
||||||
|
"bool_no": "No",
|
||||||
|
|
||||||
|
// ── Config ──
|
||||||
|
"config_title": "⚙ Configuration",
|
||||||
|
"config_help1": " [↑↓/jk] Navigate • [Enter/Space] Edit • [r] Refresh",
|
||||||
|
"config_help2": " Bool: Enter to toggle • String/Int: Enter to type, Enter to confirm, Esc to cancel",
|
||||||
|
"updated_ok": "✓ Updated successfully",
|
||||||
|
"no_config": " No configuration loaded",
|
||||||
|
"invalid_int": "invalid integer",
|
||||||
|
"section_server": "Server",
|
||||||
|
"section_logging": "Logging & Stats",
|
||||||
|
"section_quota": "Quota Exceeded Handling",
|
||||||
|
"section_routing": "Routing",
|
||||||
|
"section_websocket": "WebSocket",
|
||||||
|
"section_ampcode": "AMP Code",
|
||||||
|
"section_other": "Other",
|
||||||
|
|
||||||
|
// ── Auth Files ──
|
||||||
|
"auth_title": "🔑 Auth Files",
|
||||||
|
"auth_help1": " [↑↓/jk] Navigate • [Enter] Expand • [e] Enable/Disable • [d] Delete • [r] Refresh",
|
||||||
|
"auth_help2": " [1] Edit prefix • [2] Edit proxy_url • [3] Edit priority",
|
||||||
|
"no_auth_files": " No auth files found",
|
||||||
|
"confirm_delete": "⚠ Delete %s? [y/n]",
|
||||||
|
"deleted": "Deleted %s",
|
||||||
|
"enabled": "Enabled",
|
||||||
|
"disabled": "Disabled",
|
||||||
|
"updated_field": "Updated %s on %s",
|
||||||
|
"status_active": "active",
|
||||||
|
"status_disabled": "disabled",
|
||||||
|
|
||||||
|
// ── API Keys ──
|
||||||
|
"keys_title": "🔐 API Keys",
|
||||||
|
"keys_help": " [↑↓/jk] Navigate • [a] Add • [e] Edit • [d] Delete • [c] Copy • [r] Refresh",
|
||||||
|
"no_keys": " No API Keys. Press [a] to add",
|
||||||
|
"access_keys": "Access API Keys",
|
||||||
|
"confirm_delete_key": "⚠ Delete %s? [y/n]",
|
||||||
|
"key_added": "API Key added",
|
||||||
|
"key_updated": "API Key updated",
|
||||||
|
"key_deleted": "API Key deleted",
|
||||||
|
"copied": "✓ Copied to clipboard",
|
||||||
|
"copy_failed": "✗ Copy failed",
|
||||||
|
"new_key_prompt": " New Key: ",
|
||||||
|
"edit_key_prompt": " Edit Key: ",
|
||||||
|
"enter_add": " Enter: Add • Esc: Cancel",
|
||||||
|
"enter_save_esc": " Enter: Save • Esc: Cancel",
|
||||||
|
|
||||||
|
// ── OAuth ──
|
||||||
|
"oauth_title": "🔐 OAuth Login",
|
||||||
|
"oauth_select": " Select a provider and press [Enter] to start OAuth login:",
|
||||||
|
"oauth_help": " [↑↓/jk] Navigate • [Enter] Login • [Esc] Clear status",
|
||||||
|
"oauth_initiating": "⏳ Initiating %s login...",
|
||||||
|
"oauth_success": "Authentication successful! Refresh Auth Files tab to see the new credential.",
|
||||||
|
"oauth_completed": "Authentication flow completed.",
|
||||||
|
"oauth_failed": "Authentication failed",
|
||||||
|
"oauth_timeout": "OAuth flow timed out (5 minutes)",
|
||||||
|
"oauth_press_esc": " Press [Esc] to cancel",
|
||||||
|
"oauth_auth_url": " Authorization URL:",
|
||||||
|
"oauth_remote_hint": " Remote browser mode: Open the URL above in browser, paste the callback URL below after authorization.",
|
||||||
|
"oauth_callback_url": " Callback URL:",
|
||||||
|
"oauth_press_c": " Press [c] to enter callback URL • [Esc] to go back",
|
||||||
|
"oauth_submitting": "⏳ Submitting callback...",
|
||||||
|
"oauth_submit_ok": "✓ Callback submitted, waiting...",
|
||||||
|
"oauth_submit_fail": "✗ Callback submission failed",
|
||||||
|
"oauth_waiting": " Waiting for authentication...",
|
||||||
|
|
||||||
|
// ── Usage ──
|
||||||
|
"usage_title": "📈 Usage Statistics",
|
||||||
|
"usage_help": " [r] Refresh • [↑↓] Scroll",
|
||||||
|
"usage_no_data": " Usage data not available",
|
||||||
|
"usage_total_reqs": "Total Requests",
|
||||||
|
"usage_total_tokens": "Total Tokens",
|
||||||
|
"usage_success": "Success",
|
||||||
|
"usage_failure": "Failed",
|
||||||
|
"usage_total_token_l": "Total Tokens",
|
||||||
|
"usage_rpm": "RPM",
|
||||||
|
"usage_tpm": "TPM",
|
||||||
|
"usage_req_by_hour": "Requests by Hour",
|
||||||
|
"usage_tok_by_hour": "Token Usage by Hour",
|
||||||
|
"usage_req_by_day": "Requests by Day",
|
||||||
|
"usage_api_detail": "API Detail Statistics",
|
||||||
|
"usage_input": "Input",
|
||||||
|
"usage_output": "Output",
|
||||||
|
"usage_cached": "Cached",
|
||||||
|
"usage_reasoning": "Reasoning",
|
||||||
|
|
||||||
|
// ── Logs ──
|
||||||
|
"logs_title": "📋 Logs",
|
||||||
|
"logs_auto_scroll": "● AUTO-SCROLL",
|
||||||
|
"logs_paused": "○ PAUSED",
|
||||||
|
"logs_filter": "Filter",
|
||||||
|
"logs_lines": "Lines",
|
||||||
|
"logs_help": " [a] Auto-scroll • [c] Clear • [1] All [2] info+ [3] warn+ [4] error • [↑↓] Scroll",
|
||||||
|
"logs_waiting": " Waiting for log output...",
|
||||||
|
}
|
||||||
405
internal/tui/keys_tab.go
Normal file
405
internal/tui/keys_tab.go
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/atotto/clipboard"
|
||||||
|
"github.com/charmbracelet/bubbles/textinput"
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
// keysTabModel displays and manages API keys.
|
||||||
|
type keysTabModel struct {
|
||||||
|
client *Client
|
||||||
|
viewport viewport.Model
|
||||||
|
keys []string
|
||||||
|
gemini []map[string]any
|
||||||
|
claude []map[string]any
|
||||||
|
codex []map[string]any
|
||||||
|
vertex []map[string]any
|
||||||
|
openai []map[string]any
|
||||||
|
err error
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
ready bool
|
||||||
|
cursor int
|
||||||
|
confirm int // -1 = no deletion pending
|
||||||
|
status string
|
||||||
|
|
||||||
|
// Editing / Adding
|
||||||
|
editing bool
|
||||||
|
adding bool
|
||||||
|
editIdx int
|
||||||
|
editInput textinput.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
type keysDataMsg struct {
|
||||||
|
apiKeys []string
|
||||||
|
gemini []map[string]any
|
||||||
|
claude []map[string]any
|
||||||
|
codex []map[string]any
|
||||||
|
vertex []map[string]any
|
||||||
|
openai []map[string]any
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type keyActionMsg struct {
|
||||||
|
action string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newKeysTabModel(client *Client) keysTabModel {
|
||||||
|
ti := textinput.New()
|
||||||
|
ti.CharLimit = 512
|
||||||
|
ti.Prompt = " Key: "
|
||||||
|
return keysTabModel{
|
||||||
|
client: client,
|
||||||
|
confirm: -1,
|
||||||
|
editInput: ti,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m keysTabModel) Init() tea.Cmd {
|
||||||
|
return m.fetchKeys
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m keysTabModel) fetchKeys() tea.Msg {
|
||||||
|
result := keysDataMsg{}
|
||||||
|
apiKeys, err := m.client.GetAPIKeys()
|
||||||
|
if err != nil {
|
||||||
|
result.err = err
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
result.apiKeys = apiKeys
|
||||||
|
result.gemini, _ = m.client.GetGeminiKeys()
|
||||||
|
result.claude, _ = m.client.GetClaudeKeys()
|
||||||
|
result.codex, _ = m.client.GetCodexKeys()
|
||||||
|
result.vertex, _ = m.client.GetVertexKeys()
|
||||||
|
result.openai, _ = m.client.GetOpenAICompat()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m keysTabModel) Update(msg tea.Msg) (keysTabModel, tea.Cmd) {
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case localeChangedMsg:
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
case keysDataMsg:
|
||||||
|
if msg.err != nil {
|
||||||
|
m.err = msg.err
|
||||||
|
} else {
|
||||||
|
m.err = nil
|
||||||
|
m.keys = msg.apiKeys
|
||||||
|
m.gemini = msg.gemini
|
||||||
|
m.claude = msg.claude
|
||||||
|
m.codex = msg.codex
|
||||||
|
m.vertex = msg.vertex
|
||||||
|
m.openai = msg.openai
|
||||||
|
if m.cursor >= len(m.keys) {
|
||||||
|
m.cursor = max(0, len(m.keys)-1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case keyActionMsg:
|
||||||
|
if msg.err != nil {
|
||||||
|
m.status = errorStyle.Render("✗ " + msg.err.Error())
|
||||||
|
} else {
|
||||||
|
m.status = successStyle.Render("✓ " + msg.action)
|
||||||
|
}
|
||||||
|
m.confirm = -1
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, m.fetchKeys
|
||||||
|
|
||||||
|
case tea.KeyMsg:
|
||||||
|
// ---- Editing / Adding mode ----
|
||||||
|
if m.editing || m.adding {
|
||||||
|
switch msg.String() {
|
||||||
|
case "enter":
|
||||||
|
value := strings.TrimSpace(m.editInput.Value())
|
||||||
|
if value == "" {
|
||||||
|
m.editing = false
|
||||||
|
m.adding = false
|
||||||
|
m.editInput.Blur()
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
isAdding := m.adding
|
||||||
|
editIdx := m.editIdx
|
||||||
|
m.editing = false
|
||||||
|
m.adding = false
|
||||||
|
m.editInput.Blur()
|
||||||
|
if isAdding {
|
||||||
|
return m, func() tea.Msg {
|
||||||
|
err := m.client.AddAPIKey(value)
|
||||||
|
if err != nil {
|
||||||
|
return keyActionMsg{err: err}
|
||||||
|
}
|
||||||
|
return keyActionMsg{action: T("key_added")}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, func() tea.Msg {
|
||||||
|
err := m.client.EditAPIKey(editIdx, value)
|
||||||
|
if err != nil {
|
||||||
|
return keyActionMsg{err: err}
|
||||||
|
}
|
||||||
|
return keyActionMsg{action: T("key_updated")}
|
||||||
|
}
|
||||||
|
case "esc":
|
||||||
|
m.editing = false
|
||||||
|
m.adding = false
|
||||||
|
m.editInput.Blur()
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
default:
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.editInput, cmd = m.editInput.Update(msg)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Delete confirmation ----
|
||||||
|
if m.confirm >= 0 {
|
||||||
|
switch msg.String() {
|
||||||
|
case "y", "Y":
|
||||||
|
idx := m.confirm
|
||||||
|
m.confirm = -1
|
||||||
|
return m, func() tea.Msg {
|
||||||
|
err := m.client.DeleteAPIKey(idx)
|
||||||
|
if err != nil {
|
||||||
|
return keyActionMsg{err: err}
|
||||||
|
}
|
||||||
|
return keyActionMsg{action: T("key_deleted")}
|
||||||
|
}
|
||||||
|
case "n", "N", "esc":
|
||||||
|
m.confirm = -1
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Normal mode ----
|
||||||
|
switch msg.String() {
|
||||||
|
case "j", "down":
|
||||||
|
if len(m.keys) > 0 {
|
||||||
|
m.cursor = (m.cursor + 1) % len(m.keys)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "k", "up":
|
||||||
|
if len(m.keys) > 0 {
|
||||||
|
m.cursor = (m.cursor - 1 + len(m.keys)) % len(m.keys)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "a":
|
||||||
|
// Add new key
|
||||||
|
m.adding = true
|
||||||
|
m.editing = false
|
||||||
|
m.editInput.SetValue("")
|
||||||
|
m.editInput.Prompt = T("new_key_prompt")
|
||||||
|
m.editInput.Focus()
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, textinput.Blink
|
||||||
|
case "e":
|
||||||
|
// Edit selected key
|
||||||
|
if m.cursor < len(m.keys) {
|
||||||
|
m.editing = true
|
||||||
|
m.adding = false
|
||||||
|
m.editIdx = m.cursor
|
||||||
|
m.editInput.SetValue(m.keys[m.cursor])
|
||||||
|
m.editInput.Prompt = T("edit_key_prompt")
|
||||||
|
m.editInput.Focus()
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, textinput.Blink
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "d":
|
||||||
|
// Delete selected key
|
||||||
|
if m.cursor < len(m.keys) {
|
||||||
|
m.confirm = m.cursor
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "c":
|
||||||
|
// Copy selected key to clipboard
|
||||||
|
if m.cursor < len(m.keys) {
|
||||||
|
key := m.keys[m.cursor]
|
||||||
|
if err := clipboard.WriteAll(key); err != nil {
|
||||||
|
m.status = errorStyle.Render(T("copy_failed") + ": " + err.Error())
|
||||||
|
} else {
|
||||||
|
m.status = successStyle.Render(T("copied"))
|
||||||
|
}
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "r":
|
||||||
|
m.status = ""
|
||||||
|
return m, m.fetchKeys
|
||||||
|
default:
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *keysTabModel) SetSize(w, h int) {
|
||||||
|
m.width = w
|
||||||
|
m.height = h
|
||||||
|
m.editInput.Width = w - 16
|
||||||
|
if !m.ready {
|
||||||
|
m.viewport = viewport.New(w, h)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
m.ready = true
|
||||||
|
} else {
|
||||||
|
m.viewport.Width = w
|
||||||
|
m.viewport.Height = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m keysTabModel) View() string {
|
||||||
|
if !m.ready {
|
||||||
|
return T("loading")
|
||||||
|
}
|
||||||
|
return m.viewport.View()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m keysTabModel) renderContent() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
sb.WriteString(titleStyle.Render(T("keys_title")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(T("keys_help")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(strings.Repeat("─", m.width))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
if m.err != nil {
|
||||||
|
sb.WriteString(errorStyle.Render(T("error_prefix") + m.err.Error()))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ━━━ Access API Keys (interactive) ━━━
|
||||||
|
sb.WriteString(tableHeaderStyle.Render(fmt.Sprintf(" %s (%d)", T("access_keys"), len(m.keys))))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
if len(m.keys) == 0 {
|
||||||
|
sb.WriteString(subtitleStyle.Render(T("no_keys")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, key := range m.keys {
|
||||||
|
cursor := " "
|
||||||
|
rowStyle := lipgloss.NewStyle()
|
||||||
|
if i == m.cursor {
|
||||||
|
cursor = "▸ "
|
||||||
|
rowStyle = lipgloss.NewStyle().Bold(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
row := fmt.Sprintf("%s%d. %s", cursor, i+1, maskKey(key))
|
||||||
|
sb.WriteString(rowStyle.Render(row))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
// Delete confirmation
|
||||||
|
if m.confirm == i {
|
||||||
|
sb.WriteString(warningStyle.Render(fmt.Sprintf(" "+T("confirm_delete_key"), maskKey(key))))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edit input
|
||||||
|
if m.editing && m.editIdx == i {
|
||||||
|
sb.WriteString(m.editInput.View())
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(T("enter_save_esc")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add input
|
||||||
|
if m.adding {
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(m.editInput.View())
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(T("enter_add")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
// ━━━ Provider Keys (read-only display) ━━━
|
||||||
|
renderProviderKeys(&sb, "Gemini API Keys", m.gemini)
|
||||||
|
renderProviderKeys(&sb, "Claude API Keys", m.claude)
|
||||||
|
renderProviderKeys(&sb, "Codex API Keys", m.codex)
|
||||||
|
renderProviderKeys(&sb, "Vertex API Keys", m.vertex)
|
||||||
|
|
||||||
|
if len(m.openai) > 0 {
|
||||||
|
renderSection(&sb, "OpenAI Compatibility", len(m.openai))
|
||||||
|
for i, entry := range m.openai {
|
||||||
|
name := getString(entry, "name")
|
||||||
|
baseURL := getString(entry, "base-url")
|
||||||
|
prefix := getString(entry, "prefix")
|
||||||
|
info := name
|
||||||
|
if prefix != "" {
|
||||||
|
info += " (prefix: " + prefix + ")"
|
||||||
|
}
|
||||||
|
if baseURL != "" {
|
||||||
|
info += " → " + baseURL
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf(" %d. %s\n", i+1, info))
|
||||||
|
}
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.status != "" {
|
||||||
|
sb.WriteString(m.status)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderSection(sb *strings.Builder, title string, count int) {
|
||||||
|
header := fmt.Sprintf("%s (%d)", title, count)
|
||||||
|
sb.WriteString(tableHeaderStyle.Render(" " + header))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderProviderKeys(sb *strings.Builder, title string, keys []map[string]any) {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
renderSection(sb, title, len(keys))
|
||||||
|
for i, key := range keys {
|
||||||
|
apiKey := getString(key, "api-key")
|
||||||
|
prefix := getString(key, "prefix")
|
||||||
|
baseURL := getString(key, "base-url")
|
||||||
|
info := maskKey(apiKey)
|
||||||
|
if prefix != "" {
|
||||||
|
info += " (prefix: " + prefix + ")"
|
||||||
|
}
|
||||||
|
if baseURL != "" {
|
||||||
|
info += " → " + baseURL
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf(" %d. %s\n", i+1, info))
|
||||||
|
}
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func maskKey(key string) string {
|
||||||
|
if len(key) <= 8 {
|
||||||
|
return strings.Repeat("*", len(key))
|
||||||
|
}
|
||||||
|
return key[:4] + strings.Repeat("*", len(key)-8) + key[len(key)-4:]
|
||||||
|
}
|
||||||
78
internal/tui/loghook.go
Normal file
78
internal/tui/loghook.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogHook is a logrus hook that captures log entries and sends them to a channel.
|
||||||
|
type LogHook struct {
|
||||||
|
ch chan string
|
||||||
|
formatter log.Formatter
|
||||||
|
mu sync.Mutex
|
||||||
|
levels []log.Level
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLogHook creates a new LogHook with a buffered channel of the given size.
|
||||||
|
func NewLogHook(bufSize int) *LogHook {
|
||||||
|
return &LogHook{
|
||||||
|
ch: make(chan string, bufSize),
|
||||||
|
formatter: &log.TextFormatter{DisableColors: true, FullTimestamp: true},
|
||||||
|
levels: log.AllLevels,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFormatter sets a custom formatter for the hook.
|
||||||
|
func (h *LogHook) SetFormatter(f log.Formatter) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.formatter = f
|
||||||
|
}
|
||||||
|
|
||||||
|
// Levels returns the log levels this hook should fire on.
|
||||||
|
func (h *LogHook) Levels() []log.Level {
|
||||||
|
return h.levels
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fire is called by logrus when a log entry is fired.
|
||||||
|
func (h *LogHook) Fire(entry *log.Entry) error {
|
||||||
|
h.mu.Lock()
|
||||||
|
f := h.formatter
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
var line string
|
||||||
|
if f != nil {
|
||||||
|
b, err := f.Format(entry)
|
||||||
|
if err == nil {
|
||||||
|
line = strings.TrimRight(string(b), "\n\r")
|
||||||
|
} else {
|
||||||
|
line = fmt.Sprintf("[%s] %s", entry.Level, entry.Message)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
line = fmt.Sprintf("[%s] %s", entry.Level, entry.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-blocking send
|
||||||
|
select {
|
||||||
|
case h.ch <- line:
|
||||||
|
default:
|
||||||
|
// Drop oldest if full
|
||||||
|
select {
|
||||||
|
case <-h.ch:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case h.ch <- line:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chan returns the channel to read log lines from.
|
||||||
|
func (h *LogHook) Chan() <-chan string {
|
||||||
|
return h.ch
|
||||||
|
}
|
||||||
261
internal/tui/logs_tab.go
Normal file
261
internal/tui/logs_tab.go
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
// logsTabModel displays real-time log lines from hook/API source.
|
||||||
|
type logsTabModel struct {
|
||||||
|
client *Client
|
||||||
|
hook *LogHook
|
||||||
|
viewport viewport.Model
|
||||||
|
lines []string
|
||||||
|
maxLines int
|
||||||
|
autoScroll bool
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
ready bool
|
||||||
|
filter string // "", "debug", "info", "warn", "error"
|
||||||
|
after int64
|
||||||
|
lastErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
type logsPollMsg struct {
|
||||||
|
lines []string
|
||||||
|
latest int64
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type logsTickMsg struct{}
|
||||||
|
type logLineMsg string
|
||||||
|
|
||||||
|
func newLogsTabModel(client *Client, hook *LogHook) logsTabModel {
|
||||||
|
return logsTabModel{
|
||||||
|
client: client,
|
||||||
|
hook: hook,
|
||||||
|
maxLines: 5000,
|
||||||
|
autoScroll: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m logsTabModel) Init() tea.Cmd {
|
||||||
|
if m.hook != nil {
|
||||||
|
return m.waitForLog
|
||||||
|
}
|
||||||
|
return m.fetchLogs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m logsTabModel) fetchLogs() tea.Msg {
|
||||||
|
lines, latest, err := m.client.GetLogs(m.after, 200)
|
||||||
|
return logsPollMsg{
|
||||||
|
lines: lines,
|
||||||
|
latest: latest,
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m logsTabModel) waitForNextPoll() tea.Cmd {
|
||||||
|
return tea.Tick(2*time.Second, func(_ time.Time) tea.Msg {
|
||||||
|
return logsTickMsg{}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m logsTabModel) waitForLog() tea.Msg {
|
||||||
|
if m.hook == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
line, ok := <-m.hook.Chan()
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return logLineMsg(line)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m logsTabModel) Update(msg tea.Msg) (logsTabModel, tea.Cmd) {
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case localeChangedMsg:
|
||||||
|
m.viewport.SetContent(m.renderLogs())
|
||||||
|
return m, nil
|
||||||
|
case logsTickMsg:
|
||||||
|
if m.hook != nil {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
return m, m.fetchLogs
|
||||||
|
case logsPollMsg:
|
||||||
|
if m.hook != nil {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
if msg.err != nil {
|
||||||
|
m.lastErr = msg.err
|
||||||
|
} else {
|
||||||
|
m.lastErr = nil
|
||||||
|
m.after = msg.latest
|
||||||
|
if len(msg.lines) > 0 {
|
||||||
|
m.lines = append(m.lines, msg.lines...)
|
||||||
|
if len(m.lines) > m.maxLines {
|
||||||
|
m.lines = m.lines[len(m.lines)-m.maxLines:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.viewport.SetContent(m.renderLogs())
|
||||||
|
if m.autoScroll {
|
||||||
|
m.viewport.GotoBottom()
|
||||||
|
}
|
||||||
|
return m, m.waitForNextPoll()
|
||||||
|
case logLineMsg:
|
||||||
|
m.lines = append(m.lines, string(msg))
|
||||||
|
if len(m.lines) > m.maxLines {
|
||||||
|
m.lines = m.lines[len(m.lines)-m.maxLines:]
|
||||||
|
}
|
||||||
|
m.viewport.SetContent(m.renderLogs())
|
||||||
|
if m.autoScroll {
|
||||||
|
m.viewport.GotoBottom()
|
||||||
|
}
|
||||||
|
return m, m.waitForLog
|
||||||
|
|
||||||
|
case tea.KeyMsg:
|
||||||
|
switch msg.String() {
|
||||||
|
case "a":
|
||||||
|
m.autoScroll = !m.autoScroll
|
||||||
|
if m.autoScroll {
|
||||||
|
m.viewport.GotoBottom()
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "c":
|
||||||
|
m.lines = nil
|
||||||
|
m.lastErr = nil
|
||||||
|
m.viewport.SetContent(m.renderLogs())
|
||||||
|
return m, nil
|
||||||
|
case "1":
|
||||||
|
m.filter = ""
|
||||||
|
m.viewport.SetContent(m.renderLogs())
|
||||||
|
return m, nil
|
||||||
|
case "2":
|
||||||
|
m.filter = "info"
|
||||||
|
m.viewport.SetContent(m.renderLogs())
|
||||||
|
return m, nil
|
||||||
|
case "3":
|
||||||
|
m.filter = "warn"
|
||||||
|
m.viewport.SetContent(m.renderLogs())
|
||||||
|
return m, nil
|
||||||
|
case "4":
|
||||||
|
m.filter = "error"
|
||||||
|
m.viewport.SetContent(m.renderLogs())
|
||||||
|
return m, nil
|
||||||
|
default:
|
||||||
|
wasAtBottom := m.viewport.AtBottom()
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
// If user scrolls up, disable auto-scroll
|
||||||
|
if !m.viewport.AtBottom() && wasAtBottom {
|
||||||
|
m.autoScroll = false
|
||||||
|
}
|
||||||
|
// If user scrolls to bottom, re-enable auto-scroll
|
||||||
|
if m.viewport.AtBottom() {
|
||||||
|
m.autoScroll = true
|
||||||
|
}
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *logsTabModel) SetSize(w, h int) {
|
||||||
|
m.width = w
|
||||||
|
m.height = h
|
||||||
|
if !m.ready {
|
||||||
|
m.viewport = viewport.New(w, h)
|
||||||
|
m.viewport.SetContent(m.renderLogs())
|
||||||
|
m.ready = true
|
||||||
|
} else {
|
||||||
|
m.viewport.Width = w
|
||||||
|
m.viewport.Height = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m logsTabModel) View() string {
|
||||||
|
if !m.ready {
|
||||||
|
return T("loading")
|
||||||
|
}
|
||||||
|
return m.viewport.View()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m logsTabModel) renderLogs() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
scrollStatus := successStyle.Render(T("logs_auto_scroll"))
|
||||||
|
if !m.autoScroll {
|
||||||
|
scrollStatus = warningStyle.Render(T("logs_paused"))
|
||||||
|
}
|
||||||
|
filterLabel := "ALL"
|
||||||
|
if m.filter != "" {
|
||||||
|
filterLabel = strings.ToUpper(m.filter) + "+"
|
||||||
|
}
|
||||||
|
|
||||||
|
header := fmt.Sprintf(" %s %s %s: %s %s: %d",
|
||||||
|
T("logs_title"), scrollStatus, T("logs_filter"), filterLabel, T("logs_lines"), len(m.lines))
|
||||||
|
sb.WriteString(titleStyle.Render(header))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(T("logs_help")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(strings.Repeat("─", m.width))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
if m.lastErr != nil {
|
||||||
|
sb.WriteString(errorStyle.Render("⚠ Error: " + m.lastErr.Error()))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.lines) == 0 {
|
||||||
|
sb.WriteString(subtitleStyle.Render(T("logs_waiting")))
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, line := range m.lines {
|
||||||
|
if m.filter != "" && !m.matchLevel(line) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
styled := m.styleLine(line)
|
||||||
|
sb.WriteString(styled)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m logsTabModel) matchLevel(line string) bool {
|
||||||
|
switch m.filter {
|
||||||
|
case "error":
|
||||||
|
return strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]") || strings.Contains(line, "[panic]")
|
||||||
|
case "warn":
|
||||||
|
return strings.Contains(line, "[warn") || strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]")
|
||||||
|
case "info":
|
||||||
|
return !strings.Contains(line, "[debug]")
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m logsTabModel) styleLine(line string) string {
|
||||||
|
if strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]") {
|
||||||
|
return logErrorStyle.Render(line)
|
||||||
|
}
|
||||||
|
if strings.Contains(line, "[warn") {
|
||||||
|
return logWarnStyle.Render(line)
|
||||||
|
}
|
||||||
|
if strings.Contains(line, "[info") {
|
||||||
|
return logInfoStyle.Render(line)
|
||||||
|
}
|
||||||
|
if strings.Contains(line, "[debug]") {
|
||||||
|
return logDebugStyle.Render(line)
|
||||||
|
}
|
||||||
|
return line
|
||||||
|
}
|
||||||
473
internal/tui/oauth_tab.go
Normal file
473
internal/tui/oauth_tab.go
Normal file
@@ -0,0 +1,473 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/bubbles/textinput"
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
// oauthProvider represents an OAuth provider option.
|
||||||
|
type oauthProvider struct {
|
||||||
|
name string
|
||||||
|
apiPath string // management API path
|
||||||
|
emoji string
|
||||||
|
}
|
||||||
|
|
||||||
|
var oauthProviders = []oauthProvider{
|
||||||
|
{"Gemini CLI", "gemini-cli-auth-url", "🟦"},
|
||||||
|
{"Claude (Anthropic)", "anthropic-auth-url", "🟧"},
|
||||||
|
{"Codex (OpenAI)", "codex-auth-url", "🟩"},
|
||||||
|
{"Antigravity", "antigravity-auth-url", "🟪"},
|
||||||
|
{"Qwen", "qwen-auth-url", "🟨"},
|
||||||
|
{"Kimi", "kimi-auth-url", "🟫"},
|
||||||
|
{"IFlow", "iflow-auth-url", "⬜"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// oauthTabModel handles OAuth login flows.
|
||||||
|
type oauthTabModel struct {
|
||||||
|
client *Client
|
||||||
|
viewport viewport.Model
|
||||||
|
cursor int
|
||||||
|
state oauthState
|
||||||
|
message string
|
||||||
|
err error
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
ready bool
|
||||||
|
|
||||||
|
// Remote browser mode
|
||||||
|
authURL string // auth URL to display
|
||||||
|
authState string // OAuth state parameter
|
||||||
|
providerName string // current provider name
|
||||||
|
callbackInput textinput.Model
|
||||||
|
inputActive bool // true when user is typing callback URL
|
||||||
|
}
|
||||||
|
|
||||||
|
type oauthState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
oauthIdle oauthState = iota
|
||||||
|
oauthPending
|
||||||
|
oauthRemote // remote browser mode: waiting for manual callback
|
||||||
|
oauthSuccess
|
||||||
|
oauthError
|
||||||
|
)
|
||||||
|
|
||||||
|
// Messages
|
||||||
|
type oauthStartMsg struct {
|
||||||
|
url string
|
||||||
|
state string
|
||||||
|
providerName string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type oauthPollMsg struct {
|
||||||
|
done bool
|
||||||
|
message string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type oauthCallbackSubmitMsg struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOAuthTabModel(client *Client) oauthTabModel {
|
||||||
|
ti := textinput.New()
|
||||||
|
ti.Placeholder = "http://localhost:.../auth/callback?code=...&state=..."
|
||||||
|
ti.CharLimit = 2048
|
||||||
|
ti.Prompt = " 回调 URL: "
|
||||||
|
return oauthTabModel{
|
||||||
|
client: client,
|
||||||
|
callbackInput: ti,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m oauthTabModel) Init() tea.Cmd {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m oauthTabModel) Update(msg tea.Msg) (oauthTabModel, tea.Cmd) {
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case localeChangedMsg:
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
case oauthStartMsg:
|
||||||
|
if msg.err != nil {
|
||||||
|
m.state = oauthError
|
||||||
|
m.err = msg.err
|
||||||
|
m.message = errorStyle.Render("✗ " + msg.err.Error())
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
m.authURL = msg.url
|
||||||
|
m.authState = msg.state
|
||||||
|
m.providerName = msg.providerName
|
||||||
|
m.state = oauthRemote
|
||||||
|
m.callbackInput.SetValue("")
|
||||||
|
m.callbackInput.Focus()
|
||||||
|
m.inputActive = true
|
||||||
|
m.message = ""
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
// Also start polling in the background
|
||||||
|
return m, tea.Batch(textinput.Blink, m.pollOAuthStatus(msg.state))
|
||||||
|
|
||||||
|
case oauthPollMsg:
|
||||||
|
if msg.err != nil {
|
||||||
|
m.state = oauthError
|
||||||
|
m.err = msg.err
|
||||||
|
m.message = errorStyle.Render("✗ " + msg.err.Error())
|
||||||
|
m.inputActive = false
|
||||||
|
m.callbackInput.Blur()
|
||||||
|
} else if msg.done {
|
||||||
|
m.state = oauthSuccess
|
||||||
|
m.message = successStyle.Render("✓ " + msg.message)
|
||||||
|
m.inputActive = false
|
||||||
|
m.callbackInput.Blur()
|
||||||
|
} else {
|
||||||
|
m.message = warningStyle.Render("⏳ " + msg.message)
|
||||||
|
}
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case oauthCallbackSubmitMsg:
|
||||||
|
if msg.err != nil {
|
||||||
|
m.message = errorStyle.Render(T("oauth_submit_fail") + ": " + msg.err.Error())
|
||||||
|
} else {
|
||||||
|
m.message = successStyle.Render(T("oauth_submit_ok"))
|
||||||
|
}
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case tea.KeyMsg:
|
||||||
|
// ---- Input active: typing callback URL ----
|
||||||
|
if m.inputActive {
|
||||||
|
switch msg.String() {
|
||||||
|
case "enter":
|
||||||
|
callbackURL := m.callbackInput.Value()
|
||||||
|
if callbackURL == "" {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
m.inputActive = false
|
||||||
|
m.callbackInput.Blur()
|
||||||
|
m.message = warningStyle.Render(T("oauth_submitting"))
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, m.submitCallback(callbackURL)
|
||||||
|
case "esc":
|
||||||
|
m.inputActive = false
|
||||||
|
m.callbackInput.Blur()
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
default:
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.callbackInput, cmd = m.callbackInput.Update(msg)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Remote mode but not typing ----
|
||||||
|
if m.state == oauthRemote {
|
||||||
|
switch msg.String() {
|
||||||
|
case "c", "C":
|
||||||
|
// Re-activate input
|
||||||
|
m.inputActive = true
|
||||||
|
m.callbackInput.Focus()
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, textinput.Blink
|
||||||
|
case "esc":
|
||||||
|
m.state = oauthIdle
|
||||||
|
m.message = ""
|
||||||
|
m.authURL = ""
|
||||||
|
m.authState = ""
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Pending (auto polling) ----
|
||||||
|
if m.state == oauthPending {
|
||||||
|
if msg.String() == "esc" {
|
||||||
|
m.state = oauthIdle
|
||||||
|
m.message = ""
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Idle ----
|
||||||
|
switch msg.String() {
|
||||||
|
case "up", "k":
|
||||||
|
if m.cursor > 0 {
|
||||||
|
m.cursor--
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "down", "j":
|
||||||
|
if m.cursor < len(oauthProviders)-1 {
|
||||||
|
m.cursor++
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "enter":
|
||||||
|
if m.cursor >= 0 && m.cursor < len(oauthProviders) {
|
||||||
|
provider := oauthProviders[m.cursor]
|
||||||
|
m.state = oauthPending
|
||||||
|
m.message = warningStyle.Render(fmt.Sprintf(T("oauth_initiating"), provider.name))
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, m.startOAuth(provider)
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case "esc":
|
||||||
|
m.state = oauthIdle
|
||||||
|
m.message = ""
|
||||||
|
m.err = nil
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m oauthTabModel) startOAuth(provider oauthProvider) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
// Call the auth URL endpoint with is_webui=true
|
||||||
|
data, err := m.client.getJSON("/v0/management/" + provider.apiPath + "?is_webui=true")
|
||||||
|
if err != nil {
|
||||||
|
return oauthStartMsg{err: fmt.Errorf("failed to start %s login: %w", provider.name, err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
authURL := getString(data, "url")
|
||||||
|
state := getString(data, "state")
|
||||||
|
if authURL == "" {
|
||||||
|
return oauthStartMsg{err: fmt.Errorf("no auth URL returned for %s", provider.name)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to open browser (best effort)
|
||||||
|
_ = openBrowser(authURL)
|
||||||
|
|
||||||
|
return oauthStartMsg{url: authURL, state: state, providerName: provider.name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
// Determine provider from current context
|
||||||
|
providerKey := ""
|
||||||
|
for _, p := range oauthProviders {
|
||||||
|
if p.name == m.providerName {
|
||||||
|
// Map provider name to the canonical key the API expects
|
||||||
|
switch p.apiPath {
|
||||||
|
case "gemini-cli-auth-url":
|
||||||
|
providerKey = "gemini"
|
||||||
|
case "anthropic-auth-url":
|
||||||
|
providerKey = "anthropic"
|
||||||
|
case "codex-auth-url":
|
||||||
|
providerKey = "codex"
|
||||||
|
case "antigravity-auth-url":
|
||||||
|
providerKey = "antigravity"
|
||||||
|
case "qwen-auth-url":
|
||||||
|
providerKey = "qwen"
|
||||||
|
case "kimi-auth-url":
|
||||||
|
providerKey = "kimi"
|
||||||
|
case "iflow-auth-url":
|
||||||
|
providerKey = "iflow"
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
body := map[string]string{
|
||||||
|
"provider": providerKey,
|
||||||
|
"redirect_url": callbackURL,
|
||||||
|
"state": m.authState,
|
||||||
|
}
|
||||||
|
err := m.client.postJSON("/v0/management/oauth-callback", body)
|
||||||
|
if err != nil {
|
||||||
|
return oauthCallbackSubmitMsg{err: err}
|
||||||
|
}
|
||||||
|
return oauthCallbackSubmitMsg{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m oauthTabModel) pollOAuthStatus(state string) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
// Poll session status for up to 5 minutes
|
||||||
|
deadline := time.Now().Add(5 * time.Minute)
|
||||||
|
for {
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
return oauthPollMsg{done: false, err: fmt.Errorf("%s", T("oauth_timeout"))}
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
status, errMsg, err := m.client.GetAuthStatus(state)
|
||||||
|
if err != nil {
|
||||||
|
continue // Ignore transient errors
|
||||||
|
}
|
||||||
|
|
||||||
|
switch status {
|
||||||
|
case "ok":
|
||||||
|
return oauthPollMsg{
|
||||||
|
done: true,
|
||||||
|
message: T("oauth_success"),
|
||||||
|
}
|
||||||
|
case "error":
|
||||||
|
return oauthPollMsg{
|
||||||
|
done: false,
|
||||||
|
err: fmt.Errorf("%s: %s", T("oauth_failed"), errMsg),
|
||||||
|
}
|
||||||
|
case "wait":
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
return oauthPollMsg{
|
||||||
|
done: true,
|
||||||
|
message: T("oauth_completed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *oauthTabModel) SetSize(w, h int) {
|
||||||
|
m.width = w
|
||||||
|
m.height = h
|
||||||
|
m.callbackInput.Width = w - 16
|
||||||
|
if !m.ready {
|
||||||
|
m.viewport = viewport.New(w, h)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
m.ready = true
|
||||||
|
} else {
|
||||||
|
m.viewport.Width = w
|
||||||
|
m.viewport.Height = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m oauthTabModel) View() string {
|
||||||
|
if !m.ready {
|
||||||
|
return T("loading")
|
||||||
|
}
|
||||||
|
return m.viewport.View()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m oauthTabModel) renderContent() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
sb.WriteString(titleStyle.Render(T("oauth_title")))
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
|
||||||
|
if m.message != "" {
|
||||||
|
sb.WriteString(" " + m.message)
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Remote browser mode ----
|
||||||
|
if m.state == oauthRemote {
|
||||||
|
sb.WriteString(m.renderRemoteMode())
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.state == oauthPending {
|
||||||
|
sb.WriteString(helpStyle.Render(T("oauth_press_esc")))
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(helpStyle.Render(T("oauth_select")))
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
|
||||||
|
for i, p := range oauthProviders {
|
||||||
|
isSelected := i == m.cursor
|
||||||
|
prefix := " "
|
||||||
|
if isSelected {
|
||||||
|
prefix = "▸ "
|
||||||
|
}
|
||||||
|
|
||||||
|
label := fmt.Sprintf("%s %s", p.emoji, p.name)
|
||||||
|
if isSelected {
|
||||||
|
label = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#FFFFFF")).Background(colorPrimary).Padding(0, 1).Render(label)
|
||||||
|
} else {
|
||||||
|
label = lipgloss.NewStyle().Foreground(colorText).Padding(0, 1).Render(label)
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(prefix + label + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(T("oauth_help")))
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m oauthTabModel) renderRemoteMode() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
providerStyle := lipgloss.NewStyle().Bold(true).Foreground(colorHighlight)
|
||||||
|
sb.WriteString(providerStyle.Render(fmt.Sprintf(" ✦ %s OAuth", m.providerName)))
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
|
||||||
|
// Auth URL section
|
||||||
|
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorInfo).Render(T("oauth_auth_url")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
// Wrap URL to fit terminal width
|
||||||
|
urlStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
|
||||||
|
maxURLWidth := m.width - 6
|
||||||
|
if maxURLWidth < 40 {
|
||||||
|
maxURLWidth = 40
|
||||||
|
}
|
||||||
|
wrappedURL := wrapText(m.authURL, maxURLWidth)
|
||||||
|
for _, line := range wrappedURL {
|
||||||
|
sb.WriteString(" " + urlStyle.Render(line) + "\n")
|
||||||
|
}
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
sb.WriteString(helpStyle.Render(T("oauth_remote_hint")))
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
|
||||||
|
// Callback URL input
|
||||||
|
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorInfo).Render(T("oauth_callback_url")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
if m.inputActive {
|
||||||
|
sb.WriteString(m.callbackInput.View())
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(" " + T("enter_submit") + " • " + T("esc_cancel")))
|
||||||
|
} else {
|
||||||
|
sb.WriteString(helpStyle.Render(T("oauth_press_c")))
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
sb.WriteString(warningStyle.Render(T("oauth_waiting")))
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapText splits a long string into lines of at most maxWidth characters.
|
||||||
|
func wrapText(s string, maxWidth int) []string {
|
||||||
|
if maxWidth <= 0 {
|
||||||
|
return []string{s}
|
||||||
|
}
|
||||||
|
var lines []string
|
||||||
|
for len(s) > maxWidth {
|
||||||
|
lines = append(lines, s[:maxWidth])
|
||||||
|
s = s[maxWidth:]
|
||||||
|
}
|
||||||
|
if len(s) > 0 {
|
||||||
|
lines = append(lines, s)
|
||||||
|
}
|
||||||
|
return lines
|
||||||
|
}
|
||||||
126
internal/tui/styles.go
Normal file
126
internal/tui/styles.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
// Package tui provides a terminal-based management interface for CLIProxyAPI.
|
||||||
|
package tui
|
||||||
|
|
||||||
|
import "github.com/charmbracelet/lipgloss"
|
||||||
|
|
||||||
|
// Color palette
|
||||||
|
var (
|
||||||
|
colorPrimary = lipgloss.Color("#7C3AED") // violet
|
||||||
|
colorSecondary = lipgloss.Color("#6366F1") // indigo
|
||||||
|
colorSuccess = lipgloss.Color("#22C55E") // green
|
||||||
|
colorWarning = lipgloss.Color("#EAB308") // yellow
|
||||||
|
colorError = lipgloss.Color("#EF4444") // red
|
||||||
|
colorInfo = lipgloss.Color("#3B82F6") // blue
|
||||||
|
colorMuted = lipgloss.Color("#6B7280") // gray
|
||||||
|
colorBg = lipgloss.Color("#1E1E2E") // dark bg
|
||||||
|
colorSurface = lipgloss.Color("#313244") // slightly lighter
|
||||||
|
colorText = lipgloss.Color("#CDD6F4") // light text
|
||||||
|
colorSubtext = lipgloss.Color("#A6ADC8") // dimmer text
|
||||||
|
colorBorder = lipgloss.Color("#45475A") // border
|
||||||
|
colorHighlight = lipgloss.Color("#F5C2E7") // pink highlight
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tab bar styles
|
||||||
|
var (
|
||||||
|
tabActiveStyle = lipgloss.NewStyle().
|
||||||
|
Bold(true).
|
||||||
|
Foreground(lipgloss.Color("#FFFFFF")).
|
||||||
|
Background(colorPrimary).
|
||||||
|
Padding(0, 2)
|
||||||
|
|
||||||
|
tabInactiveStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorSubtext).
|
||||||
|
Background(colorSurface).
|
||||||
|
Padding(0, 2)
|
||||||
|
|
||||||
|
tabBarStyle = lipgloss.NewStyle().
|
||||||
|
Background(colorSurface).
|
||||||
|
PaddingLeft(1).
|
||||||
|
PaddingBottom(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Content styles
|
||||||
|
var (
|
||||||
|
titleStyle = lipgloss.NewStyle().
|
||||||
|
Bold(true).
|
||||||
|
Foreground(colorHighlight).
|
||||||
|
MarginBottom(1)
|
||||||
|
|
||||||
|
subtitleStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorSubtext).
|
||||||
|
Italic(true)
|
||||||
|
|
||||||
|
labelStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorInfo).
|
||||||
|
Bold(true).
|
||||||
|
Width(24)
|
||||||
|
|
||||||
|
valueStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorText)
|
||||||
|
|
||||||
|
sectionStyle = lipgloss.NewStyle().
|
||||||
|
Border(lipgloss.RoundedBorder()).
|
||||||
|
BorderForeground(colorBorder).
|
||||||
|
Padding(1, 2)
|
||||||
|
|
||||||
|
errorStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorError).
|
||||||
|
Bold(true)
|
||||||
|
|
||||||
|
successStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorSuccess)
|
||||||
|
|
||||||
|
warningStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorWarning)
|
||||||
|
|
||||||
|
statusBarStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorSubtext).
|
||||||
|
Background(colorSurface).
|
||||||
|
PaddingLeft(1).
|
||||||
|
PaddingRight(1)
|
||||||
|
|
||||||
|
helpStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorMuted)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Log level styles
|
||||||
|
var (
|
||||||
|
logDebugStyle = lipgloss.NewStyle().Foreground(colorMuted)
|
||||||
|
logInfoStyle = lipgloss.NewStyle().Foreground(colorInfo)
|
||||||
|
logWarnStyle = lipgloss.NewStyle().Foreground(colorWarning)
|
||||||
|
logErrorStyle = lipgloss.NewStyle().Foreground(colorError)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Table styles
|
||||||
|
var (
|
||||||
|
tableHeaderStyle = lipgloss.NewStyle().
|
||||||
|
Bold(true).
|
||||||
|
Foreground(colorHighlight).
|
||||||
|
BorderBottom(true).
|
||||||
|
BorderStyle(lipgloss.NormalBorder()).
|
||||||
|
BorderForeground(colorBorder)
|
||||||
|
|
||||||
|
tableCellStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorText).
|
||||||
|
PaddingRight(2)
|
||||||
|
|
||||||
|
tableSelectedStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#FFFFFF")).
|
||||||
|
Background(colorPrimary).
|
||||||
|
Bold(true)
|
||||||
|
)
|
||||||
|
|
||||||
|
func logLevelStyle(level string) lipgloss.Style {
|
||||||
|
switch level {
|
||||||
|
case "debug":
|
||||||
|
return logDebugStyle
|
||||||
|
case "info":
|
||||||
|
return logInfoStyle
|
||||||
|
case "warn", "warning":
|
||||||
|
return logWarnStyle
|
||||||
|
case "error", "fatal", "panic":
|
||||||
|
return logErrorStyle
|
||||||
|
default:
|
||||||
|
return logInfoStyle
|
||||||
|
}
|
||||||
|
}
|
||||||
364
internal/tui/usage_tab.go
Normal file
364
internal/tui/usage_tab.go
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
// usageTabModel displays usage statistics with charts and breakdowns.
|
||||||
|
type usageTabModel struct {
|
||||||
|
client *Client
|
||||||
|
viewport viewport.Model
|
||||||
|
usage map[string]any
|
||||||
|
err error
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
ready bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type usageDataMsg struct {
|
||||||
|
usage map[string]any
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUsageTabModel(client *Client) usageTabModel {
|
||||||
|
return usageTabModel{
|
||||||
|
client: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m usageTabModel) Init() tea.Cmd {
|
||||||
|
return m.fetchData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m usageTabModel) fetchData() tea.Msg {
|
||||||
|
usage, err := m.client.GetUsage()
|
||||||
|
return usageDataMsg{usage: usage, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m usageTabModel) Update(msg tea.Msg) (usageTabModel, tea.Cmd) {
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case localeChangedMsg:
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
case usageDataMsg:
|
||||||
|
if msg.err != nil {
|
||||||
|
m.err = msg.err
|
||||||
|
} else {
|
||||||
|
m.err = nil
|
||||||
|
m.usage = msg.usage
|
||||||
|
}
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
case tea.KeyMsg:
|
||||||
|
if msg.String() == "r" {
|
||||||
|
return m, m.fetchData
|
||||||
|
}
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.viewport, cmd = m.viewport.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *usageTabModel) SetSize(w, h int) {
|
||||||
|
m.width = w
|
||||||
|
m.height = h
|
||||||
|
if !m.ready {
|
||||||
|
m.viewport = viewport.New(w, h)
|
||||||
|
m.viewport.SetContent(m.renderContent())
|
||||||
|
m.ready = true
|
||||||
|
} else {
|
||||||
|
m.viewport.Width = w
|
||||||
|
m.viewport.Height = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m usageTabModel) View() string {
|
||||||
|
if !m.ready {
|
||||||
|
return T("loading")
|
||||||
|
}
|
||||||
|
return m.viewport.View()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m usageTabModel) renderContent() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
sb.WriteString(titleStyle.Render(T("usage_title")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(helpStyle.Render(T("usage_help")))
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
|
||||||
|
if m.err != nil {
|
||||||
|
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.usage == nil {
|
||||||
|
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
usageMap, _ := m.usage["usage"].(map[string]any)
|
||||||
|
if usageMap == nil {
|
||||||
|
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
totalReqs := int64(getFloat(usageMap, "total_requests"))
|
||||||
|
successCnt := int64(getFloat(usageMap, "success_count"))
|
||||||
|
failureCnt := int64(getFloat(usageMap, "failure_count"))
|
||||||
|
totalTokens := int64(getFloat(usageMap, "total_tokens"))
|
||||||
|
|
||||||
|
// ━━━ Overview Cards ━━━
|
||||||
|
cardWidth := 20
|
||||||
|
if m.width > 0 {
|
||||||
|
cardWidth = (m.width - 6) / 4
|
||||||
|
if cardWidth < 16 {
|
||||||
|
cardWidth = 16
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cardStyle := lipgloss.NewStyle().
|
||||||
|
Border(lipgloss.RoundedBorder()).
|
||||||
|
BorderForeground(lipgloss.Color("240")).
|
||||||
|
Padding(0, 1).
|
||||||
|
Width(cardWidth).
|
||||||
|
Height(3)
|
||||||
|
|
||||||
|
// Total Requests
|
||||||
|
card1 := cardStyle.Copy().BorderForeground(lipgloss.Color("111")).Render(fmt.Sprintf(
|
||||||
|
"%s\n%s\n%s",
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_reqs")),
|
||||||
|
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("%d", totalReqs)),
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("● %s: %d ● %s: %d", T("usage_success"), successCnt, T("usage_failure"), failureCnt)),
|
||||||
|
))
|
||||||
|
|
||||||
|
// Total Tokens
|
||||||
|
card2 := cardStyle.Copy().BorderForeground(lipgloss.Color("214")).Render(fmt.Sprintf(
|
||||||
|
"%s\n%s\n%s",
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_tokens")),
|
||||||
|
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(formatLargeNumber(totalTokens)),
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_token_l"), formatLargeNumber(totalTokens))),
|
||||||
|
))
|
||||||
|
|
||||||
|
// RPM
|
||||||
|
rpm := float64(0)
|
||||||
|
if totalReqs > 0 {
|
||||||
|
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
|
||||||
|
rpm = float64(totalReqs) / float64(len(rByH)) / 60.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
card3 := cardStyle.Copy().BorderForeground(lipgloss.Color("76")).Render(fmt.Sprintf(
|
||||||
|
"%s\n%s\n%s",
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_rpm")),
|
||||||
|
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("%.2f", rpm)),
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %d", T("usage_total_reqs"), totalReqs)),
|
||||||
|
))
|
||||||
|
|
||||||
|
// TPM
|
||||||
|
tpm := float64(0)
|
||||||
|
if totalTokens > 0 {
|
||||||
|
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
|
||||||
|
tpm = float64(totalTokens) / float64(len(tByH)) / 60.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
card4 := cardStyle.Copy().BorderForeground(lipgloss.Color("170")).Render(fmt.Sprintf(
|
||||||
|
"%s\n%s\n%s",
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_tpm")),
|
||||||
|
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("%.2f", tpm)),
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_tokens"), formatLargeNumber(totalTokens))),
|
||||||
|
))
|
||||||
|
|
||||||
|
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
|
||||||
|
// ━━━ Requests by Hour (ASCII bar chart) ━━━
|
||||||
|
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
|
||||||
|
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_hour")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(renderBarChart(rByH, m.width-6, lipgloss.Color("111")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ━━━ Tokens by Hour ━━━
|
||||||
|
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
|
||||||
|
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_tok_by_hour")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(renderBarChart(tByH, m.width-6, lipgloss.Color("214")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ━━━ Requests by Day ━━━
|
||||||
|
if rByD, ok := usageMap["requests_by_day"].(map[string]any); ok && len(rByD) > 0 {
|
||||||
|
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_day")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(renderBarChart(rByD, m.width-6, lipgloss.Color("76")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ━━━ API Detail Stats ━━━
|
||||||
|
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
|
||||||
|
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_api_detail")))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(strings.Repeat("─", minInt(m.width, 80)))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
header := fmt.Sprintf(" %-30s %10s %12s", "API", T("requests"), T("tokens"))
|
||||||
|
sb.WriteString(tableHeaderStyle.Render(header))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
for apiName, apiSnap := range apis {
|
||||||
|
if apiMap, ok := apiSnap.(map[string]any); ok {
|
||||||
|
apiReqs := int64(getFloat(apiMap, "total_requests"))
|
||||||
|
apiToks := int64(getFloat(apiMap, "total_tokens"))
|
||||||
|
|
||||||
|
row := fmt.Sprintf(" %-30s %10d %12s",
|
||||||
|
truncate(maskKey(apiName), 30), apiReqs, formatLargeNumber(apiToks))
|
||||||
|
sb.WriteString(lipgloss.NewStyle().Bold(true).Render(row))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
// Per-model breakdown
|
||||||
|
if models, ok := apiMap["models"].(map[string]any); ok {
|
||||||
|
for model, v := range models {
|
||||||
|
if stats, ok := v.(map[string]any); ok {
|
||||||
|
mReqs := int64(getFloat(stats, "total_requests"))
|
||||||
|
mToks := int64(getFloat(stats, "total_tokens"))
|
||||||
|
mRow := fmt.Sprintf(" ├─ %-28s %10d %12s",
|
||||||
|
truncate(model, 28), mReqs, formatLargeNumber(mToks))
|
||||||
|
sb.WriteString(tableCellStyle.Render(mRow))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
|
||||||
|
// Token type breakdown from details
|
||||||
|
sb.WriteString(m.renderTokenBreakdown(stats))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("\n")
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderTokenBreakdown aggregates input/output/cached/reasoning tokens from model details.
|
||||||
|
func (m usageTabModel) renderTokenBreakdown(modelStats map[string]any) string {
|
||||||
|
details, ok := modelStats["details"]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
detailList, ok := details.([]any)
|
||||||
|
if !ok || len(detailList) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var inputTotal, outputTotal, cachedTotal, reasoningTotal int64
|
||||||
|
for _, d := range detailList {
|
||||||
|
dm, ok := d.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tokens, ok := dm["tokens"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
inputTotal += int64(getFloat(tokens, "input_tokens"))
|
||||||
|
outputTotal += int64(getFloat(tokens, "output_tokens"))
|
||||||
|
cachedTotal += int64(getFloat(tokens, "cached_tokens"))
|
||||||
|
reasoningTotal += int64(getFloat(tokens, "reasoning_tokens"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if inputTotal == 0 && outputTotal == 0 && cachedTotal == 0 && reasoningTotal == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := []string{}
|
||||||
|
if inputTotal > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_input"), formatLargeNumber(inputTotal)))
|
||||||
|
}
|
||||||
|
if outputTotal > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_output"), formatLargeNumber(outputTotal)))
|
||||||
|
}
|
||||||
|
if cachedTotal > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_cached"), formatLargeNumber(cachedTotal)))
|
||||||
|
}
|
||||||
|
if reasoningTotal > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_reasoning"), formatLargeNumber(reasoningTotal)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(" │ %s\n",
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Join(parts, " ")))
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderBarChart renders a simple ASCII horizontal bar chart.
|
||||||
|
func renderBarChart(data map[string]any, maxBarWidth int, barColor lipgloss.Color) string {
|
||||||
|
if maxBarWidth < 10 {
|
||||||
|
maxBarWidth = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort keys
|
||||||
|
keys := make([]string, 0, len(data))
|
||||||
|
for k := range data {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
// Find max value
|
||||||
|
maxVal := float64(0)
|
||||||
|
for _, k := range keys {
|
||||||
|
v := getFloat(data, k)
|
||||||
|
if v > maxVal {
|
||||||
|
maxVal = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if maxVal == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
barStyle := lipgloss.NewStyle().Foreground(barColor)
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
labelWidth := 12
|
||||||
|
barAvail := maxBarWidth - labelWidth - 12
|
||||||
|
if barAvail < 5 {
|
||||||
|
barAvail = 5
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, k := range keys {
|
||||||
|
v := getFloat(data, k)
|
||||||
|
barLen := int(v / maxVal * float64(barAvail))
|
||||||
|
if barLen < 1 && v > 0 {
|
||||||
|
barLen = 1
|
||||||
|
}
|
||||||
|
bar := strings.Repeat("█", barLen)
|
||||||
|
label := k
|
||||||
|
if len(label) > labelWidth {
|
||||||
|
label = label[:labelWidth]
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf(" %-*s %s %s\n",
|
||||||
|
labelWidth, label,
|
||||||
|
barStyle.Render(bar),
|
||||||
|
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%.0f", v)),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
@@ -184,6 +184,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
|||||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||||
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||||
}
|
}
|
||||||
|
if o.Websockets != n.Websockets {
|
||||||
|
changes = append(changes, fmt.Sprintf("codex[%d].websockets: %t -> %t", i, o.Websockets, n.Websockets))
|
||||||
|
}
|
||||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||||
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
|
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -164,6 +164,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
|
|||||||
if ck.BaseURL != "" {
|
if ck.BaseURL != "" {
|
||||||
attrs["base_url"] = ck.BaseURL
|
attrs["base_url"] = ck.BaseURL
|
||||||
}
|
}
|
||||||
|
if ck.Websockets {
|
||||||
|
attrs["websockets"] = "true"
|
||||||
|
}
|
||||||
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
|
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
|
||||||
attrs["models_hash"] = hash
|
attrs["models_hash"] = hash
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -231,10 +231,11 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
|
|||||||
Config: &config.Config{
|
Config: &config.Config{
|
||||||
CodexKey: []config.CodexKey{
|
CodexKey: []config.CodexKey{
|
||||||
{
|
{
|
||||||
APIKey: "codex-key-123",
|
APIKey: "codex-key-123",
|
||||||
Prefix: "dev",
|
Prefix: "dev",
|
||||||
BaseURL: "https://api.openai.com",
|
BaseURL: "https://api.openai.com",
|
||||||
ProxyURL: "http://proxy.local",
|
ProxyURL: "http://proxy.local",
|
||||||
|
Websockets: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -259,6 +260,9 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
|
|||||||
if auths[0].ProxyURL != "http://proxy.local" {
|
if auths[0].ProxyURL != "http://proxy.local" {
|
||||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||||
}
|
}
|
||||||
|
if auths[0].Attributes["websockets"] != "true" {
|
||||||
|
t.Errorf("expected websockets=true, got %s", auths[0].Attributes["websockets"])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) {
|
func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) {
|
||||||
|
|||||||
@@ -112,12 +112,13 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) {
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
|
|
||||||
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
@@ -165,7 +166,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
|
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||||
stopKeepAlive()
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
@@ -194,6 +195,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
@@ -225,7 +227,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
|||||||
// This allows proper cleanup and cancellation of ongoing requests
|
// This allows proper cleanup and cancellation of ongoing requests
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
|
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
c.Header("Cache-Control", "no-cache")
|
c.Header("Cache-Control", "no-cache")
|
||||||
@@ -257,6 +259,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
|||||||
if !ok {
|
if !ok {
|
||||||
// Stream closed without data? Send DONE or just headers.
|
// Stream closed without data? Send DONE or just headers.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
return
|
return
|
||||||
@@ -264,6 +267,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
|||||||
|
|
||||||
// Success! Set headers now.
|
// Success! Set headers now.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
|
|
||||||
// Write the first chunk
|
// Write the first chunk
|
||||||
if len(chunk) > 0 {
|
if len(chunk) > 0 {
|
||||||
|
|||||||
@@ -159,7 +159,8 @@ func (h *GeminiCLIAPIHandler) handleInternalStreamGenerateContent(c *gin.Context
|
|||||||
modelName := modelResult.String()
|
modelName := modelResult.String()
|
||||||
|
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
h.forwardCLIStream(c, flusher, "", func(err error) { cliCancel(err) }, dataChan, errChan)
|
h.forwardCLIStream(c, flusher, "", func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -172,12 +173,13 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
|
|||||||
modelName := modelResult.String()
|
modelName := modelResult.String()
|
||||||
|
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
|||||||
}
|
}
|
||||||
|
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||||
|
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
@@ -223,6 +223,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
|||||||
if alt == "" {
|
if alt == "" {
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
return
|
return
|
||||||
@@ -232,6 +233,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
|||||||
if alt == "" {
|
if alt == "" {
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
|
|
||||||
// Write first chunk
|
// Write first chunk
|
||||||
if alt == "" {
|
if alt == "" {
|
||||||
@@ -262,12 +264,13 @@ func (h *GeminiAPIHandler) handleCountTokens(c *gin.Context, modelName string, r
|
|||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
alt := h.GetAlt(c)
|
alt := h.GetAlt(c)
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
@@ -286,13 +289,14 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
|
|||||||
alt := h.GetAlt(c)
|
alt := h.GetAlt(c)
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||||
stopKeepAlive()
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,6 +52,45 @@ const (
|
|||||||
defaultStreamingBootstrapRetries = 0
|
defaultStreamingBootstrapRetries = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type pinnedAuthContextKey struct{}
|
||||||
|
type selectedAuthCallbackContextKey struct{}
|
||||||
|
type executionSessionContextKey struct{}
|
||||||
|
|
||||||
|
// WithPinnedAuthID returns a child context that requests execution on a specific auth ID.
|
||||||
|
func WithPinnedAuthID(ctx context.Context, authID string) context.Context {
|
||||||
|
authID = strings.TrimSpace(authID)
|
||||||
|
if authID == "" {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, pinnedAuthContextKey{}, authID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSelectedAuthIDCallback returns a child context that receives the selected auth ID.
|
||||||
|
func WithSelectedAuthIDCallback(ctx context.Context, callback func(string)) context.Context {
|
||||||
|
if callback == nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, selectedAuthCallbackContextKey{}, callback)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithExecutionSessionID returns a child context tagged with a long-lived execution session ID.
|
||||||
|
func WithExecutionSessionID(ctx context.Context, sessionID string) context.Context {
|
||||||
|
sessionID = strings.TrimSpace(sessionID)
|
||||||
|
if sessionID == "" {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, executionSessionContextKey{}, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
|
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
|
||||||
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
|
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
|
||||||
func BuildErrorResponseBody(status int, errText string) []byte {
|
func BuildErrorResponseBody(status int, errText string) []byte {
|
||||||
@@ -140,6 +179,12 @@ func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
|||||||
return retries
|
return retries
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PassthroughHeadersEnabled returns whether upstream response headers should be forwarded to clients.
|
||||||
|
// Default is false.
|
||||||
|
func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool {
|
||||||
|
return cfg != nil && cfg.PassthroughHeaders
|
||||||
|
}
|
||||||
|
|
||||||
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||||
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
||||||
// It is forwarded as execution metadata; when absent we generate a UUID.
|
// It is forwarded as execution metadata; when absent we generate a UUID.
|
||||||
@@ -152,7 +197,59 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
|
|||||||
if key == "" {
|
if key == "" {
|
||||||
key = uuid.NewString()
|
key = uuid.NewString()
|
||||||
}
|
}
|
||||||
return map[string]any{idempotencyKeyMetadataKey: key}
|
|
||||||
|
meta := map[string]any{idempotencyKeyMetadataKey: key}
|
||||||
|
if pinnedAuthID := pinnedAuthIDFromContext(ctx); pinnedAuthID != "" {
|
||||||
|
meta[coreexecutor.PinnedAuthMetadataKey] = pinnedAuthID
|
||||||
|
}
|
||||||
|
if selectedCallback := selectedAuthIDCallbackFromContext(ctx); selectedCallback != nil {
|
||||||
|
meta[coreexecutor.SelectedAuthCallbackMetadataKey] = selectedCallback
|
||||||
|
}
|
||||||
|
if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" {
|
||||||
|
meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID
|
||||||
|
}
|
||||||
|
return meta
|
||||||
|
}
|
||||||
|
|
||||||
|
func pinnedAuthIDFromContext(ctx context.Context) string {
|
||||||
|
if ctx == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
raw := ctx.Value(pinnedAuthContextKey{})
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(v)
|
||||||
|
case []byte:
|
||||||
|
return strings.TrimSpace(string(v))
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectedAuthIDCallbackFromContext(ctx context.Context) func(string) {
|
||||||
|
if ctx == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw := ctx.Value(selectedAuthCallbackContextKey{})
|
||||||
|
if callback, ok := raw.(func(string)); ok && callback != nil {
|
||||||
|
return callback
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func executionSessionIDFromContext(ctx context.Context) string {
|
||||||
|
if ctx == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
raw := ctx.Value(executionSessionContextKey{})
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(v)
|
||||||
|
case []byte:
|
||||||
|
return strings.TrimSpace(string(v))
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BaseAPIHandler contains the handlers for API endpoints.
|
// BaseAPIHandler contains the handlers for API endpoints.
|
||||||
@@ -371,10 +468,10 @@ func appendAPIResponse(c *gin.Context, data []byte) {
|
|||||||
|
|
||||||
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
||||||
// This path is the only supported execution route.
|
// This path is the only supported execution route.
|
||||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
|
||||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
return nil, errMsg
|
return nil, nil, errMsg
|
||||||
}
|
}
|
||||||
reqMeta := requestExecutionMetadata(ctx)
|
reqMeta := requestExecutionMetadata(ctx)
|
||||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||||
@@ -407,17 +504,20 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
|||||||
addon = hdr.Clone()
|
addon = hdr.Clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||||
}
|
}
|
||||||
return resp.Payload, nil
|
if !PassthroughHeadersEnabled(h.Cfg) {
|
||||||
|
return resp.Payload, nil, nil
|
||||||
|
}
|
||||||
|
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
||||||
// This path is the only supported execution route.
|
// This path is the only supported execution route.
|
||||||
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
|
||||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
return nil, errMsg
|
return nil, nil, errMsg
|
||||||
}
|
}
|
||||||
reqMeta := requestExecutionMetadata(ctx)
|
reqMeta := requestExecutionMetadata(ctx)
|
||||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||||
@@ -450,20 +550,24 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
|||||||
addon = hdr.Clone()
|
addon = hdr.Clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||||
}
|
}
|
||||||
return resp.Payload, nil
|
if !PassthroughHeadersEnabled(h.Cfg) {
|
||||||
|
return resp.Payload, nil, nil
|
||||||
|
}
|
||||||
|
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
||||||
// This path is the only supported execution route.
|
// This path is the only supported execution route.
|
||||||
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
|
// The returned http.Header carries upstream response headers captured before streaming begins.
|
||||||
|
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
|
||||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||||
errChan <- errMsg
|
errChan <- errMsg
|
||||||
close(errChan)
|
close(errChan)
|
||||||
return nil, errChan
|
return nil, nil, errChan
|
||||||
}
|
}
|
||||||
reqMeta := requestExecutionMetadata(ctx)
|
reqMeta := requestExecutionMetadata(ctx)
|
||||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||||
@@ -482,7 +586,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
SourceFormat: sdktranslator.FromString(handlerType),
|
SourceFormat: sdktranslator.FromString(handlerType),
|
||||||
}
|
}
|
||||||
opts.Metadata = reqMeta
|
opts.Metadata = reqMeta
|
||||||
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||||
status := http.StatusInternalServerError
|
status := http.StatusInternalServerError
|
||||||
@@ -499,8 +603,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
}
|
}
|
||||||
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||||
close(errChan)
|
close(errChan)
|
||||||
return nil, errChan
|
return nil, nil, errChan
|
||||||
}
|
}
|
||||||
|
passthroughHeadersEnabled := PassthroughHeadersEnabled(h.Cfg)
|
||||||
|
// Capture upstream headers from the initial connection synchronously before the goroutine starts.
|
||||||
|
// Keep a mutable map so bootstrap retries can replace it before first payload is sent.
|
||||||
|
var upstreamHeaders http.Header
|
||||||
|
if passthroughHeadersEnabled {
|
||||||
|
upstreamHeaders = cloneHeader(FilterUpstreamHeaders(streamResult.Headers))
|
||||||
|
if upstreamHeaders == nil {
|
||||||
|
upstreamHeaders = make(http.Header)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chunks := streamResult.Chunks
|
||||||
dataChan := make(chan []byte)
|
dataChan := make(chan []byte)
|
||||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -574,9 +689,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
if !sentPayload {
|
if !sentPayload {
|
||||||
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
|
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
|
||||||
bootstrapRetries++
|
bootstrapRetries++
|
||||||
retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||||
if retryErr == nil {
|
if retryErr == nil {
|
||||||
chunks = retryChunks
|
if passthroughHeadersEnabled {
|
||||||
|
replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers))
|
||||||
|
}
|
||||||
|
chunks = retryResult.Chunks
|
||||||
continue outer
|
continue outer
|
||||||
}
|
}
|
||||||
streamErr = retryErr
|
streamErr = retryErr
|
||||||
@@ -599,6 +717,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(chunk.Payload) > 0 {
|
if len(chunk.Payload) > 0 {
|
||||||
|
if handlerType == "openai-response" {
|
||||||
|
if err := validateSSEDataJSON(chunk.Payload); err != nil {
|
||||||
|
_ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
sentPayload = true
|
sentPayload = true
|
||||||
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
|
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
|
||||||
return
|
return
|
||||||
@@ -607,7 +731,36 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return dataChan, errChan
|
return dataChan, upstreamHeaders, errChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateSSEDataJSON(chunk []byte) error {
|
||||||
|
for _, line := range bytes.Split(chunk, []byte("\n")) {
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := bytes.TrimSpace(line[5:])
|
||||||
|
if len(data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if bytes.Equal(data, []byte("[DONE]")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if json.Valid(data) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const max = 512
|
||||||
|
preview := data
|
||||||
|
if len(preview) > max {
|
||||||
|
preview = preview[:max]
|
||||||
|
}
|
||||||
|
return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func statusFromError(err error) int {
|
func statusFromError(err error) int {
|
||||||
@@ -667,13 +820,33 @@ func cloneBytes(src []byte) []byte {
|
|||||||
return dst
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cloneHeader(src http.Header) http.Header {
|
||||||
|
if src == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
dst := make(http.Header, len(src))
|
||||||
|
for key, values := range src {
|
||||||
|
dst[key] = append([]string(nil), values...)
|
||||||
|
}
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func replaceHeader(dst http.Header, src http.Header) {
|
||||||
|
for key := range dst {
|
||||||
|
delete(dst, key)
|
||||||
|
}
|
||||||
|
for key, values := range src {
|
||||||
|
dst[key] = append([]string(nil), values...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
|
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
|
||||||
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
|
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
|
||||||
status := http.StatusInternalServerError
|
status := http.StatusInternalServerError
|
||||||
if msg != nil && msg.StatusCode > 0 {
|
if msg != nil && msg.StatusCode > 0 {
|
||||||
status = msg.StatusCode
|
status = msg.StatusCode
|
||||||
}
|
}
|
||||||
if msg != nil && msg.Addon != nil {
|
if msg != nil && msg.Addon != nil && PassthroughHeadersEnabled(h.Cfg) {
|
||||||
for key, values := range msg.Addon {
|
for key, values := range msg.Addon {
|
||||||
if len(values) == 0 {
|
if len(values) == 0 {
|
||||||
continue
|
continue
|
||||||
|
|||||||
68
sdk/api/handlers/handlers_error_response_test.go
Normal file
68
sdk/api/handlers/handlers_error_response_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWriteErrorResponse_AddonHeadersDisabledByDefault(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(nil, nil)
|
||||||
|
handler.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
Error: errors.New("rate limit"),
|
||||||
|
Addon: http.Header{
|
||||||
|
"Retry-After": {"30"},
|
||||||
|
"X-Request-Id": {"req-1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
if got := recorder.Header().Get("Retry-After"); got != "" {
|
||||||
|
t.Fatalf("Retry-After should be empty when passthrough is disabled, got %q", got)
|
||||||
|
}
|
||||||
|
if got := recorder.Header().Get("X-Request-Id"); got != "" {
|
||||||
|
t.Fatalf("X-Request-Id should be empty when passthrough is disabled, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteErrorResponse_AddonHeadersEnabled(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.Writer.Header().Set("X-Request-Id", "old-value")
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{PassthroughHeaders: true}, nil)
|
||||||
|
handler.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
Error: errors.New("rate limit"),
|
||||||
|
Addon: http.Header{
|
||||||
|
"Retry-After": {"30"},
|
||||||
|
"X-Request-Id": {"new-1", "new-2"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
if got := recorder.Header().Get("Retry-After"); got != "30" {
|
||||||
|
t.Fatalf("Retry-After = %q, want %q", got, "30")
|
||||||
|
}
|
||||||
|
if got := recorder.Header().Values("X-Request-Id"); !reflect.DeepEqual(got, []string{"new-1", "new-2"}) {
|
||||||
|
t.Fatalf("X-Request-Id = %#v, want %#v", got, []string{"new-1", "new-2"})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,7 +23,7 @@ func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreex
|
|||||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
e.mu.Lock()
|
e.mu.Lock()
|
||||||
e.calls++
|
e.calls++
|
||||||
call := e.calls
|
call := e.calls
|
||||||
@@ -40,12 +40,18 @@ func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth,
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
close(ch)
|
close(ch)
|
||||||
return ch, nil
|
return &coreexecutor.StreamResult{
|
||||||
|
Headers: http.Header{"X-Upstream-Attempt": {"1"}},
|
||||||
|
Chunks: ch,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||||
close(ch)
|
close(ch)
|
||||||
return ch, nil
|
return &coreexecutor.StreamResult{
|
||||||
|
Headers: http.Header{"X-Upstream-Attempt": {"2"}},
|
||||||
|
Chunks: ch,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
@@ -81,7 +87,7 @@ func (e *payloadThenErrorStreamExecutor) Execute(context.Context, *coreauth.Auth
|
|||||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
e.mu.Lock()
|
e.mu.Lock()
|
||||||
e.calls++
|
e.calls++
|
||||||
e.mu.Unlock()
|
e.mu.Unlock()
|
||||||
@@ -97,7 +103,7 @@ func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreaut
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
close(ch)
|
close(ch)
|
||||||
return ch, nil
|
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *payloadThenErrorStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
func (e *payloadThenErrorStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
@@ -122,6 +128,113 @@ func (e *payloadThenErrorStreamExecutor) Calls() int {
|
|||||||
return e.calls
|
return e.calls
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type authAwareStreamExecutor struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
calls int
|
||||||
|
authIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type invalidJSONStreamExecutor struct{}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
|
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||||
|
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")}
|
||||||
|
close(ch)
|
||||||
|
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, &coreauth.Error{
|
||||||
|
Code: "not_implemented",
|
||||||
|
Message: "HttpRequest not implemented",
|
||||||
|
HTTPStatus: http.StatusNotImplemented,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
|
_ = ctx
|
||||||
|
_ = req
|
||||||
|
_ = opts
|
||||||
|
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||||
|
|
||||||
|
authID := ""
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
e.mu.Lock()
|
||||||
|
e.calls++
|
||||||
|
e.authIDs = append(e.authIDs, authID)
|
||||||
|
e.mu.Unlock()
|
||||||
|
|
||||||
|
if authID == "auth1" {
|
||||||
|
ch <- coreexecutor.StreamChunk{
|
||||||
|
Err: &coreauth.Error{
|
||||||
|
Code: "unauthorized",
|
||||||
|
Message: "unauthorized",
|
||||||
|
Retryable: false,
|
||||||
|
HTTPStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
close(ch)
|
||||||
|
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||||
|
close(ch)
|
||||||
|
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, &coreauth.Error{
|
||||||
|
Code: "not_implemented",
|
||||||
|
Message: "HttpRequest not implemented",
|
||||||
|
HTTPStatus: http.StatusNotImplemented,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) Calls() int {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
return e.calls
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) AuthIDs() []string {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
out := make([]string, len(e.authIDs))
|
||||||
|
copy(out, e.authIDs)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||||
executor := &failOnceStreamExecutor{}
|
executor := &failOnceStreamExecutor{}
|
||||||
manager := coreauth.NewManager(nil, nil, nil)
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
@@ -154,12 +267,78 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
|||||||
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||||
|
PassthroughHeaders: true,
|
||||||
|
Streaming: sdkconfig.StreamingConfig{
|
||||||
|
BootstrapRetries: 1,
|
||||||
|
},
|
||||||
|
}, manager)
|
||||||
|
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
|
if dataChan == nil || errChan == nil {
|
||||||
|
t.Fatalf("expected non-nil channels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []byte
|
||||||
|
for chunk := range dataChan {
|
||||||
|
got = append(got, chunk...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for msg := range errChan {
|
||||||
|
if msg != nil {
|
||||||
|
t.Fatalf("unexpected error: %+v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(got) != "ok" {
|
||||||
|
t.Fatalf("expected payload ok, got %q", string(got))
|
||||||
|
}
|
||||||
|
if executor.Calls() != 2 {
|
||||||
|
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
|
||||||
|
}
|
||||||
|
upstreamAttemptHeader := upstreamHeaders.Get("X-Upstream-Attempt")
|
||||||
|
if upstreamAttemptHeader != "2" {
|
||||||
|
t.Fatalf("expected upstream header from retry attempt, got %q", upstreamAttemptHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteStreamWithAuthManager_HeaderPassthroughDisabledByDefault(t *testing.T) {
|
||||||
|
executor := &failOnceStreamExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
|
||||||
|
auth1 := &coreauth.Auth{
|
||||||
|
ID: "auth1",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test1@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth1): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth2 := &coreauth.Auth{
|
||||||
|
ID: "auth2",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test2@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth2): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||||
|
})
|
||||||
|
|
||||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||||
Streaming: sdkconfig.StreamingConfig{
|
Streaming: sdkconfig.StreamingConfig{
|
||||||
BootstrapRetries: 1,
|
BootstrapRetries: 1,
|
||||||
},
|
},
|
||||||
}, manager)
|
}, manager)
|
||||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
if dataChan == nil || errChan == nil {
|
if dataChan == nil || errChan == nil {
|
||||||
t.Fatalf("expected non-nil channels")
|
t.Fatalf("expected non-nil channels")
|
||||||
}
|
}
|
||||||
@@ -168,7 +347,6 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
|||||||
for chunk := range dataChan {
|
for chunk := range dataChan {
|
||||||
got = append(got, chunk...)
|
got = append(got, chunk...)
|
||||||
}
|
}
|
||||||
|
|
||||||
for msg := range errChan {
|
for msg := range errChan {
|
||||||
if msg != nil {
|
if msg != nil {
|
||||||
t.Fatalf("unexpected error: %+v", msg)
|
t.Fatalf("unexpected error: %+v", msg)
|
||||||
@@ -178,8 +356,8 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
|||||||
if string(got) != "ok" {
|
if string(got) != "ok" {
|
||||||
t.Fatalf("expected payload ok, got %q", string(got))
|
t.Fatalf("expected payload ok, got %q", string(got))
|
||||||
}
|
}
|
||||||
if executor.Calls() != 2 {
|
if upstreamHeaders != nil {
|
||||||
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
|
t.Fatalf("expected nil upstream headers when passthrough is disabled, got %#v", upstreamHeaders)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,7 +398,7 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
|||||||
BootstrapRetries: 1,
|
BootstrapRetries: 1,
|
||||||
},
|
},
|
||||||
}, manager)
|
}, manager)
|
||||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
if dataChan == nil || errChan == nil {
|
if dataChan == nil || errChan == nil {
|
||||||
t.Fatalf("expected non-nil channels")
|
t.Fatalf("expected non-nil channels")
|
||||||
}
|
}
|
||||||
@@ -252,3 +430,180 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
|||||||
t.Fatalf("expected 1 stream attempt, got %d", executor.Calls())
|
t.Fatalf("expected 1 stream attempt, got %d", executor.Calls())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) {
|
||||||
|
executor := &authAwareStreamExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
|
||||||
|
auth1 := &coreauth.Auth{
|
||||||
|
ID: "auth1",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test1@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth1): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth2 := &coreauth.Auth{
|
||||||
|
ID: "auth2",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test2@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth2): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||||
|
Streaming: sdkconfig.StreamingConfig{
|
||||||
|
BootstrapRetries: 1,
|
||||||
|
},
|
||||||
|
}, manager)
|
||||||
|
ctx := WithPinnedAuthID(context.Background(), "auth1")
|
||||||
|
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
|
if dataChan == nil || errChan == nil {
|
||||||
|
t.Fatalf("expected non-nil channels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []byte
|
||||||
|
for chunk := range dataChan {
|
||||||
|
got = append(got, chunk...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotErr error
|
||||||
|
for msg := range errChan {
|
||||||
|
if msg != nil && msg.Error != nil {
|
||||||
|
gotErr = msg.Error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(got) != 0 {
|
||||||
|
t.Fatalf("expected empty payload, got %q", string(got))
|
||||||
|
}
|
||||||
|
if gotErr == nil {
|
||||||
|
t.Fatalf("expected terminal error, got nil")
|
||||||
|
}
|
||||||
|
authIDs := executor.AuthIDs()
|
||||||
|
if len(authIDs) == 0 {
|
||||||
|
t.Fatalf("expected at least one upstream attempt")
|
||||||
|
}
|
||||||
|
for _, authID := range authIDs {
|
||||||
|
if authID != "auth1" {
|
||||||
|
t.Fatalf("expected all attempts on auth1, got sequence %v", authIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *testing.T) {
|
||||||
|
executor := &authAwareStreamExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
|
||||||
|
auth2 := &coreauth.Auth{
|
||||||
|
ID: "auth2",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test2@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth2): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||||
|
Streaming: sdkconfig.StreamingConfig{
|
||||||
|
BootstrapRetries: 0,
|
||||||
|
},
|
||||||
|
}, manager)
|
||||||
|
|
||||||
|
selectedAuthID := ""
|
||||||
|
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
|
||||||
|
selectedAuthID = authID
|
||||||
|
})
|
||||||
|
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
|
if dataChan == nil || errChan == nil {
|
||||||
|
t.Fatalf("expected non-nil channels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []byte
|
||||||
|
for chunk := range dataChan {
|
||||||
|
got = append(got, chunk...)
|
||||||
|
}
|
||||||
|
for msg := range errChan {
|
||||||
|
if msg != nil {
|
||||||
|
t.Fatalf("unexpected error: %+v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(got) != "ok" {
|
||||||
|
t.Fatalf("expected payload ok, got %q", string(got))
|
||||||
|
}
|
||||||
|
if selectedAuthID != "auth2" {
|
||||||
|
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) {
|
||||||
|
executor := &invalidJSONStreamExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
|
||||||
|
auth1 := &coreauth.Auth{
|
||||||
|
ID: "auth1",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test1@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth1): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||||
|
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
|
if dataChan == nil || errChan == nil {
|
||||||
|
t.Fatalf("expected non-nil channels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []byte
|
||||||
|
for chunk := range dataChan {
|
||||||
|
got = append(got, chunk...)
|
||||||
|
}
|
||||||
|
if len(got) != 0 {
|
||||||
|
t.Fatalf("expected empty payload, got %q", string(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
gotErr := false
|
||||||
|
for msg := range errChan {
|
||||||
|
if msg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if msg.StatusCode != http.StatusBadGateway {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode)
|
||||||
|
}
|
||||||
|
if msg.Error == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
gotErr = true
|
||||||
|
}
|
||||||
|
if !gotErr {
|
||||||
|
t.Fatalf("expected terminal error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
80
sdk/api/handlers/header_filter.go
Normal file
80
sdk/api/handlers/header_filter.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT
|
||||||
|
// be forwarded by proxies, plus security-sensitive headers that should not leak.
|
||||||
|
var hopByHopHeaders = map[string]struct{}{
|
||||||
|
// RFC 7230 hop-by-hop
|
||||||
|
"Connection": {},
|
||||||
|
"Keep-Alive": {},
|
||||||
|
"Proxy-Authenticate": {},
|
||||||
|
"Proxy-Authorization": {},
|
||||||
|
"Te": {},
|
||||||
|
"Trailer": {},
|
||||||
|
"Transfer-Encoding": {},
|
||||||
|
"Upgrade": {},
|
||||||
|
// Security-sensitive
|
||||||
|
"Set-Cookie": {},
|
||||||
|
// CPA-managed (set by handlers, not upstream)
|
||||||
|
"Content-Length": {},
|
||||||
|
"Content-Encoding": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterUpstreamHeaders returns a copy of src with hop-by-hop and security-sensitive
|
||||||
|
// headers removed. Returns nil if src is nil or empty after filtering.
|
||||||
|
func FilterUpstreamHeaders(src http.Header) http.Header {
|
||||||
|
if src == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
connectionScoped := connectionScopedHeaders(src)
|
||||||
|
dst := make(http.Header)
|
||||||
|
for key, values := range src {
|
||||||
|
canonicalKey := http.CanonicalHeaderKey(key)
|
||||||
|
if _, blocked := hopByHopHeaders[canonicalKey]; blocked {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, scoped := connectionScoped[canonicalKey]; scoped {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dst[key] = values
|
||||||
|
}
|
||||||
|
if len(dst) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectionScopedHeaders(src http.Header) map[string]struct{} {
|
||||||
|
scoped := make(map[string]struct{})
|
||||||
|
for _, rawValue := range src.Values("Connection") {
|
||||||
|
for _, token := range strings.Split(rawValue, ",") {
|
||||||
|
headerName := strings.TrimSpace(token)
|
||||||
|
if headerName == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
scoped[http.CanonicalHeaderKey(headerName)] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scoped
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteUpstreamHeaders writes filtered upstream headers to the gin response writer.
|
||||||
|
// Headers already set by CPA (e.g., Content-Type) are NOT overwritten.
|
||||||
|
func WriteUpstreamHeaders(dst http.Header, src http.Header) {
|
||||||
|
if src == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for key, values := range src {
|
||||||
|
// Don't overwrite headers already set by CPA handlers
|
||||||
|
if dst.Get(key) != "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, v := range values {
|
||||||
|
dst.Add(key, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
55
sdk/api/handlers/header_filter_test.go
Normal file
55
sdk/api/handlers/header_filter_test.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFilterUpstreamHeaders_RemovesConnectionScopedHeaders(t *testing.T) {
|
||||||
|
src := http.Header{}
|
||||||
|
src.Add("Connection", "keep-alive, x-hop-a, x-hop-b")
|
||||||
|
src.Add("Connection", "x-hop-c")
|
||||||
|
src.Set("Keep-Alive", "timeout=5")
|
||||||
|
src.Set("X-Hop-A", "a")
|
||||||
|
src.Set("X-Hop-B", "b")
|
||||||
|
src.Set("X-Hop-C", "c")
|
||||||
|
src.Set("X-Request-Id", "req-1")
|
||||||
|
src.Set("Set-Cookie", "session=secret")
|
||||||
|
|
||||||
|
filtered := FilterUpstreamHeaders(src)
|
||||||
|
if filtered == nil {
|
||||||
|
t.Fatalf("expected filtered headers, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := filtered.Get("X-Request-Id")
|
||||||
|
if requestID != "req-1" {
|
||||||
|
t.Fatalf("expected X-Request-Id to be preserved, got %q", requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
blockedHeaderKeys := []string{
|
||||||
|
"Connection",
|
||||||
|
"Keep-Alive",
|
||||||
|
"X-Hop-A",
|
||||||
|
"X-Hop-B",
|
||||||
|
"X-Hop-C",
|
||||||
|
"Set-Cookie",
|
||||||
|
}
|
||||||
|
for _, key := range blockedHeaderKeys {
|
||||||
|
value := filtered.Get(key)
|
||||||
|
if value != "" {
|
||||||
|
t.Fatalf("expected %s to be removed, got %q", key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterUpstreamHeaders_ReturnsNilWhenAllHeadersBlocked(t *testing.T) {
|
||||||
|
src := http.Header{}
|
||||||
|
src.Add("Connection", "x-hop-a")
|
||||||
|
src.Set("X-Hop-A", "a")
|
||||||
|
src.Set("Set-Cookie", "session=secret")
|
||||||
|
|
||||||
|
filtered := FilterUpstreamHeaders(src)
|
||||||
|
if filtered != nil {
|
||||||
|
t.Fatalf("expected nil when all headers are filtered, got %#v", filtered)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -420,6 +420,7 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
|||||||
|
|
||||||
// Check if this chunk has any meaningful content
|
// Check if this chunk has any meaningful content
|
||||||
hasContent := false
|
hasContent := false
|
||||||
|
hasUsage := root.Get("usage").Exists()
|
||||||
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
|
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
|
||||||
chatChoices.ForEach(func(_, choice gjson.Result) bool {
|
chatChoices.ForEach(func(_, choice gjson.Result) bool {
|
||||||
// Check if delta has content or finish_reason
|
// Check if delta has content or finish_reason
|
||||||
@@ -438,8 +439,8 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no meaningful content, return nil to indicate this chunk should be skipped
|
// If no meaningful content and no usage, return nil to indicate this chunk should be skipped
|
||||||
if !hasContent {
|
if !hasContent && !hasUsage {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -498,6 +499,11 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
|||||||
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
|
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Copy usage if present
|
||||||
|
if usage := root.Get("usage"); usage.Exists() {
|
||||||
|
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
|
||||||
|
}
|
||||||
|
|
||||||
return []byte(out)
|
return []byte(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -513,12 +519,13 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
@@ -528,12 +535,13 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponseViaResponses(c *gin.Context
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
converted := convertResponsesObjectToChatCompletion(cliCtx, modelName, originalChatJSON, rawJSON, resp)
|
converted := convertResponsesObjectToChatCompletion(cliCtx, modelName, originalChatJSON, rawJSON, resp)
|
||||||
if converted == nil {
|
if converted == nil {
|
||||||
h.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
h.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||||
@@ -569,7 +577,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||||
|
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
@@ -602,6 +610,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
|||||||
if !ok {
|
if !ok {
|
||||||
// Stream closed without data? Send DONE or just headers.
|
// Stream closed without data? Send DONE or just headers.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
@@ -610,6 +619,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
|||||||
|
|
||||||
// Success! Commit to streaming headers.
|
// Success! Commit to streaming headers.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
|
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
@@ -635,7 +645,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponseViaResponses(c *gin.Context, r
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
||||||
var param any
|
var param any
|
||||||
|
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
@@ -666,6 +676,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponseViaResponses(c *gin.Context, r
|
|||||||
case chunk, ok := <-dataChan:
|
case chunk, ok := <-dataChan:
|
||||||
if !ok {
|
if !ok {
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
@@ -673,6 +684,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponseViaResponses(c *gin.Context, r
|
|||||||
}
|
}
|
||||||
|
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
writeConvertedResponsesChunk(c, cliCtx, modelName, originalChatJSON, rawJSON, chunk, ¶m)
|
writeConvertedResponsesChunk(c, cliCtx, modelName, originalChatJSON, rawJSON, chunk, ¶m)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
|
||||||
@@ -698,13 +710,14 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
|
|||||||
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||||
stopKeepAlive()
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
completionsResp := convertChatCompletionsResponseToCompletions(resp)
|
completionsResp := convertChatCompletionsResponseToCompletions(resp)
|
||||||
_, _ = c.Writer.Write(completionsResp)
|
_, _ = c.Writer.Write(completionsResp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
@@ -735,7 +748,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||||
|
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
@@ -766,6 +779,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
|||||||
case chunk, ok := <-dataChan:
|
case chunk, ok := <-dataChan:
|
||||||
if !ok {
|
if !ok {
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
@@ -774,6 +788,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
|||||||
|
|
||||||
// Success! Set headers.
|
// Success! Set headers.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
|
|
||||||
// Write the first chunk
|
// Write the first chunk
|
||||||
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ func (e *compactCaptureExecutor) Execute(ctx context.Context, auth *coreauth.Aut
|
|||||||
return coreexecutor.Response{Payload: []byte(`{"ok":true}`)}, nil
|
return coreexecutor.Response{Payload: []byte(`{"ok":true}`)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -139,13 +139,14 @@ func (h *OpenAIResponsesAPIHandler) Compact(c *gin.Context) {
|
|||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact")
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact")
|
||||||
stopKeepAlive()
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
@@ -164,13 +165,14 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
|
|||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||||
|
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||||
stopKeepAlive()
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
@@ -180,12 +182,13 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponseViaChat(c *gin.Con
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(chatJSON, "model").String()
|
modelName := gjson.GetBytes(chatJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
var param any
|
var param any
|
||||||
converted := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(cliCtx, modelName, originalResponsesJSON, originalResponsesJSON, resp, ¶m)
|
converted := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(cliCtx, modelName, originalResponsesJSON, originalResponsesJSON, resp, ¶m)
|
||||||
if converted == "" {
|
if converted == "" {
|
||||||
@@ -223,7 +226,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
|||||||
// New core execution path
|
// New core execution path
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||||
|
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
@@ -256,6 +259,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
|||||||
if !ok {
|
if !ok {
|
||||||
// Stream closed without data? Send headers and done.
|
// Stream closed without data? Send headers and done.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
@@ -264,6 +268,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
|||||||
|
|
||||||
// Success! Set headers.
|
// Success! Set headers.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
|
|
||||||
// Write first chunk logic (matching forwardResponsesStream)
|
// Write first chunk logic (matching forwardResponsesStream)
|
||||||
if bytes.HasPrefix(chunk, []byte("event:")) {
|
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||||
@@ -294,7 +299,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponseViaChat(c *gin.Contex
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(chatJSON, "model").String()
|
modelName := gjson.GetBytes(chatJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
||||||
var param any
|
var param any
|
||||||
|
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
@@ -324,6 +329,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponseViaChat(c *gin.Contex
|
|||||||
case chunk, ok := <-dataChan:
|
case chunk, ok := <-dataChan:
|
||||||
if !ok {
|
if !ok {
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
@@ -331,6 +337,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponseViaChat(c *gin.Contex
|
|||||||
}
|
}
|
||||||
|
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
writeChatAsResponsesChunk(c, cliCtx, modelName, originalResponsesJSON, chunk, ¶m)
|
writeChatAsResponsesChunk(c, cliCtx, modelName, originalResponsesJSON, chunk, ¶m)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
|
||||||
@@ -411,8 +418,8 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
|
|||||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||||
errText = errMsg.Error.Error()
|
errText = errMsg.Error.Error()
|
||||||
}
|
}
|
||||||
body := handlers.BuildErrorResponseBody(status, errText)
|
chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0)
|
||||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
|
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
|
||||||
},
|
},
|
||||||
WriteDone: func() {
|
WriteDone: func() {
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
|
||||||
|
h := NewOpenAIResponsesAPIHandler(base)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected gin writer to implement http.Flusher")
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make(chan []byte)
|
||||||
|
errs := make(chan *interfaces.ErrorMessage, 1)
|
||||||
|
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
|
||||||
|
close(errs)
|
||||||
|
|
||||||
|
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
|
||||||
|
body := recorder.Body.String()
|
||||||
|
if !strings.Contains(body, `"type":"error"`) {
|
||||||
|
t.Fatalf("expected responses error chunk, got: %q", body)
|
||||||
|
}
|
||||||
|
if strings.Contains(body, `"error":{`) {
|
||||||
|
t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
662
sdk/api/handlers/openai/openai_responses_websocket.go
Normal file
662
sdk/api/handlers/openai/openai_responses_websocket.go
Normal file
@@ -0,0 +1,662 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
wsRequestTypeCreate = "response.create"
|
||||||
|
wsRequestTypeAppend = "response.append"
|
||||||
|
wsEventTypeError = "error"
|
||||||
|
wsEventTypeCompleted = "response.completed"
|
||||||
|
wsEventTypeDone = "response.done"
|
||||||
|
wsDoneMarker = "[DONE]"
|
||||||
|
wsTurnStateHeader = "x-codex-turn-state"
|
||||||
|
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||||
|
wsPayloadLogMaxSize = 2048
|
||||||
|
)
|
||||||
|
|
||||||
|
var responsesWebsocketUpgrader = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 4096,
|
||||||
|
WriteBufferSize: 4096,
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponsesWebsocket handles websocket requests for /v1/responses.
|
||||||
|
// It accepts `response.create` and `response.append` requests and streams
|
||||||
|
// response events back as JSON websocket text messages.
|
||||||
|
func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||||
|
conn, err := responsesWebsocketUpgrader.Upgrade(c.Writer, c.Request, websocketUpgradeHeaders(c.Request))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
passthroughSessionID := uuid.NewString()
|
||||||
|
clientRemoteAddr := ""
|
||||||
|
if c != nil && c.Request != nil {
|
||||||
|
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
|
||||||
|
}
|
||||||
|
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr)
|
||||||
|
var wsTerminateErr error
|
||||||
|
var wsBodyLog strings.Builder
|
||||||
|
defer func() {
|
||||||
|
if wsTerminateErr != nil {
|
||||||
|
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
|
||||||
|
} else {
|
||||||
|
log.Infof("responses websocket: session closing id=%s", passthroughSessionID)
|
||||||
|
}
|
||||||
|
if h != nil && h.AuthManager != nil {
|
||||||
|
h.AuthManager.CloseExecutionSession(passthroughSessionID)
|
||||||
|
log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID)
|
||||||
|
}
|
||||||
|
setWebsocketRequestBody(c, wsBodyLog.String())
|
||||||
|
if errClose := conn.Close(); errClose != nil {
|
||||||
|
log.Warnf("responses websocket: close connection error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var lastRequest []byte
|
||||||
|
lastResponseOutput := []byte("[]")
|
||||||
|
pinnedAuthID := ""
|
||||||
|
|
||||||
|
for {
|
||||||
|
msgType, payload, errReadMessage := conn.ReadMessage()
|
||||||
|
if errReadMessage != nil {
|
||||||
|
wsTerminateErr = errReadMessage
|
||||||
|
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error()))
|
||||||
|
if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||||
|
log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage)
|
||||||
|
} else {
|
||||||
|
// log.Warnf("responses websocket: read message failed id=%s error=%v", passthroughSessionID, errReadMessage)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// log.Infof(
|
||||||
|
// "responses websocket: downstream_in id=%s type=%d event=%s payload=%s",
|
||||||
|
// passthroughSessionID,
|
||||||
|
// msgType,
|
||||||
|
// websocketPayloadEventType(payload),
|
||||||
|
// websocketPayloadPreview(payload),
|
||||||
|
// )
|
||||||
|
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
||||||
|
|
||||||
|
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
|
||||||
|
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
||||||
|
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||||
|
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestJSON []byte
|
||||||
|
var updatedLastRequest []byte
|
||||||
|
var errMsg *interfaces.ErrorMessage
|
||||||
|
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithMode(
|
||||||
|
payload,
|
||||||
|
lastRequest,
|
||||||
|
lastResponseOutput,
|
||||||
|
allowIncrementalInputWithPreviousResponseID,
|
||||||
|
)
|
||||||
|
if errMsg != nil {
|
||||||
|
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||||
|
markAPIResponseTimestamp(c)
|
||||||
|
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||||
|
appendWebsocketEvent(&wsBodyLog, "response", errorPayload)
|
||||||
|
log.Infof(
|
||||||
|
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||||
|
passthroughSessionID,
|
||||||
|
websocket.TextMessage,
|
||||||
|
websocketPayloadEventType(errorPayload),
|
||||||
|
websocketPayloadPreview(errorPayload),
|
||||||
|
)
|
||||||
|
if errWrite != nil {
|
||||||
|
log.Warnf(
|
||||||
|
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||||
|
passthroughSessionID,
|
||||||
|
websocketPayloadEventType(errorPayload),
|
||||||
|
errWrite,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lastRequest = updatedLastRequest
|
||||||
|
|
||||||
|
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||||
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
|
cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx)
|
||||||
|
cliCtx = handlers.WithExecutionSessionID(cliCtx, passthroughSessionID)
|
||||||
|
if pinnedAuthID != "" {
|
||||||
|
cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID)
|
||||||
|
} else {
|
||||||
|
cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) {
|
||||||
|
pinnedAuthID = strings.TrimSpace(authID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||||
|
|
||||||
|
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
|
||||||
|
if errForward != nil {
|
||||||
|
wsTerminateErr = errForward
|
||||||
|
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
|
||||||
|
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lastResponseOutput = completedOutput
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketUpgradeHeaders(req *http.Request) http.Header {
|
||||||
|
headers := http.Header{}
|
||||||
|
if req == nil {
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep the same sticky turn-state across reconnects when provided by the client.
|
||||||
|
turnState := strings.TrimSpace(req.Header.Get(wsTurnStateHeader))
|
||||||
|
if turnState != "" {
|
||||||
|
headers.Set(wsTurnStateHeader, turnState)
|
||||||
|
}
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||||
|
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||||
|
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
|
||||||
|
switch requestType {
|
||||||
|
case wsRequestTypeCreate:
|
||||||
|
// log.Infof("responses websocket: response.create request")
|
||||||
|
if len(lastRequest) == 0 {
|
||||||
|
return normalizeResponseCreateRequest(rawJSON)
|
||||||
|
}
|
||||||
|
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
||||||
|
case wsRequestTypeAppend:
|
||||||
|
// log.Infof("responses websocket: response.append request")
|
||||||
|
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
||||||
|
default:
|
||||||
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("unsupported websocket request type: %s", requestType),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||||
|
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||||
|
if errDelete != nil {
|
||||||
|
normalized = bytes.Clone(rawJSON)
|
||||||
|
}
|
||||||
|
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||||
|
if !gjson.GetBytes(normalized, "input").Exists() {
|
||||||
|
normalized, _ = sjson.SetRawBytes(normalized, "input", []byte("[]"))
|
||||||
|
}
|
||||||
|
|
||||||
|
modelName := strings.TrimSpace(gjson.GetBytes(normalized, "model").String())
|
||||||
|
if modelName == "" {
|
||||||
|
return nil, nil, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("missing model in response.create request"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return normalized, bytes.Clone(normalized), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||||
|
if len(lastRequest) == 0 {
|
||||||
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("websocket request received before response.create"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nextInput := gjson.GetBytes(rawJSON, "input")
|
||||||
|
if !nextInput.Exists() || !nextInput.IsArray() {
|
||||||
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("websocket request requires array field: input"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
|
||||||
|
// Do not expand it into a full input transcript; upstream expects the incremental payload.
|
||||||
|
if allowIncrementalInputWithPreviousResponseID {
|
||||||
|
if prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()); prev != "" {
|
||||||
|
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||||
|
if errDelete != nil {
|
||||||
|
normalized = bytes.Clone(rawJSON)
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(normalized, "model").Exists() {
|
||||||
|
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||||
|
if modelName != "" {
|
||||||
|
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
||||||
|
instructions := gjson.GetBytes(lastRequest, "instructions")
|
||||||
|
if instructions.Exists() {
|
||||||
|
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||||
|
return normalized, bytes.Clone(normalized), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
existingInput := gjson.GetBytes(lastRequest, "input")
|
||||||
|
mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
|
||||||
|
if errMerge != nil {
|
||||||
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
|
||||||
|
if errMerge != nil {
|
||||||
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("invalid request input: %w", errMerge),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||||
|
if errDelete != nil {
|
||||||
|
normalized = bytes.Clone(rawJSON)
|
||||||
|
}
|
||||||
|
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
|
||||||
|
var errSet error
|
||||||
|
normalized, errSet = sjson.SetRawBytes(normalized, "input", []byte(mergedInput))
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("failed to merge websocket input: %w", errSet),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(normalized, "model").Exists() {
|
||||||
|
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||||
|
if modelName != "" {
|
||||||
|
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
||||||
|
instructions := gjson.GetBytes(lastRequest, "instructions")
|
||||||
|
if instructions.Exists() {
|
||||||
|
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||||
|
return normalized, bytes.Clone(normalized), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
|
||||||
|
if len(attributes) > 0 {
|
||||||
|
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
|
||||||
|
parsed, errParse := strconv.ParseBool(raw)
|
||||||
|
if errParse == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(metadata) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
raw, ok := metadata["websockets"]
|
||||||
|
if !ok || raw == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch value := raw.(type) {
|
||||||
|
case bool:
|
||||||
|
return value
|
||||||
|
case string:
|
||||||
|
parsed, errParse := strconv.ParseBool(strings.TrimSpace(value))
|
||||||
|
if errParse == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
|
||||||
|
existingRaw = strings.TrimSpace(existingRaw)
|
||||||
|
appendRaw = strings.TrimSpace(appendRaw)
|
||||||
|
if existingRaw == "" {
|
||||||
|
existingRaw = "[]"
|
||||||
|
}
|
||||||
|
if appendRaw == "" {
|
||||||
|
appendRaw = "[]"
|
||||||
|
}
|
||||||
|
|
||||||
|
var existing []json.RawMessage
|
||||||
|
if err := json.Unmarshal([]byte(existingRaw), &existing); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var appendItems []json.RawMessage
|
||||||
|
if err := json.Unmarshal([]byte(appendRaw), &appendItems); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := append(existing, appendItems...)
|
||||||
|
out, err := json.Marshal(merged)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(out), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeJSONArrayRaw(raw []byte) string {
|
||||||
|
trimmed := strings.TrimSpace(string(raw))
|
||||||
|
if trimmed == "" {
|
||||||
|
return "[]"
|
||||||
|
}
|
||||||
|
result := gjson.Parse(trimmed)
|
||||||
|
if result.Type == gjson.JSON && result.IsArray() {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
return "[]"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||||
|
c *gin.Context,
|
||||||
|
conn *websocket.Conn,
|
||||||
|
cancel handlers.APIHandlerCancelFunc,
|
||||||
|
data <-chan []byte,
|
||||||
|
errs <-chan *interfaces.ErrorMessage,
|
||||||
|
wsBodyLog *strings.Builder,
|
||||||
|
sessionID string,
|
||||||
|
) ([]byte, error) {
|
||||||
|
completed := false
|
||||||
|
completedOutput := []byte("[]")
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
cancel(c.Request.Context().Err())
|
||||||
|
return completedOutput, c.Request.Context().Err()
|
||||||
|
case errMsg, ok := <-errs:
|
||||||
|
if !ok {
|
||||||
|
errs = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if errMsg != nil {
|
||||||
|
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||||
|
markAPIResponseTimestamp(c)
|
||||||
|
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||||
|
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
||||||
|
log.Infof(
|
||||||
|
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||||
|
sessionID,
|
||||||
|
websocket.TextMessage,
|
||||||
|
websocketPayloadEventType(errorPayload),
|
||||||
|
websocketPayloadPreview(errorPayload),
|
||||||
|
)
|
||||||
|
if errWrite != nil {
|
||||||
|
// log.Warnf(
|
||||||
|
// "responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||||
|
// sessionID,
|
||||||
|
// websocketPayloadEventType(errorPayload),
|
||||||
|
// errWrite,
|
||||||
|
// )
|
||||||
|
cancel(errMsg.Error)
|
||||||
|
return completedOutput, errWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errMsg != nil {
|
||||||
|
cancel(errMsg.Error)
|
||||||
|
} else {
|
||||||
|
cancel(nil)
|
||||||
|
}
|
||||||
|
return completedOutput, nil
|
||||||
|
case chunk, ok := <-data:
|
||||||
|
if !ok {
|
||||||
|
if !completed {
|
||||||
|
errMsg := &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusRequestTimeout,
|
||||||
|
Error: fmt.Errorf("stream closed before response.completed"),
|
||||||
|
}
|
||||||
|
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||||
|
markAPIResponseTimestamp(c)
|
||||||
|
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||||
|
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
||||||
|
log.Infof(
|
||||||
|
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||||
|
sessionID,
|
||||||
|
websocket.TextMessage,
|
||||||
|
websocketPayloadEventType(errorPayload),
|
||||||
|
websocketPayloadPreview(errorPayload),
|
||||||
|
)
|
||||||
|
if errWrite != nil {
|
||||||
|
log.Warnf(
|
||||||
|
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||||
|
sessionID,
|
||||||
|
websocketPayloadEventType(errorPayload),
|
||||||
|
errWrite,
|
||||||
|
)
|
||||||
|
cancel(errMsg.Error)
|
||||||
|
return completedOutput, errWrite
|
||||||
|
}
|
||||||
|
cancel(errMsg.Error)
|
||||||
|
return completedOutput, nil
|
||||||
|
}
|
||||||
|
cancel(nil)
|
||||||
|
return completedOutput, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||||
|
for i := range payloads {
|
||||||
|
eventType := gjson.GetBytes(payloads[i], "type").String()
|
||||||
|
if eventType == wsEventTypeCompleted {
|
||||||
|
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
|
||||||
|
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
|
||||||
|
|
||||||
|
completed = true
|
||||||
|
completedOutput = responseCompletedOutputFromPayload(payloads[i])
|
||||||
|
}
|
||||||
|
markAPIResponseTimestamp(c)
|
||||||
|
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
||||||
|
// log.Infof(
|
||||||
|
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||||
|
// sessionID,
|
||||||
|
// websocket.TextMessage,
|
||||||
|
// websocketPayloadEventType(payloads[i]),
|
||||||
|
// websocketPayloadPreview(payloads[i]),
|
||||||
|
// )
|
||||||
|
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
||||||
|
log.Warnf(
|
||||||
|
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||||
|
sessionID,
|
||||||
|
websocketPayloadEventType(payloads[i]),
|
||||||
|
errWrite,
|
||||||
|
)
|
||||||
|
cancel(errWrite)
|
||||||
|
return completedOutput, errWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseCompletedOutputFromPayload(payload []byte) []byte {
|
||||||
|
output := gjson.GetBytes(payload, "response.output")
|
||||||
|
if output.Exists() && output.IsArray() {
|
||||||
|
return bytes.Clone([]byte(output.Raw))
|
||||||
|
}
|
||||||
|
return []byte("[]")
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
|
||||||
|
payloads := make([][]byte, 0, 2)
|
||||||
|
lines := bytes.Split(chunk, []byte("\n"))
|
||||||
|
for i := range lines {
|
||||||
|
line := bytes.TrimSpace(lines[i])
|
||||||
|
if len(line) == 0 || bytes.HasPrefix(line, []byte("event:")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(line, []byte("data:")) {
|
||||||
|
line = bytes.TrimSpace(line[len("data:"):])
|
||||||
|
}
|
||||||
|
if len(line) == 0 || bytes.Equal(line, []byte(wsDoneMarker)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if json.Valid(line) {
|
||||||
|
payloads = append(payloads, bytes.Clone(line))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(payloads) > 0 {
|
||||||
|
return payloads
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := bytes.TrimSpace(chunk)
|
||||||
|
if bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||||
|
trimmed = bytes.TrimSpace(trimmed[len("data:"):])
|
||||||
|
}
|
||||||
|
if len(trimmed) > 0 && !bytes.Equal(trimmed, []byte(wsDoneMarker)) && json.Valid(trimmed) {
|
||||||
|
payloads = append(payloads, bytes.Clone(trimmed))
|
||||||
|
}
|
||||||
|
return payloads
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) {
|
||||||
|
status := http.StatusInternalServerError
|
||||||
|
errText := http.StatusText(status)
|
||||||
|
if errMsg != nil {
|
||||||
|
if errMsg.StatusCode > 0 {
|
||||||
|
status = errMsg.StatusCode
|
||||||
|
errText = http.StatusText(status)
|
||||||
|
}
|
||||||
|
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
|
||||||
|
errText = errMsg.Error.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
body := handlers.BuildErrorResponseBody(status, errText)
|
||||||
|
payload := map[string]any{
|
||||||
|
"type": wsEventTypeError,
|
||||||
|
"status": status,
|
||||||
|
}
|
||||||
|
|
||||||
|
if errMsg != nil && errMsg.Addon != nil {
|
||||||
|
headers := map[string]any{}
|
||||||
|
for key, values := range errMsg.Addon {
|
||||||
|
if len(values) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
headers[key] = values[0]
|
||||||
|
}
|
||||||
|
if len(headers) > 0 {
|
||||||
|
payload["headers"] = headers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) > 0 && json.Valid(body) {
|
||||||
|
var decoded map[string]any
|
||||||
|
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
|
||||||
|
if inner, ok := decoded["error"]; ok {
|
||||||
|
payload["error"] = inner
|
||||||
|
} else {
|
||||||
|
payload["error"] = decoded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := payload["error"]; !ok {
|
||||||
|
payload["error"] = map[string]any{
|
||||||
|
"type": "server_error",
|
||||||
|
"message": errText,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return data, conn.WriteMessage(websocket.TextMessage, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
||||||
|
if builder == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
trimmedPayload := bytes.TrimSpace(payload)
|
||||||
|
if len(trimmedPayload) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if builder.Len() > 0 {
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
builder.WriteString("websocket.")
|
||||||
|
builder.WriteString(eventType)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
builder.Write(trimmedPayload)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketPayloadEventType(payload []byte) string {
|
||||||
|
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
||||||
|
if eventType == "" {
|
||||||
|
return "-"
|
||||||
|
}
|
||||||
|
return eventType
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketPayloadPreview(payload []byte) string {
|
||||||
|
trimmedPayload := bytes.TrimSpace(payload)
|
||||||
|
if len(trimmedPayload) == 0 {
|
||||||
|
return "<empty>"
|
||||||
|
}
|
||||||
|
preview := trimmedPayload
|
||||||
|
if len(preview) > wsPayloadLogMaxSize {
|
||||||
|
preview = preview[:wsPayloadLogMaxSize]
|
||||||
|
}
|
||||||
|
previewText := strings.ReplaceAll(string(preview), "\n", "\\n")
|
||||||
|
previewText = strings.ReplaceAll(previewText, "\r", "\\r")
|
||||||
|
if len(trimmedPayload) > wsPayloadLogMaxSize {
|
||||||
|
return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload))
|
||||||
|
}
|
||||||
|
return previewText
|
||||||
|
}
|
||||||
|
|
||||||
|
func setWebsocketRequestBody(c *gin.Context, body string) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
trimmedBody := strings.TrimSpace(body)
|
||||||
|
if trimmedBody == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(wsRequestBodyKey, []byte(trimmedBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
func markAPIResponseTimestamp(c *gin.Context) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
|
||||||
|
}
|
||||||
249
sdk/api/handlers/openai/openai_responses_websocket_test.go
Normal file
249
sdk/api/handlers/openai/openai_responses_websocket_test.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
||||||
|
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
|
||||||
|
normalized, last, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "type").Exists() {
|
||||||
|
t.Fatalf("normalized create request must not include type field")
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(normalized, "stream").Bool() {
|
||||||
|
t.Fatalf("normalized create request must force stream=true")
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "model").String() != "test-model" {
|
||||||
|
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
|
||||||
|
}
|
||||||
|
if !bytes.Equal(last, normalized) {
|
||||||
|
t.Fatalf("last request snapshot should match normalized request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestCreateWithHistory(t *testing.T) {
|
||||||
|
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
lastResponseOutput := []byte(`[
|
||||||
|
{"type":"function_call","id":"fc-1","call_id":"call-1"},
|
||||||
|
{"type":"message","id":"assistant-1"}
|
||||||
|
]`)
|
||||||
|
raw := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||||
|
|
||||||
|
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "type").Exists() {
|
||||||
|
t.Fatalf("normalized subsequent create request must not include type field")
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "model").String() != "test-model" {
|
||||||
|
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
input := gjson.GetBytes(normalized, "input").Array()
|
||||||
|
if len(input) != 4 {
|
||||||
|
t.Fatalf("merged input len = %d, want 4", len(input))
|
||||||
|
}
|
||||||
|
if input[0].Get("id").String() != "msg-1" ||
|
||||||
|
input[1].Get("id").String() != "fc-1" ||
|
||||||
|
input[2].Get("id").String() != "assistant-1" ||
|
||||||
|
input[3].Get("id").String() != "tool-out-1" {
|
||||||
|
t.Fatalf("unexpected merged input order")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(next, normalized) {
|
||||||
|
t.Fatalf("next request snapshot should match normalized request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDIncremental(t *testing.T) {
|
||||||
|
lastRequest := []byte(`{"model":"test-model","stream":true,"instructions":"be helpful","input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
lastResponseOutput := []byte(`[
|
||||||
|
{"type":"function_call","id":"fc-1","call_id":"call-1"},
|
||||||
|
{"type":"message","id":"assistant-1"}
|
||||||
|
]`)
|
||||||
|
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||||
|
|
||||||
|
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "type").Exists() {
|
||||||
|
t.Fatalf("normalized request must not include type field")
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "previous_response_id").String() != "resp-1" {
|
||||||
|
t.Fatalf("previous_response_id must be preserved in incremental mode")
|
||||||
|
}
|
||||||
|
input := gjson.GetBytes(normalized, "input").Array()
|
||||||
|
if len(input) != 1 {
|
||||||
|
t.Fatalf("incremental input len = %d, want 1", len(input))
|
||||||
|
}
|
||||||
|
if input[0].Get("id").String() != "tool-out-1" {
|
||||||
|
t.Fatalf("unexpected incremental input item id: %s", input[0].Get("id").String())
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "model").String() != "test-model" {
|
||||||
|
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "instructions").String() != "be helpful" {
|
||||||
|
t.Fatalf("unexpected instructions: %s", gjson.GetBytes(normalized, "instructions").String())
|
||||||
|
}
|
||||||
|
if !bytes.Equal(next, normalized) {
|
||||||
|
t.Fatalf("next request snapshot should match normalized request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDMergedWhenIncrementalDisabled(t *testing.T) {
|
||||||
|
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
lastResponseOutput := []byte(`[
|
||||||
|
{"type":"function_call","id":"fc-1","call_id":"call-1"},
|
||||||
|
{"type":"message","id":"assistant-1"}
|
||||||
|
]`)
|
||||||
|
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||||
|
|
||||||
|
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "previous_response_id").Exists() {
|
||||||
|
t.Fatalf("previous_response_id must be removed when incremental mode is disabled")
|
||||||
|
}
|
||||||
|
input := gjson.GetBytes(normalized, "input").Array()
|
||||||
|
if len(input) != 4 {
|
||||||
|
t.Fatalf("merged input len = %d, want 4", len(input))
|
||||||
|
}
|
||||||
|
if input[0].Get("id").String() != "msg-1" ||
|
||||||
|
input[1].Get("id").String() != "fc-1" ||
|
||||||
|
input[2].Get("id").String() != "assistant-1" ||
|
||||||
|
input[3].Get("id").String() != "tool-out-1" {
|
||||||
|
t.Fatalf("unexpected merged input order")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(next, normalized) {
|
||||||
|
t.Fatalf("next request snapshot should match normalized request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestAppend(t *testing.T) {
|
||||||
|
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
lastResponseOutput := []byte(`[
|
||||||
|
{"type":"message","id":"assistant-1"},
|
||||||
|
{"type":"function_call_output","id":"tool-out-1"}
|
||||||
|
]`)
|
||||||
|
raw := []byte(`{"type":"response.append","input":[{"type":"message","id":"msg-2"},{"type":"message","id":"msg-3"}]}`)
|
||||||
|
|
||||||
|
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
input := gjson.GetBytes(normalized, "input").Array()
|
||||||
|
if len(input) != 5 {
|
||||||
|
t.Fatalf("merged input len = %d, want 5", len(input))
|
||||||
|
}
|
||||||
|
if input[0].Get("id").String() != "msg-1" ||
|
||||||
|
input[1].Get("id").String() != "assistant-1" ||
|
||||||
|
input[2].Get("id").String() != "tool-out-1" ||
|
||||||
|
input[3].Get("id").String() != "msg-2" ||
|
||||||
|
input[4].Get("id").String() != "msg-3" {
|
||||||
|
t.Fatalf("unexpected merged input order")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(next, normalized) {
|
||||||
|
t.Fatalf("next request snapshot should match normalized append request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestAppendWithoutCreate(t *testing.T) {
|
||||||
|
raw := []byte(`{"type":"response.append","input":[]}`)
|
||||||
|
|
||||||
|
_, _, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
|
||||||
|
if errMsg == nil {
|
||||||
|
t.Fatalf("expected error for append without previous request")
|
||||||
|
}
|
||||||
|
if errMsg.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status = %d, want %d", errMsg.StatusCode, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketJSONPayloadsFromChunk(t *testing.T) {
|
||||||
|
chunk := []byte("event: response.created\n\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\ndata: [DONE]\n")
|
||||||
|
|
||||||
|
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||||
|
if len(payloads) != 1 {
|
||||||
|
t.Fatalf("payloads len = %d, want 1", len(payloads))
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(payloads[0], "type").String() != "response.created" {
|
||||||
|
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketJSONPayloadsFromPlainJSONChunk(t *testing.T) {
|
||||||
|
chunk := []byte(`{"type":"response.completed","response":{"id":"resp-1"}}`)
|
||||||
|
|
||||||
|
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||||
|
if len(payloads) != 1 {
|
||||||
|
t.Fatalf("payloads len = %d, want 1", len(payloads))
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(payloads[0], "type").String() != "response.completed" {
|
||||||
|
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseCompletedOutputFromPayload(t *testing.T) {
|
||||||
|
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"message","id":"out-1"}]}}`)
|
||||||
|
|
||||||
|
output := responseCompletedOutputFromPayload(payload)
|
||||||
|
items := gjson.ParseBytes(output).Array()
|
||||||
|
if len(items) != 1 {
|
||||||
|
t.Fatalf("output len = %d, want 1", len(items))
|
||||||
|
}
|
||||||
|
if items[0].Get("id").String() != "out-1" {
|
||||||
|
t.Fatalf("unexpected output id: %s", items[0].Get("id").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendWebsocketEvent(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
|
||||||
|
appendWebsocketEvent(&builder, "request", []byte(" {\"type\":\"response.create\"}\n"))
|
||||||
|
appendWebsocketEvent(&builder, "response", []byte("{\"type\":\"response.created\"}"))
|
||||||
|
|
||||||
|
got := builder.String()
|
||||||
|
if !strings.Contains(got, "websocket.request\n{\"type\":\"response.create\"}\n") {
|
||||||
|
t.Fatalf("request event not found in body: %s", got)
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, "websocket.response\n{\"type\":\"response.created\"}\n") {
|
||||||
|
t.Fatalf("response event not found in body: %s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetWebsocketRequestBody(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
setWebsocketRequestBody(c, " \n ")
|
||||||
|
if _, exists := c.Get(wsRequestBodyKey); exists {
|
||||||
|
t.Fatalf("request body key should not be set for empty body")
|
||||||
|
}
|
||||||
|
|
||||||
|
setWebsocketRequestBody(c, "event body")
|
||||||
|
value, exists := c.Get(wsRequestBodyKey)
|
||||||
|
if !exists {
|
||||||
|
t.Fatalf("request body key not set")
|
||||||
|
}
|
||||||
|
bodyBytes, ok := value.([]byte)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("request body key type mismatch")
|
||||||
|
}
|
||||||
|
if string(bodyBytes) != "event body" {
|
||||||
|
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
|
||||||
|
}
|
||||||
|
}
|
||||||
119
sdk/api/handlers/openai_responses_stream_error.go
Normal file
119
sdk/api/handlers/openai_responses_stream_error.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openAIResponsesStreamErrorChunk struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
SequenceNumber int `json:"sequence_number"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIResponsesStreamErrorCode(status int) string {
|
||||||
|
switch status {
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
return "invalid_api_key"
|
||||||
|
case http.StatusForbidden:
|
||||||
|
return "insufficient_quota"
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
return "rate_limit_exceeded"
|
||||||
|
case http.StatusNotFound:
|
||||||
|
return "model_not_found"
|
||||||
|
case http.StatusRequestTimeout:
|
||||||
|
return "request_timeout"
|
||||||
|
default:
|
||||||
|
if status >= http.StatusInternalServerError {
|
||||||
|
return "internal_server_error"
|
||||||
|
}
|
||||||
|
if status >= http.StatusBadRequest {
|
||||||
|
return "invalid_request_error"
|
||||||
|
}
|
||||||
|
return "unknown_error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk.
|
||||||
|
//
|
||||||
|
// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for
|
||||||
|
// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union
|
||||||
|
// of chunks that requires a top-level `type` field.
|
||||||
|
func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte {
|
||||||
|
if status <= 0 {
|
||||||
|
status = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
if sequenceNumber < 0 {
|
||||||
|
sequenceNumber = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
message := strings.TrimSpace(errText)
|
||||||
|
if message == "" {
|
||||||
|
message = http.StatusText(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := openAIResponsesStreamErrorCode(status)
|
||||||
|
|
||||||
|
trimmed := strings.TrimSpace(errText)
|
||||||
|
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(trimmed), &payload); err == nil {
|
||||||
|
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" {
|
||||||
|
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||||
|
message = strings.TrimSpace(m)
|
||||||
|
}
|
||||||
|
if v, ok := payload["code"]; ok && v != nil {
|
||||||
|
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||||
|
code = strings.TrimSpace(c)
|
||||||
|
} else {
|
||||||
|
code = strings.TrimSpace(fmt.Sprint(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 {
|
||||||
|
sequenceNumber = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if e, ok := payload["error"].(map[string]any); ok {
|
||||||
|
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||||
|
message = strings.TrimSpace(m)
|
||||||
|
}
|
||||||
|
if v, ok := e["code"]; ok && v != nil {
|
||||||
|
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||||
|
code = strings.TrimSpace(c)
|
||||||
|
} else {
|
||||||
|
code = strings.TrimSpace(fmt.Sprint(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(code) == "" {
|
||||||
|
code = "unknown_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(openAIResponsesStreamErrorChunk{
|
||||||
|
Type: "error",
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
SequenceNumber: sequenceNumber,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extremely defensive fallback.
|
||||||
|
data, _ = json.Marshal(openAIResponsesStreamErrorChunk{
|
||||||
|
Type: "error",
|
||||||
|
Code: "internal_server_error",
|
||||||
|
Message: message,
|
||||||
|
SequenceNumber: sequenceNumber,
|
||||||
|
})
|
||||||
|
if len(data) > 0 {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`)
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user