mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-30 01:06:39 +00:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c95129884 | ||
|
|
9fa2a7e9df | ||
|
|
d443c86620 | ||
|
|
7be3f1c36c | ||
|
|
f6ab6d97b9 | ||
|
|
bc866bac49 | ||
|
|
50e6d845f4 | ||
|
|
a8cb01819d | ||
|
|
1a99cfded4 | ||
|
|
530273906b | ||
|
|
06ddf575d9 | ||
|
|
cf369d4684 | ||
|
|
3099114cbb | ||
|
|
44b63f0767 | ||
|
|
6705d20194 | ||
|
|
a38a9c0b0f | ||
|
|
8286caa366 | ||
|
|
bd1ec8424d | ||
|
|
225e2c6797 | ||
|
|
d8fc485513 | ||
|
|
f137eb0ac4 | ||
|
|
f39a460487 | ||
|
|
ee171bc563 | ||
|
|
a95428f204 | ||
|
|
3ca5fb1046 | ||
|
|
a091d12f4e |
@@ -23,11 +23,14 @@ config.yaml
|
|||||||
|
|
||||||
# Development/editor
|
# Development/editor
|
||||||
bin/*
|
bin/*
|
||||||
.claude/*
|
|
||||||
.vscode/*
|
.vscode/*
|
||||||
|
.claude/*
|
||||||
|
.codex/*
|
||||||
.gemini/*
|
.gemini/*
|
||||||
.serena/*
|
.serena/*
|
||||||
.agent/*
|
.agent/*
|
||||||
|
.agents/*
|
||||||
|
.opencode/*
|
||||||
.bmad/*
|
.bmad/*
|
||||||
_bmad/*
|
_bmad/*
|
||||||
_bmad-output/*
|
_bmad-output/*
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -34,10 +34,14 @@ GEMINI.md
|
|||||||
|
|
||||||
# Tooling metadata
|
# Tooling metadata
|
||||||
.vscode/*
|
.vscode/*
|
||||||
|
.codex/*
|
||||||
.claude/*
|
.claude/*
|
||||||
.gemini/*
|
.gemini/*
|
||||||
.serena/*
|
.serena/*
|
||||||
.agent/*
|
.agent/*
|
||||||
|
.agents/*
|
||||||
|
.agents/*
|
||||||
|
.opencode/*
|
||||||
.bmad/*
|
.bmad/*
|
||||||
_bmad/*
|
_bmad/*
|
||||||
_bmad-output/*
|
_bmad-output/*
|
||||||
|
|||||||
@@ -434,7 +434,7 @@ func main() {
|
|||||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||||
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||||
|
|
||||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
|
if err = logging.ConfigureLogOutput(cfg); err != nil {
|
||||||
log.Errorf("failed to configure log output: %v", err)
|
log.Errorf("failed to configure log output: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ auth-dir: "~/.cli-proxy-api"
|
|||||||
api-keys:
|
api-keys:
|
||||||
- "your-api-key-1"
|
- "your-api-key-1"
|
||||||
- "your-api-key-2"
|
- "your-api-key-2"
|
||||||
|
- "your-api-key-3"
|
||||||
|
|
||||||
# Enable debug logging
|
# Enable debug logging
|
||||||
debug: false
|
debug: false
|
||||||
@@ -170,9 +171,9 @@ ws-auth: false
|
|||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# models: # optional: map aliases to upstream model names
|
# models: # optional: map aliases to upstream model names
|
||||||
# - name: "gemini-2.0-flash" # upstream model name
|
# - name: "gemini-2.5-flash" # upstream model name
|
||||||
# alias: "vertex-flash" # client-visible alias
|
# alias: "vertex-flash" # client-visible alias
|
||||||
# - name: "gemini-1.5-pro"
|
# - name: "gemini-2.5-pro"
|
||||||
# alias: "vertex-pro"
|
# alias: "vertex-pro"
|
||||||
|
|
||||||
# Amp Integration
|
# Amp Integration
|
||||||
@@ -181,6 +182,18 @@ ws-auth: false
|
|||||||
# upstream-url: "https://ampcode.com"
|
# upstream-url: "https://ampcode.com"
|
||||||
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
|
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
|
||||||
# upstream-api-key: ""
|
# upstream-api-key: ""
|
||||||
|
# # Per-client upstream API key mapping
|
||||||
|
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
|
||||||
|
# # Useful when different clients need to use different Amp accounts/quotas.
|
||||||
|
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
|
||||||
|
# upstream-api-keys:
|
||||||
|
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
|
||||||
|
# api-keys: # Client keys that use this upstream key
|
||||||
|
# - "your-api-key-1"
|
||||||
|
# - "your-api-key-2"
|
||||||
|
# - upstream-api-key: "amp_key_for_team_b"
|
||||||
|
# api-keys:
|
||||||
|
# - "your-api-key-3"
|
||||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
|
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
|
||||||
# restrict-management-to-localhost: false
|
# restrict-management-to-localhost: false
|
||||||
# # Force model mappings to run before checking local API keys (default: false)
|
# # Force model mappings to run before checking local API keys (default: false)
|
||||||
@@ -190,12 +203,42 @@ ws-auth: false
|
|||||||
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
|
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
|
||||||
# # but you have a similar model available (e.g., Claude Sonnet 4).
|
# # but you have a similar model available (e.g., Claude Sonnet 4).
|
||||||
# model-mappings:
|
# model-mappings:
|
||||||
# - from: "claude-opus-4.5" # Model requested by Amp CLI
|
# - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
|
||||||
# to: "claude-sonnet-4" # Route to this available model instead
|
# to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
|
||||||
# - from: "gpt-5"
|
# - from: "claude-sonnet-4-5-20250929"
|
||||||
# to: "gemini-2.5-pro"
|
# to: "gemini-claude-sonnet-4-5-thinking"
|
||||||
# - from: "claude-3-opus-20240229"
|
# - from: "claude-haiku-4-5-20251001"
|
||||||
# to: "claude-3-5-sonnet-20241022"
|
# to: "gemini-2.5-flash"
|
||||||
|
|
||||||
|
# Global OAuth model name mappings (per channel)
|
||||||
|
# These mappings rename model IDs for both model listing and request routing.
|
||||||
|
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||||
|
# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||||
|
# oauth-model-mappings:
|
||||||
|
# gemini-cli:
|
||||||
|
# - name: "gemini-2.5-pro" # original model name under this channel
|
||||||
|
# alias: "g2.5p" # client-visible alias
|
||||||
|
# vertex:
|
||||||
|
# - name: "gemini-2.5-pro"
|
||||||
|
# alias: "g2.5p"
|
||||||
|
# aistudio:
|
||||||
|
# - name: "gemini-2.5-pro"
|
||||||
|
# alias: "g2.5p"
|
||||||
|
# antigravity:
|
||||||
|
# - name: "gemini-3-pro-preview"
|
||||||
|
# alias: "g3p"
|
||||||
|
# claude:
|
||||||
|
# - name: "claude-sonnet-4-5-20250929"
|
||||||
|
# alias: "cs4.5"
|
||||||
|
# codex:
|
||||||
|
# - name: "gpt-5"
|
||||||
|
# alias: "g5"
|
||||||
|
# qwen:
|
||||||
|
# - name: "qwen3-coder-plus"
|
||||||
|
# alias: "qwen-plus"
|
||||||
|
# iflow:
|
||||||
|
# - name: "glm-4.7"
|
||||||
|
# alias: "glm-god"
|
||||||
|
|
||||||
# OAuth provider excluded models
|
# OAuth provider excluded models
|
||||||
# oauth-excluded-models:
|
# oauth-excluded-models:
|
||||||
|
|||||||
@@ -431,9 +431,46 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
|||||||
log.WithError(err).Warnf("failed to stat auth file %s", path)
|
log.WithError(err).Warnf("failed to stat auth file %s", path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if claims := extractCodexIDTokenClaims(auth); claims != nil {
|
||||||
|
entry["id_token"] = claims
|
||||||
|
}
|
||||||
return entry
|
return entry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
|
||||||
|
if auth == nil || auth.Metadata == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
idTokenRaw, ok := auth.Metadata["id_token"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
idToken := strings.TrimSpace(idTokenRaw)
|
||||||
|
if idToken == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
claims, err := codex.ParseJWTToken(idToken)
|
||||||
|
if err != nil || claims == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := gin.H{}
|
||||||
|
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" {
|
||||||
|
result["chatgpt_account_id"] = v
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
|
||||||
|
result["plan_type"] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
func authEmail(auth *coreauth.Auth) string {
|
func authEmail(auth *coreauth.Auth) string {
|
||||||
if auth == nil {
|
if auth == nil {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -940,3 +940,151 @@ func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
|
|||||||
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
|
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
|
||||||
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
|
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
|
||||||
|
func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
|
||||||
|
if h == nil || h.cfg == nil {
|
||||||
|
c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
|
||||||
|
func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&body); err != nil {
|
||||||
|
c.JSON(400, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Normalize entries: trim whitespace, filter empty
|
||||||
|
normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
|
||||||
|
h.cfg.AmpCode.UpstreamAPIKeys = normalized
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
|
||||||
|
// Matching is done by upstream-api-key value.
|
||||||
|
func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&body); err != nil {
|
||||||
|
c.JSON(400, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
existing := make(map[string]int)
|
||||||
|
for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
|
||||||
|
existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, newEntry := range body.Value {
|
||||||
|
upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
|
||||||
|
if upstreamKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
normalizedEntry := config.AmpUpstreamAPIKeyEntry{
|
||||||
|
UpstreamAPIKey: upstreamKey,
|
||||||
|
APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
|
||||||
|
}
|
||||||
|
if idx, ok := existing[upstreamKey]; ok {
|
||||||
|
h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
|
||||||
|
} else {
|
||||||
|
h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
|
||||||
|
existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
|
||||||
|
// Body must be JSON: {"value": ["<upstream-api-key>", ...]}.
|
||||||
|
// If "value" is an empty array, clears all entries.
|
||||||
|
// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
|
||||||
|
func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Value []string `json:"value"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&body); err != nil {
|
||||||
|
c.JSON(400, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if body.Value == nil {
|
||||||
|
c.JSON(400, gin.H{"error": "missing value"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty array means clear all
|
||||||
|
if len(body.Value) == 0 {
|
||||||
|
h.cfg.AmpCode.UpstreamAPIKeys = nil
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
toRemove := make(map[string]bool)
|
||||||
|
for _, key := range body.Value {
|
||||||
|
trimmed := strings.TrimSpace(key)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
toRemove[trimmed] = true
|
||||||
|
}
|
||||||
|
if len(toRemove) == 0 {
|
||||||
|
c.JSON(400, gin.H{"error": "empty value"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
|
||||||
|
for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
|
||||||
|
if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
|
||||||
|
newEntries = append(newEntries, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.cfg.AmpCode.UpstreamAPIKeys = newEntries
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
|
||||||
|
func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
|
||||||
|
if upstreamKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
apiKeys := normalizeAPIKeysList(entry.APIKeys)
|
||||||
|
out = append(out, config.AmpUpstreamAPIKeyEntry{
|
||||||
|
UpstreamAPIKey: upstreamKey,
|
||||||
|
APIKeys: apiKeys,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeAPIKeysList trims and filters empty strings from a list of API keys.
|
||||||
|
func normalizeAPIKeysList(keys []string) []string {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(keys))
|
||||||
|
for _, k := range keys {
|
||||||
|
trimmed := strings.TrimSpace(k)
|
||||||
|
if trimmed != "" {
|
||||||
|
out = append(out, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
@@ -59,6 +59,11 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewHandler creates a new management handler instance.
|
||||||
|
func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
|
||||||
|
return NewHandler(cfg, "", manager)
|
||||||
|
}
|
||||||
|
|
||||||
// SetConfig updates the in-memory config reference when the server hot-reloads.
|
// SetConfig updates the in-memory config reference when the server hot-reloads.
|
||||||
func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg }
|
func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg }
|
||||||
|
|
||||||
|
|||||||
@@ -227,11 +227,20 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check API key change
|
// Check API key change (both default and per-client mappings)
|
||||||
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
|
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
|
||||||
if apiKeyChanged {
|
upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
|
||||||
|
if apiKeyChanged || upstreamAPIKeysChanged {
|
||||||
if m.secretSource != nil {
|
if m.secretSource != nil {
|
||||||
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
||||||
|
if apiKeyChanged {
|
||||||
|
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
|
||||||
|
ms.InvalidateCache()
|
||||||
|
}
|
||||||
|
if upstreamAPIKeysChanged {
|
||||||
|
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
|
||||||
|
}
|
||||||
|
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||||
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
|
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
|
||||||
ms.InvalidateCache()
|
ms.InvalidateCache()
|
||||||
}
|
}
|
||||||
@@ -251,10 +260,22 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
|||||||
|
|
||||||
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
|
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
|
||||||
if m.secretSource == nil {
|
if m.secretSource == nil {
|
||||||
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
// Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
|
||||||
|
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
||||||
|
mappedSource := NewMappedSecretSource(defaultSource)
|
||||||
|
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
||||||
|
m.secretSource = mappedSource
|
||||||
|
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
||||||
|
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
|
||||||
|
ms.InvalidateCache()
|
||||||
|
ms.UpdateMappings(settings.UpstreamAPIKeys)
|
||||||
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||||
|
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
|
||||||
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
|
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
|
||||||
ms.InvalidateCache()
|
ms.InvalidateCache()
|
||||||
|
mappedSource := NewMappedSecretSource(ms)
|
||||||
|
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
||||||
|
m.secretSource = mappedSource
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
||||||
@@ -313,6 +334,66 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b
|
|||||||
return oldKey != newKey
|
return oldKey != newKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
|
||||||
|
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
|
||||||
|
if old == nil {
|
||||||
|
return len(new.UpstreamAPIKeys) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build map for comparison: upstreamKey -> set of clientKeys
|
||||||
|
type entryInfo struct {
|
||||||
|
upstreamKey string
|
||||||
|
clientKeys map[string]struct{}
|
||||||
|
}
|
||||||
|
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
|
||||||
|
for i, entry := range old.UpstreamAPIKeys {
|
||||||
|
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
|
||||||
|
for _, k := range entry.APIKeys {
|
||||||
|
trimmed := strings.TrimSpace(k)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
clientKeys[trimmed] = struct{}{}
|
||||||
|
}
|
||||||
|
oldEntries[i] = entryInfo{
|
||||||
|
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
|
||||||
|
clientKeys: clientKeys,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, newEntry := range new.UpstreamAPIKeys {
|
||||||
|
if i >= len(oldEntries) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
oldE := oldEntries[i]
|
||||||
|
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
|
||||||
|
for _, k := range newEntry.APIKeys {
|
||||||
|
trimmed := strings.TrimSpace(k)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newKeys[trimmed] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(newKeys) != len(oldE.clientKeys) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for k := range newKeys {
|
||||||
|
if _, ok := oldE.clientKeys[k]; !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// GetModelMapper returns the model mapper instance (for testing/debugging).
|
// GetModelMapper returns the model mapper instance (for testing/debugging).
|
||||||
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
||||||
return m.modelMapper
|
return m.modelMapper
|
||||||
|
|||||||
@@ -312,3 +312,41 @@ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
|
||||||
|
m := &AmpModule{}
|
||||||
|
|
||||||
|
oldCfg := &config.AmpCode{
|
||||||
|
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
newCfg := &config.AmpCode{
|
||||||
|
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
||||||
|
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
|
||||||
|
m := &AmpModule{}
|
||||||
|
|
||||||
|
oldCfg := &config.AmpCode{
|
||||||
|
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
newCfg := &config.AmpCode{
|
||||||
|
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
||||||
|
t.Fatal("expected no change when only whitespace/empty entries differ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,6 +18,33 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func removeQueryValuesMatching(req *http.Request, key string, match string) {
|
||||||
|
if req == nil || req.URL == nil || match == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
q := req.URL.Query()
|
||||||
|
values, ok := q[key]
|
||||||
|
if !ok || len(values) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
kept := make([]string, 0, len(values))
|
||||||
|
for _, v := range values {
|
||||||
|
if v == match {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(kept) == 0 {
|
||||||
|
q.Del(key)
|
||||||
|
} else {
|
||||||
|
q[key] = kept
|
||||||
|
}
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
// readCloser wraps a reader and forwards Close to a separate closer.
|
// readCloser wraps a reader and forwards Close to a separate closer.
|
||||||
// Used to restore peeked bytes while preserving upstream body Close behavior.
|
// Used to restore peeked bytes while preserving upstream body Close behavior.
|
||||||
type readCloser struct {
|
type readCloser struct {
|
||||||
@@ -48,6 +75,14 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
|||||||
// We will set our own Authorization using the configured upstream-api-key
|
// We will set our own Authorization using the configured upstream-api-key
|
||||||
req.Header.Del("Authorization")
|
req.Header.Del("Authorization")
|
||||||
req.Header.Del("X-Api-Key")
|
req.Header.Del("X-Api-Key")
|
||||||
|
req.Header.Del("X-Goog-Api-Key")
|
||||||
|
|
||||||
|
// Remove query-based credentials if they match the authenticated client API key.
|
||||||
|
// This prevents leaking client auth material to the Amp upstream while avoiding
|
||||||
|
// breaking unrelated upstream query parameters.
|
||||||
|
clientKey := getClientAPIKeyFromContext(req.Context())
|
||||||
|
removeQueryValuesMatching(req, "key", clientKey)
|
||||||
|
removeQueryValuesMatching(req, "auth_token", clientKey)
|
||||||
|
|
||||||
// Preserve correlation headers for debugging
|
// Preserve correlation headers for debugging
|
||||||
if req.Header.Get("X-Request-ID") == "" {
|
if req.Header.Get("X-Request-ID") == "" {
|
||||||
|
|||||||
@@ -3,11 +3,15 @@ package amp
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Helper: compress data with gzip
|
// Helper: compress data with gzip
|
||||||
@@ -306,6 +310,159 @@ func TestReverseProxy_EmptySecret(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
|
||||||
|
type captured struct {
|
||||||
|
headers http.Header
|
||||||
|
query string
|
||||||
|
}
|
||||||
|
got := make(chan captured, 1)
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`ok`))
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Simulate clientAPIKeyMiddleware injection (per-request)
|
||||||
|
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
|
||||||
|
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer client-key")
|
||||||
|
req.Header.Set("X-Api-Key", "client-key")
|
||||||
|
req.Header.Set("X-Goog-Api-Key", "client-key")
|
||||||
|
|
||||||
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
|
||||||
|
c := <-got
|
||||||
|
|
||||||
|
// These are client-provided credentials and must not reach the upstream.
|
||||||
|
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
|
||||||
|
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
|
||||||
|
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
|
||||||
|
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
|
||||||
|
}
|
||||||
|
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
|
||||||
|
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query-based credentials should be stripped only when they match the authenticated client key.
|
||||||
|
// Should keep unrelated values and parameters.
|
||||||
|
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
|
||||||
|
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
|
||||||
|
}
|
||||||
|
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
|
||||||
|
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
|
||||||
|
gotHeaders := make(chan http.Header, 1)
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotHeaders <- r.Header.Clone()
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`ok`))
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
defaultSource := NewStaticSecretSource("default")
|
||||||
|
mapped := NewMappedSecretSource(defaultSource)
|
||||||
|
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u1",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy, err := createReverseProxy(upstream.URL, mapped)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Simulate clientAPIKeyMiddleware injection (per-request)
|
||||||
|
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
|
||||||
|
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
res, err := http.Get(srv.URL + "/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
|
||||||
|
hdr := <-gotHeaders
|
||||||
|
if hdr.Get("X-Api-Key") != "u1" {
|
||||||
|
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||||
|
}
|
||||||
|
if hdr.Get("Authorization") != "Bearer u1" {
|
||||||
|
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
|
||||||
|
gotHeaders := make(chan http.Header, 1)
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotHeaders <- r.Header.Clone()
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`ok`))
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
defaultSource := NewStaticSecretSource("default")
|
||||||
|
mapped := NewMappedSecretSource(defaultSource)
|
||||||
|
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u1",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy, err := createReverseProxy(upstream.URL, mapped)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
|
||||||
|
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
res, err := http.Get(srv.URL + "/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
|
||||||
|
hdr := <-gotHeaders
|
||||||
|
if hdr.Get("X-Api-Key") != "default" {
|
||||||
|
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||||
|
}
|
||||||
|
if hdr.Get("Authorization") != "Bearer default" {
|
||||||
|
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxy_ErrorHandler(t *testing.T) {
|
func TestReverseProxy_ErrorHandler(t *testing.T) {
|
||||||
// Point proxy to a non-routable address to trigger error
|
// Point proxy to a non-routable address to trigger error
|
||||||
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
|
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package amp
|
package amp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -16,6 +17,37 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// clientAPIKeyContextKey is the context key used to pass the client API key
|
||||||
|
// from gin.Context to the request context for SecretSource lookup.
|
||||||
|
type clientAPIKeyContextKey struct{}
|
||||||
|
|
||||||
|
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
|
||||||
|
// into the request context so that SecretSource can look it up for per-client upstream routing.
|
||||||
|
func clientAPIKeyMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// Extract the client API key from gin context (set by AuthMiddleware)
|
||||||
|
if apiKey, exists := c.Get("apiKey"); exists {
|
||||||
|
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
|
||||||
|
// Inject into request context for SecretSource.Get(ctx) to read
|
||||||
|
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getClientAPIKeyFromContext retrieves the client API key from request context.
|
||||||
|
// Returns empty string if not present.
|
||||||
|
func getClientAPIKeyFromContext(ctx context.Context) string {
|
||||||
|
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
|
||||||
|
if keyStr, ok := val.(string); ok {
|
||||||
|
return keyStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
|
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
|
||||||
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
|
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
|
||||||
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
|
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
|
||||||
@@ -129,6 +161,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
|
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Inject client API key into request context for per-client upstream routing
|
||||||
|
ampAPI.Use(clientAPIKeyMiddleware())
|
||||||
|
|
||||||
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||||
proxyHandler := func(c *gin.Context) {
|
proxyHandler := func(c *gin.Context) {
|
||||||
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
||||||
@@ -175,6 +210,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
if authWithBypass != nil {
|
if authWithBypass != nil {
|
||||||
rootMiddleware = append(rootMiddleware, authWithBypass)
|
rootMiddleware = append(rootMiddleware, authWithBypass)
|
||||||
}
|
}
|
||||||
|
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
|
||||||
|
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
|
||||||
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
|
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
|
||||||
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
||||||
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
|
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
|
||||||
@@ -244,6 +281,8 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
|||||||
if auth != nil {
|
if auth != nil {
|
||||||
ampProviders.Use(auth)
|
ampProviders.Use(auth)
|
||||||
}
|
}
|
||||||
|
// Inject client API key into request context for per-client upstream routing
|
||||||
|
ampProviders.Use(clientAPIKeyMiddleware())
|
||||||
|
|
||||||
provider := ampProviders.Group("/:provider")
|
provider := ampProviders.Group("/:provider")
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SecretSource provides Amp API keys with configurable precedence and caching
|
// SecretSource provides Amp API keys with configurable precedence and caching
|
||||||
@@ -164,3 +167,82 @@ func NewStaticSecretSource(key string) *StaticSecretSource {
|
|||||||
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
|
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
|
||||||
return s.key, nil
|
return s.key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
|
||||||
|
// When a request context contains a client API key that matches a configured mapping,
|
||||||
|
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
|
||||||
|
type MappedSecretSource struct {
|
||||||
|
defaultSource SecretSource
|
||||||
|
mu sync.RWMutex
|
||||||
|
lookup map[string]string // clientKey -> upstreamKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
|
||||||
|
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
|
||||||
|
return &MappedSecretSource{
|
||||||
|
defaultSource: defaultSource,
|
||||||
|
lookup: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves the Amp API key, checking per-client mappings first.
|
||||||
|
// If the request context contains a client API key that matches a configured mapping,
|
||||||
|
// returns the corresponding upstream key. Otherwise, falls back to the default source.
|
||||||
|
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
|
||||||
|
// Try to get client API key from request context
|
||||||
|
clientKey := getClientAPIKeyFromContext(ctx)
|
||||||
|
if clientKey != "" {
|
||||||
|
s.mu.RLock()
|
||||||
|
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return upstreamKey, nil
|
||||||
|
}
|
||||||
|
s.mu.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to default source
|
||||||
|
return s.defaultSource.Get(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
|
||||||
|
// If the same client key appears in multiple entries, logs a warning and uses the first one.
|
||||||
|
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
|
||||||
|
newLookup := make(map[string]string)
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
|
||||||
|
if upstreamKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, clientKey := range entry.APIKeys {
|
||||||
|
trimmedKey := strings.TrimSpace(clientKey)
|
||||||
|
if trimmedKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := newLookup[trimmedKey]; exists {
|
||||||
|
// Log warning for duplicate client key, first one wins
|
||||||
|
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newLookup[trimmedKey] = upstreamKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.lookup = newLookup
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
|
||||||
|
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
|
||||||
|
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
||||||
|
ms.UpdateExplicitKey(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
|
||||||
|
func (s *MappedSecretSource) InvalidateCache() {
|
||||||
|
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
||||||
|
ms.InvalidateCache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,6 +8,10 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/sirupsen/logrus/hooks/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
|
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
|
||||||
@@ -278,3 +282,85 @@ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
|
|||||||
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
|
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
|
||||||
|
defaultSource := NewStaticSecretSource("default")
|
||||||
|
s := NewMappedSecretSource(defaultSource)
|
||||||
|
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u1",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
||||||
|
got, err := s.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != "u1" {
|
||||||
|
t.Fatalf("want u1, got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
|
||||||
|
got, err = s.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != "default" {
|
||||||
|
t.Fatalf("want default fallback, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
|
||||||
|
defaultSource := NewStaticSecretSource("default")
|
||||||
|
s := NewMappedSecretSource(defaultSource)
|
||||||
|
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u1",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u2",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
||||||
|
got, err := s.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != "u1" {
|
||||||
|
t.Fatalf("want u1 (first wins), got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
|
||||||
|
hook := test.NewLocal(log.StandardLogger())
|
||||||
|
defer hook.Reset()
|
||||||
|
|
||||||
|
defaultSource := NewStaticSecretSource("default")
|
||||||
|
s := NewMappedSecretSource(defaultSource)
|
||||||
|
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u1",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u2",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
foundWarning := false
|
||||||
|
for _, entry := range hook.AllEntries() {
|
||||||
|
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
|
||||||
|
foundWarning = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundWarning {
|
||||||
|
t.Fatal("expected warning log for duplicate client key, but none was found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -571,6 +571,10 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
|
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
|
||||||
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||||
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||||
|
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
|
||||||
|
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
|
||||||
|
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
|
||||||
|
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
|
||||||
|
|
||||||
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
||||||
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
||||||
@@ -873,7 +877,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
|
if err := logging.ConfigureLogOutput(cfg); err != nil {
|
||||||
log.Errorf("failed to reconfigure log output: %v", err)
|
log.Errorf("failed to reconfigure log output: %v", err)
|
||||||
} else {
|
} else {
|
||||||
if oldCfg == nil {
|
if oldCfg == nil {
|
||||||
|
|||||||
@@ -98,6 +98,14 @@ type Config struct {
|
|||||||
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
||||||
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
||||||
|
|
||||||
|
// OAuthModelMappings defines global model name mappings for OAuth/file-backed auth channels.
|
||||||
|
// These mappings affect both model listing and model routing for supported channels:
|
||||||
|
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||||
|
//
|
||||||
|
// NOTE: This does not apply to existing per-credential model alias features under:
|
||||||
|
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
||||||
|
OAuthModelMappings map[string][]ModelNameMapping `yaml:"oauth-model-mappings,omitempty" json:"oauth-model-mappings,omitempty"`
|
||||||
|
|
||||||
// Payload defines default and override rules for provider payload parameters.
|
// Payload defines default and override rules for provider payload parameters.
|
||||||
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
||||||
|
|
||||||
@@ -149,6 +157,13 @@ type RoutingConfig struct {
|
|||||||
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModelNameMapping defines a model ID rename mapping for a specific channel.
|
||||||
|
// It maps the original model name (Name) to the client-visible alias (Alias).
|
||||||
|
type ModelNameMapping struct {
|
||||||
|
Name string `yaml:"name" json:"name"`
|
||||||
|
Alias string `yaml:"alias" json:"alias"`
|
||||||
|
}
|
||||||
|
|
||||||
// AmpModelMapping defines a model name mapping for Amp CLI requests.
|
// AmpModelMapping defines a model name mapping for Amp CLI requests.
|
||||||
// When Amp requests a model that isn't available locally, this mapping
|
// When Amp requests a model that isn't available locally, this mapping
|
||||||
// allows routing to an alternative model that IS available.
|
// allows routing to an alternative model that IS available.
|
||||||
@@ -175,6 +190,11 @@ type AmpCode struct {
|
|||||||
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
|
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
|
||||||
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||||
|
|
||||||
|
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
|
||||||
|
// When a client authenticates with a key that matches an entry, that upstream key is used.
|
||||||
|
// If no match is found, falls back to UpstreamAPIKey (default behavior).
|
||||||
|
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
|
||||||
|
|
||||||
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
||||||
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
|
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
|
||||||
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
|
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
|
||||||
@@ -190,6 +210,17 @@ type AmpCode struct {
|
|||||||
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
|
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key.
|
||||||
|
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
|
||||||
|
// is used for the upstream Amp request.
|
||||||
|
type AmpUpstreamAPIKeyEntry struct {
|
||||||
|
// UpstreamAPIKey is the API key to use when proxying to the Amp upstream.
|
||||||
|
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||||
|
|
||||||
|
// APIKeys are the client API keys (from top-level api-keys) that map to this upstream key.
|
||||||
|
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||||
|
}
|
||||||
|
|
||||||
// PayloadConfig defines default and override parameter rules applied to provider payloads.
|
// PayloadConfig defines default and override parameter rules applied to provider payloads.
|
||||||
type PayloadConfig struct {
|
type PayloadConfig struct {
|
||||||
// Default defines rules that only set parameters when they are missing in the payload.
|
// Default defines rules that only set parameters when they are missing in the payload.
|
||||||
@@ -490,6 +521,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
// Normalize OAuth provider model exclusion map.
|
// Normalize OAuth provider model exclusion map.
|
||||||
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
|
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
|
||||||
|
|
||||||
|
// Normalize global OAuth model name mappings.
|
||||||
|
cfg.SanitizeOAuthModelMappings()
|
||||||
|
|
||||||
if cfg.legacyMigrationPending {
|
if cfg.legacyMigrationPending {
|
||||||
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||||
if !optional && configFile != "" {
|
if !optional && configFile != "" {
|
||||||
@@ -506,6 +540,50 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SanitizeOAuthModelMappings normalizes and deduplicates global OAuth model name mappings.
|
||||||
|
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||||
|
// and ensures (From, To) pairs are unique within each channel.
|
||||||
|
func (cfg *Config) SanitizeOAuthModelMappings() {
|
||||||
|
if cfg == nil || len(cfg.OAuthModelMappings) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out := make(map[string][]ModelNameMapping, len(cfg.OAuthModelMappings))
|
||||||
|
for rawChannel, mappings := range cfg.OAuthModelMappings {
|
||||||
|
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||||
|
if channel == "" || len(mappings) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenName := make(map[string]struct{}, len(mappings))
|
||||||
|
seenAlias := make(map[string]struct{}, len(mappings))
|
||||||
|
clean := make([]ModelNameMapping, 0, len(mappings))
|
||||||
|
for _, mapping := range mappings {
|
||||||
|
name := strings.TrimSpace(mapping.Name)
|
||||||
|
alias := strings.TrimSpace(mapping.Alias)
|
||||||
|
if name == "" || alias == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.EqualFold(name, alias) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nameKey := strings.ToLower(name)
|
||||||
|
aliasKey := strings.ToLower(alias)
|
||||||
|
if _, ok := seenName[nameKey]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seenAlias[aliasKey]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenName[nameKey] = struct{}{}
|
||||||
|
seenAlias[aliasKey] = struct{}{}
|
||||||
|
clean = append(clean, ModelNameMapping{Name: name, Alias: alias})
|
||||||
|
}
|
||||||
|
if len(clean) > 0 {
|
||||||
|
out[channel] = clean
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfg.OAuthModelMappings = out
|
||||||
|
}
|
||||||
|
|
||||||
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
|
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
|
||||||
// not actionable, specifically those missing a BaseURL. It trims whitespace before
|
// not actionable, specifically those missing a BaseURL. It trims whitespace before
|
||||||
// evaluation and preserves the relative order of remaining entries.
|
// evaluation and preserves the relative order of remaining entries.
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gopkg.in/natefinch/lumberjack.v2"
|
"gopkg.in/natefinch/lumberjack.v2"
|
||||||
@@ -84,10 +85,30 @@ func SetupBaseLogger() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isDirWritable checks if the specified directory exists and is writable by attempting to create and remove a test file.
|
||||||
|
func isDirWritable(dir string) bool {
|
||||||
|
info, err := os.Stat(dir)
|
||||||
|
if err != nil || !info.IsDir() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
testFile := filepath.Join(dir, ".perm_test")
|
||||||
|
f, err := os.Create(testFile)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = f.Close()
|
||||||
|
_ = os.Remove(testFile)
|
||||||
|
}()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
|
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
|
||||||
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
|
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
|
||||||
// until the total size is within the limit.
|
// until the total size is within the limit.
|
||||||
func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
|
func ConfigureLogOutput(cfg *config.Config) error {
|
||||||
SetupBaseLogger()
|
SetupBaseLogger()
|
||||||
|
|
||||||
writerMu.Lock()
|
writerMu.Lock()
|
||||||
@@ -96,10 +117,12 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
|
|||||||
logDir := "logs"
|
logDir := "logs"
|
||||||
if base := util.WritablePath(); base != "" {
|
if base := util.WritablePath(); base != "" {
|
||||||
logDir = filepath.Join(base, "logs")
|
logDir = filepath.Join(base, "logs")
|
||||||
|
} else if !isDirWritable(logDir) {
|
||||||
|
logDir = filepath.Join(cfg.AuthDir, "logs")
|
||||||
}
|
}
|
||||||
|
|
||||||
protectedPath := ""
|
protectedPath := ""
|
||||||
if loggingToFile {
|
if cfg.LoggingToFile {
|
||||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||||
return fmt.Errorf("logging: failed to create log directory: %w", err)
|
return fmt.Errorf("logging: failed to create log directory: %w", err)
|
||||||
}
|
}
|
||||||
@@ -123,7 +146,7 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
|
|||||||
log.SetOutput(os.Stdout)
|
log.SetOutput(os.Stdout)
|
||||||
}
|
}
|
||||||
|
|
||||||
configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath)
|
configureLogDirCleanerLocked(logDir, cfg.LogsMaxTotalSizeMB, protectedPath)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -76,7 +76,12 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
|
|||||||
|
|
||||||
// Execute performs a non-streaming request to the Antigravity API.
|
// Execute performs a non-streaming request to the Antigravity API.
|
||||||
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
if strings.Contains(req.Model, "claude") {
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||||
|
if upstreamModel == "" {
|
||||||
|
upstreamModel = req.Model
|
||||||
|
}
|
||||||
|
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
|
||||||
|
if isClaude {
|
||||||
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,7 +103,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
@@ -109,7 +114,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
for idx, baseURL := range baseURLs {
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL)
|
httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, false, opts.Alt, baseURL)
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
err = errReq
|
err = errReq
|
||||||
return resp, err
|
return resp, err
|
||||||
@@ -190,10 +195,15 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
|||||||
to := sdktranslator.FromString("antigravity")
|
to := sdktranslator.FromString("antigravity")
|
||||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||||
|
if upstreamModel == "" {
|
||||||
|
upstreamModel = req.Model
|
||||||
|
}
|
||||||
|
|
||||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
translated = normalizeAntigravityThinking(req.Model, translated, true)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
@@ -204,7 +214,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
|||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
for idx, baseURL := range baseURLs {
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
|
httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL)
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
err = errReq
|
err = errReq
|
||||||
return resp, err
|
return resp, err
|
||||||
@@ -524,10 +534,16 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
to := sdktranslator.FromString("antigravity")
|
to := sdktranslator.FromString("antigravity")
|
||||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||||
|
if upstreamModel == "" {
|
||||||
|
upstreamModel = req.Model
|
||||||
|
}
|
||||||
|
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
|
||||||
|
|
||||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
@@ -538,7 +554,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
for idx, baseURL := range baseURLs {
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
|
httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL)
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
err = errReq
|
err = errReq
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -676,6 +692,12 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
to := sdktranslator.FromString("antigravity")
|
to := sdktranslator.FromString("antigravity")
|
||||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||||
|
|
||||||
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||||
|
if upstreamModel == "" {
|
||||||
|
upstreamModel = req.Model
|
||||||
|
}
|
||||||
|
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
@@ -694,7 +716,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||||
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
|
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
|
||||||
payload = normalizeAntigravityThinking(req.Model, payload)
|
payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
|
||||||
payload = deleteJSONField(payload, "project")
|
payload = deleteJSONField(payload, "project")
|
||||||
payload = deleteJSONField(payload, "model")
|
payload = deleteJSONField(payload, "model")
|
||||||
payload = deleteJSONField(payload, "request.safetySettings")
|
payload = deleteJSONField(payload, "request.safetySettings")
|
||||||
@@ -1308,7 +1330,7 @@ func alias2ModelName(modelName string) string {
|
|||||||
|
|
||||||
// normalizeAntigravityThinking clamps or removes thinking config based on model support.
|
// normalizeAntigravityThinking clamps or removes thinking config based on model support.
|
||||||
// For Claude models, it additionally ensures thinking budget < max_tokens.
|
// For Claude models, it additionally ensures thinking budget < max_tokens.
|
||||||
func normalizeAntigravityThinking(model string, payload []byte) []byte {
|
func normalizeAntigravityThinking(model string, payload []byte, isClaude bool) []byte {
|
||||||
payload = util.StripThinkingConfigIfUnsupported(model, payload)
|
payload = util.StripThinkingConfigIfUnsupported(model, payload)
|
||||||
if !util.ModelSupportsThinking(model) {
|
if !util.ModelSupportsThinking(model) {
|
||||||
return payload
|
return payload
|
||||||
@@ -1320,7 +1342,6 @@ func normalizeAntigravityThinking(model string, payload []byte) []byte {
|
|||||||
raw := int(budget.Int())
|
raw := int(budget.Int())
|
||||||
normalized := util.NormalizeThinkingBudget(model, raw)
|
normalized := util.NormalizeThinkingBudget(model, raw)
|
||||||
|
|
||||||
isClaude := strings.Contains(strings.ToLower(model), "claude")
|
|
||||||
if isClaude {
|
if isClaude {
|
||||||
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
|
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
|
||||||
if effectiveMax > 0 && normalized >= effectiveMax {
|
if effectiveMax > 0 && normalized >= effectiveMax {
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction)
|
out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction)
|
||||||
}
|
}
|
||||||
} else if systemResult.Type == gjson.String {
|
} else if systemResult.Type == gjson.String {
|
||||||
out, _ = sjson.Set(out, "request.system_instruction.parts.-1.text", systemResult.String())
|
out, _ = sjson.Set(out, "system_instruction.parts.-1.text", systemResult.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// contents
|
// contents
|
||||||
|
|||||||
@@ -344,7 +344,7 @@ func cleanupRequiredFields(jsonStr string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas.
|
// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas.
|
||||||
// Claude VALIDATED mode requires at least one property in tool schemas.
|
// Claude VALIDATED mode requires at least one required property in tool schemas.
|
||||||
func addEmptySchemaPlaceholder(jsonStr string) string {
|
func addEmptySchemaPlaceholder(jsonStr string) string {
|
||||||
// Find all "type" fields
|
// Find all "type" fields
|
||||||
paths := findPaths(jsonStr, "type")
|
paths := findPaths(jsonStr, "type")
|
||||||
@@ -364,6 +364,9 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
|
|||||||
// Check if properties exists and is empty or missing
|
// Check if properties exists and is empty or missing
|
||||||
propsPath := joinPath(parentPath, "properties")
|
propsPath := joinPath(parentPath, "properties")
|
||||||
propsVal := gjson.Get(jsonStr, propsPath)
|
propsVal := gjson.Get(jsonStr, propsPath)
|
||||||
|
reqPath := joinPath(parentPath, "required")
|
||||||
|
reqVal := gjson.Get(jsonStr, reqPath)
|
||||||
|
hasRequiredProperties := reqVal.IsArray() && len(reqVal.Array()) > 0
|
||||||
|
|
||||||
needsPlaceholder := false
|
needsPlaceholder := false
|
||||||
if !propsVal.Exists() {
|
if !propsVal.Exists() {
|
||||||
@@ -381,8 +384,17 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
|
|||||||
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
|
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
|
||||||
|
|
||||||
// Add to required array
|
// Add to required array
|
||||||
reqPath := joinPath(parentPath, "required")
|
|
||||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
|
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If schema has properties but none are required, add a minimal placeholder.
|
||||||
|
if propsVal.IsObject() && !hasRequiredProperties {
|
||||||
|
placeholderPath := joinPath(propsPath, "_")
|
||||||
|
if !gjson.Get(jsonStr, placeholderPath).Exists() {
|
||||||
|
jsonStr, _ = sjson.Set(jsonStr, placeholderPath+".type", "boolean")
|
||||||
|
}
|
||||||
|
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"_"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -614,71 +614,6 @@ func TestCleanJSONSchemaForAntigravity_MultipleNonNullTypes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) {
|
|
||||||
// propertyNames is used to validate object property names (e.g., must match a pattern)
|
|
||||||
// Gemini doesn't support this keyword and will reject requests containing it
|
|
||||||
input := `{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"metadata": {
|
|
||||||
"type": "object",
|
|
||||||
"propertyNames": {
|
|
||||||
"pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$"
|
|
||||||
},
|
|
||||||
"additionalProperties": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
|
|
||||||
expected := `{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"metadata": {
|
|
||||||
"type": "object"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
|
|
||||||
result := CleanJSONSchemaForGemini(input)
|
|
||||||
compareJSON(t, expected, result)
|
|
||||||
|
|
||||||
// Verify propertyNames is completely removed
|
|
||||||
if strings.Contains(result, "propertyNames") {
|
|
||||||
t.Errorf("propertyNames keyword should be removed, got: %s", result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) {
|
|
||||||
// Test deeply nested propertyNames (as seen in real Claude tool schemas)
|
|
||||||
input := `{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"items": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"config": {
|
|
||||||
"type": "object",
|
|
||||||
"propertyNames": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
|
|
||||||
result := CleanJSONSchemaForGemini(input)
|
|
||||||
|
|
||||||
if strings.Contains(result, "propertyNames") {
|
|
||||||
t.Errorf("Nested propertyNames should be removed, got: %s", result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
|
func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
|
||||||
var expMap, actMap map[string]interface{}
|
var expMap, actMap map[string]interface{}
|
||||||
errExp := json.Unmarshal([]byte(expectedJSON), &expMap)
|
errExp := json.Unmarshal([]byte(expectedJSON), &expMap)
|
||||||
|
|||||||
@@ -7,10 +7,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ThinkingBudgetMetadataKey = "thinking_budget"
|
ThinkingBudgetMetadataKey = "thinking_budget"
|
||||||
ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts"
|
ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts"
|
||||||
ReasoningEffortMetadataKey = "reasoning_effort"
|
ReasoningEffortMetadataKey = "reasoning_effort"
|
||||||
ThinkingOriginalModelMetadataKey = "thinking_original_model"
|
ThinkingOriginalModelMetadataKey = "thinking_original_model"
|
||||||
|
ModelMappingOriginalModelMetadataKey = "model_mapping_original_model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NormalizeThinkingModel parses dynamic thinking suffixes on model names and returns
|
// NormalizeThinkingModel parses dynamic thinking suffixes on model names and returns
|
||||||
@@ -215,6 +216,13 @@ func ResolveOriginalModel(model string, metadata map[string]any) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if metadata != nil {
|
if metadata != nil {
|
||||||
|
if v, ok := metadata[ModelMappingOriginalModelMetadataKey]; ok {
|
||||||
|
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
|
||||||
|
if base := normalize(s); base != "" {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if v, ok := metadata[ThinkingOriginalModelMetadataKey]; ok {
|
if v, ok := metadata[ThinkingOriginalModelMetadataKey]; ok {
|
||||||
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
|
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
|
||||||
if base := normalize(s); base != "" {
|
if base := normalize(s); base != "" {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -126,7 +127,7 @@ func (w *Watcher) reloadConfig() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
|
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
|
||||||
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix
|
forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelMappings, newConfig.OAuthModelMappings))
|
||||||
|
|
||||||
log.Infof("config successfully reloaded, triggering client reload")
|
log.Infof("config successfully reloaded, triggering client reload")
|
||||||
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
|
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
|
||||||
|
|||||||
@@ -185,6 +185,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
|||||||
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
|
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
|
||||||
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
|
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
|
||||||
}
|
}
|
||||||
|
oldUpstreamAPIKeysCount := len(oldCfg.AmpCode.UpstreamAPIKeys)
|
||||||
|
newUpstreamAPIKeysCount := len(newCfg.AmpCode.UpstreamAPIKeys)
|
||||||
|
if !equalUpstreamAPIKeys(oldCfg.AmpCode.UpstreamAPIKeys, newCfg.AmpCode.UpstreamAPIKeys) {
|
||||||
|
changes = append(changes, fmt.Sprintf("ampcode.upstream-api-keys: updated (%d -> %d entries)", oldUpstreamAPIKeysCount, newUpstreamAPIKeysCount))
|
||||||
|
}
|
||||||
|
|
||||||
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
|
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
|
||||||
changes = append(changes, entries...)
|
changes = append(changes, entries...)
|
||||||
@@ -301,3 +306,43 @@ func formatProxyURL(raw string) string {
|
|||||||
}
|
}
|
||||||
return scheme + "://" + host
|
return scheme + "://" + host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func equalStringSet(a, b []string) bool {
|
||||||
|
if len(a) == 0 && len(b) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
aSet := make(map[string]struct{}, len(a))
|
||||||
|
for _, k := range a {
|
||||||
|
aSet[strings.TrimSpace(k)] = struct{}{}
|
||||||
|
}
|
||||||
|
bSet := make(map[string]struct{}, len(b))
|
||||||
|
for _, k := range b {
|
||||||
|
bSet[strings.TrimSpace(k)] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(aSet) != len(bSet) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for k := range aSet {
|
||||||
|
if _, ok := bSet[k]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// equalUpstreamAPIKeys compares two slices of AmpUpstreamAPIKeyEntry for equality.
|
||||||
|
// Comparison is done by count and content (upstream key and client keys).
|
||||||
|
func equalUpstreamAPIKeys(a, b []config.AmpUpstreamAPIKeyEntry) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if strings.TrimSpace(a[i].UpstreamAPIKey) != strings.TrimSpace(b[i].UpstreamAPIKey) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !equalStringSet(a[i].APIKeys, b[i].APIKeys) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -619,7 +619,22 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
body := BuildErrorResponseBody(status, errText)
|
body := BuildErrorResponseBody(status, errText)
|
||||||
c.Set("API_RESPONSE", bytes.Clone(body))
|
// Append first to preserve upstream response logs, then drop duplicate payloads if already recorded.
|
||||||
|
var previous []byte
|
||||||
|
if existing, exists := c.Get("API_RESPONSE"); exists {
|
||||||
|
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
|
||||||
|
previous = bytes.Clone(existingBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
appendAPIResponse(c, body)
|
||||||
|
trimmedErrText := strings.TrimSpace(errText)
|
||||||
|
trimmedBody := bytes.TrimSpace(body)
|
||||||
|
if len(previous) > 0 {
|
||||||
|
if (trimmedErrText != "" && bytes.Contains(previous, []byte(trimmedErrText))) ||
|
||||||
|
(len(trimmedBody) > 0 && bytes.Contains(previous, trimmedBody)) {
|
||||||
|
c.Set("API_RESPONSE", previous)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !c.Writer.Written() {
|
if !c.Writer.Written() {
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
|||||||
62
sdk/api/management.go
Normal file
62
sdk/api/management.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
// Package api exposes helpers for embedding CLIProxyAPI.
|
||||||
|
//
|
||||||
|
// It wraps internal management handler types so external projects can integrate
|
||||||
|
// management endpoints without importing internal packages.
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
internalmanagement "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ManagementTokenRequester exposes a limited subset of management endpoints for requesting tokens.
|
||||||
|
type ManagementTokenRequester interface {
|
||||||
|
RequestAnthropicToken(*gin.Context)
|
||||||
|
RequestGeminiCLIToken(*gin.Context)
|
||||||
|
RequestCodexToken(*gin.Context)
|
||||||
|
RequestAntigravityToken(*gin.Context)
|
||||||
|
RequestQwenToken(*gin.Context)
|
||||||
|
RequestIFlowToken(*gin.Context)
|
||||||
|
RequestIFlowCookieToken(*gin.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
type managementTokenRequester struct {
|
||||||
|
handler *internalmanagement.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManagementTokenRequester creates a limited management handler exposing only token request endpoints.
|
||||||
|
func NewManagementTokenRequester(cfg *config.Config, manager *coreauth.Manager) ManagementTokenRequester {
|
||||||
|
return &managementTokenRequester{
|
||||||
|
handler: internalmanagement.NewHandlerWithoutConfigFilePath(cfg, manager),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managementTokenRequester) RequestAnthropicToken(c *gin.Context) {
|
||||||
|
m.handler.RequestAnthropicToken(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managementTokenRequester) RequestGeminiCLIToken(c *gin.Context) {
|
||||||
|
m.handler.RequestGeminiCLIToken(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managementTokenRequester) RequestCodexToken(c *gin.Context) {
|
||||||
|
m.handler.RequestCodexToken(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managementTokenRequester) RequestAntigravityToken(c *gin.Context) {
|
||||||
|
m.handler.RequestAntigravityToken(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managementTokenRequester) RequestQwenToken(c *gin.Context) {
|
||||||
|
m.handler.RequestQwenToken(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managementTokenRequester) RequestIFlowToken(c *gin.Context) {
|
||||||
|
m.handler.RequestIFlowToken(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managementTokenRequester) RequestIFlowCookieToken(c *gin.Context) {
|
||||||
|
m.handler.RequestIFlowCookieToken(c)
|
||||||
|
}
|
||||||
@@ -111,6 +111,9 @@ type Manager struct {
|
|||||||
requestRetry atomic.Int32
|
requestRetry atomic.Int32
|
||||||
maxRetryInterval atomic.Int64
|
maxRetryInterval atomic.Int64
|
||||||
|
|
||||||
|
// modelNameMappings stores global model name alias mappings (alias -> upstream name) keyed by channel.
|
||||||
|
modelNameMappings atomic.Value
|
||||||
|
|
||||||
// Optional HTTP RoundTripper provider injected by host.
|
// Optional HTTP RoundTripper provider injected by host.
|
||||||
rtProvider RoundTripperProvider
|
rtProvider RoundTripperProvider
|
||||||
|
|
||||||
@@ -410,6 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
|||||||
}
|
}
|
||||||
execReq := req
|
execReq := req
|
||||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||||
|
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata)
|
||||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||||
if errExec != nil {
|
if errExec != nil {
|
||||||
@@ -471,6 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
|||||||
}
|
}
|
||||||
execReq := req
|
execReq := req
|
||||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||||
|
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata)
|
||||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||||
if errExec != nil {
|
if errExec != nil {
|
||||||
@@ -532,6 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
|||||||
}
|
}
|
||||||
execReq := req
|
execReq := req
|
||||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||||
|
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata)
|
||||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||||
if errStream != nil {
|
if errStream != nil {
|
||||||
rerr := &Error{Message: errStream.Error()}
|
rerr := &Error{Message: errStream.Error()}
|
||||||
@@ -592,6 +598,7 @@ func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]
|
|||||||
keys := []string{
|
keys := []string{
|
||||||
util.ThinkingOriginalModelMetadataKey,
|
util.ThinkingOriginalModelMetadataKey,
|
||||||
util.GeminiOriginalModelMetadataKey,
|
util.GeminiOriginalModelMetadataKey,
|
||||||
|
util.ModelMappingOriginalModelMetadataKey,
|
||||||
}
|
}
|
||||||
var out map[string]any
|
var out map[string]any
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
|
|||||||
172
sdk/cliproxy/auth/model_name_mappings.go
Normal file
172
sdk/cliproxy/auth/model_name_mappings.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type modelNameMappingTable struct {
|
||||||
|
// reverse maps channel -> alias (lower) -> original upstream model name.
|
||||||
|
reverse map[string]map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelNameMapping) *modelNameMappingTable {
|
||||||
|
if len(mappings) == 0 {
|
||||||
|
return &modelNameMappingTable{}
|
||||||
|
}
|
||||||
|
out := &modelNameMappingTable{
|
||||||
|
reverse: make(map[string]map[string]string, len(mappings)),
|
||||||
|
}
|
||||||
|
for rawChannel, entries := range mappings {
|
||||||
|
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||||
|
if channel == "" || len(entries) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rev := make(map[string]string, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
name := strings.TrimSpace(entry.Name)
|
||||||
|
alias := strings.TrimSpace(entry.Alias)
|
||||||
|
if name == "" || alias == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.EqualFold(name, alias) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
aliasKey := strings.ToLower(alias)
|
||||||
|
if _, exists := rev[aliasKey]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rev[aliasKey] = name
|
||||||
|
}
|
||||||
|
if len(rev) > 0 {
|
||||||
|
out.reverse[channel] = rev
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out.reverse) == 0 {
|
||||||
|
out.reverse = nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOAuthModelMappings updates the OAuth model name mapping table used during execution.
|
||||||
|
// The mapping is applied per-auth channel to resolve the upstream model name while keeping the
|
||||||
|
// client-visible model name unchanged for translation/response formatting.
|
||||||
|
func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.ModelNameMapping) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
table := compileModelNameMappingTable(mappings)
|
||||||
|
// atomic.Value requires non-nil store values.
|
||||||
|
if table == nil {
|
||||||
|
table = &modelNameMappingTable{}
|
||||||
|
}
|
||||||
|
m.modelNameMappings.Store(table)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any {
|
||||||
|
original := m.resolveOAuthUpstreamModel(auth, requestedModel)
|
||||||
|
if original == "" {
|
||||||
|
return metadata
|
||||||
|
}
|
||||||
|
if metadata != nil {
|
||||||
|
if v, ok := metadata[util.ModelMappingOriginalModelMetadataKey]; ok {
|
||||||
|
if s, okStr := v.(string); okStr && strings.EqualFold(s, original) {
|
||||||
|
return metadata
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out := make(map[string]any, 1)
|
||||||
|
if len(metadata) > 0 {
|
||||||
|
out = make(map[string]any, len(metadata)+1)
|
||||||
|
for k, v := range metadata {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out[util.ModelMappingOriginalModelMetadataKey] = original
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
|
||||||
|
if m == nil || auth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
channel := modelMappingChannel(auth)
|
||||||
|
if channel == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
key := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||||
|
if key == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
raw := m.modelNameMappings.Load()
|
||||||
|
table, _ := raw.(*modelNameMappingTable)
|
||||||
|
if table == nil || table.reverse == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
rev := table.reverse[channel]
|
||||||
|
if rev == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
original := strings.TrimSpace(rev[key])
|
||||||
|
if original == "" || strings.EqualFold(original, requestedModel) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return original
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelMappingChannel extracts the OAuth model mapping channel from an Auth object.
|
||||||
|
// It determines the provider and auth kind from the Auth's attributes and delegates
|
||||||
|
// to OAuthModelMappingChannel for the actual channel resolution.
|
||||||
|
func modelMappingChannel(auth *Auth) string {
|
||||||
|
if auth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||||
|
authKind := ""
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"]))
|
||||||
|
}
|
||||||
|
if authKind == "" {
|
||||||
|
if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||||
|
authKind = "apikey"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return OAuthModelMappingChannel(provider, authKind)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuthModelMappingChannel returns the OAuth model mapping channel name for a given provider
|
||||||
|
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
|
||||||
|
// OAuth model mappings (e.g., API key authentication).
|
||||||
|
//
|
||||||
|
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||||
|
func OAuthModelMappingChannel(provider, authKind string) string {
|
||||||
|
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||||
|
authKind = strings.ToLower(strings.TrimSpace(authKind))
|
||||||
|
switch provider {
|
||||||
|
case "gemini":
|
||||||
|
// gemini provider uses gemini-api-key config, not oauth-model-mappings.
|
||||||
|
// OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer.
|
||||||
|
return ""
|
||||||
|
case "vertex":
|
||||||
|
if authKind == "apikey" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return "vertex"
|
||||||
|
case "claude":
|
||||||
|
if authKind == "apikey" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return "claude"
|
||||||
|
case "codex":
|
||||||
|
if authKind == "apikey" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return "codex"
|
||||||
|
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow":
|
||||||
|
return provider
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -215,6 +215,7 @@ func (b *Builder) Build() (*Service, error) {
|
|||||||
}
|
}
|
||||||
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
|
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
|
||||||
coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider())
|
coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider())
|
||||||
|
coreManager.SetOAuthModelMappings(b.cfg.OAuthModelMappings)
|
||||||
|
|
||||||
service := &Service{
|
service := &Service{
|
||||||
cfg: b.cfg,
|
cfg: b.cfg,
|
||||||
|
|||||||
@@ -556,6 +556,9 @@ func (s *Service) Run(ctx context.Context) error {
|
|||||||
s.cfgMu.Lock()
|
s.cfgMu.Lock()
|
||||||
s.cfg = newCfg
|
s.cfg = newCfg
|
||||||
s.cfgMu.Unlock()
|
s.cfgMu.Unlock()
|
||||||
|
if s.coreManager != nil {
|
||||||
|
s.coreManager.SetOAuthModelMappings(newCfg.OAuthModelMappings)
|
||||||
|
}
|
||||||
s.rebindExecutors()
|
s.rebindExecutors()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -681,6 +684,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"]))
|
authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"]))
|
||||||
|
if authKind == "" {
|
||||||
|
if kind, _ := a.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||||
|
authKind = "apikey"
|
||||||
|
}
|
||||||
|
}
|
||||||
if a.Attributes != nil {
|
if a.Attributes != nil {
|
||||||
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
|
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
|
||||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||||
@@ -845,6 +853,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
models = applyOAuthModelMappings(s.cfg, provider, authKind, models)
|
||||||
if len(models) > 0 {
|
if len(models) > 0 {
|
||||||
key := provider
|
key := provider
|
||||||
if key == "" {
|
if key == "" {
|
||||||
@@ -1154,6 +1163,93 @@ func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func rewriteModelInfoName(name, oldID, newID string) string {
|
||||||
|
trimmed := strings.TrimSpace(name)
|
||||||
|
if trimmed == "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
oldID = strings.TrimSpace(oldID)
|
||||||
|
newID = strings.TrimSpace(newID)
|
||||||
|
if oldID == "" || newID == "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
if strings.EqualFold(oldID, newID) {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(trimmed, "/"+oldID) {
|
||||||
|
prefix := strings.TrimSuffix(trimmed, oldID)
|
||||||
|
return prefix + newID
|
||||||
|
}
|
||||||
|
if trimmed == "models/"+oldID {
|
||||||
|
return "models/" + newID
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo {
|
||||||
|
if cfg == nil || len(models) == 0 {
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
channel := coreauth.OAuthModelMappingChannel(provider, authKind)
|
||||||
|
if channel == "" || len(cfg.OAuthModelMappings) == 0 {
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
mappings := cfg.OAuthModelMappings[channel]
|
||||||
|
if len(mappings) == 0 {
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
forward := make(map[string]string, len(mappings))
|
||||||
|
for i := range mappings {
|
||||||
|
name := strings.TrimSpace(mappings[i].Name)
|
||||||
|
alias := strings.TrimSpace(mappings[i].Alias)
|
||||||
|
if name == "" || alias == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.EqualFold(name, alias) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := strings.ToLower(name)
|
||||||
|
if _, exists := forward[key]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
forward[key] = alias
|
||||||
|
}
|
||||||
|
if len(forward) == 0 {
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
out := make([]*ModelInfo, 0, len(models))
|
||||||
|
seen := make(map[string]struct{}, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
id := strings.TrimSpace(model.ID)
|
||||||
|
if id == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mappedID := id
|
||||||
|
if to, ok := forward[strings.ToLower(id)]; ok && strings.TrimSpace(to) != "" {
|
||||||
|
mappedID = strings.TrimSpace(to)
|
||||||
|
}
|
||||||
|
uniqueKey := strings.ToLower(mappedID)
|
||||||
|
if _, exists := seen[uniqueKey]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[uniqueKey] = struct{}{}
|
||||||
|
if mappedID == id {
|
||||||
|
out = append(out, model)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
clone := *model
|
||||||
|
clone.ID = mappedID
|
||||||
|
if clone.Name != "" {
|
||||||
|
clone.Name = rewriteModelInfoName(clone.Name, id, mappedID)
|
||||||
|
}
|
||||||
|
out = append(out, &clone)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
|
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
|
||||||
if entry == nil || len(entry.Models) == 0 {
|
if entry == nil || len(entry.Models) == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ type StreamingConfig = internalconfig.StreamingConfig
|
|||||||
type TLSConfig = internalconfig.TLSConfig
|
type TLSConfig = internalconfig.TLSConfig
|
||||||
type RemoteManagement = internalconfig.RemoteManagement
|
type RemoteManagement = internalconfig.RemoteManagement
|
||||||
type AmpCode = internalconfig.AmpCode
|
type AmpCode = internalconfig.AmpCode
|
||||||
|
type ModelNameMapping = internalconfig.ModelNameMapping
|
||||||
type PayloadConfig = internalconfig.PayloadConfig
|
type PayloadConfig = internalconfig.PayloadConfig
|
||||||
type PayloadRule = internalconfig.PayloadRule
|
type PayloadRule = internalconfig.PayloadRule
|
||||||
type PayloadModelRule = internalconfig.PayloadModelRule
|
type PayloadModelRule = internalconfig.PayloadModelRule
|
||||||
|
|||||||
@@ -56,6 +56,10 @@ func setupAmpRouter(h *management.Handler) *gin.Engine {
|
|||||||
mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey)
|
mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey)
|
||||||
mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey)
|
mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey)
|
||||||
mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey)
|
mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey)
|
||||||
|
mgmt.GET("/ampcode/upstream-api-keys", h.GetAmpUpstreamAPIKeys)
|
||||||
|
mgmt.PUT("/ampcode/upstream-api-keys", h.PutAmpUpstreamAPIKeys)
|
||||||
|
mgmt.PATCH("/ampcode/upstream-api-keys", h.PatchAmpUpstreamAPIKeys)
|
||||||
|
mgmt.DELETE("/ampcode/upstream-api-keys", h.DeleteAmpUpstreamAPIKeys)
|
||||||
mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost)
|
mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost)
|
||||||
mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost)
|
mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost)
|
||||||
mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings)
|
mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings)
|
||||||
@@ -188,6 +192,90 @@ func TestPutAmpUpstreamAPIKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPutAmpUpstreamAPIKeys_PersistsAndReturns(t *testing.T) {
|
||||||
|
h, configPath := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value":[{"upstream-api-key":" u1 ","api-keys":[" k1 ","","k2"]}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it was persisted to disk
|
||||||
|
loaded, err := config.LoadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load config from disk: %v", err)
|
||||||
|
}
|
||||||
|
if len(loaded.AmpCode.UpstreamAPIKeys) != 1 {
|
||||||
|
t.Fatalf("expected 1 upstream-api-keys entry, got %d", len(loaded.AmpCode.UpstreamAPIKeys))
|
||||||
|
}
|
||||||
|
entry := loaded.AmpCode.UpstreamAPIKeys[0]
|
||||||
|
if entry.UpstreamAPIKey != "u1" {
|
||||||
|
t.Fatalf("expected upstream-api-key u1, got %q", entry.UpstreamAPIKey)
|
||||||
|
}
|
||||||
|
if len(entry.APIKeys) != 2 || entry.APIKeys[0] != "k1" || entry.APIKeys[1] != "k2" {
|
||||||
|
t.Fatalf("expected api-keys [k1 k2], got %#v", entry.APIKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it is returned by GET /ampcode
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
var resp map[string]config.AmpCode
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
if got := resp["ampcode"].UpstreamAPIKeys; len(got) != 1 || got[0].UpstreamAPIKey != "u1" {
|
||||||
|
t.Fatalf("expected upstream-api-keys to be present after update, got %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteAmpUpstreamAPIKeys_ClearsAll(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
// Seed with one entry
|
||||||
|
putBody := `{"value":[{"upstream-api-key":"u1","api-keys":["k1"]}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(putBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
deleteBody := `{"value":[]}`
|
||||||
|
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(deleteBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-keys", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
var resp map[string][]config.AmpUpstreamAPIKeyEntry
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
if resp["upstream-api-keys"] != nil && len(resp["upstream-api-keys"]) != 0 {
|
||||||
|
t.Fatalf("expected cleared list, got %#v", resp["upstream-api-keys"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key.
|
// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key.
|
||||||
func TestDeleteAmpUpstreamAPIKey(t *testing.T) {
|
func TestDeleteAmpUpstreamAPIKey(t *testing.T) {
|
||||||
h, _ := newAmpTestHandler(t)
|
h, _ := newAmpTestHandler(t)
|
||||||
|
|||||||
Reference in New Issue
Block a user