mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-15 18:23:44 +00:00
Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
462a70541e | ||
|
|
2407c1f4af | ||
|
|
2c743c8f0b | ||
|
|
9f2c278ee6 | ||
|
|
d605985f45 | ||
|
|
d52b28b147 | ||
|
|
4afe1f42ca | ||
|
|
7481c0eaa0 | ||
|
|
024bc25b2c | ||
|
|
ffdfad8482 | ||
|
|
b91ee8d008 | ||
|
|
6586f08584 | ||
|
|
f49e887fe6 | ||
|
|
344066fd11 | ||
|
|
bcb8092488 | ||
|
|
1efade8bdb | ||
|
|
a5b3ff11fd | ||
|
|
084558f200 | ||
|
|
b602eae215 | ||
|
|
d02bf9c243 | ||
|
|
26a5f67df2 | ||
|
|
600fd42a83 | ||
|
|
670685139a | ||
|
|
52b6306388 | ||
|
|
f957b8948c | ||
|
|
cd0b14dd2d | ||
|
|
894703a484 | ||
|
|
521ec6f1b8 | ||
|
|
b0c5d9640a | ||
|
|
ef8e94e992 | ||
|
|
9df96a4bb4 | ||
|
|
28a428ae2f | ||
|
|
b326ec3641 | ||
|
|
fcecbc7d46 | ||
|
|
f4007f53ba | ||
|
|
d08a2453f7 | ||
|
|
3f53eea1e0 | ||
|
|
5a812a1e93 | ||
|
|
5e624cc7b1 | ||
|
|
f3d1cc8dc1 | ||
|
|
e889efeda7 | ||
|
|
0a3a95521c | ||
|
|
e0be6c5786 |
@@ -28,3 +28,4 @@ bin/*
|
||||
.claude/*
|
||||
.vscode/*
|
||||
.serena/*
|
||||
.bmad/*
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -32,6 +32,7 @@ GEMINI.md
|
||||
.vscode/*
|
||||
.claude/*
|
||||
.serena/*
|
||||
.bmad/*
|
||||
.mcp/cache/
|
||||
|
||||
# macOS
|
||||
|
||||
@@ -25,6 +25,9 @@ remote-management:
|
||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||
disable-control-panel: false
|
||||
|
||||
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
||||
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
|
||||
# Authentication directory (supports ~ for home directory)
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
@@ -50,6 +53,9 @@ usage-statistics-enabled: false
|
||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
@@ -67,6 +73,7 @@ ws-auth: false
|
||||
# Gemini API keys
|
||||
# gemini-api-key:
|
||||
# - api-key: "AIzaSy...01"
|
||||
# prefix: "test" # optional: require calls like "test/gemini-3-pro-preview" to target this credential
|
||||
# base-url: "https://generativelanguage.googleapis.com"
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -81,6 +88,7 @@ ws-auth: false
|
||||
# Codex API keys
|
||||
# codex-api-key:
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/gpt-5-codex" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom codex API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -95,6 +103,7 @@ ws-auth: false
|
||||
# claude-api-key:
|
||||
# - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/claude-sonnet-latest" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom claude API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -121,6 +130,7 @@ ws-auth: false
|
||||
# OpenAI compatibility providers
|
||||
# openai-compatibility:
|
||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||
# prefix: "test" # optional: require calls like "test/kimi-k2" to target this provider's credentials
|
||||
# base-url: "https://openrouter.ai/api/v1" # The base URL of the provider.
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -135,6 +145,7 @@ ws-auth: false
|
||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
||||
# vertex-api-key:
|
||||
# - api-key: "vk-123..." # x-goog-api-key header
|
||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||
# headers:
|
||||
|
||||
12
go.mod
12
go.mod
@@ -18,10 +18,10 @@ require (
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/tiktoken-go/tokenizer v0.7.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/net v0.46.0
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/term v0.37.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
@@ -69,9 +69,9 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sync v0.17.0 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
||||
24
go.sum
24
go.sum
@@ -160,23 +160,23 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
|
||||
@@ -146,6 +146,9 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
m := &AmpModule{enabled: true}
|
||||
ms := NewMultiSourceSecretWithPath("", p, time.Minute)
|
||||
m.secretSource = ms
|
||||
m.lastConfig = &config.AmpCode{
|
||||
UpstreamAPIKey: "old-key",
|
||||
}
|
||||
|
||||
// Warm the cache
|
||||
if _, err := ms.Get(context.Background()); err != nil {
|
||||
@@ -157,7 +160,7 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update config - should invalidate cache
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x"}}); err != nil {
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x", UpstreamAPIKey: "new-key"}}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -267,7 +267,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
v1betaAmp := provider.Group("/v1beta")
|
||||
{
|
||||
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1betaAmp.POST("/models/:action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,9 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
m.setProxy(proxy)
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
m.registerManagementRoutes(r, base)
|
||||
m.registerManagementRoutes(r, base, nil)
|
||||
srv := httptest.NewServer(r)
|
||||
defer srv.Close()
|
||||
|
||||
managementPaths := []struct {
|
||||
path string
|
||||
@@ -63,11 +65,17 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
for _, path := range managementPaths {
|
||||
t.Run(path.path, func(t *testing.T) {
|
||||
proxyCalled = false
|
||||
req := httptest.NewRequest(path.method, path.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
req, err := http.NewRequest(path.method, srv.URL+path.path, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
t.Fatalf("route %s not registered", path.path)
|
||||
}
|
||||
if !proxyCalled {
|
||||
|
||||
@@ -230,13 +230,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
|
||||
|
||||
// Create server instance
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s := &Server{
|
||||
engine: engine,
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
|
||||
cfg: cfg,
|
||||
accessManager: accessManager,
|
||||
requestLogger: requestLogger,
|
||||
@@ -334,8 +330,8 @@ func (s *Server) setupRoutes() {
|
||||
v1beta.Use(AuthMiddleware(s.accessManager))
|
||||
{
|
||||
v1beta.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1beta.POST("/models/:action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1beta.POST("/models/*action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
|
||||
// Root endpoint
|
||||
@@ -628,7 +624,7 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
||||
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
@@ -938,17 +934,11 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
// Save YAML snapshot for next comparison
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s.handlers.SetOpenAICompatProviders(providerNames)
|
||||
|
||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||
|
||||
if !cfg.RemoteManagement.DisableControlPanel {
|
||||
staticDir := managementasset.StaticDir(s.configFilePath)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
if s.mgmt != nil {
|
||||
s.mgmt.SetConfig(cfg)
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
|
||||
// Config represents the application's configuration, loaded from a YAML file.
|
||||
type Config struct {
|
||||
config.SDKConfig `yaml:",inline"`
|
||||
@@ -116,6 +118,9 @@ type RemoteManagement struct {
|
||||
SecretKey string `yaml:"secret-key"`
|
||||
// DisableControlPanel skips serving and syncing the bundled management UI when true.
|
||||
DisableControlPanel bool `yaml:"disable-control-panel"`
|
||||
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
|
||||
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
|
||||
PanelGitHubRepository string `yaml:"panel-github-repository"`
|
||||
}
|
||||
|
||||
// QuotaExceeded defines the behavior when API quota limits are exceeded.
|
||||
@@ -194,6 +199,9 @@ type ClaudeKey struct {
|
||||
// APIKey is the authentication key for accessing Claude API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Claude API endpoint.
|
||||
// If empty, the default Claude API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
@@ -226,6 +234,9 @@ type CodexKey struct {
|
||||
// APIKey is the authentication key for accessing Codex API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Codex API endpoint.
|
||||
// If empty, the default Codex API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
@@ -246,6 +257,9 @@ type GeminiKey struct {
|
||||
// APIKey is the authentication key for accessing Gemini API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL optionally overrides the Gemini API endpoint.
|
||||
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
||||
|
||||
@@ -294,6 +308,9 @@ type OpenAICompatibility struct {
|
||||
// Name is the identifier for this OpenAI compatibility configuration.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the external OpenAI-compatible API endpoint.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
|
||||
@@ -369,6 +386,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
|
||||
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
||||
if optional {
|
||||
@@ -405,6 +423,11 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
_ = SaveConfigPreserveCommentsUpdateNestedScalar(configFile, []string{"remote-management", "secret-key"}, hashed)
|
||||
}
|
||||
|
||||
cfg.RemoteManagement.PanelGitHubRepository = strings.TrimSpace(cfg.RemoteManagement.PanelGitHubRepository)
|
||||
if cfg.RemoteManagement.PanelGitHubRepository == "" {
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
}
|
||||
|
||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
||||
syncInlineAccessProvider(&cfg)
|
||||
|
||||
@@ -456,6 +479,7 @@ func (cfg *Config) SanitizeOpenAICompatibility() {
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
e := cfg.OpenAICompatibility[i]
|
||||
e.Name = strings.TrimSpace(e.Name)
|
||||
e.Prefix = normalizeModelPrefix(e.Prefix)
|
||||
e.BaseURL = strings.TrimSpace(e.BaseURL)
|
||||
e.Headers = NormalizeHeaders(e.Headers)
|
||||
if e.BaseURL == "" {
|
||||
@@ -476,6 +500,7 @@ func (cfg *Config) SanitizeCodexKeys() {
|
||||
out := make([]CodexKey, 0, len(cfg.CodexKey))
|
||||
for i := range cfg.CodexKey {
|
||||
e := cfg.CodexKey[i]
|
||||
e.Prefix = normalizeModelPrefix(e.Prefix)
|
||||
e.BaseURL = strings.TrimSpace(e.BaseURL)
|
||||
e.Headers = NormalizeHeaders(e.Headers)
|
||||
e.ExcludedModels = NormalizeExcludedModels(e.ExcludedModels)
|
||||
@@ -494,6 +519,7 @@ func (cfg *Config) SanitizeClaudeKeys() {
|
||||
}
|
||||
for i := range cfg.ClaudeKey {
|
||||
entry := &cfg.ClaudeKey[i]
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||
}
|
||||
@@ -530,6 +556,7 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
||||
if entry.APIKey == "" {
|
||||
continue
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
@@ -543,6 +570,18 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
||||
cfg.GeminiKey = out
|
||||
}
|
||||
|
||||
func normalizeModelPrefix(prefix string) string {
|
||||
trimmed := strings.TrimSpace(prefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(trimmed, "/") {
|
||||
return ""
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func syncInlineAccessProvider(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
|
||||
@@ -13,6 +13,9 @@ type VertexCompatKey struct {
|
||||
// Maps to the x-goog-api-key header.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Vertex-compatible API endpoint.
|
||||
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
||||
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
|
||||
@@ -53,6 +56,7 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
if entry.APIKey == "" {
|
||||
continue
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
if entry.BaseURL == "" {
|
||||
// BaseURL is required for Vertex API key entries
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -23,10 +24,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
managementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
// ManagementFileName exposes the control panel asset filename.
|
||||
@@ -97,7 +98,7 @@ func runAutoUpdater(ctx context.Context) {
|
||||
|
||||
configPath, _ := schedulerConfigPath.Load().(string)
|
||||
staticDir := StaticDir(configPath)
|
||||
EnsureLatestManagementHTML(ctx, staticDir, cfg.ProxyURL)
|
||||
EnsureLatestManagementHTML(ctx, staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
|
||||
runOnce()
|
||||
@@ -181,7 +182,7 @@ func FilePath(configFilePath string) string {
|
||||
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
|
||||
// The function is designed to run in a background goroutine and will never panic.
|
||||
// It enforces a 3-hour rate limit to avoid frequent checks on config/auth file changes.
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string) {
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
@@ -214,6 +215,7 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
return
|
||||
}
|
||||
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
@@ -225,7 +227,7 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
localHash = ""
|
||||
}
|
||||
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client)
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return
|
||||
@@ -254,8 +256,44 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
}
|
||||
|
||||
func fetchLatestAsset(ctx context.Context, client *http.Client) (*releaseAsset, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, managementReleaseURL, nil)
|
||||
func resolveReleaseURL(repo string) string {
|
||||
repo = strings.TrimSpace(repo)
|
||||
if repo == "" {
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(repo)
|
||||
if err != nil || parsed.Host == "" {
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
host := strings.ToLower(parsed.Host)
|
||||
parsed.Path = strings.TrimSuffix(parsed.Path, "/")
|
||||
|
||||
if host == "api.github.com" {
|
||||
if !strings.HasSuffix(strings.ToLower(parsed.Path), "/releases/latest") {
|
||||
parsed.Path = parsed.Path + "/releases/latest"
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
if host == "github.com" {
|
||||
parts := strings.Split(strings.Trim(parsed.Path, "/"), "/")
|
||||
if len(parts) >= 2 && parts[0] != "" && parts[1] != "" {
|
||||
repoName := strings.TrimSuffix(parts[1], ".git")
|
||||
return fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", parts[0], repoName)
|
||||
}
|
||||
}
|
||||
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
func fetchLatestAsset(ctx context.Context, client *http.Client, releaseURL string) (*releaseAsset, string, error) {
|
||||
if strings.TrimSpace(releaseURL) == "" {
|
||||
releaseURL = defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, releaseURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("create release request: %w", err)
|
||||
}
|
||||
|
||||
@@ -630,6 +630,13 @@ func GetQwenModels() []*ModelInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// iFlowThinkingSupport is a shared ThinkingSupport configuration for iFlow models
|
||||
// that support thinking mode via chat_template_kwargs.enable_thinking (boolean toggle).
|
||||
// Uses level-based configuration so standard normalization flows apply before conversion.
|
||||
var iFlowThinkingSupport = &ThinkingSupport{
|
||||
Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"},
|
||||
}
|
||||
|
||||
// GetIFlowModels returns supported models for iFlow OAuth accounts.
|
||||
func GetIFlowModels() []*ModelInfo {
|
||||
entries := []struct {
|
||||
@@ -645,9 +652,9 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000},
|
||||
@@ -655,10 +662,10 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||
{ID: "deepseek-v3", DisplayName: "DeepSeek-V3-671B", Description: "DeepSeek V3 671B", Created: 1734307200},
|
||||
{ID: "qwen3-32b", DisplayName: "Qwen3-32B", Description: "Qwen3 32B", Created: 1747094400},
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
|
||||
@@ -66,6 +66,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
@@ -157,6 +158,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
|
||||
toolsResult := gjson.GetBytes(body, "tools")
|
||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||
@@ -442,3 +444,21 @@ func ensureToolsArray(body []byte) []byte {
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// applyIFlowThinkingConfig converts normalized reasoning_effort to iFlow chat_template_kwargs.enable_thinking.
|
||||
// This should be called after NormalizeThinkingConfig has processed the payload.
|
||||
// iFlow only supports boolean enable_thinking, so any non-"none" effort enables thinking.
|
||||
func applyIFlowThinkingConfig(body []byte) []byte {
|
||||
effort := gjson.GetBytes(body, "reasoning_effort")
|
||||
if !effort.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
val := strings.ToLower(strings.TrimSpace(effort.String()))
|
||||
enableThinking := val != "none" && val != ""
|
||||
|
||||
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
|
||||
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -166,16 +166,17 @@ type KiroExecutor struct {
|
||||
// This is critical because OpenAI and Claude formats have different tool structures:
|
||||
// - OpenAI: tools[].function.name, tools[].function.description
|
||||
// - Claude: tools[].name, tools[].description
|
||||
// headers parameter allows checking Anthropic-Beta header for thinking mode detection.
|
||||
// Returns the serialized JSON payload and a boolean indicating whether thinking mode was injected.
|
||||
func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, sourceFormat sdktranslator.Format) ([]byte, bool) {
|
||||
func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, sourceFormat sdktranslator.Format, headers http.Header) ([]byte, bool) {
|
||||
switch sourceFormat.String() {
|
||||
case "openai":
|
||||
log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String())
|
||||
return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly)
|
||||
return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
|
||||
default:
|
||||
// Default to Claude format (also handles "claude", "kiro", etc.)
|
||||
log.Debugf("kiro: using Claude payload builder for source format: %s", sourceFormat.String())
|
||||
return kiroclaude.BuildKiroPayload(body, modelID, profileArn, origin, isAgentic, isChatOnly)
|
||||
return kiroclaude.BuildKiroPayload(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,7 +250,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
|
||||
// Rebuild payload with the correct origin for this endpoint
|
||||
// Each endpoint requires its matching Origin value in the request body
|
||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from)
|
||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||
|
||||
log.Debugf("kiro: trying endpoint %d/%d: %s (Name: %s, Origin: %s)",
|
||||
endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin)
|
||||
@@ -359,7 +360,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
auth = refreshedAuth
|
||||
accessToken, profileArn = kiroCredentials(auth)
|
||||
// Rebuild payload with new profile ARN if changed
|
||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from)
|
||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||
log.Infof("kiro: token refreshed successfully, retrying request")
|
||||
continue
|
||||
}
|
||||
@@ -416,7 +417,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
if refreshedAuth != nil {
|
||||
auth = refreshedAuth
|
||||
accessToken, profileArn = kiroCredentials(auth)
|
||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from)
|
||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||
log.Infof("kiro: token refreshed for 403, retrying request")
|
||||
continue
|
||||
}
|
||||
@@ -555,10 +556,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
|
||||
// Rebuild payload with the correct origin for this endpoint
|
||||
// Each endpoint requires its matching Origin value in the request body
|
||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from)
|
||||
// Kiro API always returns <thinking> tags regardless of whether thinking mode was requested
|
||||
// So we always enable thinking parsing for Kiro responses
|
||||
thinkingEnabled := true
|
||||
kiroPayload, thinkingEnabled := buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||
|
||||
log.Debugf("kiro: stream trying endpoint %d/%d: %s (Name: %s, Origin: %s)",
|
||||
endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin)
|
||||
@@ -681,7 +679,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
auth = refreshedAuth
|
||||
accessToken, profileArn = kiroCredentials(auth)
|
||||
// Rebuild payload with new profile ARN if changed
|
||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from)
|
||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||
log.Infof("kiro: token refreshed successfully, retrying stream request")
|
||||
continue
|
||||
}
|
||||
@@ -738,7 +736,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
if refreshedAuth != nil {
|
||||
auth = refreshedAuth
|
||||
accessToken, profileArn = kiroCredentials(auth)
|
||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from)
|
||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||
log.Infof("kiro: token refreshed for 403, retrying stream request")
|
||||
continue
|
||||
}
|
||||
@@ -1702,6 +1700,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
pendingEndTagChars := 0 // Number of chars that might be start of </thinking>
|
||||
isThinkingBlockOpen := false // Track if thinking content block is open
|
||||
thinkingBlockIndex := -1 // Index of the thinking content block
|
||||
var accumulatedThinkingContent strings.Builder // Accumulate thinking content for signature generation
|
||||
|
||||
// Code block state tracking for heuristic thinking tag parsing
|
||||
// When inside a markdown code block, <thinking> tags should NOT be parsed
|
||||
@@ -1847,6 +1846,8 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
}
|
||||
// Accumulate thinking content for signature generation
|
||||
accumulatedThinkingContent.WriteString(pendingText)
|
||||
} else {
|
||||
// Output as regular text
|
||||
if !isTextBlockOpen {
|
||||
@@ -2390,6 +2391,8 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
}
|
||||
// Accumulate thinking content for signature generation
|
||||
accumulatedThinkingContent.WriteString(thinkContent)
|
||||
}
|
||||
|
||||
// Note: Partial tag handling is done via pendingEndTagChars
|
||||
@@ -2397,7 +2400,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
|
||||
// Close thinking block
|
||||
if isThinkingBlockOpen {
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(thinkingBlockIndex)
|
||||
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
@@ -2405,6 +2408,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
}
|
||||
}
|
||||
isThinkingBlockOpen = false
|
||||
accumulatedThinkingContent.Reset() // Reset for potential next thinking block
|
||||
}
|
||||
|
||||
inThinkBlock = false
|
||||
@@ -2450,6 +2454,8 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
}
|
||||
// Accumulate thinking content for signature generation
|
||||
accumulatedThinkingContent.WriteString(contentToEmit)
|
||||
}
|
||||
|
||||
remaining = ""
|
||||
@@ -2592,6 +2598,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
// Handle tool uses in response (with deduplication)
|
||||
for _, tu := range toolUses {
|
||||
toolUseID := kirocommon.GetString(tu, "toolUseId")
|
||||
toolName := kirocommon.GetString(tu, "name")
|
||||
|
||||
// Check for duplicate
|
||||
if processedIDs[toolUseID] {
|
||||
@@ -2615,7 +2622,6 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
|
||||
// Emit tool_use content block
|
||||
contentBlockIndex++
|
||||
toolName := kirocommon.GetString(tu, "name")
|
||||
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
@@ -2888,7 +2894,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
if calculatedInputTokens > 0 {
|
||||
localEstimate := totalUsage.InputTokens
|
||||
totalUsage.InputTokens = calculatedInputTokens
|
||||
log.Infof("kiro: using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)",
|
||||
log.Debugf("kiro: using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)",
|
||||
upstreamContextPercentage, calculatedInputTokens, localEstimate)
|
||||
}
|
||||
}
|
||||
@@ -2897,7 +2903,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
|
||||
// Log upstream usage information if received
|
||||
if hasUpstreamUsage {
|
||||
log.Infof("kiro: upstream usage - credits: %.4f, context: %.2f%%, final tokens - input: %d, output: %d, total: %d",
|
||||
log.Debugf("kiro: upstream usage - credits: %.4f, context: %.2f%%, final tokens - input: %d, output: %d, total: %d",
|
||||
upstreamCreditUsage, upstreamContextPercentage,
|
||||
totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens)
|
||||
}
|
||||
|
||||
@@ -72,13 +72,7 @@ func ApplyReasoningEffortMetadata(payload []byte, metadata map[string]any, model
|
||||
// Fallback: numeric thinking_budget suffix for level-based (OpenAI-style) models.
|
||||
if util.ModelUsesThinkingLevels(baseModel) || allowCompat {
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(baseModel, *budget); ok && effort != "" {
|
||||
if *budget == 0 && effort == "none" && util.ModelUsesThinkingLevels(baseModel) {
|
||||
if _, supported := util.NormalizeReasoningEffortLevel(baseModel, effort); !supported {
|
||||
return StripThinkingFields(payload, false)
|
||||
}
|
||||
}
|
||||
|
||||
if effort, ok := util.ThinkingBudgetToEffort(baseModel, *budget); ok && effort != "" {
|
||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
||||
return updated
|
||||
}
|
||||
@@ -273,7 +267,7 @@ func StripThinkingFields(payload []byte, effortOnly bool) []byte {
|
||||
"reasoning.effort",
|
||||
}
|
||||
if !effortOnly {
|
||||
fieldsToRemove = append([]string{"reasoning"}, fieldsToRemove...)
|
||||
fieldsToRemove = append([]string{"reasoning", "thinking"}, fieldsToRemove...)
|
||||
}
|
||||
out := payload
|
||||
for _, field := range fieldsToRemove {
|
||||
|
||||
@@ -7,10 +7,8 @@ package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
client "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -42,27 +40,30 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// system instruction
|
||||
var systemInstruction *client.Content
|
||||
systemInstructionJSON := ""
|
||||
hasSystemInstruction := false
|
||||
systemResult := gjson.GetBytes(rawJSON, "system")
|
||||
if systemResult.IsArray() {
|
||||
systemResults := systemResult.Array()
|
||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
|
||||
systemInstructionJSON = `{"role":"user","parts":[]}`
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemPromptResult := systemResults[i]
|
||||
systemTypePromptResult := systemPromptResult.Get("type")
|
||||
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||
systemPrompt := systemPromptResult.Get("text").String()
|
||||
systemPart := client.Part{Text: systemPrompt}
|
||||
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
|
||||
partJSON := `{}`
|
||||
if systemPrompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", systemPrompt)
|
||||
}
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", partJSON)
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
}
|
||||
if len(systemInstruction.Parts) == 0 {
|
||||
systemInstruction = nil
|
||||
}
|
||||
}
|
||||
|
||||
// contents
|
||||
contents := make([]client.Content, 0)
|
||||
contentsJSON := "[]"
|
||||
hasContents := false
|
||||
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messageResults := messagesResult.Array()
|
||||
@@ -76,7 +77,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
clientContent := client.Content{Role: role, Parts: []client.Part{}}
|
||||
clientContentJSON := `{"role":"","parts":[]}`
|
||||
clientContentJSON, _ = sjson.Set(clientContentJSON, "role", role)
|
||||
contentsResult := messageResult.Get("content")
|
||||
if contentsResult.IsArray() {
|
||||
contentResults := contentsResult.Array()
|
||||
@@ -90,25 +92,39 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if signatureResult.Exists() {
|
||||
signature = signatureResult.String()
|
||||
}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt, Thought: true, ThoughtSignature: signature})
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.Set(partJSON, "thought", true)
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
}
|
||||
if signature != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
prompt := contentResult.Get("text").String()
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
|
||||
partJSON := `{}`
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
functionID := contentResult.Get("id").String()
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||
if strings.Contains(modelName, "claude") {
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
|
||||
})
|
||||
} else {
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
|
||||
ThoughtSignature: geminiCLIClaudeThoughtSignature,
|
||||
})
|
||||
if gjson.Valid(functionArgs) {
|
||||
argsResult := gjson.Parse(functionArgs)
|
||||
if argsResult.IsObject() {
|
||||
partJSON := `{}`
|
||||
if !strings.Contains(modelName, "claude") {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", geminiCLIClaudeThoughtSignature)
|
||||
}
|
||||
if functionID != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID)
|
||||
}
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsResult.Raw)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||
@@ -117,37 +133,74 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-2], "-")
|
||||
}
|
||||
responseData := contentResult.Get("content").Raw
|
||||
functionResponse := client.FunctionResponse{ID: toolCallID, Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
||||
functionResponseResult := contentResult.Get("content")
|
||||
|
||||
functionResponseJSON := `{}`
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "id", toolCallID)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "name", funcName)
|
||||
|
||||
responseData := ""
|
||||
if functionResponseResult.Type == gjson.String {
|
||||
responseData = functionResponseResult.String()
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
||||
} else if functionResponseResult.IsArray() {
|
||||
frResults := functionResponseResult.Array()
|
||||
if len(frResults) == 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
}
|
||||
|
||||
} else if functionResponseResult.IsObject() {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionResponse", functionResponseJSON)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
|
||||
sourceResult := contentResult.Get("source")
|
||||
if sourceResult.Get("type").String() == "base64" {
|
||||
inlineData := &client.InlineData{
|
||||
MimeType: sourceResult.Get("media_type").String(),
|
||||
Data: sourceResult.Get("data").String(),
|
||||
inlineDataJSON := `{}`
|
||||
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType)
|
||||
}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{InlineData: inlineData})
|
||||
if data := sourceResult.Get("data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "inlineData", inlineDataJSON)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
}
|
||||
}
|
||||
contents = append(contents, clientContent)
|
||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||
hasContents = true
|
||||
} else if contentsResult.Type == gjson.String {
|
||||
prompt := contentsResult.String()
|
||||
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
|
||||
partJSON := `{}`
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||
hasContents = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tools
|
||||
var tools []client.ToolDeclaration
|
||||
toolsJSON := ""
|
||||
toolDeclCount := 0
|
||||
toolsResult := gjson.GetBytes(rawJSON, "tools")
|
||||
if toolsResult.IsArray() {
|
||||
tools = make([]client.ToolDeclaration, 1)
|
||||
tools[0].FunctionDeclarations = make([]any, 0)
|
||||
toolsJSON = `[{"functionDeclarations":[]}]`
|
||||
toolsResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolsResults); i++ {
|
||||
toolResult := toolsResults[i]
|
||||
@@ -158,30 +211,23 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
|
||||
tool, _ = sjson.Delete(tool, "strict")
|
||||
tool, _ = sjson.Delete(tool, "input_examples")
|
||||
var toolDeclaration any
|
||||
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
|
||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
|
||||
}
|
||||
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
|
||||
toolDeclCount++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tools = make([]client.ToolDeclaration, 0)
|
||||
}
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
out := `{"model":"","request":{"contents":[]}}`
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
if systemInstruction != nil {
|
||||
b, _ := json.Marshal(systemInstruction)
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", string(b))
|
||||
if hasSystemInstruction {
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON)
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
b, _ := json.Marshal(contents)
|
||||
out, _ = sjson.SetRaw(out, "request.contents", string(b))
|
||||
if hasContents {
|
||||
out, _ = sjson.SetRaw(out, "request.contents", contentsJSON)
|
||||
}
|
||||
if len(tools) > 0 && len(tools[0].FunctionDeclarations) > 0 {
|
||||
b, _ := json.Marshal(tools)
|
||||
out, _ = sjson.SetRaw(out, "request.tools", string(b))
|
||||
if toolDeclCount > 0 {
|
||||
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON)
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
|
||||
@@ -222,62 +222,61 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if role == "assistant" {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
if content.Type == gjson.String {
|
||||
// Assistant text -> single model content
|
||||
node := []byte(`{"role":"model","parts":[{"text":""}]}`)
|
||||
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if !content.Exists() || content.Type == gjson.Null {
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
p++
|
||||
}
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
// Handle non-JSON output gracefully (matches dev branch approach)
|
||||
if resp != "null" {
|
||||
parsed := gjson.Parse(resp)
|
||||
if parsed.Type == gjson.JSON {
|
||||
toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw))
|
||||
} else {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", resp)
|
||||
}
|
||||
}
|
||||
pp++
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
// Handle non-JSON output gracefully (matches dev branch approach)
|
||||
if resp != "null" {
|
||||
parsed := gjson.Parse(resp)
|
||||
if parsed.Type == gjson.JSON {
|
||||
toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw))
|
||||
} else {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", resp)
|
||||
}
|
||||
}
|
||||
pp++
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -361,18 +360,3 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
// itoa converts int to string without strconv import for few usages.
|
||||
func itoa(i int) string { return fmt.Sprintf("%d", i) }
|
||||
|
||||
// quoteIfNeeded ensures a string is valid JSON value (quotes plain text), pass-through for JSON objects/arrays.
|
||||
func quoteIfNeeded(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "\"\""
|
||||
}
|
||||
if len(s) > 0 && (s[0] == '{' || s[0] == '[') {
|
||||
return s
|
||||
}
|
||||
// escape quotes minimally
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
|
||||
@@ -219,15 +219,20 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
// Convert thinking.budget_tokens to reasoning.effort for level-based models
|
||||
reasoningEffort := "medium" // default
|
||||
if thinking := rootResult.Get("thinking"); thinking.Exists() && thinking.IsObject() {
|
||||
if thinking.Get("type").String() == "enabled" {
|
||||
switch thinking.Get("type").String() {
|
||||
case "enabled":
|
||||
if util.ModelUsesThinkingLevels(modelName) {
|
||||
if budgetTokens := thinking.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
case "disabled":
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, 0); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
template, _ = sjson.Set(template, "reasoning.effort", reasoningEffort)
|
||||
|
||||
@@ -253,7 +253,7 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
if util.ModelUsesThinkingLevels(modelName) {
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,52 +205,52 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if role == "assistant" {
|
||||
p := 0
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
if content.Type == gjson.String {
|
||||
// Assistant text -> single model content
|
||||
node := []byte(`{"role":"model","parts":[{"text":""}]}`)
|
||||
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if !content.Exists() || content.Type == gjson.Null {
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
p++
|
||||
}
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -334,18 +334,3 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
|
||||
// itoa converts int to string without strconv import for few usages.
|
||||
func itoa(i int) string { return fmt.Sprintf("%d", i) }
|
||||
|
||||
// quoteIfNeeded ensures a string is valid JSON value (quotes plain text), pass-through for JSON objects/arrays.
|
||||
func quoteIfNeeded(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "\"\""
|
||||
}
|
||||
if len(s) > 0 && (s[0] == '{' || s[0] == '[') {
|
||||
return s
|
||||
}
|
||||
// escape quotes minimally
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
|
||||
@@ -207,15 +207,16 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
} else if role == "assistant" {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
|
||||
if content.Type == gjson.String {
|
||||
// Assistant text -> single model content
|
||||
node := []byte(`{"role":"model","parts":[{"text":""}]}`)
|
||||
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
p++
|
||||
} else if content.IsArray() {
|
||||
// Assistant multimodal content (e.g. text + image) -> single model content with parts
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
for _, item := range content.Array() {
|
||||
switch item.Get("type").String() {
|
||||
case "text":
|
||||
@@ -237,47 +238,45 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
} else if !content.Exists() || content.Type == gjson.Null {
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
}
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -363,18 +362,3 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// itoa converts int to string without strconv import for few usages.
|
||||
func itoa(i int) string { return fmt.Sprintf("%d", i) }
|
||||
|
||||
// quoteIfNeeded ensures a string is valid JSON value (quotes plain text), pass-through for JSON objects/arrays.
|
||||
func quoteIfNeeded(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "\"\""
|
||||
}
|
||||
if len(s) > 0 && (s[0] == '{' || s[0] == '[') {
|
||||
return s
|
||||
}
|
||||
// escape quotes minimally
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ package claude
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
@@ -33,6 +34,7 @@ type KiroInferenceConfig struct {
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
// KiroConversationState holds the conversation context
|
||||
type KiroConversationState struct {
|
||||
ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field
|
||||
@@ -134,9 +136,11 @@ func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bo
|
||||
// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE.
|
||||
// isAgentic parameter enables chunked write optimization prompt for -agentic model variants.
|
||||
// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode).
|
||||
// Supports thinking mode - when Claude API thinking parameter is present, injects thinkingHint.
|
||||
// headers parameter allows checking Anthropic-Beta header for thinking mode detection.
|
||||
// metadata parameter is kept for API compatibility but no longer used for thinking configuration.
|
||||
// Supports thinking mode - when enabled, injects thinking tags into system prompt.
|
||||
// Returns the payload and a boolean indicating whether thinking mode was injected.
|
||||
func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) ([]byte, bool) {
|
||||
func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, headers http.Header, metadata map[string]any) ([]byte, bool) {
|
||||
// Extract max_tokens for potential use in inferenceConfig
|
||||
// Handle -1 as "use maximum" (Kiro max output is ~32000 tokens)
|
||||
const kiroMaxOutputTokens = 32000
|
||||
@@ -181,26 +185,9 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA
|
||||
// Extract system prompt
|
||||
systemPrompt := extractSystemPrompt(claudeBody)
|
||||
|
||||
// Check for thinking mode using the comprehensive IsThinkingEnabled function
|
||||
// This supports Claude API format, OpenAI reasoning_effort, and AMP/Cursor format
|
||||
thinkingEnabled := IsThinkingEnabled(claudeBody)
|
||||
_, budgetTokens := checkThinkingMode(claudeBody) // Get budget tokens from Claude format if available
|
||||
if budgetTokens <= 0 {
|
||||
// Calculate budgetTokens based on max_tokens if available
|
||||
// Use 50% of max_tokens for thinking, with min 8000 and max 24000
|
||||
if maxTokens > 0 {
|
||||
budgetTokens = maxTokens / 2
|
||||
if budgetTokens < 8000 {
|
||||
budgetTokens = 8000
|
||||
}
|
||||
if budgetTokens > 24000 {
|
||||
budgetTokens = 24000
|
||||
}
|
||||
log.Debugf("kiro: budgetTokens calculated from max_tokens: %d (max_tokens=%d)", budgetTokens, maxTokens)
|
||||
} else {
|
||||
budgetTokens = 16000 // Default budget tokens
|
||||
}
|
||||
}
|
||||
// Check for thinking mode using the comprehensive IsThinkingEnabledWithHeaders function
|
||||
// This supports Claude API format, OpenAI reasoning_effort, AMP/Cursor format, and Anthropic-Beta header
|
||||
thinkingEnabled := IsThinkingEnabledWithHeaders(claudeBody, headers)
|
||||
|
||||
// Inject timestamp context
|
||||
timestamp := time.Now().Format("2006-01-02 15:04:05 MST")
|
||||
@@ -231,19 +218,26 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA
|
||||
log.Debugf("kiro: injected tool_choice hint into system prompt")
|
||||
}
|
||||
|
||||
// Inject thinking hint when thinking mode is enabled
|
||||
if thinkingEnabled {
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
dynamicThinkingHint := fmt.Sprintf("<thinking_mode>interleaved</thinking_mode><max_thinking_length>%d</max_thinking_length>", budgetTokens)
|
||||
systemPrompt += dynamicThinkingHint
|
||||
log.Debugf("kiro: injected dynamic thinking hint into system prompt, max_thinking_length: %d", budgetTokens)
|
||||
}
|
||||
|
||||
// Convert Claude tools to Kiro format
|
||||
kiroTools := convertClaudeToolsToKiro(tools)
|
||||
|
||||
// Thinking mode implementation:
|
||||
// Kiro API doesn't accept max_tokens for thinking. Instead, thinking mode is enabled
|
||||
// by injecting <thinking_mode> and <max_thinking_length> tags into the system prompt.
|
||||
// We use a fixed max_thinking_length value since Kiro handles the actual budget internally.
|
||||
if thinkingEnabled {
|
||||
thinkingHint := `<thinking_mode>interleaved</thinking_mode>
|
||||
<max_thinking_length>200000</max_thinking_length>
|
||||
|
||||
IMPORTANT: You MUST use <thinking>...</thinking> tags to show your reasoning process before providing your final response. Think step by step inside the thinking tags.`
|
||||
if systemPrompt != "" {
|
||||
systemPrompt = thinkingHint + "\n\n" + systemPrompt
|
||||
} else {
|
||||
systemPrompt = thinkingHint
|
||||
}
|
||||
log.Infof("kiro: injected thinking prompt, has_tools: %v", len(kiroTools) > 0)
|
||||
}
|
||||
|
||||
// Process messages and build history
|
||||
history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin)
|
||||
|
||||
@@ -280,6 +274,7 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA
|
||||
}
|
||||
|
||||
// Build inferenceConfig if we have any inference parameters
|
||||
// Note: Kiro API doesn't actually use max_tokens for thinking budget
|
||||
var inferenceConfig *KiroInferenceConfig
|
||||
if maxTokens > 0 || hasTemperature || hasTopP {
|
||||
inferenceConfig = &KiroInferenceConfig{}
|
||||
@@ -350,7 +345,7 @@ func extractSystemPrompt(claudeBody []byte) string {
|
||||
// checkThinkingMode checks if thinking mode is enabled in the Claude request
|
||||
func checkThinkingMode(claudeBody []byte) (bool, int64) {
|
||||
thinkingEnabled := false
|
||||
var budgetTokens int64 = 16000
|
||||
var budgetTokens int64 = 24000
|
||||
|
||||
thinkingField := gjson.GetBytes(claudeBody, "thinking")
|
||||
if thinkingField.Exists() {
|
||||
@@ -373,6 +368,32 @@ func checkThinkingMode(claudeBody []byte) (bool, int64) {
|
||||
return thinkingEnabled, budgetTokens
|
||||
}
|
||||
|
||||
// hasThinkingTagInBody checks if the request body already contains thinking configuration tags.
|
||||
// This is used to prevent duplicate injection when client (e.g., AMP/Cursor) already includes thinking config.
|
||||
func hasThinkingTagInBody(body []byte) bool {
|
||||
bodyStr := string(body)
|
||||
return strings.Contains(bodyStr, "<thinking_mode>") || strings.Contains(bodyStr, "<max_thinking_length>")
|
||||
}
|
||||
|
||||
|
||||
// IsThinkingEnabledFromHeader checks if thinking mode is enabled via Anthropic-Beta header.
|
||||
// Claude CLI uses "Anthropic-Beta: interleaved-thinking-2025-05-14" to enable thinking.
|
||||
func IsThinkingEnabledFromHeader(headers http.Header) bool {
|
||||
if headers == nil {
|
||||
return false
|
||||
}
|
||||
betaHeader := headers.Get("Anthropic-Beta")
|
||||
if betaHeader == "" {
|
||||
return false
|
||||
}
|
||||
// Check for interleaved-thinking beta feature
|
||||
if strings.Contains(betaHeader, "interleaved-thinking") {
|
||||
log.Debugf("kiro: thinking mode enabled via Anthropic-Beta header: %s", betaHeader)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsThinkingEnabled is a public wrapper to check if thinking mode is enabled.
|
||||
// This is used by the executor to determine whether to parse <thinking> tags in responses.
|
||||
// When thinking is NOT enabled in the request, <thinking> tags in responses should be
|
||||
@@ -383,6 +404,21 @@ func checkThinkingMode(claudeBody []byte) (bool, int64) {
|
||||
// - OpenAI format: reasoning_effort parameter
|
||||
// - AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
|
||||
func IsThinkingEnabled(body []byte) bool {
|
||||
return IsThinkingEnabledWithHeaders(body, nil)
|
||||
}
|
||||
|
||||
// IsThinkingEnabledWithHeaders checks if thinking mode is enabled from body or headers.
|
||||
// This is the comprehensive check that supports all thinking detection methods:
|
||||
// - Claude API format: thinking.type = "enabled"
|
||||
// - OpenAI format: reasoning_effort parameter
|
||||
// - AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
|
||||
// - Anthropic-Beta header: interleaved-thinking-2025-05-14
|
||||
func IsThinkingEnabledWithHeaders(body []byte, headers http.Header) bool {
|
||||
// Check Anthropic-Beta header first (Claude Code uses this)
|
||||
if IsThinkingEnabledFromHeader(headers) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check Claude API format first (thinking.type = "enabled")
|
||||
enabled, _ := checkThinkingMode(body)
|
||||
if enabled {
|
||||
@@ -771,4 +807,4 @@ func BuildAssistantMessageStruct(msg gjson.Result) KiroAssistantResponseMessage
|
||||
Content: contentBuilder.String(),
|
||||
ToolUses: toolUses,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
@@ -14,6 +16,18 @@ import (
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
)
|
||||
|
||||
// generateThinkingSignature generates a signature for thinking content.
|
||||
// This is required by Claude API for thinking blocks in non-streaming responses.
|
||||
// The signature is a base64-encoded hash of the thinking content.
|
||||
func generateThinkingSignature(thinkingContent string) string {
|
||||
if thinkingContent == "" {
|
||||
return ""
|
||||
}
|
||||
// Generate a deterministic signature based on content hash
|
||||
hash := sha256.Sum256([]byte(thinkingContent))
|
||||
return base64.StdEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Local references to kirocommon constants for thinking block parsing
|
||||
var (
|
||||
thinkingStartTag = kirocommon.ThinkingStartTag
|
||||
@@ -149,9 +163,12 @@ func ExtractThinkingFromContent(content string) []map[string]interface{} {
|
||||
if endIdx == -1 {
|
||||
// No closing tag found, treat rest as thinking content (incomplete response)
|
||||
if strings.TrimSpace(remaining) != "" {
|
||||
// Generate signature for thinking content (required by Claude API)
|
||||
signature := generateThinkingSignature(remaining)
|
||||
blocks = append(blocks, map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": remaining,
|
||||
"type": "thinking",
|
||||
"thinking": remaining,
|
||||
"signature": signature,
|
||||
})
|
||||
log.Warnf("kiro: extractThinkingFromContent - missing closing </thinking> tag")
|
||||
}
|
||||
@@ -161,9 +178,12 @@ func ExtractThinkingFromContent(content string) []map[string]interface{} {
|
||||
// Extract thinking content between tags
|
||||
thinkContent := remaining[:endIdx]
|
||||
if strings.TrimSpace(thinkContent) != "" {
|
||||
// Generate signature for thinking content (required by Claude API)
|
||||
signature := generateThinkingSignature(thinkContent)
|
||||
blocks = append(blocks, map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": thinkContent,
|
||||
"type": "thinking",
|
||||
"thinking": thinkContent,
|
||||
"signature": signature,
|
||||
})
|
||||
log.Debugf("kiro: extractThinkingFromContent - extracted thinking block (len: %d)", len(thinkContent))
|
||||
}
|
||||
|
||||
@@ -99,6 +99,16 @@ func BuildClaudeContentBlockStopEvent(index int) []byte {
|
||||
return []byte("event: content_block_stop\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeThinkingBlockStopEvent creates a content_block_stop SSE event for thinking blocks.
|
||||
func BuildClaudeThinkingBlockStopEvent(index int) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": index,
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: content_block_stop\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// BuildClaudeMessageDeltaEvent creates the message_delta event with stop_reason and usage
|
||||
func BuildClaudeMessageDeltaEvent(stopReason string, usageInfo usage.Detail) []byte {
|
||||
deltaEvent := map[string]interface{}{
|
||||
|
||||
@@ -187,6 +187,7 @@ func ConvertKiroNonStreamToOpenAI(ctx context.Context, model string, originalReq
|
||||
|
||||
// Extract content
|
||||
var content string
|
||||
var reasoningContent string
|
||||
var toolUses []KiroToolUse
|
||||
var stopReason string
|
||||
|
||||
@@ -202,7 +203,8 @@ func ConvertKiroNonStreamToOpenAI(ctx context.Context, model string, originalReq
|
||||
case "text":
|
||||
content += block.Get("text").String()
|
||||
case "thinking":
|
||||
// Skip thinking blocks for OpenAI format (or convert to reasoning_content if needed)
|
||||
// Convert thinking blocks to reasoning_content for OpenAI format
|
||||
reasoningContent += block.Get("thinking").String()
|
||||
case "tool_use":
|
||||
toolUseID := block.Get("id").String()
|
||||
toolName := block.Get("name").String()
|
||||
@@ -233,8 +235,8 @@ func ConvertKiroNonStreamToOpenAI(ctx context.Context, model string, originalReq
|
||||
}
|
||||
usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens
|
||||
|
||||
// Build OpenAI response
|
||||
openaiResponse := BuildOpenAIResponse(content, toolUses, model, usageInfo, stopReason)
|
||||
// Build OpenAI response with reasoning_content support
|
||||
openaiResponse := BuildOpenAIResponseWithReasoning(content, reasoningContent, toolUses, model, usageInfo, stopReason)
|
||||
return string(openaiResponse)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,13 @@ package openai
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
kiroclaude "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude"
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -133,8 +135,10 @@ func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bo
|
||||
// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE.
|
||||
// isAgentic parameter enables chunked write optimization prompt for -agentic model variants.
|
||||
// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode).
|
||||
// headers parameter allows checking Anthropic-Beta header for thinking mode detection.
|
||||
// metadata parameter is kept for API compatibility but no longer used for thinking configuration.
|
||||
// Returns the payload and a boolean indicating whether thinking mode was injected.
|
||||
func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) ([]byte, bool) {
|
||||
func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, headers http.Header, metadata map[string]any) ([]byte, bool) {
|
||||
// Extract max_tokens for potential use in inferenceConfig
|
||||
// Handle -1 as "use maximum" (Kiro max output is ~32000 tokens)
|
||||
const kiroMaxOutputTokens = 32000
|
||||
@@ -219,35 +223,30 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s
|
||||
log.Debugf("kiro-openai: injected response_format hint into system prompt")
|
||||
}
|
||||
|
||||
// Check for thinking mode and inject thinking hint
|
||||
// Supports OpenAI reasoning_effort parameter and model name hints
|
||||
thinkingEnabled, budgetTokens := checkThinkingModeFromOpenAI(openaiBody)
|
||||
if thinkingEnabled {
|
||||
// Adjust budgetTokens based on max_tokens if not explicitly set by reasoning_effort
|
||||
// Use 50% of max_tokens for thinking, with min 8000 and max 24000
|
||||
if maxTokens > 0 && budgetTokens == 16000 { // 16000 is the default, meaning not explicitly set
|
||||
calculatedBudget := maxTokens / 2
|
||||
if calculatedBudget < 8000 {
|
||||
calculatedBudget = 8000
|
||||
}
|
||||
if calculatedBudget > 24000 {
|
||||
calculatedBudget = 24000
|
||||
}
|
||||
budgetTokens = calculatedBudget
|
||||
log.Debugf("kiro-openai: budgetTokens calculated from max_tokens: %d (max_tokens=%d)", budgetTokens, maxTokens)
|
||||
}
|
||||
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n"
|
||||
}
|
||||
dynamicThinkingHint := fmt.Sprintf("<thinking_mode>interleaved</thinking_mode><max_thinking_length>%d</max_thinking_length>", budgetTokens)
|
||||
systemPrompt += dynamicThinkingHint
|
||||
log.Debugf("kiro-openai: injected dynamic thinking hint into system prompt, max_thinking_length: %d", budgetTokens)
|
||||
}
|
||||
// Check for thinking mode
|
||||
// Supports OpenAI reasoning_effort parameter, model name hints, and Anthropic-Beta header
|
||||
thinkingEnabled := checkThinkingModeFromOpenAIWithHeaders(openaiBody, headers)
|
||||
|
||||
// Convert OpenAI tools to Kiro format
|
||||
kiroTools := convertOpenAIToolsToKiro(tools)
|
||||
|
||||
// Thinking mode implementation:
|
||||
// Kiro API doesn't accept max_tokens for thinking. Instead, thinking mode is enabled
|
||||
// by injecting <thinking_mode> and <max_thinking_length> tags into the system prompt.
|
||||
// We use a fixed max_thinking_length value since Kiro handles the actual budget internally.
|
||||
if thinkingEnabled {
|
||||
thinkingHint := `<thinking_mode>interleaved</thinking_mode>
|
||||
<max_thinking_length>200000</max_thinking_length>
|
||||
|
||||
IMPORTANT: You MUST use <thinking>...</thinking> tags to show your reasoning process before providing your final response. Think step by step inside the thinking tags.`
|
||||
if systemPrompt != "" {
|
||||
systemPrompt = thinkingHint + "\n\n" + systemPrompt
|
||||
} else {
|
||||
systemPrompt = thinkingHint
|
||||
}
|
||||
log.Debugf("kiro-openai: injected thinking prompt")
|
||||
}
|
||||
|
||||
// Process messages and build history
|
||||
history, currentUserMsg, currentToolResults := processOpenAIMessages(messages, modelID, origin)
|
||||
|
||||
@@ -284,6 +283,7 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s
|
||||
}
|
||||
|
||||
// Build inferenceConfig if we have any inference parameters
|
||||
// Note: Kiro API doesn't actually use max_tokens for thinking budget
|
||||
var inferenceConfig *KiroInferenceConfig
|
||||
if maxTokens > 0 || hasTemperature || hasTopP {
|
||||
inferenceConfig = &KiroInferenceConfig{}
|
||||
@@ -682,13 +682,28 @@ func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResul
|
||||
}
|
||||
|
||||
// checkThinkingModeFromOpenAI checks if thinking mode is enabled in the OpenAI request.
|
||||
// Returns (thinkingEnabled, budgetTokens).
|
||||
// Returns thinkingEnabled.
|
||||
// Supports:
|
||||
// - reasoning_effort parameter (low/medium/high/auto)
|
||||
// - Model name containing "thinking" or "reason"
|
||||
// - <thinking_mode> tag in system prompt (AMP/Cursor format)
|
||||
func checkThinkingModeFromOpenAI(openaiBody []byte) (bool, int64) {
|
||||
var budgetTokens int64 = 16000 // Default budget
|
||||
func checkThinkingModeFromOpenAI(openaiBody []byte) bool {
|
||||
return checkThinkingModeFromOpenAIWithHeaders(openaiBody, nil)
|
||||
}
|
||||
|
||||
// checkThinkingModeFromOpenAIWithHeaders checks if thinking mode is enabled in the OpenAI request.
|
||||
// Returns thinkingEnabled.
|
||||
// Supports:
|
||||
// - Anthropic-Beta header with interleaved-thinking (Claude CLI)
|
||||
// - reasoning_effort parameter (low/medium/high/auto)
|
||||
// - Model name containing "thinking" or "reason"
|
||||
// - <thinking_mode> tag in system prompt (AMP/Cursor format)
|
||||
func checkThinkingModeFromOpenAIWithHeaders(openaiBody []byte, headers http.Header) bool {
|
||||
// Check Anthropic-Beta header first (Claude CLI uses this)
|
||||
if kiroclaude.IsThinkingEnabledFromHeader(headers) {
|
||||
log.Debugf("kiro-openai: thinking mode enabled via Anthropic-Beta header")
|
||||
return true
|
||||
}
|
||||
|
||||
// Check OpenAI format: reasoning_effort parameter
|
||||
// Valid values: "low", "medium", "high", "auto" (not "none")
|
||||
@@ -697,18 +712,7 @@ func checkThinkingModeFromOpenAI(openaiBody []byte) (bool, int64) {
|
||||
effort := reasoningEffort.String()
|
||||
if effort != "" && effort != "none" {
|
||||
log.Debugf("kiro-openai: thinking mode enabled via reasoning_effort: %s", effort)
|
||||
// Adjust budget based on effort level
|
||||
switch effort {
|
||||
case "low":
|
||||
budgetTokens = 8000
|
||||
case "medium":
|
||||
budgetTokens = 16000
|
||||
case "high":
|
||||
budgetTokens = 32000
|
||||
case "auto":
|
||||
budgetTokens = 16000
|
||||
}
|
||||
return true, budgetTokens
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -725,17 +729,7 @@ func checkThinkingModeFromOpenAI(openaiBody []byte) (bool, int64) {
|
||||
thinkingMode := bodyStr[startIdx : startIdx+endIdx]
|
||||
if thinkingMode == "interleaved" || thinkingMode == "enabled" {
|
||||
log.Debugf("kiro-openai: thinking mode enabled via AMP/Cursor format: %s", thinkingMode)
|
||||
// Try to extract max_thinking_length if present
|
||||
if maxLenStart := strings.Index(bodyStr, "<max_thinking_length>"); maxLenStart >= 0 {
|
||||
maxLenStart += len("<max_thinking_length>")
|
||||
if maxLenEnd := strings.Index(bodyStr[maxLenStart:], "</max_thinking_length>"); maxLenEnd >= 0 {
|
||||
maxLenStr := bodyStr[maxLenStart : maxLenStart+maxLenEnd]
|
||||
if parsed, err := fmt.Sscanf(maxLenStr, "%d", &budgetTokens); err == nil && parsed == 1 {
|
||||
log.Debugf("kiro-openai: extracted max_thinking_length: %d", budgetTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true, budgetTokens
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -746,13 +740,21 @@ func checkThinkingModeFromOpenAI(openaiBody []byte) (bool, int64) {
|
||||
modelLower := strings.ToLower(model)
|
||||
if strings.Contains(modelLower, "thinking") || strings.Contains(modelLower, "-reason") {
|
||||
log.Debugf("kiro-openai: thinking mode enabled via model name hint: %s", model)
|
||||
return true, budgetTokens
|
||||
return true
|
||||
}
|
||||
|
||||
log.Debugf("kiro-openai: no thinking mode detected in OpenAI request")
|
||||
return false, budgetTokens
|
||||
return false
|
||||
}
|
||||
|
||||
// hasThinkingTagInBody checks if the request body already contains thinking configuration tags.
|
||||
// This is used to prevent duplicate injection when client (e.g., AMP/Cursor) already includes thinking config.
|
||||
func hasThinkingTagInBody(body []byte) bool {
|
||||
bodyStr := string(body)
|
||||
return strings.Contains(bodyStr, "<thinking_mode>") || strings.Contains(bodyStr, "<max_thinking_length>")
|
||||
}
|
||||
|
||||
|
||||
// extractToolChoiceHint extracts tool_choice from OpenAI request and returns a system prompt hint.
|
||||
// OpenAI tool_choice values:
|
||||
// - "none": Don't use any tools
|
||||
@@ -845,4 +847,4 @@ func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult {
|
||||
}
|
||||
}
|
||||
return unique
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,12 +21,25 @@ var functionCallIDCounter uint64
|
||||
// Supports tool_calls when tools are present in the response.
|
||||
// stopReason is passed from upstream; fallback logic applied if empty.
|
||||
func BuildOpenAIResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
|
||||
return BuildOpenAIResponseWithReasoning(content, "", toolUses, model, usageInfo, stopReason)
|
||||
}
|
||||
|
||||
// BuildOpenAIResponseWithReasoning constructs an OpenAI Chat Completions-compatible response with reasoning_content support.
|
||||
// Supports tool_calls when tools are present in the response.
|
||||
// reasoningContent is included as reasoning_content field in the message when present.
|
||||
// stopReason is passed from upstream; fallback logic applied if empty.
|
||||
func BuildOpenAIResponseWithReasoning(content, reasoningContent string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
|
||||
// Build the message object
|
||||
message := map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
|
||||
// Add reasoning_content if present (for thinking/reasoning models)
|
||||
if reasoningContent != "" {
|
||||
message["reasoning_content"] = reasoningContent
|
||||
}
|
||||
|
||||
// Add tool_calls if present
|
||||
if len(toolUses) > 0 {
|
||||
var toolCalls []map[string]interface{}
|
||||
|
||||
@@ -63,10 +63,22 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
|
||||
// Thinking: Convert Claude thinking.budget_tokens to OpenAI reasoning_effort
|
||||
if thinking := root.Get("thinking"); thinking.Exists() && thinking.IsObject() {
|
||||
if thinkingType := thinking.Get("type"); thinkingType.Exists() && thinkingType.String() == "enabled" {
|
||||
if budgetTokens := thinking.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
if thinkingType := thinking.Get("type"); thinkingType.Exists() {
|
||||
switch thinkingType.String() {
|
||||
case "enabled":
|
||||
if budgetTokens := thinking.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
} else {
|
||||
// No budget_tokens specified, default to "auto" for enabled thinking
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, -1); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
case "disabled":
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, 0); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,9 +128,10 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
param.CreatedAt = root.Get("created").Int()
|
||||
}
|
||||
|
||||
// Check if this is the first chunk (has role)
|
||||
// Emit message_start on the very first chunk, regardless of whether it has a role field.
|
||||
// Some providers (like Copilot) may send tool_calls in the first chunk without a role field.
|
||||
if delta := root.Get("choices.0.delta"); delta.Exists() {
|
||||
if role := delta.Get("role"); role.Exists() && role.String() == "assistant" && !param.MessageStarted {
|
||||
if !param.MessageStarted {
|
||||
// Send message_start event
|
||||
messageStart := map[string]interface{}{
|
||||
"type": "message_start",
|
||||
|
||||
@@ -83,7 +83,7 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
package util
|
||||
|
||||
// OpenAIThinkingBudgetToEffort maps a numeric thinking budget (tokens)
|
||||
// into an OpenAI-style reasoning effort level for level-based models.
|
||||
//
|
||||
// Ranges:
|
||||
// - 0 -> "none"
|
||||
// - -1 -> "auto"
|
||||
// - 1..1024 -> "low"
|
||||
// - 1025..8192 -> "medium"
|
||||
// - 8193..24576 -> "high"
|
||||
// - 24577.. -> highest supported level for the model (defaults to "xhigh")
|
||||
//
|
||||
// Negative values other than -1 are treated as unsupported.
|
||||
func OpenAIThinkingBudgetToEffort(model string, budget int) (string, bool) {
|
||||
switch {
|
||||
case budget == -1:
|
||||
return "auto", true
|
||||
case budget < -1:
|
||||
return "", false
|
||||
case budget == 0:
|
||||
return "none", true
|
||||
case budget > 0 && budget <= 1024:
|
||||
return "low", true
|
||||
case budget <= 8192:
|
||||
return "medium", true
|
||||
case budget <= 24576:
|
||||
return "high", true
|
||||
case budget > 24576:
|
||||
if levels := GetModelThinkingLevels(model); len(levels) > 0 {
|
||||
return levels[len(levels)-1], true
|
||||
}
|
||||
return "xhigh", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
@@ -118,3 +118,83 @@ func IsOpenAICompatibilityModel(model string) bool {
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(info.Type), "openai-compatibility")
|
||||
}
|
||||
|
||||
// ThinkingEffortToBudget maps a reasoning effort level to a numeric thinking budget (tokens),
|
||||
// clamping the result to the model's supported range.
|
||||
//
|
||||
// Mappings (values are normalized to model's supported range):
|
||||
// - "none" -> 0
|
||||
// - "auto" -> -1
|
||||
// - "minimal" -> 512
|
||||
// - "low" -> 1024
|
||||
// - "medium" -> 8192
|
||||
// - "high" -> 24576
|
||||
// - "xhigh" -> 32768
|
||||
//
|
||||
// Returns false when the effort level is empty or unsupported.
|
||||
func ThinkingEffortToBudget(model, effort string) (int, bool) {
|
||||
if effort == "" {
|
||||
return 0, false
|
||||
}
|
||||
normalized, ok := NormalizeReasoningEffortLevel(model, effort)
|
||||
if !ok {
|
||||
normalized = strings.ToLower(strings.TrimSpace(effort))
|
||||
}
|
||||
switch normalized {
|
||||
case "none":
|
||||
return 0, true
|
||||
case "auto":
|
||||
return NormalizeThinkingBudget(model, -1), true
|
||||
case "minimal":
|
||||
return NormalizeThinkingBudget(model, 512), true
|
||||
case "low":
|
||||
return NormalizeThinkingBudget(model, 1024), true
|
||||
case "medium":
|
||||
return NormalizeThinkingBudget(model, 8192), true
|
||||
case "high":
|
||||
return NormalizeThinkingBudget(model, 24576), true
|
||||
case "xhigh":
|
||||
return NormalizeThinkingBudget(model, 32768), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ThinkingBudgetToEffort maps a numeric thinking budget (tokens)
|
||||
// to a reasoning effort level for level-based models.
|
||||
//
|
||||
// Mappings:
|
||||
// - 0 -> "none" (or lowest supported level if model doesn't support "none")
|
||||
// - -1 -> "auto"
|
||||
// - 1..1024 -> "low"
|
||||
// - 1025..8192 -> "medium"
|
||||
// - 8193..24576 -> "high"
|
||||
// - 24577.. -> highest supported level for the model (defaults to "xhigh")
|
||||
//
|
||||
// Returns false when the budget is unsupported (negative values other than -1).
|
||||
func ThinkingBudgetToEffort(model string, budget int) (string, bool) {
|
||||
switch {
|
||||
case budget == -1:
|
||||
return "auto", true
|
||||
case budget < -1:
|
||||
return "", false
|
||||
case budget == 0:
|
||||
if levels := GetModelThinkingLevels(model); len(levels) > 0 {
|
||||
return levels[0], true
|
||||
}
|
||||
return "none", true
|
||||
case budget > 0 && budget <= 1024:
|
||||
return "low", true
|
||||
case budget <= 8192:
|
||||
return "medium", true
|
||||
case budget <= 24576:
|
||||
return "high", true
|
||||
case budget > 24576:
|
||||
if levels := GetModelThinkingLevels(model); len(levels) > 0 {
|
||||
return levels[len(levels)-1], true
|
||||
}
|
||||
return "xhigh", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,36 +201,6 @@ func ReasoningEffortFromMetadata(metadata map[string]any) (string, bool) {
|
||||
return "", true
|
||||
}
|
||||
|
||||
// ThinkingEffortToBudget maps reasoning effort levels to approximate budgets,
|
||||
// clamping the result to the model's supported range.
|
||||
func ThinkingEffortToBudget(model, effort string) (int, bool) {
|
||||
if effort == "" {
|
||||
return 0, false
|
||||
}
|
||||
normalized, ok := NormalizeReasoningEffortLevel(model, effort)
|
||||
if !ok {
|
||||
normalized = strings.ToLower(strings.TrimSpace(effort))
|
||||
}
|
||||
switch normalized {
|
||||
case "none":
|
||||
return 0, true
|
||||
case "auto":
|
||||
return NormalizeThinkingBudget(model, -1), true
|
||||
case "minimal":
|
||||
return NormalizeThinkingBudget(model, 512), true
|
||||
case "low":
|
||||
return NormalizeThinkingBudget(model, 1024), true
|
||||
case "medium":
|
||||
return NormalizeThinkingBudget(model, 8192), true
|
||||
case "high":
|
||||
return NormalizeThinkingBudget(model, 24576), true
|
||||
case "xhigh":
|
||||
return NormalizeThinkingBudget(model, 32768), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveOriginalModel returns the original model name stored in metadata (if present),
|
||||
// otherwise falls back to the provided model.
|
||||
func ResolveOriginalModel(model string, metadata map[string]any) string {
|
||||
|
||||
303
internal/watcher/diff/config_diff.go
Normal file
303
internal/watcher/diff/config_diff.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// BuildConfigChangeDetails computes a redacted, human-readable list of config changes.
|
||||
// Secrets are never printed; only structural or non-sensitive fields are surfaced.
|
||||
func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
changes := make([]string, 0, 16)
|
||||
if oldCfg == nil || newCfg == nil {
|
||||
return changes
|
||||
}
|
||||
|
||||
// Simple scalars
|
||||
if oldCfg.Port != newCfg.Port {
|
||||
changes = append(changes, fmt.Sprintf("port: %d -> %d", oldCfg.Port, newCfg.Port))
|
||||
}
|
||||
if oldCfg.AuthDir != newCfg.AuthDir {
|
||||
changes = append(changes, fmt.Sprintf("auth-dir: %s -> %s", oldCfg.AuthDir, newCfg.AuthDir))
|
||||
}
|
||||
if oldCfg.Debug != newCfg.Debug {
|
||||
changes = append(changes, fmt.Sprintf("debug: %t -> %t", oldCfg.Debug, newCfg.Debug))
|
||||
}
|
||||
if oldCfg.LoggingToFile != newCfg.LoggingToFile {
|
||||
changes = append(changes, fmt.Sprintf("logging-to-file: %t -> %t", oldCfg.LoggingToFile, newCfg.LoggingToFile))
|
||||
}
|
||||
if oldCfg.UsageStatisticsEnabled != newCfg.UsageStatisticsEnabled {
|
||||
changes = append(changes, fmt.Sprintf("usage-statistics-enabled: %t -> %t", oldCfg.UsageStatisticsEnabled, newCfg.UsageStatisticsEnabled))
|
||||
}
|
||||
if oldCfg.DisableCooling != newCfg.DisableCooling {
|
||||
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
|
||||
}
|
||||
if oldCfg.RequestLog != newCfg.RequestLog {
|
||||
changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog))
|
||||
}
|
||||
if oldCfg.RequestRetry != newCfg.RequestRetry {
|
||||
changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry))
|
||||
}
|
||||
if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval {
|
||||
changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval))
|
||||
}
|
||||
if oldCfg.ProxyURL != newCfg.ProxyURL {
|
||||
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", formatProxyURL(oldCfg.ProxyURL), formatProxyURL(newCfg.ProxyURL)))
|
||||
}
|
||||
if oldCfg.WebsocketAuth != newCfg.WebsocketAuth {
|
||||
changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth))
|
||||
}
|
||||
if oldCfg.ForceModelPrefix != newCfg.ForceModelPrefix {
|
||||
changes = append(changes, fmt.Sprintf("force-model-prefix: %t -> %t", oldCfg.ForceModelPrefix, newCfg.ForceModelPrefix))
|
||||
}
|
||||
|
||||
// Quota-exceeded behavior
|
||||
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
||||
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-project: %t -> %t", oldCfg.QuotaExceeded.SwitchProject, newCfg.QuotaExceeded.SwitchProject))
|
||||
}
|
||||
if oldCfg.QuotaExceeded.SwitchPreviewModel != newCfg.QuotaExceeded.SwitchPreviewModel {
|
||||
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-preview-model: %t -> %t", oldCfg.QuotaExceeded.SwitchPreviewModel, newCfg.QuotaExceeded.SwitchPreviewModel))
|
||||
}
|
||||
|
||||
// API keys (redacted) and counts
|
||||
if len(oldCfg.APIKeys) != len(newCfg.APIKeys) {
|
||||
changes = append(changes, fmt.Sprintf("api-keys count: %d -> %d", len(oldCfg.APIKeys), len(newCfg.APIKeys)))
|
||||
} else if !reflect.DeepEqual(trimStrings(oldCfg.APIKeys), trimStrings(newCfg.APIKeys)) {
|
||||
changes = append(changes, "api-keys: values updated (count unchanged, redacted)")
|
||||
}
|
||||
if len(oldCfg.GeminiKey) != len(newCfg.GeminiKey) {
|
||||
changes = append(changes, fmt.Sprintf("gemini-api-key count: %d -> %d", len(oldCfg.GeminiKey), len(newCfg.GeminiKey)))
|
||||
} else {
|
||||
for i := range oldCfg.GeminiKey {
|
||||
o := oldCfg.GeminiKey[i]
|
||||
n := newCfg.GeminiKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Claude keys (do not print key material)
|
||||
if len(oldCfg.ClaudeKey) != len(newCfg.ClaudeKey) {
|
||||
changes = append(changes, fmt.Sprintf("claude-api-key count: %d -> %d", len(oldCfg.ClaudeKey), len(newCfg.ClaudeKey)))
|
||||
} else {
|
||||
for i := range oldCfg.ClaudeKey {
|
||||
o := oldCfg.ClaudeKey[i]
|
||||
n := newCfg.ClaudeKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Codex keys (do not print key material)
|
||||
if len(oldCfg.CodexKey) != len(newCfg.CodexKey) {
|
||||
changes = append(changes, fmt.Sprintf("codex-api-key count: %d -> %d", len(oldCfg.CodexKey), len(newCfg.CodexKey)))
|
||||
} else {
|
||||
for i := range oldCfg.CodexKey {
|
||||
o := oldCfg.CodexKey[i]
|
||||
n := newCfg.CodexKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AmpCode settings (redacted where needed)
|
||||
oldAmpURL := strings.TrimSpace(oldCfg.AmpCode.UpstreamURL)
|
||||
newAmpURL := strings.TrimSpace(newCfg.AmpCode.UpstreamURL)
|
||||
if oldAmpURL != newAmpURL {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.upstream-url: %s -> %s", oldAmpURL, newAmpURL))
|
||||
}
|
||||
oldAmpKey := strings.TrimSpace(oldCfg.AmpCode.UpstreamAPIKey)
|
||||
newAmpKey := strings.TrimSpace(newCfg.AmpCode.UpstreamAPIKey)
|
||||
switch {
|
||||
case oldAmpKey == "" && newAmpKey != "":
|
||||
changes = append(changes, "ampcode.upstream-api-key: added")
|
||||
case oldAmpKey != "" && newAmpKey == "":
|
||||
changes = append(changes, "ampcode.upstream-api-key: removed")
|
||||
case oldAmpKey != newAmpKey:
|
||||
changes = append(changes, "ampcode.upstream-api-key: updated")
|
||||
}
|
||||
if oldCfg.AmpCode.RestrictManagementToLocalhost != newCfg.AmpCode.RestrictManagementToLocalhost {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.restrict-management-to-localhost: %t -> %t", oldCfg.AmpCode.RestrictManagementToLocalhost, newCfg.AmpCode.RestrictManagementToLocalhost))
|
||||
}
|
||||
oldMappings := SummarizeAmpModelMappings(oldCfg.AmpCode.ModelMappings)
|
||||
newMappings := SummarizeAmpModelMappings(newCfg.AmpCode.ModelMappings)
|
||||
if oldMappings.hash != newMappings.hash {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.model-mappings: updated (%d -> %d entries)", oldMappings.count, newMappings.count))
|
||||
}
|
||||
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
|
||||
}
|
||||
|
||||
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
|
||||
changes = append(changes, entries...)
|
||||
}
|
||||
|
||||
// Remote management (never print the key)
|
||||
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote))
|
||||
}
|
||||
if oldCfg.RemoteManagement.DisableControlPanel != newCfg.RemoteManagement.DisableControlPanel {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.disable-control-panel: %t -> %t", oldCfg.RemoteManagement.DisableControlPanel, newCfg.RemoteManagement.DisableControlPanel))
|
||||
}
|
||||
oldPanelRepo := strings.TrimSpace(oldCfg.RemoteManagement.PanelGitHubRepository)
|
||||
newPanelRepo := strings.TrimSpace(newCfg.RemoteManagement.PanelGitHubRepository)
|
||||
if oldPanelRepo != newPanelRepo {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.panel-github-repository: %s -> %s", oldPanelRepo, newPanelRepo))
|
||||
}
|
||||
if oldCfg.RemoteManagement.SecretKey != newCfg.RemoteManagement.SecretKey {
|
||||
switch {
|
||||
case oldCfg.RemoteManagement.SecretKey == "" && newCfg.RemoteManagement.SecretKey != "":
|
||||
changes = append(changes, "remote-management.secret-key: created")
|
||||
case oldCfg.RemoteManagement.SecretKey != "" && newCfg.RemoteManagement.SecretKey == "":
|
||||
changes = append(changes, "remote-management.secret-key: deleted")
|
||||
default:
|
||||
changes = append(changes, "remote-management.secret-key: updated")
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI compatibility providers (summarized)
|
||||
if compat := DiffOpenAICompatibility(oldCfg.OpenAICompatibility, newCfg.OpenAICompatibility); len(compat) > 0 {
|
||||
changes = append(changes, "openai-compatibility:")
|
||||
for _, c := range compat {
|
||||
changes = append(changes, " "+c)
|
||||
}
|
||||
}
|
||||
|
||||
// Vertex-compatible API keys
|
||||
if len(oldCfg.VertexCompatAPIKey) != len(newCfg.VertexCompatAPIKey) {
|
||||
changes = append(changes, fmt.Sprintf("vertex-api-key count: %d -> %d", len(oldCfg.VertexCompatAPIKey), len(newCfg.VertexCompatAPIKey)))
|
||||
} else {
|
||||
for i := range oldCfg.VertexCompatAPIKey {
|
||||
o := oldCfg.VertexCompatAPIKey[i]
|
||||
n := newCfg.VertexCompatAPIKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].api-key: updated", i))
|
||||
}
|
||||
oldModels := SummarizeVertexModels(o.Models)
|
||||
newModels := SummarizeVertexModels(n.Models)
|
||||
if oldModels.hash != newModels.hash {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].headers: updated", i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changes
|
||||
}
|
||||
|
||||
func trimStrings(in []string) []string {
|
||||
out := make([]string, len(in))
|
||||
for i := range in {
|
||||
out[i] = strings.TrimSpace(in[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func equalStringMap(a, b map[string]string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for k, v := range a {
|
||||
if b[k] != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formatProxyURL(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "<none>"
|
||||
}
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return "<redacted>"
|
||||
}
|
||||
host := strings.TrimSpace(parsed.Host)
|
||||
scheme := strings.TrimSpace(parsed.Scheme)
|
||||
if host == "" {
|
||||
// Allow host:port style without scheme.
|
||||
parsed2, err2 := url.Parse("http://" + trimmed)
|
||||
if err2 == nil {
|
||||
host = strings.TrimSpace(parsed2.Host)
|
||||
}
|
||||
scheme = ""
|
||||
}
|
||||
if host == "" {
|
||||
return "<redacted>"
|
||||
}
|
||||
if scheme == "" {
|
||||
return host
|
||||
}
|
||||
return scheme + "://" + host
|
||||
}
|
||||
529
internal/watcher/diff/config_diff_test.go
Normal file
529
internal/watcher/diff/config_diff_test.go
Normal file
@@ -0,0 +1,529 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestBuildConfigChangeDetails(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
Port: 8080,
|
||||
AuthDir: "/tmp/auth-old",
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model"}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://old-upstream",
|
||||
ModelMappings: []config.AmpModelMapping{{From: "from-old", To: "to-old"}},
|
||||
RestrictManagementToLocalhost: false,
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: false,
|
||||
SecretKey: "old",
|
||||
DisableControlPanel: false,
|
||||
PanelGitHubRepository: "repo-old",
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"providerA": {"m1"},
|
||||
},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "compat-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
newCfg := &config.Config{
|
||||
Port: 9090,
|
||||
AuthDir: "/tmp/auth-new",
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model", "extra"}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://new-upstream",
|
||||
RestrictManagementToLocalhost: true,
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "from-old", To: "to-old"},
|
||||
{From: "from-new", To: "to-new"},
|
||||
},
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: true,
|
||||
SecretKey: "new",
|
||||
DisableControlPanel: true,
|
||||
PanelGitHubRepository: "repo-new",
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"providerA": {"m1", "m2"},
|
||||
"providerB": {"x"},
|
||||
},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "compat-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}, {Name: "m2"}},
|
||||
},
|
||||
{
|
||||
Name: "compat-b",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
|
||||
expectContains(t, details, "port: 8080 -> 9090")
|
||||
expectContains(t, details, "auth-dir: /tmp/auth-old -> /tmp/auth-new")
|
||||
expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "ampcode.upstream-url: http://old-upstream -> http://new-upstream")
|
||||
expectContains(t, details, "ampcode.model-mappings: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "remote-management.allow-remote: false -> true")
|
||||
expectContains(t, details, "remote-management.secret-key: updated")
|
||||
expectContains(t, details, "oauth-excluded-models[providera]: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "oauth-excluded-models[providerb]: added (1 entries)")
|
||||
expectContains(t, details, "openai-compatibility:")
|
||||
expectContains(t, details, " provider added: compat-b (api-keys=1, models=0)")
|
||||
expectContains(t, details, " provider updated: compat-a (models 1 -> 2)")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_NoChanges(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Port: 8080,
|
||||
}
|
||||
if details := BuildConfigChangeDetails(cfg, cfg); len(details) != 0 {
|
||||
t.Fatalf("expected no change entries, got %v", details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_GeminiVertexHeadersAndForceMappings(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", BaseURL: "http://v-old", Models: []config.VertexCompatModel{{Name: "m1"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
|
||||
ForceModelMappings: false,
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"a", "b"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", BaseURL: "http://v-new", Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}},
|
||||
ForceModelMappings: true,
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, details, "gemini[0].headers: updated")
|
||||
expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "ampcode.model-mappings: updated (1 -> 1 entries)")
|
||||
expectContains(t, details, "ampcode.force-model-mappings: false -> true")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_ModelPrefixes(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Prefix: "old-g", BaseURL: "http://g", ProxyURL: "http://gp"},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c1", Prefix: "old-c", BaseURL: "http://c", ProxyURL: "http://cp"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x1", Prefix: "old-x", BaseURL: "http://x", ProxyURL: "http://xp"},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", Prefix: "old-v", BaseURL: "http://v", ProxyURL: "http://vp"},
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Prefix: "new-g", BaseURL: "http://g", ProxyURL: "http://gp"},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c1", Prefix: "new-c", BaseURL: "http://c", ProxyURL: "http://cp"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x1", Prefix: "new-x", BaseURL: "http://x", ProxyURL: "http://xp"},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", Prefix: "new-v", BaseURL: "http://v", ProxyURL: "http://vp"},
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "gemini[0].prefix: old-g -> new-g")
|
||||
expectContains(t, changes, "claude[0].prefix: old-c -> new-c")
|
||||
expectContains(t, changes, "codex[0].prefix: old-x -> new-x")
|
||||
expectContains(t, changes, "vertex[0].prefix: old-v -> new-v")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_NilSafe(t *testing.T) {
|
||||
if details := BuildConfigChangeDetails(nil, &config.Config{}); len(details) != 0 {
|
||||
t.Fatalf("expected empty change list when old nil, got %v", details)
|
||||
}
|
||||
if details := BuildConfigChangeDetails(&config.Config{}, nil); len(details) != 0 {
|
||||
t.Fatalf("expected empty change list when new nil, got %v", details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_SecretsAndCounts(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
APIKeys: []string{"a"},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "",
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
APIKeys: []string{"a", "b", "c"},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "new-key",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "new-secret",
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, details, "api-keys count: 1 -> 3")
|
||||
expectContains(t, details, "ampcode.upstream-api-key: added")
|
||||
expectContains(t, details, "remote-management.secret-key: created")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
Port: 1000,
|
||||
AuthDir: "/old",
|
||||
Debug: false,
|
||||
LoggingToFile: false,
|
||||
UsageStatisticsEnabled: false,
|
||||
DisableCooling: false,
|
||||
RequestRetry: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}},
|
||||
CodexKey: []config.CodexKey{{APIKey: "x1"}},
|
||||
AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false},
|
||||
RemoteManagement: config.RemoteManagement{DisableControlPanel: false, PanelGitHubRepository: "old/repo", SecretKey: "keep"},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: false,
|
||||
ProxyURL: "http://old-proxy",
|
||||
APIKeys: []string{"key-1"},
|
||||
ForceModelPrefix: false,
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
Port: 2000,
|
||||
AuthDir: "/new",
|
||||
Debug: true,
|
||||
LoggingToFile: true,
|
||||
UsageStatisticsEnabled: true,
|
||||
DisableCooling: true,
|
||||
RequestRetry: 2,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c1", BaseURL: "http://new", ProxyURL: "http://p", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}},
|
||||
{APIKey: "c2"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x1", BaseURL: "http://x", ProxyURL: "http://px", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"b"}},
|
||||
{APIKey: "x2"},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "",
|
||||
RestrictManagementToLocalhost: true,
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
DisableControlPanel: true,
|
||||
PanelGitHubRepository: "new/repo",
|
||||
SecretKey: "",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{" key-1 ", "key-2"},
|
||||
ForceModelPrefix: true,
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, details, "debug: false -> true")
|
||||
expectContains(t, details, "logging-to-file: false -> true")
|
||||
expectContains(t, details, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, details, "disable-cooling: false -> true")
|
||||
expectContains(t, details, "request-log: false -> true")
|
||||
expectContains(t, details, "request-retry: 1 -> 2")
|
||||
expectContains(t, details, "max-retry-interval: 1 -> 3")
|
||||
expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||
expectContains(t, details, "ws-auth: false -> true")
|
||||
expectContains(t, details, "force-model-prefix: false -> true")
|
||||
expectContains(t, details, "quota-exceeded.switch-project: false -> true")
|
||||
expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true")
|
||||
expectContains(t, details, "api-keys count: 1 -> 2")
|
||||
expectContains(t, details, "claude-api-key count: 1 -> 2")
|
||||
expectContains(t, details, "codex-api-key count: 1 -> 2")
|
||||
expectContains(t, details, "ampcode.restrict-management-to-localhost: false -> true")
|
||||
expectContains(t, details, "ampcode.upstream-api-key: removed")
|
||||
expectContains(t, details, "remote-management.disable-control-panel: false -> true")
|
||||
expectContains(t, details, "remote-management.panel-github-repository: old/repo -> new/repo")
|
||||
expectContains(t, details, "remote-management.secret-key: deleted")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
Port: 1,
|
||||
AuthDir: "/a",
|
||||
Debug: false,
|
||||
LoggingToFile: false,
|
||||
UsageStatisticsEnabled: false,
|
||||
DisableCooling: false,
|
||||
RequestRetry: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g-old", BaseURL: "http://g-old", ProxyURL: "http://gp-old", Headers: map[string]string{"A": "1"}},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c-old", BaseURL: "http://c-old", ProxyURL: "http://cp-old", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"x"}},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x-old", BaseURL: "http://x-old", ProxyURL: "http://xp-old", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"x"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v-old", BaseURL: "http://v-old", ProxyURL: "http://vp-old", Headers: map[string]string{"H": "1"}, Models: []config.VertexCompatModel{{Name: "m1"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://amp-old",
|
||||
UpstreamAPIKey: "old-key",
|
||||
RestrictManagementToLocalhost: false,
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
|
||||
ForceModelMappings: false,
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: false,
|
||||
DisableControlPanel: false,
|
||||
PanelGitHubRepository: "old/repo",
|
||||
SecretKey: "old",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: false,
|
||||
ProxyURL: "http://old-proxy",
|
||||
APIKeys: []string{" keyA "},
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{"p1": {"a"}},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "prov-old",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
Port: 2,
|
||||
AuthDir: "/b",
|
||||
Debug: true,
|
||||
LoggingToFile: true,
|
||||
UsageStatisticsEnabled: true,
|
||||
DisableCooling: true,
|
||||
RequestRetry: 2,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g-new", BaseURL: "http://g-new", ProxyURL: "http://gp-new", Headers: map[string]string{"A": "2"}, ExcludedModels: []string{"x", "y"}},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c-new", BaseURL: "http://c-new", ProxyURL: "http://cp-new", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"x", "y"}},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x-new", BaseURL: "http://x-new", ProxyURL: "http://xp-new", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"x", "y"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v-new", BaseURL: "http://v-new", ProxyURL: "http://vp-new", Headers: map[string]string{"H": "2"}, Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://amp-new",
|
||||
UpstreamAPIKey: "",
|
||||
RestrictManagementToLocalhost: true,
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}},
|
||||
ForceModelMappings: true,
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: true,
|
||||
DisableControlPanel: true,
|
||||
PanelGitHubRepository: "new/repo",
|
||||
SecretKey: "",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{"keyB"},
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{"p1": {"b", "c"}, "p2": {"d"}},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "prov-old",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
{APIKey: "k2"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}, {Name: "m2"}},
|
||||
},
|
||||
{
|
||||
Name: "prov-new",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "k3"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "port: 1 -> 2")
|
||||
expectContains(t, changes, "auth-dir: /a -> /b")
|
||||
expectContains(t, changes, "debug: false -> true")
|
||||
expectContains(t, changes, "logging-to-file: false -> true")
|
||||
expectContains(t, changes, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, changes, "disable-cooling: false -> true")
|
||||
expectContains(t, changes, "request-retry: 1 -> 2")
|
||||
expectContains(t, changes, "max-retry-interval: 1 -> 3")
|
||||
expectContains(t, changes, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||
expectContains(t, changes, "ws-auth: false -> true")
|
||||
expectContains(t, changes, "quota-exceeded.switch-project: false -> true")
|
||||
expectContains(t, changes, "quota-exceeded.switch-preview-model: false -> true")
|
||||
expectContains(t, changes, "api-keys: values updated (count unchanged, redacted)")
|
||||
expectContains(t, changes, "gemini[0].base-url: http://g-old -> http://g-new")
|
||||
expectContains(t, changes, "gemini[0].proxy-url: http://gp-old -> http://gp-new")
|
||||
expectContains(t, changes, "gemini[0].api-key: updated")
|
||||
expectContains(t, changes, "gemini[0].headers: updated")
|
||||
expectContains(t, changes, "gemini[0].excluded-models: updated (0 -> 2 entries)")
|
||||
expectContains(t, changes, "claude[0].base-url: http://c-old -> http://c-new")
|
||||
expectContains(t, changes, "claude[0].proxy-url: http://cp-old -> http://cp-new")
|
||||
expectContains(t, changes, "claude[0].api-key: updated")
|
||||
expectContains(t, changes, "claude[0].headers: updated")
|
||||
expectContains(t, changes, "claude[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "codex[0].base-url: http://x-old -> http://x-new")
|
||||
expectContains(t, changes, "codex[0].proxy-url: http://xp-old -> http://xp-new")
|
||||
expectContains(t, changes, "codex[0].api-key: updated")
|
||||
expectContains(t, changes, "codex[0].headers: updated")
|
||||
expectContains(t, changes, "codex[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "vertex[0].base-url: http://v-old -> http://v-new")
|
||||
expectContains(t, changes, "vertex[0].proxy-url: http://vp-old -> http://vp-new")
|
||||
expectContains(t, changes, "vertex[0].api-key: updated")
|
||||
expectContains(t, changes, "vertex[0].models: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "vertex[0].headers: updated")
|
||||
expectContains(t, changes, "ampcode.upstream-url: http://amp-old -> http://amp-new")
|
||||
expectContains(t, changes, "ampcode.upstream-api-key: removed")
|
||||
expectContains(t, changes, "ampcode.restrict-management-to-localhost: false -> true")
|
||||
expectContains(t, changes, "ampcode.model-mappings: updated (1 -> 1 entries)")
|
||||
expectContains(t, changes, "ampcode.force-model-mappings: false -> true")
|
||||
expectContains(t, changes, "oauth-excluded-models[p1]: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "oauth-excluded-models[p2]: added (1 entries)")
|
||||
expectContains(t, changes, "remote-management.allow-remote: false -> true")
|
||||
expectContains(t, changes, "remote-management.disable-control-panel: false -> true")
|
||||
expectContains(t, changes, "remote-management.panel-github-repository: old/repo -> new/repo")
|
||||
expectContains(t, changes, "remote-management.secret-key: deleted")
|
||||
expectContains(t, changes, "openai-compatibility:")
|
||||
}
|
||||
|
||||
func TestFormatProxyURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "empty", in: "", want: "<none>"},
|
||||
{name: "invalid", in: "http://[::1", want: "<redacted>"},
|
||||
{name: "fullURLRedactsUserinfoAndPath", in: "http://user:pass@example.com:8080/path?x=1#frag", want: "http://example.com:8080"},
|
||||
{name: "socks5RedactsUserinfoAndPath", in: "socks5://user:pass@192.168.1.1:1080/path?x=1", want: "socks5://192.168.1.1:1080"},
|
||||
{name: "socks5HostPort", in: "socks5://proxy.example.com:1080/", want: "socks5://proxy.example.com:1080"},
|
||||
{name: "hostPortNoScheme", in: "example.com:1234/path?x=1", want: "example.com:1234"},
|
||||
{name: "relativePathRedacted", in: "/just/path", want: "<redacted>"},
|
||||
{name: "schemeAndHost", in: "https://example.com", want: "https://example.com"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := formatProxyURL(tt.in); got != tt.want {
|
||||
t.Fatalf("expected %q, got %q", tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_SecretAndUpstreamUpdates(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "old",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "old",
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "new",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "new",
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "ampcode.upstream-api-key: updated")
|
||||
expectContains(t, changes, "remote-management.secret-key: updated")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_CountBranches(t *testing.T) {
|
||||
oldCfg := &config.Config{}
|
||||
newCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{{APIKey: "g"}},
|
||||
ClaudeKey: []config.ClaudeKey{{APIKey: "c"}},
|
||||
CodexKey: []config.CodexKey{{APIKey: "x"}},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v", BaseURL: "http://v"},
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "gemini-api-key count: 0 -> 1")
|
||||
expectContains(t, changes, "claude-api-key count: 0 -> 1")
|
||||
expectContains(t, changes, "codex-api-key count: 0 -> 1")
|
||||
expectContains(t, changes, "vertex-api-key count: 0 -> 1")
|
||||
}
|
||||
|
||||
func TestTrimStrings(t *testing.T) {
|
||||
out := trimStrings([]string{" a ", "b", " c"})
|
||||
if len(out) != 3 || out[0] != "a" || out[1] != "b" || out[2] != "c" {
|
||||
t.Fatalf("unexpected trimmed strings: %v", out)
|
||||
}
|
||||
}
|
||||
102
internal/watcher/diff/model_hash.go
Normal file
102
internal/watcher/diff/model_hash.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// ComputeOpenAICompatModelsHash returns a stable hash for OpenAI-compat models.
|
||||
// Used to detect model list changes during hot reload.
|
||||
func ComputeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeVertexCompatModelsHash returns a stable hash for Vertex-compatible models.
|
||||
func ComputeVertexCompatModelsHash(models []config.VertexCompatModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeClaudeModelsHash returns a stable hash for Claude model aliases.
|
||||
func ComputeClaudeModelsHash(models []config.ClaudeModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeExcludedModelsHash returns a normalized hash for excluded model lists.
|
||||
func ComputeExcludedModelsHash(excluded []string) string {
|
||||
if len(excluded) == 0 {
|
||||
return ""
|
||||
}
|
||||
normalized := make([]string, 0, len(excluded))
|
||||
for _, entry := range excluded {
|
||||
if trimmed := strings.TrimSpace(entry); trimmed != "" {
|
||||
normalized = append(normalized, strings.ToLower(trimmed))
|
||||
}
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return ""
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
data, _ := json.Marshal(normalized)
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func normalizeModelPairs(collect func(out func(key string))) []string {
|
||||
seen := make(map[string]struct{})
|
||||
keys := make([]string, 0)
|
||||
collect(func(key string) {
|
||||
if _, exists := seen[key]; exists {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
keys = append(keys, key)
|
||||
})
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func hashJoined(keys []string) string {
|
||||
if len(keys) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(strings.Join(keys, "\n")))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
159
internal/watcher/diff/model_hash_test.go
Normal file
159
internal/watcher/diff/model_hash_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_Deterministic(t *testing.T) {
|
||||
models := []config.OpenAICompatibilityModel{
|
||||
{Name: "gpt-4", Alias: "gpt4"},
|
||||
{Name: "gpt-3.5-turbo"},
|
||||
}
|
||||
hash1 := ComputeOpenAICompatModelsHash(models)
|
||||
hash2 := ComputeOpenAICompatModelsHash(models)
|
||||
if hash1 == "" {
|
||||
t.Fatal("hash should not be empty")
|
||||
}
|
||||
if hash1 != hash2 {
|
||||
t.Fatalf("hash should be deterministic, got %s vs %s", hash1, hash2)
|
||||
}
|
||||
changed := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: "gpt-4"}, {Name: "gpt-4.1"}})
|
||||
if hash1 == changed {
|
||||
t.Fatal("hash should change when model list changes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_NormalizesAndDedups(t *testing.T) {
|
||||
a := []config.OpenAICompatibilityModel{
|
||||
{Name: "gpt-4", Alias: "gpt4"},
|
||||
{Name: " "},
|
||||
{Name: "GPT-4", Alias: "GPT4"},
|
||||
{Alias: "a1"},
|
||||
}
|
||||
b := []config.OpenAICompatibilityModel{
|
||||
{Alias: "A1"},
|
||||
{Name: "gpt-4", Alias: "gpt4"},
|
||||
}
|
||||
h1 := ComputeOpenAICompatModelsHash(a)
|
||||
h2 := ComputeOpenAICompatModelsHash(b)
|
||||
if h1 == "" || h2 == "" {
|
||||
t.Fatal("expected non-empty hashes for non-empty model sets")
|
||||
}
|
||||
if h1 != h2 {
|
||||
t.Fatalf("expected normalized hashes to match, got %s / %s", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeVertexCompatModelsHash_DifferentInputs(t *testing.T) {
|
||||
models := []config.VertexCompatModel{{Name: "gemini-pro", Alias: "pro"}}
|
||||
hash1 := ComputeVertexCompatModelsHash(models)
|
||||
hash2 := ComputeVertexCompatModelsHash([]config.VertexCompatModel{{Name: "gemini-1.5-pro", Alias: "pro"}})
|
||||
if hash1 == "" || hash2 == "" {
|
||||
t.Fatal("hashes should not be empty for non-empty models")
|
||||
}
|
||||
if hash1 == hash2 {
|
||||
t.Fatal("hash should differ when model content differs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeVertexCompatModelsHash_IgnoresBlankAndOrder(t *testing.T) {
|
||||
a := []config.VertexCompatModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
{Name: " "},
|
||||
{Name: "M1", Alias: "A1"},
|
||||
}
|
||||
b := []config.VertexCompatModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
}
|
||||
if h1, h2 := ComputeVertexCompatModelsHash(a), ComputeVertexCompatModelsHash(b); h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeClaudeModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil models, got %q", got)
|
||||
}
|
||||
if got := ComputeClaudeModelsHash([]config.ClaudeModel{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
|
||||
a := []config.ClaudeModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
{Name: " "},
|
||||
{Name: "M1", Alias: "A1"},
|
||||
}
|
||||
b := []config.ClaudeModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
}
|
||||
if h1, h2 := ComputeClaudeModelsHash(a), ComputeClaudeModelsHash(b); h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeExcludedModelsHash_Normalizes(t *testing.T) {
|
||||
hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"})
|
||||
hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"})
|
||||
if hash1 == "" || hash2 == "" {
|
||||
t.Fatal("hash should not be empty for non-empty input")
|
||||
}
|
||||
if hash1 != hash2 {
|
||||
t.Fatalf("hash should be order/space insensitive for same multiset, got %s vs %s", hash1, hash2)
|
||||
}
|
||||
hash3 := ComputeExcludedModelsHash([]string{"c"})
|
||||
if hash1 == hash3 {
|
||||
t.Fatal("hash should differ for different normalized sets")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeOpenAICompatModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil input, got %q", got)
|
||||
}
|
||||
if got := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
if got := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: " "}, {Alias: ""}}); got != "" {
|
||||
t.Fatalf("expected empty hash for blank models, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeVertexCompatModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeVertexCompatModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil input, got %q", got)
|
||||
}
|
||||
if got := ComputeVertexCompatModelsHash([]config.VertexCompatModel{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
if got := ComputeVertexCompatModelsHash([]config.VertexCompatModel{{Name: " "}}); got != "" {
|
||||
t.Fatalf("expected empty hash for blank models, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeExcludedModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeExcludedModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil input, got %q", got)
|
||||
}
|
||||
if got := ComputeExcludedModelsHash([]string{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
if got := ComputeExcludedModelsHash([]string{" ", ""}); got != "" {
|
||||
t.Fatalf("expected empty hash for whitespace-only entries, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeModelsHash_Deterministic(t *testing.T) {
|
||||
models := []config.ClaudeModel{{Name: "a", Alias: "A"}, {Name: "b"}}
|
||||
h1 := ComputeClaudeModelsHash(models)
|
||||
h2 := ComputeClaudeModelsHash(models)
|
||||
if h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected deterministic hash, got %s / %s", h1, h2)
|
||||
}
|
||||
if h3 := ComputeClaudeModelsHash([]config.ClaudeModel{{Name: "a"}}); h3 == h1 {
|
||||
t.Fatalf("expected different hash when models change, got %s", h3)
|
||||
}
|
||||
}
|
||||
151
internal/watcher/diff/oauth_excluded.go
Normal file
151
internal/watcher/diff/oauth_excluded.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
type ExcludedModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeExcludedModels normalizes and hashes an excluded-model list.
|
||||
func SummarizeExcludedModels(list []string) ExcludedModelsSummary {
|
||||
if len(list) == 0 {
|
||||
return ExcludedModelsSummary{}
|
||||
}
|
||||
seen := make(map[string]struct{}, len(list))
|
||||
normalized := make([]string, 0, len(list))
|
||||
for _, entry := range list {
|
||||
if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" {
|
||||
if _, exists := seen[trimmed]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
return ExcludedModelsSummary{
|
||||
hash: ComputeExcludedModelsHash(normalized),
|
||||
count: len(normalized),
|
||||
}
|
||||
}
|
||||
|
||||
// SummarizeOAuthExcludedModels summarizes OAuth excluded models per provider.
|
||||
func SummarizeOAuthExcludedModels(entries map[string][]string) map[string]ExcludedModelsSummary {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]ExcludedModelsSummary, len(entries))
|
||||
for k, v := range entries {
|
||||
key := strings.ToLower(strings.TrimSpace(k))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = SummarizeExcludedModels(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// DiffOAuthExcludedModelChanges compares OAuth excluded models maps.
|
||||
func DiffOAuthExcludedModelChanges(oldMap, newMap map[string][]string) ([]string, []string) {
|
||||
oldSummary := SummarizeOAuthExcludedModels(oldMap)
|
||||
newSummary := SummarizeOAuthExcludedModels(newMap)
|
||||
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
|
||||
for k := range oldSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
for k := range newSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
changes := make([]string, 0, len(keys))
|
||||
affected := make([]string, 0, len(keys))
|
||||
for key := range keys {
|
||||
oldInfo, okOld := oldSummary[key]
|
||||
newInfo, okNew := newSummary[key]
|
||||
switch {
|
||||
case okOld && !okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: removed", key))
|
||||
affected = append(affected, key)
|
||||
case !okOld && okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: added (%d entries)", key, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
case okOld && okNew && oldInfo.hash != newInfo.hash:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(changes)
|
||||
sort.Strings(affected)
|
||||
return changes, affected
|
||||
}
|
||||
|
||||
type AmpModelMappingsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeAmpModelMappings hashes Amp model mappings for change detection.
|
||||
func SummarizeAmpModelMappings(mappings []config.AmpModelMapping) AmpModelMappingsSummary {
|
||||
if len(mappings) == 0 {
|
||||
return AmpModelMappingsSummary{}
|
||||
}
|
||||
entries := make([]string, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
to := strings.TrimSpace(mapping.To)
|
||||
if from == "" && to == "" {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, from+"->"+to)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return AmpModelMappingsSummary{}
|
||||
}
|
||||
sort.Strings(entries)
|
||||
sum := sha256.Sum256([]byte(strings.Join(entries, "|")))
|
||||
return AmpModelMappingsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(entries),
|
||||
}
|
||||
}
|
||||
|
||||
type VertexModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeVertexModels hashes vertex-compatible models for change detection.
|
||||
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
names := make([]string, 0, len(models))
|
||||
for _, m := range models {
|
||||
name := strings.TrimSpace(m.Name)
|
||||
alias := strings.TrimSpace(m.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
if alias != "" {
|
||||
name = alias
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
if len(names) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
sort.Strings(names)
|
||||
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
|
||||
return VertexModelsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(names),
|
||||
}
|
||||
}
|
||||
109
internal/watcher/diff/oauth_excluded_test.go
Normal file
109
internal/watcher/diff/oauth_excluded_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestSummarizeExcludedModels_NormalizesAndDedupes(t *testing.T) {
|
||||
summary := SummarizeExcludedModels([]string{"A", " a ", "B", "b"})
|
||||
if summary.count != 2 {
|
||||
t.Fatalf("expected 2 unique entries, got %d", summary.count)
|
||||
}
|
||||
if summary.hash == "" {
|
||||
t.Fatal("expected non-empty hash")
|
||||
}
|
||||
if empty := SummarizeExcludedModels(nil); empty.count != 0 || empty.hash != "" {
|
||||
t.Fatalf("expected empty summary for nil input, got %+v", empty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiffOAuthExcludedModelChanges(t *testing.T) {
|
||||
oldMap := map[string][]string{
|
||||
"ProviderA": {"model-1", "model-2"},
|
||||
"providerB": {"x"},
|
||||
}
|
||||
newMap := map[string][]string{
|
||||
"providerA": {"model-1", "model-3"},
|
||||
"providerC": {"y"},
|
||||
}
|
||||
|
||||
changes, affected := DiffOAuthExcludedModelChanges(oldMap, newMap)
|
||||
expectContains(t, changes, "oauth-excluded-models[providera]: updated (2 -> 2 entries)")
|
||||
expectContains(t, changes, "oauth-excluded-models[providerb]: removed")
|
||||
expectContains(t, changes, "oauth-excluded-models[providerc]: added (1 entries)")
|
||||
|
||||
if len(affected) != 3 {
|
||||
t.Fatalf("expected 3 affected providers, got %d", len(affected))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeAmpModelMappings(t *testing.T) {
|
||||
summary := SummarizeAmpModelMappings([]config.AmpModelMapping{
|
||||
{From: "a", To: "A"},
|
||||
{From: "b", To: "B"},
|
||||
{From: " ", To: " "}, // ignored
|
||||
})
|
||||
if summary.count != 2 {
|
||||
t.Fatalf("expected 2 entries, got %d", summary.count)
|
||||
}
|
||||
if summary.hash == "" {
|
||||
t.Fatal("expected non-empty hash")
|
||||
}
|
||||
if empty := SummarizeAmpModelMappings(nil); empty.count != 0 || empty.hash != "" {
|
||||
t.Fatalf("expected empty summary for nil input, got %+v", empty)
|
||||
}
|
||||
if blank := SummarizeAmpModelMappings([]config.AmpModelMapping{{From: " ", To: " "}}); blank.count != 0 || blank.hash != "" {
|
||||
t.Fatalf("expected blank mappings ignored, got %+v", blank)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeOAuthExcludedModels_NormalizesKeys(t *testing.T) {
|
||||
out := SummarizeOAuthExcludedModels(map[string][]string{
|
||||
"ProvA": {"X"},
|
||||
"": {"ignored"},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected only non-empty key summary, got %d", len(out))
|
||||
}
|
||||
if _, ok := out["prova"]; !ok {
|
||||
t.Fatalf("expected normalized key 'prova', got keys %v", out)
|
||||
}
|
||||
if out["prova"].count != 1 || out["prova"].hash == "" {
|
||||
t.Fatalf("unexpected summary %+v", out["prova"])
|
||||
}
|
||||
if outEmpty := SummarizeOAuthExcludedModels(nil); outEmpty != nil {
|
||||
t.Fatalf("expected nil map for nil input, got %v", outEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeVertexModels(t *testing.T) {
|
||||
summary := SummarizeVertexModels([]config.VertexCompatModel{
|
||||
{Name: "m1"},
|
||||
{Name: " ", Alias: "alias"},
|
||||
{}, // ignored
|
||||
})
|
||||
if summary.count != 2 {
|
||||
t.Fatalf("expected 2 vertex models, got %d", summary.count)
|
||||
}
|
||||
if summary.hash == "" {
|
||||
t.Fatal("expected non-empty hash")
|
||||
}
|
||||
if empty := SummarizeVertexModels(nil); empty.count != 0 || empty.hash != "" {
|
||||
t.Fatalf("expected empty summary for nil input, got %+v", empty)
|
||||
}
|
||||
if blank := SummarizeVertexModels([]config.VertexCompatModel{{Name: " "}}); blank.count != 0 || blank.hash != "" {
|
||||
t.Fatalf("expected blank model ignored, got %+v", blank)
|
||||
}
|
||||
}
|
||||
|
||||
func expectContains(t *testing.T, list []string, target string) {
|
||||
t.Helper()
|
||||
for _, entry := range list {
|
||||
if entry == target {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected list to contain %q, got %#v", target, list)
|
||||
}
|
||||
183
internal/watcher/diff/openai_compat.go
Normal file
183
internal/watcher/diff/openai_compat.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// DiffOpenAICompatibility produces human-readable change descriptions.
|
||||
func DiffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string {
|
||||
changes := make([]string, 0)
|
||||
oldMap := make(map[string]config.OpenAICompatibility, len(oldList))
|
||||
oldLabels := make(map[string]string, len(oldList))
|
||||
for idx, entry := range oldList {
|
||||
key, label := openAICompatKey(entry, idx)
|
||||
oldMap[key] = entry
|
||||
oldLabels[key] = label
|
||||
}
|
||||
newMap := make(map[string]config.OpenAICompatibility, len(newList))
|
||||
newLabels := make(map[string]string, len(newList))
|
||||
for idx, entry := range newList {
|
||||
key, label := openAICompatKey(entry, idx)
|
||||
newMap[key] = entry
|
||||
newLabels[key] = label
|
||||
}
|
||||
keySet := make(map[string]struct{}, len(oldMap)+len(newMap))
|
||||
for key := range oldMap {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
for key := range newMap {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
orderedKeys := make([]string, 0, len(keySet))
|
||||
for key := range keySet {
|
||||
orderedKeys = append(orderedKeys, key)
|
||||
}
|
||||
sort.Strings(orderedKeys)
|
||||
for _, key := range orderedKeys {
|
||||
oldEntry, oldOk := oldMap[key]
|
||||
newEntry, newOk := newMap[key]
|
||||
label := oldLabels[key]
|
||||
if label == "" {
|
||||
label = newLabels[key]
|
||||
}
|
||||
switch {
|
||||
case !oldOk:
|
||||
changes = append(changes, fmt.Sprintf("provider added: %s (api-keys=%d, models=%d)", label, countAPIKeys(newEntry), countOpenAIModels(newEntry.Models)))
|
||||
case !newOk:
|
||||
changes = append(changes, fmt.Sprintf("provider removed: %s (api-keys=%d, models=%d)", label, countAPIKeys(oldEntry), countOpenAIModels(oldEntry.Models)))
|
||||
default:
|
||||
if detail := describeOpenAICompatibilityUpdate(oldEntry, newEntry); detail != "" {
|
||||
changes = append(changes, fmt.Sprintf("provider updated: %s %s", label, detail))
|
||||
}
|
||||
}
|
||||
}
|
||||
return changes
|
||||
}
|
||||
|
||||
func describeOpenAICompatibilityUpdate(oldEntry, newEntry config.OpenAICompatibility) string {
|
||||
oldKeyCount := countAPIKeys(oldEntry)
|
||||
newKeyCount := countAPIKeys(newEntry)
|
||||
oldModelCount := countOpenAIModels(oldEntry.Models)
|
||||
newModelCount := countOpenAIModels(newEntry.Models)
|
||||
details := make([]string, 0, 3)
|
||||
if oldKeyCount != newKeyCount {
|
||||
details = append(details, fmt.Sprintf("api-keys %d -> %d", oldKeyCount, newKeyCount))
|
||||
}
|
||||
if oldModelCount != newModelCount {
|
||||
details = append(details, fmt.Sprintf("models %d -> %d", oldModelCount, newModelCount))
|
||||
}
|
||||
if !equalStringMap(oldEntry.Headers, newEntry.Headers) {
|
||||
details = append(details, "headers updated")
|
||||
}
|
||||
if len(details) == 0 {
|
||||
return ""
|
||||
}
|
||||
return "(" + strings.Join(details, ", ") + ")"
|
||||
}
|
||||
|
||||
func countAPIKeys(entry config.OpenAICompatibility) int {
|
||||
count := 0
|
||||
for _, keyEntry := range entry.APIKeyEntries {
|
||||
if strings.TrimSpace(keyEntry.APIKey) != "" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func countOpenAIModels(models []config.OpenAICompatibilityModel) int {
|
||||
count := 0
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func openAICompatKey(entry config.OpenAICompatibility, index int) (string, string) {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
if name != "" {
|
||||
return "name:" + name, name
|
||||
}
|
||||
base := strings.TrimSpace(entry.BaseURL)
|
||||
if base != "" {
|
||||
return "base:" + base, base
|
||||
}
|
||||
for _, model := range entry.Models {
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if alias == "" {
|
||||
alias = strings.TrimSpace(model.Name)
|
||||
}
|
||||
if alias != "" {
|
||||
return "alias:" + alias, alias
|
||||
}
|
||||
}
|
||||
sig := openAICompatSignature(entry)
|
||||
if sig == "" {
|
||||
return fmt.Sprintf("index:%d", index), fmt.Sprintf("entry-%d", index+1)
|
||||
}
|
||||
short := sig
|
||||
if len(short) > 8 {
|
||||
short = short[:8]
|
||||
}
|
||||
return "sig:" + sig, "compat-" + short
|
||||
}
|
||||
|
||||
func openAICompatSignature(entry config.OpenAICompatibility) string {
|
||||
var parts []string
|
||||
|
||||
if v := strings.TrimSpace(entry.Name); v != "" {
|
||||
parts = append(parts, "name="+strings.ToLower(v))
|
||||
}
|
||||
if v := strings.TrimSpace(entry.BaseURL); v != "" {
|
||||
parts = append(parts, "base="+v)
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(entry.Models))
|
||||
for _, model := range entry.Models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
models = append(models, strings.ToLower(name)+"|"+strings.ToLower(alias))
|
||||
}
|
||||
if len(models) > 0 {
|
||||
sort.Strings(models)
|
||||
parts = append(parts, "models="+strings.Join(models, ","))
|
||||
}
|
||||
|
||||
if len(entry.Headers) > 0 {
|
||||
keys := make([]string, 0, len(entry.Headers))
|
||||
for k := range entry.Headers {
|
||||
if trimmed := strings.TrimSpace(k); trimmed != "" {
|
||||
keys = append(keys, strings.ToLower(trimmed))
|
||||
}
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
sort.Strings(keys)
|
||||
parts = append(parts, "headers="+strings.Join(keys, ","))
|
||||
}
|
||||
}
|
||||
|
||||
// Intentionally exclude API key material; only count non-empty entries.
|
||||
if count := countAPIKeys(entry); count > 0 {
|
||||
parts = append(parts, fmt.Sprintf("api_keys=%d", count))
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(strings.Join(parts, "|")))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
187
internal/watcher/diff/openai_compat_test.go
Normal file
187
internal/watcher/diff/openai_compat_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestDiffOpenAICompatibility(t *testing.T) {
|
||||
oldList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-a"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
newList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-a"},
|
||||
{APIKey: "key-b"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
{Name: "m2"},
|
||||
},
|
||||
Headers: map[string]string{"X-Test": "1"},
|
||||
},
|
||||
{
|
||||
Name: "provider-b",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-b"}},
|
||||
},
|
||||
}
|
||||
|
||||
changes := DiffOpenAICompatibility(oldList, newList)
|
||||
expectContains(t, changes, "provider added: provider-b (api-keys=1, models=0)")
|
||||
expectContains(t, changes, "provider updated: provider-a (api-keys 1 -> 2, models 1 -> 2, headers updated)")
|
||||
}
|
||||
|
||||
func TestDiffOpenAICompatibility_RemovedAndUnchanged(t *testing.T) {
|
||||
oldList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-a"}},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
}
|
||||
newList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-a"}},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
}
|
||||
if changes := DiffOpenAICompatibility(oldList, newList); len(changes) != 0 {
|
||||
t.Fatalf("expected no changes, got %v", changes)
|
||||
}
|
||||
|
||||
newList = nil
|
||||
changes := DiffOpenAICompatibility(oldList, newList)
|
||||
expectContains(t, changes, "provider removed: provider-a (api-keys=1, models=1)")
|
||||
}
|
||||
|
||||
func TestOpenAICompatKeyFallbacks(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{
|
||||
BaseURL: "http://base",
|
||||
Models: []config.OpenAICompatibilityModel{{Alias: "alias-only"}},
|
||||
}
|
||||
key, label := openAICompatKey(entry, 0)
|
||||
if key != "base:http://base" || label != "http://base" {
|
||||
t.Fatalf("expected base key, got %s/%s", key, label)
|
||||
}
|
||||
|
||||
entry.BaseURL = ""
|
||||
key, label = openAICompatKey(entry, 1)
|
||||
if key != "alias:alias-only" || label != "alias-only" {
|
||||
t.Fatalf("expected alias fallback, got %s/%s", key, label)
|
||||
}
|
||||
|
||||
entry.Models = nil
|
||||
key, label = openAICompatKey(entry, 2)
|
||||
if key != "index:2" || label != "entry-3" {
|
||||
t.Fatalf("expected index fallback, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatKey_UsesName(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{Name: "My-Provider"}
|
||||
key, label := openAICompatKey(entry, 0)
|
||||
if key != "name:My-Provider" || label != "My-Provider" {
|
||||
t.Fatalf("expected name key, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatKey_SignatureFallbackWhenOnlyAPIKeys(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "k1"}, {APIKey: "k2"}},
|
||||
}
|
||||
key, label := openAICompatKey(entry, 0)
|
||||
if !strings.HasPrefix(key, "sig:") || !strings.HasPrefix(label, "compat-") {
|
||||
t.Fatalf("expected signature key, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatSignature_EmptyReturnsEmpty(t *testing.T) {
|
||||
if got := openAICompatSignature(config.OpenAICompatibility{}); got != "" {
|
||||
t.Fatalf("expected empty signature, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatSignature_StableAndNormalized(t *testing.T) {
|
||||
a := config.OpenAICompatibility{
|
||||
Name: " Provider ",
|
||||
BaseURL: "http://base",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
{Name: " "},
|
||||
{Alias: "A1"},
|
||||
},
|
||||
Headers: map[string]string{
|
||||
"X-Test": "1",
|
||||
" ": "ignored",
|
||||
},
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
{APIKey: " "},
|
||||
},
|
||||
}
|
||||
b := config.OpenAICompatibility{
|
||||
Name: "provider",
|
||||
BaseURL: "http://base",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Alias: "a1"},
|
||||
{Name: "m1"},
|
||||
},
|
||||
Headers: map[string]string{
|
||||
"x-test": "2",
|
||||
},
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k2"},
|
||||
},
|
||||
}
|
||||
|
||||
sigA := openAICompatSignature(a)
|
||||
sigB := openAICompatSignature(b)
|
||||
if sigA == "" || sigB == "" {
|
||||
t.Fatalf("expected non-empty signatures, got %q / %q", sigA, sigB)
|
||||
}
|
||||
if sigA != sigB {
|
||||
t.Fatalf("expected normalized signatures to match, got %s / %s", sigA, sigB)
|
||||
}
|
||||
|
||||
c := b
|
||||
c.Models = append(c.Models, config.OpenAICompatibilityModel{Name: "m2"})
|
||||
if sigC := openAICompatSignature(c); sigC == sigB {
|
||||
t.Fatalf("expected signature to change when models change, got %s", sigC)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountOpenAIModelsSkipsBlanks(t *testing.T) {
|
||||
models := []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
{Name: ""},
|
||||
{Alias: ""},
|
||||
{Name: " "},
|
||||
{Alias: "a1"},
|
||||
}
|
||||
if got := countOpenAIModels(models); got != 2 {
|
||||
t.Fatalf("expected 2 counted models, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatKeyUsesModelNameWhenAliasEmpty(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "model-name"}},
|
||||
}
|
||||
key, label := openAICompatKey(entry, 5)
|
||||
if key != "alias:model-name" || label != "model-name" {
|
||||
t.Fatalf("expected model-name fallback, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
294
internal/watcher/synthesizer/config.go
Normal file
294
internal/watcher/synthesizer/config.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// ConfigSynthesizer generates Auth entries from configuration API keys.
|
||||
// It handles Gemini, Claude, Codex, OpenAI-compat, and Vertex-compat providers.
|
||||
type ConfigSynthesizer struct{}
|
||||
|
||||
// NewConfigSynthesizer creates a new ConfigSynthesizer instance.
|
||||
func NewConfigSynthesizer() *ConfigSynthesizer {
|
||||
return &ConfigSynthesizer{}
|
||||
}
|
||||
|
||||
// Synthesize generates Auth entries from config API keys.
|
||||
func (s *ConfigSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, error) {
|
||||
out := make([]*coreauth.Auth, 0, 32)
|
||||
if ctx == nil || ctx.Config == nil {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Gemini API Keys
|
||||
out = append(out, s.synthesizeGeminiKeys(ctx)...)
|
||||
// Claude API Keys
|
||||
out = append(out, s.synthesizeClaudeKeys(ctx)...)
|
||||
// Codex API Keys
|
||||
out = append(out, s.synthesizeCodexKeys(ctx)...)
|
||||
// OpenAI-compat
|
||||
out = append(out, s.synthesizeOpenAICompat(ctx)...)
|
||||
// Vertex-compat
|
||||
out = append(out, s.synthesizeVertexCompat(ctx)...)
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// synthesizeGeminiKeys creates Auth entries for Gemini API keys.
|
||||
func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(cfg.GeminiKey))
|
||||
for i := range cfg.GeminiKey {
|
||||
entry := cfg.GeminiKey[i]
|
||||
key := strings.TrimSpace(entry.APIKey)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSpace(entry.Prefix)
|
||||
base := strings.TrimSpace(entry.BaseURL)
|
||||
proxyURL := strings.TrimSpace(entry.ProxyURL)
|
||||
id, token := idGen.Next("gemini:apikey", key, base)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:gemini[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
addConfigHeadersToAttrs(entry.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: "gemini",
|
||||
Label: "gemini-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, entry.ExcludedModels, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// synthesizeClaudeKeys creates Auth entries for Claude API keys.
|
||||
func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(cfg.ClaudeKey))
|
||||
for i := range cfg.ClaudeKey {
|
||||
ck := cfg.ClaudeKey[i]
|
||||
key := strings.TrimSpace(ck.APIKey)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSpace(ck.Prefix)
|
||||
base := strings.TrimSpace(ck.BaseURL)
|
||||
id, token := idGen.Next("claude:apikey", key, base)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:claude[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
if hash := diff.ComputeClaudeModelsHash(ck.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(ck.Headers, attrs)
|
||||
proxyURL := strings.TrimSpace(ck.ProxyURL)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: "claude",
|
||||
Label: "claude-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// synthesizeCodexKeys creates Auth entries for Codex API keys.
|
||||
func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(cfg.CodexKey))
|
||||
for i := range cfg.CodexKey {
|
||||
ck := cfg.CodexKey[i]
|
||||
key := strings.TrimSpace(ck.APIKey)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSpace(ck.Prefix)
|
||||
id, token := idGen.Next("codex:apikey", key, ck.BaseURL)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:codex[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if ck.BaseURL != "" {
|
||||
attrs["base_url"] = ck.BaseURL
|
||||
}
|
||||
addConfigHeadersToAttrs(ck.Headers, attrs)
|
||||
proxyURL := strings.TrimSpace(ck.ProxyURL)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: "codex",
|
||||
Label: "codex-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// synthesizeOpenAICompat creates Auth entries for OpenAI-compatible providers.
|
||||
func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0)
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
compat := &cfg.OpenAICompatibility[i]
|
||||
prefix := strings.TrimSpace(compat.Prefix)
|
||||
providerName := strings.ToLower(strings.TrimSpace(compat.Name))
|
||||
if providerName == "" {
|
||||
providerName = "openai-compatibility"
|
||||
}
|
||||
base := strings.TrimSpace(compat.BaseURL)
|
||||
|
||||
// Handle new APIKeyEntries format (preferred)
|
||||
createdEntries := 0
|
||||
for j := range compat.APIKeyEntries {
|
||||
entry := &compat.APIKeyEntries[j]
|
||||
key := strings.TrimSpace(entry.APIKey)
|
||||
proxyURL := strings.TrimSpace(entry.ProxyURL)
|
||||
idKind := fmt.Sprintf("openai-compatibility:%s", providerName)
|
||||
id, token := idGen.Next(idKind, key, base, proxyURL)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:%s[%s]", providerName, token),
|
||||
"base_url": base,
|
||||
"compat_name": compat.Name,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if key != "" {
|
||||
attrs["api_key"] = key
|
||||
}
|
||||
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: providerName,
|
||||
Label: compat.Name,
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
out = append(out, a)
|
||||
createdEntries++
|
||||
}
|
||||
// Fallback: create entry without API key if no APIKeyEntries
|
||||
if createdEntries == 0 {
|
||||
idKind := fmt.Sprintf("openai-compatibility:%s", providerName)
|
||||
id, token := idGen.Next(idKind, base)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:%s[%s]", providerName, token),
|
||||
"base_url": base,
|
||||
"compat_name": compat.Name,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: providerName,
|
||||
Label: compat.Name,
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// synthesizeVertexCompat creates Auth entries for Vertex-compatible providers.
|
||||
func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(cfg.VertexCompatAPIKey))
|
||||
for i := range cfg.VertexCompatAPIKey {
|
||||
compat := &cfg.VertexCompatAPIKey[i]
|
||||
providerName := "vertex"
|
||||
base := strings.TrimSpace(compat.BaseURL)
|
||||
|
||||
key := strings.TrimSpace(compat.APIKey)
|
||||
prefix := strings.TrimSpace(compat.Prefix)
|
||||
proxyURL := strings.TrimSpace(compat.ProxyURL)
|
||||
idKind := "vertex:apikey"
|
||||
id, token := idGen.Next(idKind, key, base, proxyURL)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:vertex-apikey[%s]", token),
|
||||
"base_url": base,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if key != "" {
|
||||
attrs["api_key"] = key
|
||||
}
|
||||
if hash := diff.ComputeVertexCompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: providerName,
|
||||
Label: "vertex-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, nil, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
613
internal/watcher/synthesizer/config_test.go
Normal file
613
internal/watcher/synthesizer/config_test.go
Normal file
@@ -0,0 +1,613 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestNewConfigSynthesizer(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
if synth == nil {
|
||||
t.Fatal("expected non-nil synthesizer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_Synthesize_NilContext(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
auths, err := synth.Synthesize(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_Synthesize_NilConfig(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: nil,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_GeminiKeys(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
geminiKeys []config.GeminiKey
|
||||
wantLen int
|
||||
validate func(*testing.T, []*coreauth.Auth)
|
||||
}{
|
||||
{
|
||||
name: "single gemini key",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{APIKey: "test-key-123", Prefix: "team-a"},
|
||||
},
|
||||
wantLen: 1,
|
||||
validate: func(t *testing.T, auths []*coreauth.Auth) {
|
||||
if auths[0].Provider != "gemini" {
|
||||
t.Errorf("expected provider gemini, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Prefix != "team-a" {
|
||||
t.Errorf("expected prefix team-a, got %s", auths[0].Prefix)
|
||||
}
|
||||
if auths[0].Label != "gemini-apikey" {
|
||||
t.Errorf("expected label gemini-apikey, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].Attributes["api_key"] != "test-key-123" {
|
||||
t.Errorf("expected api_key test-key-123, got %s", auths[0].Attributes["api_key"])
|
||||
}
|
||||
if auths[0].Status != coreauth.StatusActive {
|
||||
t.Errorf("expected status active, got %s", auths[0].Status)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "gemini key with base url and proxy",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{
|
||||
APIKey: "api-key",
|
||||
BaseURL: "https://custom.api.com",
|
||||
ProxyURL: "http://proxy.local:8080",
|
||||
Prefix: "custom",
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
validate: func(t *testing.T, auths []*coreauth.Auth) {
|
||||
if auths[0].Attributes["base_url"] != "https://custom.api.com" {
|
||||
t.Errorf("expected base_url https://custom.api.com, got %s", auths[0].Attributes["base_url"])
|
||||
}
|
||||
if auths[0].ProxyURL != "http://proxy.local:8080" {
|
||||
t.Errorf("expected proxy_url http://proxy.local:8080, got %s", auths[0].ProxyURL)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "gemini key with headers",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{
|
||||
APIKey: "api-key",
|
||||
Headers: map[string]string{"X-Custom": "value"},
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
validate: func(t *testing.T, auths []*coreauth.Auth) {
|
||||
if auths[0].Attributes["header:X-Custom"] != "value" {
|
||||
t.Errorf("expected header:X-Custom=value, got %s", auths[0].Attributes["header:X-Custom"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty api key skipped",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{APIKey: ""},
|
||||
{APIKey: " "},
|
||||
{APIKey: "valid-key"},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple gemini keys",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{APIKey: "key-1", Prefix: "a"},
|
||||
{APIKey: "key-2", Prefix: "b"},
|
||||
{APIKey: "key-3", Prefix: "c"},
|
||||
},
|
||||
wantLen: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
GeminiKey: tt.geminiKeys,
|
||||
},
|
||||
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != tt.wantLen {
|
||||
t.Fatalf("expected %d auths, got %d", tt.wantLen, len(auths))
|
||||
}
|
||||
|
||||
if tt.validate != nil && len(auths) > 0 {
|
||||
tt.validate(t, auths)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_ClaudeKeys(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{
|
||||
APIKey: "sk-ant-api-xxx",
|
||||
Prefix: "main",
|
||||
BaseURL: "https://api.anthropic.com",
|
||||
Models: []config.ClaudeModel{
|
||||
{Name: "claude-3-opus"},
|
||||
{Name: "claude-3-sonnet"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "claude" {
|
||||
t.Errorf("expected provider claude, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Label != "claude-apikey" {
|
||||
t.Errorf("expected label claude-apikey, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].Prefix != "main" {
|
||||
t.Errorf("expected prefix main, got %s", auths[0].Prefix)
|
||||
}
|
||||
if auths[0].Attributes["api_key"] != "sk-ant-api-xxx" {
|
||||
t.Errorf("expected api_key sk-ant-api-xxx, got %s", auths[0].Attributes["api_key"])
|
||||
}
|
||||
if _, ok := auths[0].Attributes["models_hash"]; !ok {
|
||||
t.Error("expected models_hash in attributes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_ClaudeKeys_SkipsEmptyAndHeaders(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: ""}, // empty, should be skipped
|
||||
{APIKey: " "}, // whitespace, should be skipped
|
||||
{APIKey: "valid-key", Headers: map[string]string{"X-Custom": "value"}},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth (empty keys skipped), got %d", len(auths))
|
||||
}
|
||||
if auths[0].Attributes["header:X-Custom"] != "value" {
|
||||
t.Errorf("expected header:X-Custom=value, got %s", auths[0].Attributes["header:X-Custom"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_CodexKeys(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
CodexKey: []config.CodexKey{
|
||||
{
|
||||
APIKey: "codex-key-123",
|
||||
Prefix: "dev",
|
||||
BaseURL: "https://api.openai.com",
|
||||
ProxyURL: "http://proxy.local",
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "codex" {
|
||||
t.Errorf("expected provider codex, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Label != "codex-apikey" {
|
||||
t.Errorf("expected label codex-apikey, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].ProxyURL != "http://proxy.local" {
|
||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: ""}, // empty, should be skipped
|
||||
{APIKey: " "}, // whitespace, should be skipped
|
||||
{APIKey: "valid-key", Headers: map[string]string{"Authorization": "Bearer xyz"}},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth (empty keys skipped), got %d", len(auths))
|
||||
}
|
||||
if auths[0].Attributes["header:Authorization"] != "Bearer xyz" {
|
||||
t.Errorf("expected header:Authorization=Bearer xyz, got %s", auths[0].Attributes["header:Authorization"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_OpenAICompat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
compat []config.OpenAICompatibility
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
name: "with APIKeyEntries",
|
||||
compat: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "CustomProvider",
|
||||
BaseURL: "https://custom.api.com",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-1"},
|
||||
{APIKey: "key-2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantLen: 2,
|
||||
},
|
||||
{
|
||||
name: "empty APIKeyEntries included (legacy)",
|
||||
compat: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "EmptyKeys",
|
||||
BaseURL: "https://empty.api.com",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: ""},
|
||||
{APIKey: " "},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantLen: 2,
|
||||
},
|
||||
{
|
||||
name: "without APIKeyEntries (fallback)",
|
||||
compat: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "NoKeyProvider",
|
||||
BaseURL: "https://no-key.api.com",
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
name: "empty name defaults",
|
||||
compat: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "",
|
||||
BaseURL: "https://default.api.com",
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
OpenAICompatibility: tt.compat,
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != tt.wantLen {
|
||||
t.Fatalf("expected %d auths, got %d", tt.wantLen, len(auths))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_VertexCompat(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{
|
||||
APIKey: "vertex-key-123",
|
||||
BaseURL: "https://vertex.googleapis.com",
|
||||
Prefix: "vertex-prod",
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "vertex" {
|
||||
t.Errorf("expected provider vertex, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Label != "vertex-apikey" {
|
||||
t.Errorf("expected label vertex-apikey, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].Prefix != "vertex-prod" {
|
||||
t.Errorf("expected prefix vertex-prod, got %s", auths[0].Prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_VertexCompat_SkipsEmptyAndHeaders(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "", BaseURL: "https://vertex.api"}, // empty key creates auth without api_key attr
|
||||
{APIKey: " ", BaseURL: "https://vertex.api"}, // whitespace key creates auth without api_key attr
|
||||
{APIKey: "valid-key", BaseURL: "https://vertex.api", Headers: map[string]string{"X-Vertex": "test"}},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Vertex compat doesn't skip empty keys - it creates auths without api_key attribute
|
||||
if len(auths) != 3 {
|
||||
t.Fatalf("expected 3 auths, got %d", len(auths))
|
||||
}
|
||||
// First two should not have api_key attribute
|
||||
if _, ok := auths[0].Attributes["api_key"]; ok {
|
||||
t.Error("expected first auth to not have api_key attribute")
|
||||
}
|
||||
if _, ok := auths[1].Attributes["api_key"]; ok {
|
||||
t.Error("expected second auth to not have api_key attribute")
|
||||
}
|
||||
// Third should have headers
|
||||
if auths[2].Attributes["header:X-Vertex"] != "test" {
|
||||
t.Errorf("expected header:X-Vertex=test, got %s", auths[2].Attributes["header:X-Vertex"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_OpenAICompat_WithModelsHash(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "TestProvider",
|
||||
BaseURL: "https://test.api.com",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "model-a"},
|
||||
{Name: "model-b"},
|
||||
},
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-with-models"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
if _, ok := auths[0].Attributes["models_hash"]; !ok {
|
||||
t.Error("expected models_hash in attributes")
|
||||
}
|
||||
if auths[0].Attributes["api_key"] != "key-with-models" {
|
||||
t.Errorf("expected api_key key-with-models, got %s", auths[0].Attributes["api_key"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_OpenAICompat_FallbackWithModels(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "NoKeyWithModels",
|
||||
BaseURL: "https://nokey.api.com",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "model-x"},
|
||||
},
|
||||
Headers: map[string]string{"X-API": "header-value"},
|
||||
// No APIKeyEntries - should use fallback path
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
if _, ok := auths[0].Attributes["models_hash"]; !ok {
|
||||
t.Error("expected models_hash in fallback path")
|
||||
}
|
||||
if auths[0].Attributes["header:X-API"] != "header-value" {
|
||||
t.Errorf("expected header:X-API=header-value, got %s", auths[0].Attributes["header:X-API"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_VertexCompat_WithModels(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{
|
||||
APIKey: "vertex-key",
|
||||
BaseURL: "https://vertex.api",
|
||||
Models: []config.VertexCompatModel{
|
||||
{Name: "gemini-pro", Alias: "pro"},
|
||||
{Name: "gemini-ultra", Alias: "ultra"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
if _, ok := auths[0].Attributes["models_hash"]; !ok {
|
||||
t.Error("expected models_hash in vertex auth with models")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_IDStability(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "stable-key", Prefix: "test"},
|
||||
},
|
||||
}
|
||||
|
||||
// Generate IDs twice with fresh generators
|
||||
synth1 := NewConfigSynthesizer()
|
||||
ctx1 := &SynthesisContext{
|
||||
Config: cfg,
|
||||
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths1, _ := synth1.Synthesize(ctx1)
|
||||
|
||||
synth2 := NewConfigSynthesizer()
|
||||
ctx2 := &SynthesisContext{
|
||||
Config: cfg,
|
||||
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths2, _ := synth2.Synthesize(ctx2)
|
||||
|
||||
if auths1[0].ID != auths2[0].ID {
|
||||
t.Errorf("same config should produce same ID: got %q and %q", auths1[0].ID, auths2[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_AllProviders(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "gemini-key"},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "claude-key"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "codex-key"},
|
||||
},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{Name: "compat", BaseURL: "https://compat.api"},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "vertex-key", BaseURL: "https://vertex.api"},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 5 {
|
||||
t.Fatalf("expected 5 auths, got %d", len(auths))
|
||||
}
|
||||
|
||||
providers := make(map[string]bool)
|
||||
for _, a := range auths {
|
||||
providers[a.Provider] = true
|
||||
}
|
||||
|
||||
expected := []string{"gemini", "claude", "codex", "compat", "vertex"}
|
||||
for _, p := range expected {
|
||||
if !providers[p] {
|
||||
t.Errorf("expected provider %s not found", p)
|
||||
}
|
||||
}
|
||||
}
|
||||
19
internal/watcher/synthesizer/context.go
Normal file
19
internal/watcher/synthesizer/context.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// SynthesisContext provides the context needed for auth synthesis.
|
||||
type SynthesisContext struct {
|
||||
// Config is the current configuration
|
||||
Config *config.Config
|
||||
// AuthDir is the directory containing auth files
|
||||
AuthDir string
|
||||
// Now is the current time for timestamps
|
||||
Now time.Time
|
||||
// IDGenerator generates stable IDs for auth entries
|
||||
IDGenerator *StableIDGenerator
|
||||
}
|
||||
224
internal/watcher/synthesizer/file.go
Normal file
224
internal/watcher/synthesizer/file.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// FileSynthesizer generates Auth entries from OAuth JSON files.
|
||||
// It handles file-based authentication and Gemini virtual auth generation.
|
||||
type FileSynthesizer struct{}
|
||||
|
||||
// NewFileSynthesizer creates a new FileSynthesizer instance.
|
||||
func NewFileSynthesizer() *FileSynthesizer {
|
||||
return &FileSynthesizer{}
|
||||
}
|
||||
|
||||
// Synthesize generates Auth entries from auth files in the auth directory.
|
||||
func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, error) {
|
||||
out := make([]*coreauth.Auth, 0, 16)
|
||||
if ctx == nil || ctx.AuthDir == "" {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(ctx.AuthDir)
|
||||
if err != nil {
|
||||
// Not an error if directory doesn't exist
|
||||
return out, nil
|
||||
}
|
||||
|
||||
now := ctx.Now
|
||||
cfg := ctx.Config
|
||||
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
||||
continue
|
||||
}
|
||||
full := filepath.Join(ctx.AuthDir, name)
|
||||
data, errRead := os.ReadFile(full)
|
||||
if errRead != nil || len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
var metadata map[string]any
|
||||
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
|
||||
continue
|
||||
}
|
||||
t, _ := metadata["type"].(string)
|
||||
if t == "" {
|
||||
continue
|
||||
}
|
||||
provider := strings.ToLower(t)
|
||||
if provider == "gemini" {
|
||||
provider = "gemini-cli"
|
||||
}
|
||||
label := provider
|
||||
if email, _ := metadata["email"].(string); email != "" {
|
||||
label = email
|
||||
}
|
||||
// Use relative path under authDir as ID to stay consistent with the file-based token store
|
||||
id := full
|
||||
if rel, errRel := filepath.Rel(ctx.AuthDir, full); errRel == nil && rel != "" {
|
||||
id = rel
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if p, ok := metadata["proxy_url"].(string); ok {
|
||||
proxyURL = p
|
||||
}
|
||||
|
||||
prefix := ""
|
||||
if rawPrefix, ok := metadata["prefix"].(string); ok {
|
||||
trimmed := strings.TrimSpace(rawPrefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed != "" && !strings.Contains(trimmed, "/") {
|
||||
prefix = trimmed
|
||||
}
|
||||
}
|
||||
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: provider,
|
||||
Label: label,
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"source": full,
|
||||
"path": full,
|
||||
},
|
||||
ProxyURL: proxyURL,
|
||||
Metadata: metadata,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, nil, "oauth")
|
||||
if provider == "gemini-cli" {
|
||||
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||
for _, v := range virtuals {
|
||||
ApplyAuthExcludedModelsMeta(v, cfg, nil, "oauth")
|
||||
}
|
||||
out = append(out, a)
|
||||
out = append(out, virtuals...)
|
||||
continue
|
||||
}
|
||||
}
|
||||
out = append(out, a)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// SynthesizeGeminiVirtualAuths creates virtual Auth entries for multi-project Gemini credentials.
|
||||
// It disables the primary auth and creates one virtual auth per project.
|
||||
func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]any, now time.Time) []*coreauth.Auth {
|
||||
if primary == nil || metadata == nil {
|
||||
return nil
|
||||
}
|
||||
projects := splitGeminiProjectIDs(metadata)
|
||||
if len(projects) <= 1 {
|
||||
return nil
|
||||
}
|
||||
email, _ := metadata["email"].(string)
|
||||
shared := geminicli.NewSharedCredential(primary.ID, email, metadata, projects)
|
||||
primary.Disabled = true
|
||||
primary.Status = coreauth.StatusDisabled
|
||||
primary.Runtime = shared
|
||||
if primary.Attributes == nil {
|
||||
primary.Attributes = make(map[string]string)
|
||||
}
|
||||
primary.Attributes["gemini_virtual_primary"] = "true"
|
||||
primary.Attributes["virtual_children"] = strings.Join(projects, ",")
|
||||
source := primary.Attributes["source"]
|
||||
authPath := primary.Attributes["path"]
|
||||
originalProvider := primary.Provider
|
||||
if originalProvider == "" {
|
||||
originalProvider = "gemini-cli"
|
||||
}
|
||||
label := primary.Label
|
||||
if label == "" {
|
||||
label = originalProvider
|
||||
}
|
||||
virtuals := make([]*coreauth.Auth, 0, len(projects))
|
||||
for _, projectID := range projects {
|
||||
attrs := map[string]string{
|
||||
"runtime_only": "true",
|
||||
"gemini_virtual_parent": primary.ID,
|
||||
"gemini_virtual_project": projectID,
|
||||
}
|
||||
if source != "" {
|
||||
attrs["source"] = source
|
||||
}
|
||||
if authPath != "" {
|
||||
attrs["path"] = authPath
|
||||
}
|
||||
metadataCopy := map[string]any{
|
||||
"email": email,
|
||||
"project_id": projectID,
|
||||
"virtual": true,
|
||||
"virtual_parent_id": primary.ID,
|
||||
"type": metadata["type"],
|
||||
}
|
||||
proxy := strings.TrimSpace(primary.ProxyURL)
|
||||
if proxy != "" {
|
||||
metadataCopy["proxy_url"] = proxy
|
||||
}
|
||||
virtual := &coreauth.Auth{
|
||||
ID: buildGeminiVirtualID(primary.ID, projectID),
|
||||
Provider: originalProvider,
|
||||
Label: fmt.Sprintf("%s [%s]", label, projectID),
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: attrs,
|
||||
Metadata: metadataCopy,
|
||||
ProxyURL: primary.ProxyURL,
|
||||
Prefix: primary.Prefix,
|
||||
CreatedAt: primary.CreatedAt,
|
||||
UpdatedAt: primary.UpdatedAt,
|
||||
Runtime: geminicli.NewVirtualCredential(projectID, shared),
|
||||
}
|
||||
virtuals = append(virtuals, virtual)
|
||||
}
|
||||
return virtuals
|
||||
}
|
||||
|
||||
// splitGeminiProjectIDs extracts and deduplicates project IDs from metadata.
|
||||
func splitGeminiProjectIDs(metadata map[string]any) []string {
|
||||
raw, _ := metadata["project_id"].(string)
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(trimmed, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
seen := make(map[string]struct{}, len(parts))
|
||||
for _, part := range parts {
|
||||
id := strings.TrimSpace(part)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
result = append(result, id)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// buildGeminiVirtualID constructs a virtual auth ID from base ID and project ID.
|
||||
func buildGeminiVirtualID(baseID, projectID string) string {
|
||||
project := strings.TrimSpace(projectID)
|
||||
if project == "" {
|
||||
project = "project"
|
||||
}
|
||||
replacer := strings.NewReplacer("/", "_", "\\", "_", " ", "_")
|
||||
return fmt.Sprintf("%s::%s", baseID, replacer.Replace(project))
|
||||
}
|
||||
612
internal/watcher/synthesizer/file_test.go
Normal file
612
internal/watcher/synthesizer/file_test.go
Normal file
@@ -0,0 +1,612 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestNewFileSynthesizer(t *testing.T) {
|
||||
synth := NewFileSynthesizer()
|
||||
if synth == nil {
|
||||
t.Fatal("expected non-nil synthesizer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_NilContext(t *testing.T) {
|
||||
synth := NewFileSynthesizer()
|
||||
auths, err := synth.Synthesize(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_EmptyAuthDir(t *testing.T) {
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: "",
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_NonExistentDir(t *testing.T) {
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: "/non/existent/path",
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create a valid auth file
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"email": "test@example.com",
|
||||
"proxy_url": "http://proxy.local",
|
||||
"prefix": "test-prefix",
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "claude-auth.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "claude" {
|
||||
t.Errorf("expected provider claude, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Label != "test@example.com" {
|
||||
t.Errorf("expected label test@example.com, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].Prefix != "test-prefix" {
|
||||
t.Errorf("expected prefix test-prefix, got %s", auths[0].Prefix)
|
||||
}
|
||||
if auths[0].ProxyURL != "http://proxy.local" {
|
||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||
}
|
||||
if auths[0].Status != coreauth.StatusActive {
|
||||
t.Errorf("expected status active, got %s", auths[0].Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_GeminiProviderMapping(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Gemini type should be mapped to gemini-cli
|
||||
authData := map[string]any{
|
||||
"type": "gemini",
|
||||
"email": "gemini@example.com",
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "gemini-auth.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "gemini-cli" {
|
||||
t.Errorf("gemini should be mapped to gemini-cli, got %s", auths[0].Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_SkipsInvalidFiles(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create various invalid files
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "not-json.txt"), []byte("text content"), 0644)
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "invalid.json"), []byte("not valid json"), 0644)
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "empty.json"), []byte(""), 0644)
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "no-type.json"), []byte(`{"email": "test@example.com"}`), 0644)
|
||||
|
||||
// Create one valid file
|
||||
validData, _ := json.Marshal(map[string]any{"type": "claude", "email": "valid@example.com"})
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "valid.json"), validData, 0644)
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("only valid auth file should be processed, got %d", len(auths))
|
||||
}
|
||||
if auths[0].Label != "valid@example.com" {
|
||||
t.Errorf("expected label valid@example.com, got %s", auths[0].Label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_SkipsDirectories(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create a subdirectory with a json file inside
|
||||
subDir := filepath.Join(tempDir, "subdir.json")
|
||||
err := os.Mkdir(subDir, 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create subdir: %v", err)
|
||||
}
|
||||
|
||||
// Create a valid file in root
|
||||
validData, _ := json.Marshal(map[string]any{"type": "claude"})
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "valid.json"), validData, 0644)
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_RelativeID(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authData := map[string]any{"type": "claude"}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "my-auth.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
// ID should be relative path
|
||||
if auths[0].ID != "my-auth.json" {
|
||||
t.Errorf("expected ID my-auth.json, got %s", auths[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_PrefixValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
wantPrefix string
|
||||
}{
|
||||
{"valid prefix", "myprefix", "myprefix"},
|
||||
{"prefix with slashes trimmed", "/myprefix/", "myprefix"},
|
||||
{"prefix with spaces trimmed", " myprefix ", "myprefix"},
|
||||
{"prefix with internal slash rejected", "my/prefix", ""},
|
||||
{"empty prefix", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"prefix": tt.prefix,
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644)
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
if auths[0].Prefix != tt.wantPrefix {
|
||||
t.Errorf("expected prefix %q, got %q", tt.wantPrefix, auths[0].Prefix)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_NilInputs(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
if SynthesizeGeminiVirtualAuths(nil, nil, now) != nil {
|
||||
t.Error("expected nil for nil primary")
|
||||
}
|
||||
if SynthesizeGeminiVirtualAuths(&coreauth.Auth{}, nil, now) != nil {
|
||||
t.Error("expected nil for nil metadata")
|
||||
}
|
||||
if SynthesizeGeminiVirtualAuths(nil, map[string]any{}, now) != nil {
|
||||
t.Error("expected nil for nil primary with metadata")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_SingleProject(t *testing.T) {
|
||||
now := time.Now()
|
||||
primary := &coreauth.Auth{
|
||||
ID: "test-id",
|
||||
Provider: "gemini-cli",
|
||||
Label: "test@example.com",
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "single-project",
|
||||
"email": "test@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
if virtuals != nil {
|
||||
t.Error("single project should not create virtuals")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
|
||||
now := time.Now()
|
||||
primary := &coreauth.Auth{
|
||||
ID: "primary-id",
|
||||
Provider: "gemini-cli",
|
||||
Label: "test@example.com",
|
||||
Prefix: "test-prefix",
|
||||
ProxyURL: "http://proxy.local",
|
||||
Attributes: map[string]string{
|
||||
"source": "test-source",
|
||||
"path": "/path/to/auth",
|
||||
},
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "project-a, project-b, project-c",
|
||||
"email": "test@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
|
||||
if len(virtuals) != 3 {
|
||||
t.Fatalf("expected 3 virtuals, got %d", len(virtuals))
|
||||
}
|
||||
|
||||
// Check primary is disabled
|
||||
if !primary.Disabled {
|
||||
t.Error("expected primary to be disabled")
|
||||
}
|
||||
if primary.Status != coreauth.StatusDisabled {
|
||||
t.Errorf("expected primary status disabled, got %s", primary.Status)
|
||||
}
|
||||
if primary.Attributes["gemini_virtual_primary"] != "true" {
|
||||
t.Error("expected gemini_virtual_primary=true")
|
||||
}
|
||||
if !strings.Contains(primary.Attributes["virtual_children"], "project-a") {
|
||||
t.Error("expected virtual_children to contain project-a")
|
||||
}
|
||||
|
||||
// Check virtuals
|
||||
projectIDs := []string{"project-a", "project-b", "project-c"}
|
||||
for i, v := range virtuals {
|
||||
if v.Provider != "gemini-cli" {
|
||||
t.Errorf("expected provider gemini-cli, got %s", v.Provider)
|
||||
}
|
||||
if v.Status != coreauth.StatusActive {
|
||||
t.Errorf("expected status active, got %s", v.Status)
|
||||
}
|
||||
if v.Prefix != "test-prefix" {
|
||||
t.Errorf("expected prefix test-prefix, got %s", v.Prefix)
|
||||
}
|
||||
if v.ProxyURL != "http://proxy.local" {
|
||||
t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL)
|
||||
}
|
||||
if v.Attributes["runtime_only"] != "true" {
|
||||
t.Error("expected runtime_only=true")
|
||||
}
|
||||
if v.Attributes["gemini_virtual_parent"] != "primary-id" {
|
||||
t.Errorf("expected gemini_virtual_parent=primary-id, got %s", v.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
if v.Attributes["gemini_virtual_project"] != projectIDs[i] {
|
||||
t.Errorf("expected gemini_virtual_project=%s, got %s", projectIDs[i], v.Attributes["gemini_virtual_project"])
|
||||
}
|
||||
if !strings.Contains(v.Label, "["+projectIDs[i]+"]") {
|
||||
t.Errorf("expected label to contain [%s], got %s", projectIDs[i], v.Label)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_EmptyProviderAndLabel(t *testing.T) {
|
||||
now := time.Now()
|
||||
// Test with empty Provider and Label to cover fallback branches
|
||||
primary := &coreauth.Auth{
|
||||
ID: "primary-id",
|
||||
Provider: "", // empty provider - should default to gemini-cli
|
||||
Label: "", // empty label - should default to provider
|
||||
Attributes: map[string]string{},
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "proj-a, proj-b",
|
||||
"email": "user@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
|
||||
if len(virtuals) != 2 {
|
||||
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
|
||||
}
|
||||
|
||||
// Check that empty provider defaults to gemini-cli
|
||||
if virtuals[0].Provider != "gemini-cli" {
|
||||
t.Errorf("expected provider gemini-cli (default), got %s", virtuals[0].Provider)
|
||||
}
|
||||
// Check that empty label defaults to provider
|
||||
if !strings.Contains(virtuals[0].Label, "gemini-cli") {
|
||||
t.Errorf("expected label to contain gemini-cli, got %s", virtuals[0].Label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_NilPrimaryAttributes(t *testing.T) {
|
||||
now := time.Now()
|
||||
primary := &coreauth.Auth{
|
||||
ID: "primary-id",
|
||||
Provider: "gemini-cli",
|
||||
Label: "test@example.com",
|
||||
Attributes: nil, // nil attributes
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "proj-a, proj-b",
|
||||
"email": "test@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
|
||||
if len(virtuals) != 2 {
|
||||
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
|
||||
}
|
||||
// Nil attributes should be initialized
|
||||
if primary.Attributes == nil {
|
||||
t.Error("expected primary.Attributes to be initialized")
|
||||
}
|
||||
if primary.Attributes["gemini_virtual_primary"] != "true" {
|
||||
t.Error("expected gemini_virtual_primary=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitGeminiProjectIDs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
metadata map[string]any
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "single project",
|
||||
metadata: map[string]any{"project_id": "proj-a"},
|
||||
want: []string{"proj-a"},
|
||||
},
|
||||
{
|
||||
name: "multiple projects",
|
||||
metadata: map[string]any{"project_id": "proj-a, proj-b, proj-c"},
|
||||
want: []string{"proj-a", "proj-b", "proj-c"},
|
||||
},
|
||||
{
|
||||
name: "with duplicates",
|
||||
metadata: map[string]any{"project_id": "proj-a, proj-b, proj-a"},
|
||||
want: []string{"proj-a", "proj-b"},
|
||||
},
|
||||
{
|
||||
name: "with empty parts",
|
||||
metadata: map[string]any{"project_id": "proj-a, , proj-b, "},
|
||||
want: []string{"proj-a", "proj-b"},
|
||||
},
|
||||
{
|
||||
name: "empty project_id",
|
||||
metadata: map[string]any{"project_id": ""},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "no project_id",
|
||||
metadata: map[string]any{},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
metadata: map[string]any{"project_id": " "},
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := splitGeminiProjectIDs(tt.metadata)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("expected %v, got %v", tt.want, got)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("expected %v, got %v", tt.want, got)
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create a gemini auth file with multiple projects
|
||||
authData := map[string]any{
|
||||
"type": "gemini",
|
||||
"email": "multi@example.com",
|
||||
"project_id": "project-a, project-b, project-c",
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "gemini-multi.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Should have 4 auths: 1 primary (disabled) + 3 virtuals
|
||||
if len(auths) != 4 {
|
||||
t.Fatalf("expected 4 auths (1 primary + 3 virtuals), got %d", len(auths))
|
||||
}
|
||||
|
||||
// First auth should be the primary (disabled)
|
||||
primary := auths[0]
|
||||
if !primary.Disabled {
|
||||
t.Error("expected primary to be disabled")
|
||||
}
|
||||
if primary.Status != coreauth.StatusDisabled {
|
||||
t.Errorf("expected primary status disabled, got %s", primary.Status)
|
||||
}
|
||||
|
||||
// Remaining auths should be virtuals
|
||||
for i := 1; i < 4; i++ {
|
||||
v := auths[i]
|
||||
if v.Status != coreauth.StatusActive {
|
||||
t.Errorf("expected virtual %d to be active, got %s", i, v.Status)
|
||||
}
|
||||
if v.Attributes["gemini_virtual_parent"] != primary.ID {
|
||||
t.Errorf("expected virtual %d parent to be %s, got %s", i, primary.ID, v.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildGeminiVirtualID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
baseID string
|
||||
projectID string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
baseID: "auth.json",
|
||||
projectID: "my-project",
|
||||
want: "auth.json::my-project",
|
||||
},
|
||||
{
|
||||
name: "with slashes",
|
||||
baseID: "path/to/auth.json",
|
||||
projectID: "project/with/slashes",
|
||||
want: "path/to/auth.json::project_with_slashes",
|
||||
},
|
||||
{
|
||||
name: "with spaces",
|
||||
baseID: "auth.json",
|
||||
projectID: "my project",
|
||||
want: "auth.json::my_project",
|
||||
},
|
||||
{
|
||||
name: "empty project",
|
||||
baseID: "auth.json",
|
||||
projectID: "",
|
||||
want: "auth.json::project",
|
||||
},
|
||||
{
|
||||
name: "whitespace project",
|
||||
baseID: "auth.json",
|
||||
projectID: " ",
|
||||
want: "auth.json::project",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := buildGeminiVirtualID(tt.baseID, tt.projectID)
|
||||
if got != tt.want {
|
||||
t.Errorf("expected %q, got %q", tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
110
internal/watcher/synthesizer/helpers.go
Normal file
110
internal/watcher/synthesizer/helpers.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// StableIDGenerator generates stable, deterministic IDs for auth entries.
|
||||
// It uses SHA256 hashing with collision handling via counters.
|
||||
// It is not safe for concurrent use.
|
||||
type StableIDGenerator struct {
|
||||
counters map[string]int
|
||||
}
|
||||
|
||||
// NewStableIDGenerator creates a new StableIDGenerator instance.
|
||||
func NewStableIDGenerator() *StableIDGenerator {
|
||||
return &StableIDGenerator{counters: make(map[string]int)}
|
||||
}
|
||||
|
||||
// Next generates a stable ID based on the kind and parts.
|
||||
// Returns the full ID (kind:hash) and the short hash portion.
|
||||
func (g *StableIDGenerator) Next(kind string, parts ...string) (string, string) {
|
||||
if g == nil {
|
||||
return kind + ":000000000000", "000000000000"
|
||||
}
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(kind))
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
hasher.Write([]byte{0})
|
||||
hasher.Write([]byte(trimmed))
|
||||
}
|
||||
digest := hex.EncodeToString(hasher.Sum(nil))
|
||||
if len(digest) < 12 {
|
||||
digest = fmt.Sprintf("%012s", digest)
|
||||
}
|
||||
short := digest[:12]
|
||||
key := kind + ":" + short
|
||||
index := g.counters[key]
|
||||
g.counters[key] = index + 1
|
||||
if index > 0 {
|
||||
short = fmt.Sprintf("%s-%d", short, index)
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", kind, short), short
|
||||
}
|
||||
|
||||
// ApplyAuthExcludedModelsMeta applies excluded models metadata to an auth entry.
|
||||
// It computes a hash of excluded models and sets the auth_kind attribute.
|
||||
func ApplyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
|
||||
if auth == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
authKindKey := strings.ToLower(strings.TrimSpace(authKind))
|
||||
seen := make(map[string]struct{})
|
||||
add := func(list []string) {
|
||||
for _, entry := range list {
|
||||
if trimmed := strings.TrimSpace(entry); trimmed != "" {
|
||||
key := strings.ToLower(trimmed)
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
if authKindKey == "apikey" {
|
||||
add(perKey)
|
||||
} else if cfg.OAuthExcludedModels != nil {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
add(cfg.OAuthExcludedModels[providerKey])
|
||||
}
|
||||
combined := make([]string, 0, len(seen))
|
||||
for k := range seen {
|
||||
combined = append(combined, k)
|
||||
}
|
||||
sort.Strings(combined)
|
||||
hash := diff.ComputeExcludedModelsHash(combined)
|
||||
if auth.Attributes == nil {
|
||||
auth.Attributes = make(map[string]string)
|
||||
}
|
||||
if hash != "" {
|
||||
auth.Attributes["excluded_models_hash"] = hash
|
||||
}
|
||||
if authKind != "" {
|
||||
auth.Attributes["auth_kind"] = authKind
|
||||
}
|
||||
}
|
||||
|
||||
// addConfigHeadersToAttrs adds header configuration to auth attributes.
|
||||
// Headers are prefixed with "header:" in the attributes map.
|
||||
func addConfigHeadersToAttrs(headers map[string]string, attrs map[string]string) {
|
||||
if len(headers) == 0 || attrs == nil {
|
||||
return
|
||||
}
|
||||
for hk, hv := range headers {
|
||||
key := strings.TrimSpace(hk)
|
||||
val := strings.TrimSpace(hv)
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
attrs["header:"+key] = val
|
||||
}
|
||||
}
|
||||
264
internal/watcher/synthesizer/helpers_test.go
Normal file
264
internal/watcher/synthesizer/helpers_test.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestNewStableIDGenerator(t *testing.T) {
|
||||
gen := NewStableIDGenerator()
|
||||
if gen == nil {
|
||||
t.Fatal("expected non-nil generator")
|
||||
}
|
||||
if gen.counters == nil {
|
||||
t.Fatal("expected non-nil counters map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStableIDGenerator_Next(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
kind string
|
||||
parts []string
|
||||
wantPrefix string
|
||||
}{
|
||||
{
|
||||
name: "basic gemini apikey",
|
||||
kind: "gemini:apikey",
|
||||
parts: []string{"test-key", ""},
|
||||
wantPrefix: "gemini:apikey:",
|
||||
},
|
||||
{
|
||||
name: "claude with base url",
|
||||
kind: "claude:apikey",
|
||||
parts: []string{"sk-ant-xxx", "https://api.anthropic.com"},
|
||||
wantPrefix: "claude:apikey:",
|
||||
},
|
||||
{
|
||||
name: "empty parts",
|
||||
kind: "codex:apikey",
|
||||
parts: []string{},
|
||||
wantPrefix: "codex:apikey:",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gen := NewStableIDGenerator()
|
||||
id, short := gen.Next(tt.kind, tt.parts...)
|
||||
|
||||
if !strings.Contains(id, tt.wantPrefix) {
|
||||
t.Errorf("expected id to contain %q, got %q", tt.wantPrefix, id)
|
||||
}
|
||||
if short == "" {
|
||||
t.Error("expected non-empty short id")
|
||||
}
|
||||
if len(short) != 12 {
|
||||
t.Errorf("expected short id length 12, got %d", len(short))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStableIDGenerator_Stability(t *testing.T) {
|
||||
gen1 := NewStableIDGenerator()
|
||||
gen2 := NewStableIDGenerator()
|
||||
|
||||
id1, _ := gen1.Next("gemini:apikey", "test-key", "https://api.example.com")
|
||||
id2, _ := gen2.Next("gemini:apikey", "test-key", "https://api.example.com")
|
||||
|
||||
if id1 != id2 {
|
||||
t.Errorf("same inputs should produce same ID: got %q and %q", id1, id2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStableIDGenerator_CollisionHandling(t *testing.T) {
|
||||
gen := NewStableIDGenerator()
|
||||
|
||||
id1, short1 := gen.Next("gemini:apikey", "same-key")
|
||||
id2, short2 := gen.Next("gemini:apikey", "same-key")
|
||||
|
||||
if id1 == id2 {
|
||||
t.Error("collision should be handled with suffix")
|
||||
}
|
||||
if short1 == short2 {
|
||||
t.Error("short ids should differ")
|
||||
}
|
||||
if !strings.Contains(short2, "-1") {
|
||||
t.Errorf("second short id should contain -1 suffix, got %q", short2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStableIDGenerator_NilReceiver(t *testing.T) {
|
||||
var gen *StableIDGenerator = nil
|
||||
id, short := gen.Next("test:kind", "part")
|
||||
|
||||
if id != "test:kind:000000000000" {
|
||||
t.Errorf("expected test:kind:000000000000, got %q", id)
|
||||
}
|
||||
if short != "000000000000" {
|
||||
t.Errorf("expected 000000000000, got %q", short)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAuthExcludedModelsMeta(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
auth *coreauth.Auth
|
||||
cfg *config.Config
|
||||
perKey []string
|
||||
authKind string
|
||||
wantHash bool
|
||||
wantKind string
|
||||
}{
|
||||
{
|
||||
name: "apikey with excluded models",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "gemini",
|
||||
Attributes: make(map[string]string),
|
||||
},
|
||||
cfg: &config.Config{},
|
||||
perKey: []string{"model-a", "model-b"},
|
||||
authKind: "apikey",
|
||||
wantHash: true,
|
||||
wantKind: "apikey",
|
||||
},
|
||||
{
|
||||
name: "oauth with provider excluded models",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "claude",
|
||||
Attributes: make(map[string]string),
|
||||
},
|
||||
cfg: &config.Config{
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"claude": {"claude-2.0"},
|
||||
},
|
||||
},
|
||||
perKey: nil,
|
||||
authKind: "oauth",
|
||||
wantHash: true,
|
||||
wantKind: "oauth",
|
||||
},
|
||||
{
|
||||
name: "nil auth",
|
||||
auth: nil,
|
||||
cfg: &config.Config{},
|
||||
},
|
||||
{
|
||||
name: "nil config",
|
||||
auth: &coreauth.Auth{Provider: "test"},
|
||||
cfg: nil,
|
||||
authKind: "apikey",
|
||||
},
|
||||
{
|
||||
name: "nil attributes initialized",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "gemini",
|
||||
Attributes: nil,
|
||||
},
|
||||
cfg: &config.Config{},
|
||||
perKey: []string{"model-x"},
|
||||
authKind: "apikey",
|
||||
wantHash: true,
|
||||
wantKind: "apikey",
|
||||
},
|
||||
{
|
||||
name: "apikey with duplicate excluded models",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "gemini",
|
||||
Attributes: make(map[string]string),
|
||||
},
|
||||
cfg: &config.Config{},
|
||||
perKey: []string{"model-a", "MODEL-A", "model-b", "model-a"},
|
||||
authKind: "apikey",
|
||||
wantHash: true,
|
||||
wantKind: "apikey",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ApplyAuthExcludedModelsMeta(tt.auth, tt.cfg, tt.perKey, tt.authKind)
|
||||
|
||||
if tt.auth != nil && tt.cfg != nil {
|
||||
if tt.wantHash {
|
||||
if _, ok := tt.auth.Attributes["excluded_models_hash"]; !ok {
|
||||
t.Error("expected excluded_models_hash in attributes")
|
||||
}
|
||||
}
|
||||
if tt.wantKind != "" {
|
||||
if got := tt.auth.Attributes["auth_kind"]; got != tt.wantKind {
|
||||
t.Errorf("expected auth_kind=%s, got %s", tt.wantKind, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddConfigHeadersToAttrs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
attrs map[string]string
|
||||
want map[string]string
|
||||
}{
|
||||
{
|
||||
name: "basic headers",
|
||||
headers: map[string]string{
|
||||
"Authorization": "Bearer token",
|
||||
"X-Custom": "value",
|
||||
},
|
||||
attrs: map[string]string{"existing": "key"},
|
||||
want: map[string]string{
|
||||
"existing": "key",
|
||||
"header:Authorization": "Bearer token",
|
||||
"header:X-Custom": "value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty headers",
|
||||
headers: map[string]string{},
|
||||
attrs: map[string]string{"existing": "key"},
|
||||
want: map[string]string{"existing": "key"},
|
||||
},
|
||||
{
|
||||
name: "nil headers",
|
||||
headers: nil,
|
||||
attrs: map[string]string{"existing": "key"},
|
||||
want: map[string]string{"existing": "key"},
|
||||
},
|
||||
{
|
||||
name: "nil attrs",
|
||||
headers: map[string]string{"key": "value"},
|
||||
attrs: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "skip empty keys and values",
|
||||
headers: map[string]string{
|
||||
"": "value",
|
||||
"key": "",
|
||||
" ": "value",
|
||||
"valid": "valid-value",
|
||||
},
|
||||
attrs: make(map[string]string),
|
||||
want: map[string]string{
|
||||
"header:valid": "valid-value",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addConfigHeadersToAttrs(tt.headers, tt.attrs)
|
||||
if !reflect.DeepEqual(tt.attrs, tt.want) {
|
||||
t.Errorf("expected %v, got %v", tt.want, tt.attrs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
16
internal/watcher/synthesizer/interface.go
Normal file
16
internal/watcher/synthesizer/interface.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Package synthesizer provides auth synthesis strategies for the watcher package.
|
||||
// It implements the Strategy pattern to support multiple auth sources:
|
||||
// - ConfigSynthesizer: generates Auth entries from config API keys
|
||||
// - FileSynthesizer: generates Auth entries from OAuth JSON files
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// AuthSynthesizer defines the interface for generating Auth entries from various sources.
|
||||
type AuthSynthesizer interface {
|
||||
// Synthesize generates Auth entries from the given context.
|
||||
// Returns a slice of Auth pointers and any error encountered.
|
||||
Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, error)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
613
internal/watcher/watcher_test.go
Normal file
613
internal/watcher/watcher_test.go
Normal file
@@ -0,0 +1,613 @@
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestApplyAuthExcludedModelsMeta_APIKey(t *testing.T) {
|
||||
auth := &coreauth.Auth{Attributes: map[string]string{}}
|
||||
cfg := &config.Config{}
|
||||
perKey := []string{" Model-1 ", "model-2"}
|
||||
|
||||
synthesizer.ApplyAuthExcludedModelsMeta(auth, cfg, perKey, "apikey")
|
||||
|
||||
expected := diff.ComputeExcludedModelsHash([]string{"model-1", "model-2"})
|
||||
if got := auth.Attributes["excluded_models_hash"]; got != expected {
|
||||
t.Fatalf("expected hash %s, got %s", expected, got)
|
||||
}
|
||||
if got := auth.Attributes["auth_kind"]; got != "apikey" {
|
||||
t.Fatalf("expected auth_kind=apikey, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAuthExcludedModelsMeta_OAuthProvider(t *testing.T) {
|
||||
auth := &coreauth.Auth{
|
||||
Provider: "TestProv",
|
||||
Attributes: map[string]string{},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"testprov": {"A", "b"},
|
||||
},
|
||||
}
|
||||
|
||||
synthesizer.ApplyAuthExcludedModelsMeta(auth, cfg, nil, "oauth")
|
||||
|
||||
expected := diff.ComputeExcludedModelsHash([]string{"a", "b"})
|
||||
if got := auth.Attributes["excluded_models_hash"]; got != expected {
|
||||
t.Fatalf("expected hash %s, got %s", expected, got)
|
||||
}
|
||||
if got := auth.Attributes["auth_kind"]; got != "oauth" {
|
||||
t.Fatalf("expected auth_kind=oauth, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAPIKeyClientsCounts(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{{APIKey: "g1"}, {APIKey: "g2"}},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1"},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}},
|
||||
CodexKey: []config.CodexKey{{APIKey: "x1"}, {APIKey: "x2"}},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "o1"}, {APIKey: "o2"}}},
|
||||
},
|
||||
}
|
||||
|
||||
gemini, vertex, claude, codex, compat := BuildAPIKeyClients(cfg)
|
||||
if gemini != 2 || vertex != 1 || claude != 1 || codex != 2 || compat != 2 {
|
||||
t.Fatalf("unexpected counts: %d %d %d %d %d", gemini, vertex, claude, codex, compat)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAuthStripsTemporalFields(t *testing.T) {
|
||||
now := time.Now()
|
||||
auth := &coreauth.Auth{
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
LastRefreshedAt: now,
|
||||
NextRefreshAfter: now,
|
||||
Quota: coreauth.QuotaState{
|
||||
NextRecoverAt: now,
|
||||
},
|
||||
Runtime: map[string]any{"k": "v"},
|
||||
}
|
||||
|
||||
normalized := normalizeAuth(auth)
|
||||
if !normalized.CreatedAt.IsZero() || !normalized.UpdatedAt.IsZero() || !normalized.LastRefreshedAt.IsZero() || !normalized.NextRefreshAfter.IsZero() {
|
||||
t.Fatal("expected time fields to be zeroed")
|
||||
}
|
||||
if normalized.Runtime != nil {
|
||||
t.Fatal("expected runtime to be nil")
|
||||
}
|
||||
if !normalized.Quota.NextRecoverAt.IsZero() {
|
||||
t.Fatal("expected quota.NextRecoverAt to be zeroed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchProvider(t *testing.T) {
|
||||
if _, ok := matchProvider("OpenAI", []string{"openai", "claude"}); !ok {
|
||||
t.Fatal("expected match to succeed ignoring case")
|
||||
}
|
||||
if _, ok := matchProvider("missing", []string{"openai"}); ok {
|
||||
t.Fatal("expected match to fail for unknown provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshotCoreAuths_ConfigAndAuthFiles(t *testing.T) {
|
||||
authDir := t.TempDir()
|
||||
metadata := map[string]any{
|
||||
"type": "gemini",
|
||||
"email": "user@example.com",
|
||||
"project_id": "proj-a, proj-b",
|
||||
"proxy_url": "https://proxy",
|
||||
}
|
||||
authFile := filepath.Join(authDir, "gemini.json")
|
||||
data, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal metadata: %v", err)
|
||||
}
|
||||
if err = os.WriteFile(authFile, data, 0o644); err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
AuthDir: authDir,
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{
|
||||
APIKey: "g-key",
|
||||
BaseURL: "https://gemini",
|
||||
ExcludedModels: []string{"Model-A", "model-b"},
|
||||
Headers: map[string]string{"X-Req": "1"},
|
||||
},
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"gemini-cli": {"Foo", "bar"},
|
||||
},
|
||||
}
|
||||
|
||||
w := &Watcher{authDir: authDir}
|
||||
w.SetConfig(cfg)
|
||||
|
||||
auths := w.SnapshotCoreAuths()
|
||||
if len(auths) != 4 {
|
||||
t.Fatalf("expected 4 auth entries (1 config + 1 primary + 2 virtual), got %d", len(auths))
|
||||
}
|
||||
|
||||
var geminiAPIKeyAuth *coreauth.Auth
|
||||
var geminiPrimary *coreauth.Auth
|
||||
virtuals := make([]*coreauth.Auth, 0)
|
||||
for _, a := range auths {
|
||||
switch {
|
||||
case a.Provider == "gemini" && a.Attributes["api_key"] == "g-key":
|
||||
geminiAPIKeyAuth = a
|
||||
case a.Attributes["gemini_virtual_primary"] == "true":
|
||||
geminiPrimary = a
|
||||
case strings.TrimSpace(a.Attributes["gemini_virtual_parent"]) != "":
|
||||
virtuals = append(virtuals, a)
|
||||
}
|
||||
}
|
||||
if geminiAPIKeyAuth == nil {
|
||||
t.Fatal("expected synthesized Gemini API key auth")
|
||||
}
|
||||
expectedAPIKeyHash := diff.ComputeExcludedModelsHash([]string{"Model-A", "model-b"})
|
||||
if geminiAPIKeyAuth.Attributes["excluded_models_hash"] != expectedAPIKeyHash {
|
||||
t.Fatalf("expected API key excluded hash %s, got %s", expectedAPIKeyHash, geminiAPIKeyAuth.Attributes["excluded_models_hash"])
|
||||
}
|
||||
if geminiAPIKeyAuth.Attributes["auth_kind"] != "apikey" {
|
||||
t.Fatalf("expected auth_kind=apikey, got %s", geminiAPIKeyAuth.Attributes["auth_kind"])
|
||||
}
|
||||
|
||||
if geminiPrimary == nil {
|
||||
t.Fatal("expected primary gemini-cli auth from file")
|
||||
}
|
||||
if !geminiPrimary.Disabled || geminiPrimary.Status != coreauth.StatusDisabled {
|
||||
t.Fatal("expected primary gemini-cli auth to be disabled when virtual auths are synthesized")
|
||||
}
|
||||
expectedOAuthHash := diff.ComputeExcludedModelsHash([]string{"Foo", "bar"})
|
||||
if geminiPrimary.Attributes["excluded_models_hash"] != expectedOAuthHash {
|
||||
t.Fatalf("expected OAuth excluded hash %s, got %s", expectedOAuthHash, geminiPrimary.Attributes["excluded_models_hash"])
|
||||
}
|
||||
if geminiPrimary.Attributes["auth_kind"] != "oauth" {
|
||||
t.Fatalf("expected auth_kind=oauth, got %s", geminiPrimary.Attributes["auth_kind"])
|
||||
}
|
||||
|
||||
if len(virtuals) != 2 {
|
||||
t.Fatalf("expected 2 virtual auths, got %d", len(virtuals))
|
||||
}
|
||||
for _, v := range virtuals {
|
||||
if v.Attributes["gemini_virtual_parent"] != geminiPrimary.ID {
|
||||
t.Fatalf("virtual auth missing parent link to %s", geminiPrimary.ID)
|
||||
}
|
||||
if v.Attributes["excluded_models_hash"] != expectedOAuthHash {
|
||||
t.Fatalf("expected virtual excluded hash %s, got %s", expectedOAuthHash, v.Attributes["excluded_models_hash"])
|
||||
}
|
||||
if v.Status != coreauth.StatusActive {
|
||||
t.Fatalf("expected virtual auth to be active, got %s", v.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadConfigIfChanged_TriggersOnChangeAndSkipsUnchanged(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", err)
|
||||
}
|
||||
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
writeConfig := func(port int, allowRemote bool) {
|
||||
cfg := &config.Config{
|
||||
Port: port,
|
||||
AuthDir: authDir,
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: allowRemote,
|
||||
},
|
||||
}
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal config: %v", err)
|
||||
}
|
||||
if err = os.WriteFile(configPath, data, 0o644); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
writeConfig(8080, false)
|
||||
|
||||
reloads := 0
|
||||
w := &Watcher{
|
||||
configPath: configPath,
|
||||
authDir: authDir,
|
||||
reloadCallback: func(*config.Config) { reloads++ },
|
||||
}
|
||||
|
||||
w.reloadConfigIfChanged()
|
||||
if reloads != 1 {
|
||||
t.Fatalf("expected first reload to trigger callback once, got %d", reloads)
|
||||
}
|
||||
|
||||
// Same content should be skipped by hash check.
|
||||
w.reloadConfigIfChanged()
|
||||
if reloads != 1 {
|
||||
t.Fatalf("expected unchanged config to be skipped, callback count %d", reloads)
|
||||
}
|
||||
|
||||
writeConfig(9090, true)
|
||||
w.reloadConfigIfChanged()
|
||||
if reloads != 2 {
|
||||
t.Fatalf("expected changed config to trigger reload, callback count %d", reloads)
|
||||
}
|
||||
w.clientsMutex.RLock()
|
||||
defer w.clientsMutex.RUnlock()
|
||||
if w.config == nil || w.config.Port != 9090 || !w.config.RemoteManagement.AllowRemote {
|
||||
t.Fatalf("expected config to be updated after reload, got %+v", w.config)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartAndStopSuccess(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir), 0o644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
var reloads int32
|
||||
w, err := NewWatcher(configPath, authDir, func(*config.Config) {
|
||||
atomic.AddInt32(&reloads, 1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create watcher: %v", err)
|
||||
}
|
||||
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := w.Start(ctx); err != nil {
|
||||
t.Fatalf("expected Start to succeed: %v", err)
|
||||
}
|
||||
cancel()
|
||||
if err := w.Stop(); err != nil {
|
||||
t.Fatalf("expected Stop to succeed: %v", err)
|
||||
}
|
||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
||||
t.Fatalf("expected one reload callback, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartFailsWhenConfigMissing(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(tmpDir, "missing-config.yaml")
|
||||
|
||||
w, err := NewWatcher(configPath, authDir, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create watcher: %v", err)
|
||||
}
|
||||
defer w.Stop()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := w.Start(ctx); err == nil {
|
||||
t.Fatal("expected Start to fail for missing config file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatchRuntimeAuthUpdateEnqueuesAndUpdatesState(t *testing.T) {
|
||||
queue := make(chan AuthUpdate, 4)
|
||||
w := &Watcher{}
|
||||
w.SetAuthUpdateQueue(queue)
|
||||
defer w.stopDispatch()
|
||||
|
||||
auth := &coreauth.Auth{ID: "auth-1", Provider: "test"}
|
||||
if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionAdd, Auth: auth}); !ok {
|
||||
t.Fatal("expected DispatchRuntimeAuthUpdate to enqueue")
|
||||
}
|
||||
|
||||
select {
|
||||
case update := <-queue:
|
||||
if update.Action != AuthUpdateActionAdd || update.Auth.ID != "auth-1" {
|
||||
t.Fatalf("unexpected update: %+v", update)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for auth update")
|
||||
}
|
||||
|
||||
if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionDelete, ID: "auth-1"}); !ok {
|
||||
t.Fatal("expected delete update to enqueue")
|
||||
}
|
||||
select {
|
||||
case update := <-queue:
|
||||
if update.Action != AuthUpdateActionDelete || update.ID != "auth-1" {
|
||||
t.Fatalf("unexpected delete update: %+v", update)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for delete update")
|
||||
}
|
||||
w.clientsMutex.RLock()
|
||||
if _, exists := w.runtimeAuths["auth-1"]; exists {
|
||||
w.clientsMutex.RUnlock()
|
||||
t.Fatal("expected runtime auth to be cleared after delete")
|
||||
}
|
||||
w.clientsMutex.RUnlock()
|
||||
}
|
||||
|
||||
func TestAddOrUpdateClientSkipsUnchanged(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "sample.json")
|
||||
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to create auth file: %v", err)
|
||||
}
|
||||
data, _ := os.ReadFile(authFile)
|
||||
sum := sha256.Sum256(data)
|
||||
|
||||
var reloads int32
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
lastAuthHashes: make(map[string]string),
|
||||
reloadCallback: func(*config.Config) {
|
||||
atomic.AddInt32(&reloads, 1)
|
||||
},
|
||||
}
|
||||
w.SetConfig(&config.Config{AuthDir: tmpDir})
|
||||
// Use normalizeAuthPath to match how addOrUpdateClient stores the key
|
||||
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(sum[:])
|
||||
|
||||
w.addOrUpdateClient(authFile)
|
||||
if got := atomic.LoadInt32(&reloads); got != 0 {
|
||||
t.Fatalf("expected no reload for unchanged file, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddOrUpdateClientTriggersReloadAndHash(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "sample.json")
|
||||
if err := os.WriteFile(authFile, []byte(`{"type":"demo","api_key":"k"}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to create auth file: %v", err)
|
||||
}
|
||||
|
||||
var reloads int32
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
lastAuthHashes: make(map[string]string),
|
||||
reloadCallback: func(*config.Config) {
|
||||
atomic.AddInt32(&reloads, 1)
|
||||
},
|
||||
}
|
||||
w.SetConfig(&config.Config{AuthDir: tmpDir})
|
||||
|
||||
w.addOrUpdateClient(authFile)
|
||||
|
||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
||||
t.Fatalf("expected reload callback once, got %d", got)
|
||||
}
|
||||
// Use normalizeAuthPath to match how addOrUpdateClient stores the key
|
||||
normalized := w.normalizeAuthPath(authFile)
|
||||
if _, ok := w.lastAuthHashes[normalized]; !ok {
|
||||
t.Fatalf("expected hash to be stored for %s", normalized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveClientRemovesHash(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "sample.json")
|
||||
var reloads int32
|
||||
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
lastAuthHashes: make(map[string]string),
|
||||
reloadCallback: func(*config.Config) {
|
||||
atomic.AddInt32(&reloads, 1)
|
||||
},
|
||||
}
|
||||
w.SetConfig(&config.Config{AuthDir: tmpDir})
|
||||
// Use normalizeAuthPath to set up the hash with the correct key format
|
||||
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
|
||||
|
||||
w.removeClient(authFile)
|
||||
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
||||
t.Fatal("expected hash to be removed after deletion")
|
||||
}
|
||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
||||
t.Fatalf("expected reload callback once, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldDebounceRemove(t *testing.T) {
|
||||
w := &Watcher{}
|
||||
path := filepath.Clean("test.json")
|
||||
|
||||
if w.shouldDebounceRemove(path, time.Now()) {
|
||||
t.Fatal("first call should not debounce")
|
||||
}
|
||||
if !w.shouldDebounceRemove(path, time.Now()) {
|
||||
t.Fatal("second call within window should debounce")
|
||||
}
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
w.lastRemoveTimes = map[string]time.Time{path: time.Now().Add(-2 * authRemoveDebounceWindow)}
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
if w.shouldDebounceRemove(path, time.Now()) {
|
||||
t.Fatal("call after window should not debounce")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthFileUnchangedUsesHash(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "sample.json")
|
||||
content := []byte(`{"type":"demo"}`)
|
||||
if err := os.WriteFile(authFile, content, 0o644); err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
w := &Watcher{lastAuthHashes: make(map[string]string)}
|
||||
unchanged, err := w.authFileUnchanged(authFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if unchanged {
|
||||
t.Fatal("expected first check to report changed")
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(content)
|
||||
// Use normalizeAuthPath to match how authFileUnchanged looks up the key
|
||||
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(sum[:])
|
||||
|
||||
unchanged, err = w.authFileUnchanged(authFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !unchanged {
|
||||
t.Fatal("expected hash match to report unchanged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadClientsCachesAuthHashes(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "one.json")
|
||||
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
config: &config.Config{AuthDir: tmpDir},
|
||||
}
|
||||
|
||||
w.reloadClients(true, nil, false)
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
defer w.clientsMutex.RUnlock()
|
||||
if len(w.lastAuthHashes) != 1 {
|
||||
t.Fatalf("expected hash cache for one auth file, got %d", len(w.lastAuthHashes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadClientsLogsConfigDiffs(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
oldCfg := &config.Config{AuthDir: tmpDir, Port: 1, Debug: false}
|
||||
newCfg := &config.Config{AuthDir: tmpDir, Port: 2, Debug: true}
|
||||
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
config: oldCfg,
|
||||
}
|
||||
w.SetConfig(oldCfg)
|
||||
w.oldConfigYaml, _ = yaml.Marshal(oldCfg)
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
w.config = newCfg
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
w.reloadClients(false, nil, false)
|
||||
}
|
||||
|
||||
func TestSetAuthUpdateQueueNilResetsDispatch(t *testing.T) {
|
||||
w := &Watcher{}
|
||||
queue := make(chan AuthUpdate, 1)
|
||||
w.SetAuthUpdateQueue(queue)
|
||||
if w.dispatchCond == nil || w.dispatchCancel == nil {
|
||||
t.Fatal("expected dispatch to be initialized")
|
||||
}
|
||||
w.SetAuthUpdateQueue(nil)
|
||||
if w.dispatchCancel != nil {
|
||||
t.Fatal("expected dispatch cancel to be cleared when queue nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopConfigReloadTimerSafeWhenNil(t *testing.T) {
|
||||
w := &Watcher{}
|
||||
w.stopConfigReloadTimer()
|
||||
w.configReloadMu.Lock()
|
||||
w.configReloadTimer = time.AfterFunc(10*time.Millisecond, func() {})
|
||||
w.configReloadMu.Unlock()
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
w.stopConfigReloadTimer()
|
||||
}
|
||||
|
||||
func TestHandleEventRemovesAuthFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "remove.json")
|
||||
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
if err := os.Remove(authFile); err != nil {
|
||||
t.Fatalf("failed to remove auth file pre-check: %v", err)
|
||||
}
|
||||
|
||||
var reloads int32
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
config: &config.Config{AuthDir: tmpDir},
|
||||
lastAuthHashes: make(map[string]string),
|
||||
reloadCallback: func(*config.Config) {
|
||||
atomic.AddInt32(&reloads, 1)
|
||||
},
|
||||
}
|
||||
// Use normalizeAuthPath to set up the hash with the correct key format
|
||||
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
|
||||
|
||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
|
||||
|
||||
if atomic.LoadInt32(&reloads) != 1 {
|
||||
t.Fatalf("expected reload callback once, got %d", reloads)
|
||||
}
|
||||
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
||||
t.Fatal("expected hash entry to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatchAuthUpdatesFlushesQueue(t *testing.T) {
|
||||
queue := make(chan AuthUpdate, 4)
|
||||
w := &Watcher{}
|
||||
w.SetAuthUpdateQueue(queue)
|
||||
defer w.stopDispatch()
|
||||
|
||||
w.dispatchAuthUpdates([]AuthUpdate{
|
||||
{Action: AuthUpdateActionAdd, ID: "a"},
|
||||
{Action: AuthUpdateActionModify, ID: "b"},
|
||||
})
|
||||
|
||||
got := make([]AuthUpdate, 0, 2)
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case u := <-queue:
|
||||
got = append(got, u)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("timed out waiting for update %d", i)
|
||||
}
|
||||
}
|
||||
if len(got) != 2 || got[0].ID != "a" || got[1].ID != "b" {
|
||||
t.Fatalf("unexpected updates order/content: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func hexString(data []byte) string {
|
||||
return strings.ToLower(fmt.Sprintf("%x", data))
|
||||
}
|
||||
@@ -84,7 +84,8 @@ func (h *GeminiAPIHandler) GeminiGetHandler(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
switch request.Action {
|
||||
action := strings.TrimPrefix(request.Action, "/")
|
||||
switch action {
|
||||
case "gemini-3-pro-preview":
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"name": "models/gemini-3-pro-preview",
|
||||
@@ -189,7 +190,7 @@ func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
action := strings.Split(request.Action, ":")
|
||||
action := strings.Split(strings.TrimPrefix(request.Action, "/"), ":")
|
||||
if len(action) != 2 {
|
||||
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
@@ -50,27 +49,6 @@ type BaseAPIHandler struct {
|
||||
|
||||
// Cfg holds the current application configuration.
|
||||
Cfg *config.SDKConfig
|
||||
|
||||
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
|
||||
openAICompatProviders []string
|
||||
openAICompatMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// GetOpenAICompatProviders safely returns a copy of the provider names
|
||||
func (h *BaseAPIHandler) GetOpenAICompatProviders() []string {
|
||||
h.openAICompatMutex.RLock()
|
||||
defer h.openAICompatMutex.RUnlock()
|
||||
result := make([]string, len(h.openAICompatProviders))
|
||||
copy(result, h.openAICompatProviders)
|
||||
return result
|
||||
}
|
||||
|
||||
// SetOpenAICompatProviders safely sets the provider names
|
||||
func (h *BaseAPIHandler) SetOpenAICompatProviders(providers []string) {
|
||||
h.openAICompatMutex.Lock()
|
||||
defer h.openAICompatMutex.Unlock()
|
||||
h.openAICompatProviders = make([]string, len(providers))
|
||||
copy(h.openAICompatProviders, providers)
|
||||
}
|
||||
|
||||
// NewBaseAPIHandlers creates a new API handlers instance.
|
||||
@@ -82,12 +60,11 @@ func (h *BaseAPIHandler) SetOpenAICompatProviders(providers []string) {
|
||||
//
|
||||
// Returns:
|
||||
// - *BaseAPIHandler: A new API handlers instance
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
|
||||
h := &BaseAPIHandler{
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
}
|
||||
h.SetOpenAICompatProviders(openAICompatProviders)
|
||||
return h
|
||||
}
|
||||
|
||||
@@ -362,30 +339,19 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
|
||||
// Resolve "auto" model to an actual available model first
|
||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||
|
||||
providerName, extractedModelName, isDynamic := h.parseDynamicModel(resolvedModelName)
|
||||
|
||||
targetModelName := resolvedModelName
|
||||
if isDynamic {
|
||||
targetModelName = extractedModelName
|
||||
}
|
||||
|
||||
// Normalize the model name to handle dynamic thinking suffixes before determining the provider.
|
||||
normalizedModel, metadata = normalizeModelMetadata(targetModelName)
|
||||
normalizedModel, metadata = normalizeModelMetadata(resolvedModelName)
|
||||
|
||||
if isDynamic {
|
||||
providers = []string{providerName}
|
||||
} else {
|
||||
// For non-dynamic models, use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 && metadata != nil {
|
||||
if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok {
|
||||
if originalModel, okStr := originalRaw.(string); okStr {
|
||||
originalModel = strings.TrimSpace(originalModel)
|
||||
if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) {
|
||||
if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 {
|
||||
providers = altProviders
|
||||
normalizedModel = originalModel
|
||||
}
|
||||
// Use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 && metadata != nil {
|
||||
if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok {
|
||||
if originalModel, okStr := originalRaw.(string); okStr {
|
||||
originalModel = strings.TrimSpace(originalModel)
|
||||
if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) {
|
||||
if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 {
|
||||
providers = altProviders
|
||||
normalizedModel = originalModel
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -403,30 +369,6 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
|
||||
return providers, normalizedModel, metadata, nil
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
|
||||
var providerPart, modelPart string
|
||||
for _, sep := range []string{"://"} {
|
||||
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
|
||||
providerPart = parts[0]
|
||||
modelPart = parts[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if providerPart == "" {
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
// Check if the provider is a configured openai-compatibility provider
|
||||
for _, pName := range h.GetOpenAICompatProviders() {
|
||||
if pName == providerPart {
|
||||
return providerPart, modelPart, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
func cloneBytes(src []byte) []byte {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -363,10 +363,11 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
@@ -396,8 +397,10 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
resp, errExec := executor.Execute(execCtx, auth, req, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
@@ -420,10 +423,11 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
@@ -453,8 +457,10 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, req, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
@@ -477,10 +483,11 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
if provider == "" {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
@@ -510,14 +517,16 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, req, opts)
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: false, Error: rerr}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errStream
|
||||
@@ -535,18 +544,66 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: false, Error: rerr})
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
out <- chunk
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: true})
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteModelForAuth(model string, metadata map[string]any, auth *Auth) (string, map[string]any) {
|
||||
if auth == nil || model == "" {
|
||||
return model, metadata
|
||||
}
|
||||
prefix := strings.TrimSpace(auth.Prefix)
|
||||
if prefix == "" {
|
||||
return model, metadata
|
||||
}
|
||||
needle := prefix + "/"
|
||||
if !strings.HasPrefix(model, needle) {
|
||||
return model, metadata
|
||||
}
|
||||
rewritten := strings.TrimPrefix(model, needle)
|
||||
return rewritten, stripPrefixFromMetadata(metadata, needle)
|
||||
}
|
||||
|
||||
func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]any {
|
||||
if len(metadata) == 0 || needle == "" {
|
||||
return metadata
|
||||
}
|
||||
keys := []string{
|
||||
util.ThinkingOriginalModelMetadataKey,
|
||||
util.GeminiOriginalModelMetadataKey,
|
||||
}
|
||||
var out map[string]any
|
||||
for _, key := range keys {
|
||||
raw, ok := metadata[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
value, okStr := raw.(string)
|
||||
if !okStr || !strings.HasPrefix(value, needle) {
|
||||
continue
|
||||
}
|
||||
if out == nil {
|
||||
out = make(map[string]any, len(metadata))
|
||||
for k, v := range metadata {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
out[key] = strings.TrimPrefix(value, needle)
|
||||
}
|
||||
if out == nil {
|
||||
return metadata
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Manager) normalizeProviders(providers []string) []string {
|
||||
if len(providers) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -19,6 +19,8 @@ type Auth struct {
|
||||
Index uint64 `json:"-"`
|
||||
// Provider is the upstream provider key (e.g. "gemini", "claude").
|
||||
Provider string `json:"provider"`
|
||||
// Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview").
|
||||
Prefix string `json:"prefix,omitempty"`
|
||||
// FileName stores the relative or absolute path of the backing auth file.
|
||||
FileName string `json:"-"`
|
||||
// Storage holds the token persistence implementation used during login flows.
|
||||
|
||||
@@ -796,7 +796,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if providerKey == "" {
|
||||
providerKey = "openai-compatibility"
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, providerKey, ms)
|
||||
GlobalModelRegistry().RegisterClient(a.ID, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
|
||||
} else {
|
||||
// Ensure stale registrations are cleared when model list becomes empty.
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
@@ -816,7 +816,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if key == "" {
|
||||
key = strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, models)
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -996,6 +996,48 @@ func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
|
||||
return filtered
|
||||
}
|
||||
|
||||
func applyModelPrefixes(models []*ModelInfo, prefix string, forceModelPrefix bool) []*ModelInfo {
|
||||
trimmedPrefix := strings.TrimSpace(prefix)
|
||||
if trimmedPrefix == "" || len(models) == 0 {
|
||||
return models
|
||||
}
|
||||
|
||||
out := make([]*ModelInfo, 0, len(models)*2)
|
||||
seen := make(map[string]struct{}, len(models)*2)
|
||||
|
||||
addModel := func(model *ModelInfo) {
|
||||
if model == nil {
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(model.ID)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
if _, exists := seen[id]; exists {
|
||||
return
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, model)
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
baseID := strings.TrimSpace(model.ID)
|
||||
if baseID == "" {
|
||||
continue
|
||||
}
|
||||
if !forceModelPrefix || trimmedPrefix == baseID {
|
||||
addModel(model)
|
||||
}
|
||||
clone := *model
|
||||
clone.ID = trimmedPrefix + "/" + baseID
|
||||
addModel(&clone)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// matchWildcard performs case-insensitive wildcard matching where '*' matches any substring.
|
||||
func matchWildcard(pattern, value string) bool {
|
||||
if pattern == "" {
|
||||
|
||||
@@ -9,6 +9,11 @@ type SDKConfig struct {
|
||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
||||
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
||||
// credentials as well.
|
||||
ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"`
|
||||
|
||||
// RequestLog enables or disables detailed request logging functionality.
|
||||
RequestLog bool `yaml:"request-log" json:"request-log"`
|
||||
|
||||
|
||||
@@ -295,7 +295,7 @@ func TestThinkingConversionsAcrossProtocolsAndModels(t *testing.T) {
|
||||
}
|
||||
// Check numeric budget fallback for allowCompat
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if mapped, okMap := util.OpenAIThinkingBudgetToEffort(normalizedModel, *budget); okMap && mapped != "" {
|
||||
if mapped, okMap := util.ThinkingBudgetToEffort(normalizedModel, *budget); okMap && mapped != "" {
|
||||
return true, mapped, false
|
||||
}
|
||||
}
|
||||
@@ -308,7 +308,7 @@ func TestThinkingConversionsAcrossProtocolsAndModels(t *testing.T) {
|
||||
effort, ok := util.ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || strings.TrimSpace(effort) == "" {
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if mapped, okMap := util.OpenAIThinkingBudgetToEffort(normalizedModel, *budget); okMap {
|
||||
if mapped, okMap := util.ThinkingBudgetToEffort(normalizedModel, *budget); okMap {
|
||||
effort = mapped
|
||||
ok = true
|
||||
}
|
||||
@@ -336,7 +336,7 @@ func TestThinkingConversionsAcrossProtocolsAndModels(t *testing.T) {
|
||||
return false, "", true
|
||||
}
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if mapped, okMap := util.OpenAIThinkingBudgetToEffort(normalizedModel, *budget); okMap && mapped != "" {
|
||||
if mapped, okMap := util.ThinkingBudgetToEffort(normalizedModel, *budget); okMap && mapped != "" {
|
||||
mapped = strings.ToLower(strings.TrimSpace(mapped))
|
||||
if normalized, okLevel := util.NormalizeReasoningEffortLevel(normalizedModel, mapped); okLevel {
|
||||
return true, normalized, false
|
||||
@@ -609,7 +609,7 @@ func TestRawPayloadThinkingConversions(t *testing.T) {
|
||||
return true, normalized, false
|
||||
}
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
if mapped, okM := util.OpenAIThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
if mapped, okM := util.ThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
return true, mapped, false
|
||||
}
|
||||
}
|
||||
@@ -625,7 +625,7 @@ func TestRawPayloadThinkingConversions(t *testing.T) {
|
||||
return false, "", true // invalid level
|
||||
}
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
if mapped, okM := util.OpenAIThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
if mapped, okM := util.ThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
// Check if the mapped effort is valid for this model
|
||||
if _, validLevel := util.NormalizeReasoningEffortLevel(model, mapped); !validLevel {
|
||||
return true, mapped, true // expect validation error
|
||||
@@ -646,7 +646,7 @@ func TestRawPayloadThinkingConversions(t *testing.T) {
|
||||
return false, "", true
|
||||
}
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
if mapped, okM := util.OpenAIThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
if mapped, okM := util.ThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
// Check if the mapped effort is valid for this model
|
||||
if _, validLevel := util.NormalizeReasoningEffortLevel(model, mapped); !validLevel {
|
||||
return true, mapped, true // expect validation error
|
||||
@@ -721,7 +721,7 @@ func TestRawPayloadThinkingConversions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIThinkingBudgetToEffortRanges(t *testing.T) {
|
||||
func TestThinkingBudgetToEffort(t *testing.T) {
|
||||
cleanup := registerCoreModels(t)
|
||||
defer cleanup()
|
||||
|
||||
@@ -733,7 +733,7 @@ func TestOpenAIThinkingBudgetToEffortRanges(t *testing.T) {
|
||||
ok bool
|
||||
}{
|
||||
{name: "dynamic-auto", model: "gpt-5", budget: -1, want: "auto", ok: true},
|
||||
{name: "zero-none", model: "gpt-5", budget: 0, want: "none", ok: true},
|
||||
{name: "zero-none", model: "gpt-5", budget: 0, want: "minimal", ok: true},
|
||||
{name: "low-min", model: "gpt-5", budget: 1, want: "low", ok: true},
|
||||
{name: "low-max", model: "gpt-5", budget: 1024, want: "low", ok: true},
|
||||
{name: "medium-min", model: "gpt-5", budget: 1025, want: "medium", ok: true},
|
||||
@@ -741,14 +741,14 @@ func TestOpenAIThinkingBudgetToEffortRanges(t *testing.T) {
|
||||
{name: "high-min", model: "gpt-5", budget: 8193, want: "high", ok: true},
|
||||
{name: "high-max", model: "gpt-5", budget: 24576, want: "high", ok: true},
|
||||
{name: "over-max-clamps-to-highest", model: "gpt-5", budget: 64000, want: "high", ok: true},
|
||||
{name: "over-max-xhigh-model", model: "gpt-5.2", budget: 50000, want: "xhigh", ok: true},
|
||||
{name: "over-max-xhigh-model", model: "gpt-5.2", budget: 64000, want: "xhigh", ok: true},
|
||||
{name: "negative-unsupported", model: "gpt-5", budget: -5, want: "", ok: false},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
cs := cs
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
got, ok := util.OpenAIThinkingBudgetToEffort(cs.model, cs.budget)
|
||||
got, ok := util.ThinkingBudgetToEffort(cs.model, cs.budget)
|
||||
if ok != cs.ok {
|
||||
t.Fatalf("ok mismatch for model=%s budget=%d: expect %v got %v", cs.model, cs.budget, cs.ok, ok)
|
||||
}
|
||||
@@ -758,3 +758,41 @@ func TestOpenAIThinkingBudgetToEffortRanges(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkingEffortToBudget(t *testing.T) {
|
||||
cleanup := registerCoreModels(t)
|
||||
defer cleanup()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
effort string
|
||||
want int
|
||||
ok bool
|
||||
}{
|
||||
{name: "none", model: "gemini-2.5-pro", effort: "none", want: 0, ok: true},
|
||||
{name: "auto", model: "gemini-2.5-pro", effort: "auto", want: -1, ok: true},
|
||||
{name: "minimal", model: "gemini-2.5-pro", effort: "minimal", want: 512, ok: true},
|
||||
{name: "low", model: "gemini-2.5-pro", effort: "low", want: 1024, ok: true},
|
||||
{name: "medium", model: "gemini-2.5-pro", effort: "medium", want: 8192, ok: true},
|
||||
{name: "high", model: "gemini-2.5-pro", effort: "high", want: 24576, ok: true},
|
||||
{name: "xhigh", model: "gemini-2.5-pro", effort: "xhigh", want: 32768, ok: true},
|
||||
{name: "empty-unsupported", model: "gemini-2.5-pro", effort: "", want: 0, ok: false},
|
||||
{name: "invalid-unsupported", model: "gemini-2.5-pro", effort: "ultra", want: 0, ok: false},
|
||||
{name: "case-insensitive", model: "gemini-2.5-pro", effort: "LOW", want: 1024, ok: true},
|
||||
{name: "case-insensitive-medium", model: "gemini-2.5-pro", effort: "MEDIUM", want: 8192, ok: true},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
cs := cs
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
got, ok := util.ThinkingEffortToBudget(cs.model, cs.effort)
|
||||
if ok != cs.ok {
|
||||
t.Fatalf("ok mismatch for model=%s effort=%s: expect %v got %v", cs.model, cs.effort, cs.ok, ok)
|
||||
}
|
||||
if got != cs.want {
|
||||
t.Fatalf("value mismatch for model=%s effort=%s: expect %d got %d", cs.model, cs.effort, cs.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user