diff --git a/config.example.yaml b/config.example.yaml index 8457e103..9dfca5bc 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -62,6 +62,11 @@ ws-auth: false # headers: # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" +# excluded-models: +# - "gemini-2.5-pro" # exclude specific models from this provider (exact match) +# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro) +# - "*-preview" # wildcard matching suffix (e.g. gemini-3-pro-preview) +# - "*flash*" # wildcard matching substring (e.g. gemini-2.5-flash-lite) # - api-key: "AIzaSy...02" # API keys for official Generative Language API (legacy compatibility) @@ -76,6 +81,11 @@ ws-auth: false # headers: # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override +# excluded-models: +# - "gpt-5.1" # exclude specific models (exact match) +# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex) +# - "*-mini" # wildcard matching suffix (e.g. gpt-5-codex-mini) +# - "*codex*" # wildcard matching substring (e.g. gpt-5-codex-low) # Claude API keys #claude-api-key: @@ -88,6 +98,11 @@ ws-auth: false # models: # - name: "claude-3-5-sonnet-20241022" # upstream model name # alias: "claude-sonnet-latest" # client alias mapped to the upstream model +# excluded-models: +# - "claude-opus-4-5-20251101" # exclude specific models (exact match) +# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219) +# - "*-think" # wildcard matching suffix (e.g. claude-opus-4-5-thinking) +# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022) # OpenAI compatibility providers #openai-compatibility: @@ -121,3 +136,25 @@ ws-auth: false # protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex # params: # JSON path (gjson/sjson syntax) -> value # "reasoning.effort": "high" + +# OAuth provider excluded models +#oauth-excluded-models: +# gemini-cli: +# - "gemini-2.5-pro" # exclude specific models (exact match) +# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro) +# - "*-preview" # wildcard matching suffix (e.g. gemini-3-pro-preview) +# - "*flash*" # wildcard matching substring (e.g. gemini-2.5-flash-lite) +# vertex: +# - "gemini-3-pro-preview" +# aistudio: +# - "gemini-3-pro-preview" +# antigravity: +# - "gemini-3-pro-preview" +# claude: +# - "claude-3-5-haiku-20241022" +# codex: +# - "gpt-5-codex-mini" +# qwen: +# - "vision-model" +# iflow: +# - "tstars2.0" diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index b4b43b0f..71193084 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -223,6 +223,7 @@ func (h *Handler) PatchGeminiKey(c *gin.Context) { value.APIKey = strings.TrimSpace(value.APIKey) value.BaseURL = strings.TrimSpace(value.BaseURL) value.ProxyURL = strings.TrimSpace(value.ProxyURL) + value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) if value.APIKey == "" { // Treat empty API key as delete. if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { @@ -504,6 +505,91 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) { c.JSON(400, gin.H{"error": "missing name or index"}) } +// oauth-excluded-models: map[string][]string +func (h *Handler) GetOAuthExcludedModels(c *gin.Context) { + c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)}) +} + +func (h *Handler) PutOAuthExcludedModels(c *gin.Context) { + data, err := c.GetRawData() + if err != nil { + c.JSON(400, gin.H{"error": "failed to read body"}) + return + } + var entries map[string][]string + if err = json.Unmarshal(data, &entries); err != nil { + var wrapper struct { + Items map[string][]string `json:"items"` + } + if err2 := json.Unmarshal(data, &wrapper); err2 != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + entries = wrapper.Items + } + h.cfg.OAuthExcludedModels = config.NormalizeOAuthExcludedModels(entries) + h.persist(c) +} + +func (h *Handler) PatchOAuthExcludedModels(c *gin.Context) { + var body struct { + Provider *string `json:"provider"` + Models []string `json:"models"` + } + if err := c.ShouldBindJSON(&body); err != nil || body.Provider == nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + provider := strings.ToLower(strings.TrimSpace(*body.Provider)) + if provider == "" { + c.JSON(400, gin.H{"error": "invalid provider"}) + return + } + normalized := config.NormalizeExcludedModels(body.Models) + if len(normalized) == 0 { + if h.cfg.OAuthExcludedModels == nil { + c.JSON(404, gin.H{"error": "provider not found"}) + return + } + if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok { + c.JSON(404, gin.H{"error": "provider not found"}) + return + } + delete(h.cfg.OAuthExcludedModels, provider) + if len(h.cfg.OAuthExcludedModels) == 0 { + h.cfg.OAuthExcludedModels = nil + } + h.persist(c) + return + } + if h.cfg.OAuthExcludedModels == nil { + h.cfg.OAuthExcludedModels = make(map[string][]string) + } + h.cfg.OAuthExcludedModels[provider] = normalized + h.persist(c) +} + +func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) { + provider := strings.ToLower(strings.TrimSpace(c.Query("provider"))) + if provider == "" { + c.JSON(400, gin.H{"error": "missing provider"}) + return + } + if h.cfg.OAuthExcludedModels == nil { + c.JSON(404, gin.H{"error": "provider not found"}) + return + } + if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok { + c.JSON(404, gin.H{"error": "provider not found"}) + return + } + delete(h.cfg.OAuthExcludedModels, provider) + if len(h.cfg.OAuthExcludedModels) == 0 { + h.cfg.OAuthExcludedModels = nil + } + h.persist(c) +} + // codex-api-key: []CodexKey func (h *Handler) GetCodexKeys(c *gin.Context) { c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey}) @@ -533,6 +619,7 @@ func (h *Handler) PutCodexKeys(c *gin.Context) { entry.BaseURL = strings.TrimSpace(entry.BaseURL) entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) entry.Headers = config.NormalizeHeaders(entry.Headers) + entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels) if entry.BaseURL == "" { continue } @@ -557,6 +644,7 @@ func (h *Handler) PatchCodexKey(c *gin.Context) { value.BaseURL = strings.TrimSpace(value.BaseURL) value.ProxyURL = strings.TrimSpace(value.ProxyURL) value.Headers = config.NormalizeHeaders(value.Headers) + value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) // If base-url becomes empty, delete instead of update if value.BaseURL == "" { if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { @@ -694,6 +782,7 @@ func normalizeClaudeKey(entry *config.ClaudeKey) { entry.BaseURL = strings.TrimSpace(entry.BaseURL) entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) entry.Headers = config.NormalizeHeaders(entry.Headers) + entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels) if len(entry.Models) == 0 { return } diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 8e5189ad..5bc0dc25 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -111,6 +111,14 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha ampAPI.Any("/otel", proxyHandler) ampAPI.Any("/otel/*path", proxyHandler) + // Root-level routes that AMP CLI expects without /api prefix + // These need the same security middleware as the /api/* routes + rootMiddleware := []gin.HandlerFunc{noCORSMiddleware()} + if restrictToLocalhost { + rootMiddleware = append(rootMiddleware, localhostOnlyMiddleware()) + } + engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...) + // Google v1beta1 passthrough with OAuth fallback // AMP CLI uses non-standard paths like /publishers/google/models/... // We bridge these to our standard Gemini handler to enable local OAuth. diff --git a/internal/api/modules/amp/routes_test.go b/internal/api/modules/amp/routes_test.go index 12240981..51bd7abd 100644 --- a/internal/api/modules/amp/routes_test.go +++ b/internal/api/modules/amp/routes_test.go @@ -37,6 +37,7 @@ func TestRegisterManagementRoutes(t *testing.T) { {"/api/meta", http.MethodGet}, {"/api/telemetry", http.MethodGet}, {"/api/threads", http.MethodGet}, + {"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix) {"/api/otel", http.MethodGet}, // Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST {"/api/provider/google/v1beta1/models", http.MethodGet}, diff --git a/internal/api/server.go b/internal/api/server.go index 3dd78c93..ab9c0354 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -543,6 +543,11 @@ func (s *Server) registerManagementRoutes() { mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat) mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat) + mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels) + mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels) + mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels) + mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels) + mgmt.GET("/auth-files", s.mgmt.ListAuthFiles) mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile) mgmt.POST("/auth-files", s.mgmt.UploadAuthFile) diff --git a/internal/config/config.go b/internal/config/config.go index 31920075..97b5a0c2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -83,6 +83,9 @@ type Config struct { // Payload defines default and override rules for provider payload parameters. Payload PayloadConfig `yaml:"payload" json:"payload"` + + // OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries. + OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"` } // TLSConfig holds HTTPS server settings. @@ -157,6 +160,9 @@ type ClaudeKey struct { // Headers optionally adds extra HTTP headers for requests sent with this key. Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` + + // ExcludedModels lists model IDs that should be excluded for this provider. + ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` } // ClaudeModel describes a mapping between an alias and the actual upstream model name. @@ -183,6 +189,9 @@ type CodexKey struct { // Headers optionally adds extra HTTP headers for requests sent with this key. Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` + + // ExcludedModels lists model IDs that should be excluded for this provider. + ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` } // GeminiKey represents the configuration for a Gemini API key, @@ -199,6 +208,9 @@ type GeminiKey struct { // Headers optionally adds extra HTTP headers for requests sent with this key. Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` + + // ExcludedModels lists model IDs that should be excluded for this provider. + ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` } // OpenAICompatibility represents the configuration for OpenAI API compatibility @@ -322,6 +334,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Sanitize OpenAI compatibility providers: drop entries without base-url cfg.SanitizeOpenAICompatibility() + // Normalize OAuth provider model exclusion map. + cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels) + // Return the populated configuration struct. return &cfg, nil } @@ -359,6 +374,7 @@ func (cfg *Config) SanitizeCodexKeys() { e := cfg.CodexKey[i] e.BaseURL = strings.TrimSpace(e.BaseURL) e.Headers = NormalizeHeaders(e.Headers) + e.ExcludedModels = NormalizeExcludedModels(e.ExcludedModels) if e.BaseURL == "" { continue } @@ -375,6 +391,7 @@ func (cfg *Config) SanitizeClaudeKeys() { for i := range cfg.ClaudeKey { entry := &cfg.ClaudeKey[i] entry.Headers = NormalizeHeaders(entry.Headers) + entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels) } } @@ -395,6 +412,7 @@ func (cfg *Config) SanitizeGeminiKeys() { entry.BaseURL = strings.TrimSpace(entry.BaseURL) entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) entry.Headers = NormalizeHeaders(entry.Headers) + entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels) if _, exists := seen[entry.APIKey]; exists { continue } @@ -457,6 +475,55 @@ func NormalizeHeaders(headers map[string]string) map[string]string { return clean } +// NormalizeExcludedModels trims, lowercases, and deduplicates model exclusion patterns. +// It preserves the order of first occurrences and drops empty entries. +func NormalizeExcludedModels(models []string) []string { + if len(models) == 0 { + return nil + } + seen := make(map[string]struct{}, len(models)) + out := make([]string, 0, len(models)) + for _, raw := range models { + trimmed := strings.ToLower(strings.TrimSpace(raw)) + if trimmed == "" { + continue + } + if _, exists := seen[trimmed]; exists { + continue + } + seen[trimmed] = struct{}{} + out = append(out, trimmed) + } + if len(out) == 0 { + return nil + } + return out +} + +// NormalizeOAuthExcludedModels cleans provider -> excluded models mappings by normalizing provider keys +// and applying model exclusion normalization to each entry. +func NormalizeOAuthExcludedModels(entries map[string][]string) map[string][]string { + if len(entries) == 0 { + return nil + } + out := make(map[string][]string, len(entries)) + for provider, models := range entries { + key := strings.ToLower(strings.TrimSpace(provider)) + if key == "" { + continue + } + normalized := NormalizeExcludedModels(models) + if len(normalized) == 0 { + continue + } + out[key] = normalized + } + if len(out) == 0 { + return nil + } + return out +} + // hashSecret hashes the given secret using bcrypt. func hashSecret(secret string) (string, error) { // Use default cost for simplicity. diff --git a/internal/runtime/executor/usage_helpers.go b/internal/runtime/executor/usage_helpers.go index 266a300e..5669d9bc 100644 --- a/internal/runtime/executor/usage_helpers.go +++ b/internal/runtime/executor/usage_helpers.go @@ -37,7 +37,7 @@ func newUsageReporter(ctx context.Context, provider, model string, auth *cliprox } if auth != nil { reporter.authID = auth.ID - reporter.authIndex = auth.Index + reporter.authIndex = auth.EnsureIndex() } return reporter } diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index fd8e0071..2c1671f5 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -271,7 +271,15 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ if resp == "" { resp = "{}" } - toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp)) + // Handle non-JSON output gracefully (matches dev branch approach) + if resp != "null" { + parsed := gjson.Parse(resp) + if parsed.Type == gjson.JSON { + toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw)) + } else { + toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", resp) + } + } pp++ } } diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 0d955064..a284541a 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -30,6 +30,16 @@ import ( log "github.com/sirupsen/logrus" ) +func matchProvider(provider string, targets []string) (string, bool) { + p := strings.ToLower(strings.TrimSpace(provider)) + for _, t := range targets { + if strings.EqualFold(p, strings.TrimSpace(t)) { + return p, true + } + } + return p, false +} + // storePersister captures persistence-capable token store methods used by the watcher. type storePersister interface { PersistConfig(ctx context.Context) error @@ -54,6 +64,7 @@ type Watcher struct { lastConfigHash string authQueue chan<- AuthUpdate currentAuths map[string]*coreauth.Auth + runtimeAuths map[string]*coreauth.Auth dispatchMu sync.Mutex dispatchCond *sync.Cond pendingUpdates map[string]AuthUpdate @@ -169,7 +180,7 @@ func (w *Watcher) Start(ctx context.Context) error { go w.processEvents(ctx) // Perform an initial full reload based on current config and auth dir - w.reloadClients(true) + w.reloadClients(true, nil) return nil } @@ -221,9 +232,57 @@ func (w *Watcher) SetAuthUpdateQueue(queue chan<- AuthUpdate) { } } +// DispatchRuntimeAuthUpdate allows external runtime providers (e.g., websocket-driven auths) +// to push auth updates through the same queue used by file/config watchers. +// Returns true if the update was enqueued; false if no queue is configured. +func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool { + if w == nil { + return false + } + w.clientsMutex.Lock() + if w.runtimeAuths == nil { + w.runtimeAuths = make(map[string]*coreauth.Auth) + } + switch update.Action { + case AuthUpdateActionAdd, AuthUpdateActionModify: + if update.Auth != nil && update.Auth.ID != "" { + clone := update.Auth.Clone() + w.runtimeAuths[clone.ID] = clone + if w.currentAuths == nil { + w.currentAuths = make(map[string]*coreauth.Auth) + } + w.currentAuths[clone.ID] = clone.Clone() + } + case AuthUpdateActionDelete: + id := update.ID + if id == "" && update.Auth != nil { + id = update.Auth.ID + } + if id != "" { + delete(w.runtimeAuths, id) + if w.currentAuths != nil { + delete(w.currentAuths, id) + } + } + } + w.clientsMutex.Unlock() + if w.getAuthQueue() == nil { + return false + } + w.dispatchAuthUpdates([]AuthUpdate{update}) + return true +} + func (w *Watcher) refreshAuthState() { auths := w.SnapshotCoreAuths() w.clientsMutex.Lock() + if len(w.runtimeAuths) > 0 { + for _, a := range w.runtimeAuths { + if a != nil { + auths = append(auths, a.Clone()) + } + } + } updates := w.prepareAuthUpdatesLocked(auths) w.clientsMutex.Unlock() w.dispatchAuthUpdates(updates) @@ -450,6 +509,142 @@ func computeClaudeModelsHash(models []config.ClaudeModel) string { return hex.EncodeToString(sum[:]) } +func computeExcludedModelsHash(excluded []string) string { + if len(excluded) == 0 { + return "" + } + normalized := make([]string, 0, len(excluded)) + for _, entry := range excluded { + if trimmed := strings.TrimSpace(entry); trimmed != "" { + normalized = append(normalized, strings.ToLower(trimmed)) + } + } + if len(normalized) == 0 { + return "" + } + sort.Strings(normalized) + data, err := json.Marshal(normalized) + if err != nil || len(data) == 0 { + return "" + } + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + +type excludedModelsSummary struct { + hash string + count int +} + +func summarizeExcludedModels(list []string) excludedModelsSummary { + if len(list) == 0 { + return excludedModelsSummary{} + } + seen := make(map[string]struct{}, len(list)) + normalized := make([]string, 0, len(list)) + for _, entry := range list { + if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" { + if _, exists := seen[trimmed]; exists { + continue + } + seen[trimmed] = struct{}{} + normalized = append(normalized, trimmed) + } + } + sort.Strings(normalized) + return excludedModelsSummary{ + hash: computeExcludedModelsHash(normalized), + count: len(normalized), + } +} + +func summarizeOAuthExcludedModels(entries map[string][]string) map[string]excludedModelsSummary { + if len(entries) == 0 { + return nil + } + out := make(map[string]excludedModelsSummary, len(entries)) + for k, v := range entries { + key := strings.ToLower(strings.TrimSpace(k)) + if key == "" { + continue + } + out[key] = summarizeExcludedModels(v) + } + return out +} + +func diffOAuthExcludedModelChanges(oldMap, newMap map[string][]string) ([]string, []string) { + oldSummary := summarizeOAuthExcludedModels(oldMap) + newSummary := summarizeOAuthExcludedModels(newMap) + keys := make(map[string]struct{}, len(oldSummary)+len(newSummary)) + for k := range oldSummary { + keys[k] = struct{}{} + } + for k := range newSummary { + keys[k] = struct{}{} + } + changes := make([]string, 0, len(keys)) + affected := make([]string, 0, len(keys)) + for key := range keys { + oldInfo, okOld := oldSummary[key] + newInfo, okNew := newSummary[key] + switch { + case okOld && !okNew: + changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: removed", key)) + affected = append(affected, key) + case !okOld && okNew: + changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: added (%d entries)", key, newInfo.count)) + affected = append(affected, key) + case okOld && okNew && oldInfo.hash != newInfo.hash: + changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count)) + affected = append(affected, key) + } + } + sort.Strings(changes) + sort.Strings(affected) + return changes, affected +} + +func applyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) { + if auth == nil || cfg == nil { + return + } + authKindKey := strings.ToLower(strings.TrimSpace(authKind)) + seen := make(map[string]struct{}) + add := func(list []string) { + for _, entry := range list { + if trimmed := strings.TrimSpace(entry); trimmed != "" { + key := strings.ToLower(trimmed) + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + } + } + } + if authKindKey == "apikey" { + add(perKey) + } else if cfg.OAuthExcludedModels != nil { + providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) + add(cfg.OAuthExcludedModels[providerKey]) + } + combined := make([]string, 0, len(seen)) + for k := range seen { + combined = append(combined, k) + } + sort.Strings(combined) + hash := computeExcludedModelsHash(combined) + if auth.Attributes == nil { + auth.Attributes = make(map[string]string) + } + if hash != "" { + auth.Attributes["excluded_models_hash"] = hash + } + if authKind != "" { + auth.Attributes["auth_kind"] = authKind + } +} + // SetClients sets the file-based clients. // SetClients removed // SetAPIKeyClients removed @@ -634,6 +829,11 @@ func (w *Watcher) reloadConfig() bool { w.config = newConfig w.clientsMutex.Unlock() + var affectedOAuthProviders []string + if oldConfig != nil { + _, affectedOAuthProviders = diffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels) + } + // Always apply the current log level based on the latest config. // This ensures logrus reflects the desired level even if change detection misses. util.SetLogLevel(newConfig) @@ -659,12 +859,12 @@ func (w *Watcher) reloadConfig() bool { log.Infof("config successfully reloaded, triggering client reload") // Reload clients with new config - w.reloadClients(authDirChanged) + w.reloadClients(authDirChanged, affectedOAuthProviders) return true } // reloadClients performs a full scan and reload of all clients. -func (w *Watcher) reloadClients(rescanAuth bool) { +func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string) { log.Debugf("starting full client load process") w.clientsMutex.RLock() @@ -676,6 +876,28 @@ func (w *Watcher) reloadClients(rescanAuth bool) { return } + if len(affectedOAuthProviders) > 0 { + w.clientsMutex.Lock() + if w.currentAuths != nil { + filtered := make(map[string]*coreauth.Auth, len(w.currentAuths)) + for id, auth := range w.currentAuths { + if auth == nil { + continue + } + provider := strings.ToLower(strings.TrimSpace(auth.Provider)) + if _, match := matchProvider(provider, affectedOAuthProviders); match { + continue + } + filtered[id] = auth + } + w.currentAuths = filtered + log.Debugf("applying oauth-excluded-models to providers %v", affectedOAuthProviders) + } else { + w.currentAuths = nil + } + w.clientsMutex.Unlock() + } + // Unregister all old API key clients before creating new ones // no legacy clients to unregister @@ -849,6 +1071,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { CreatedAt: now, UpdatedAt: now, } + applyAuthExcludedModelsMeta(a, cfg, entry.ExcludedModels, "apikey") out = append(out, a) } // Claude API keys -> synthesize auths @@ -882,6 +1105,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { CreatedAt: now, UpdatedAt: now, } + applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey") out = append(out, a) } // Codex API keys -> synthesize auths @@ -911,6 +1135,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { CreatedAt: now, UpdatedAt: now, } + applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey") out = append(out, a) } for i := range cfg.OpenAICompatibility { @@ -1071,8 +1296,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { CreatedAt: now, UpdatedAt: now, } + applyAuthExcludedModelsMeta(a, cfg, nil, "oauth") if provider == "gemini-cli" { if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 { + for _, v := range virtuals { + applyAuthExcludedModelsMeta(v, cfg, nil, "oauth") + } out = append(out, a) out = append(out, virtuals...) continue @@ -1464,6 +1693,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if !equalStringMap(o.Headers, n.Headers) { changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i)) } + oldExcluded := summarizeExcludedModels(o.ExcludedModels) + newExcluded := summarizeExcludedModels(n.ExcludedModels) + if oldExcluded.hash != newExcluded.hash { + changes = append(changes, fmt.Sprintf("gemini[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count)) + } } if !reflect.DeepEqual(trimStrings(oldCfg.GlAPIKey), trimStrings(newCfg.GlAPIKey)) { changes = append(changes, "generative-language-api-key: values updated (legacy view, redacted)") @@ -1492,6 +1726,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if !equalStringMap(o.Headers, n.Headers) { changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i)) } + oldExcluded := summarizeExcludedModels(o.ExcludedModels) + newExcluded := summarizeExcludedModels(n.ExcludedModels) + if oldExcluded.hash != newExcluded.hash { + changes = append(changes, fmt.Sprintf("claude[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count)) + } } } @@ -1517,9 +1756,18 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if !equalStringMap(o.Headers, n.Headers) { changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i)) } + oldExcluded := summarizeExcludedModels(o.ExcludedModels) + newExcluded := summarizeExcludedModels(n.ExcludedModels) + if oldExcluded.hash != newExcluded.hash { + changes = append(changes, fmt.Sprintf("codex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count)) + } } } + if entries, _ := diffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 { + changes = append(changes, entries...) + } + // Remote management (never print the key) if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote { changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote)) diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index eef70ee5..dc7887e7 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -1118,6 +1118,14 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli } authCopy := selected.Clone() m.mu.RUnlock() + if !selected.indexAssigned { + m.mu.Lock() + if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned { + current.EnsureIndex() + authCopy = current.Clone() + } + m.mu.Unlock() + } return authCopy, executor, nil } diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 8b66a9a9..4001c49c 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -146,6 +146,27 @@ func (s *Service) consumeAuthUpdates(ctx context.Context) { } } +func (s *Service) emitAuthUpdate(ctx context.Context, update watcher.AuthUpdate) { + if s == nil { + return + } + if ctx == nil { + ctx = context.Background() + } + if s.watcher != nil && s.watcher.DispatchRuntimeAuthUpdate(update) { + return + } + if s.authUpdates != nil { + select { + case s.authUpdates <- update: + return + default: + log.Debugf("auth update queue saturated, applying inline action=%v id=%s", update.Action, update.ID) + } + } + s.handleAuthUpdate(ctx, update) +} + func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdate) { if s == nil { return @@ -220,7 +241,11 @@ func (s *Service) wsOnConnected(channelID string) { Metadata: map[string]any{"email": channelID}, // metadata drives logging and usage tracking } log.Infof("websocket provider connected: %s", channelID) - s.applyCoreAuthAddOrUpdate(context.Background(), auth) + s.emitAuthUpdate(context.Background(), watcher.AuthUpdate{ + Action: watcher.AuthUpdateActionAdd, + ID: auth.ID, + Auth: auth, + }) } func (s *Service) wsOnDisconnected(channelID string, reason error) { @@ -237,7 +262,10 @@ func (s *Service) wsOnDisconnected(channelID string, reason error) { log.Infof("websocket provider disconnected: %s", channelID) } ctx := context.Background() - s.applyCoreAuthRemoval(ctx, channelID) + s.emitAuthUpdate(ctx, watcher.AuthUpdate{ + Action: watcher.AuthUpdateActionDelete, + ID: channelID, + }) } func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) { @@ -619,6 +647,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { if a == nil || a.ID == "" { return } + authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"])) if a.Attributes != nil { if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") { GlobalModelRegistry().UnregisterClient(a.ID) @@ -638,34 +667,59 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { if compatDetected { provider = "openai-compatibility" } + excluded := s.oauthExcludedModels(provider, authKind) var models []*ModelInfo switch provider { case "gemini": models = registry.GetGeminiModels() + if entry := s.resolveConfigGeminiKey(a); entry != nil { + if authKind == "apikey" { + excluded = entry.ExcludedModels + } + } + models = applyExcludedModels(models, excluded) case "vertex": // Vertex AI Gemini supports the same model identifiers as Gemini. models = registry.GetGeminiVertexModels() + models = applyExcludedModels(models, excluded) case "gemini-cli": models = registry.GetGeminiCLIModels() + models = applyExcludedModels(models, excluded) case "aistudio": models = registry.GetAIStudioModels() + models = applyExcludedModels(models, excluded) case "antigravity": ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) models = executor.FetchAntigravityModels(ctx, a, s.cfg) cancel() + models = applyExcludedModels(models, excluded) case "claude": models = registry.GetClaudeModels() - if entry := s.resolveConfigClaudeKey(a); entry != nil && len(entry.Models) > 0 { - models = buildClaudeConfigModels(entry) + if entry := s.resolveConfigClaudeKey(a); entry != nil { + if len(entry.Models) > 0 { + models = buildClaudeConfigModels(entry) + } + if authKind == "apikey" { + excluded = entry.ExcludedModels + } } + models = applyExcludedModels(models, excluded) case "codex": models = registry.GetOpenAIModels() + if entry := s.resolveConfigCodexKey(a); entry != nil { + if authKind == "apikey" { + excluded = entry.ExcludedModels + } + } + models = applyExcludedModels(models, excluded) case "qwen": models = registry.GetQwenModels() + models = applyExcludedModels(models, excluded) case "iflow": models = registry.GetIFlowModels() case "github-copilot": models = registry.GetGitHubCopilotModels() + models = applyExcludedModels(models, excluded) default: // Handle OpenAI-compatibility providers by name using config if s.cfg != nil { @@ -753,7 +807,10 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { key = strings.ToLower(strings.TrimSpace(a.Provider)) } GlobalModelRegistry().RegisterClient(a.ID, key, models) + return } + + GlobalModelRegistry().UnregisterClient(a.ID) } func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey { @@ -795,6 +852,150 @@ func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey return nil } +func (s *Service) resolveConfigGeminiKey(auth *coreauth.Auth) *config.GeminiKey { + if auth == nil || s.cfg == nil { + return nil + } + var attrKey, attrBase string + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range s.cfg.GeminiKey { + entry := &s.cfg.GeminiKey[i] + cfgKey := strings.TrimSpace(entry.APIKey) + cfgBase := strings.TrimSpace(entry.BaseURL) + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + return nil +} + +func (s *Service) resolveConfigCodexKey(auth *coreauth.Auth) *config.CodexKey { + if auth == nil || s.cfg == nil { + return nil + } + var attrKey, attrBase string + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range s.cfg.CodexKey { + entry := &s.cfg.CodexKey[i] + cfgKey := strings.TrimSpace(entry.APIKey) + cfgBase := strings.TrimSpace(entry.BaseURL) + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + return nil +} + +func (s *Service) oauthExcludedModels(provider, authKind string) []string { + cfg := s.cfg + if cfg == nil { + return nil + } + authKindKey := strings.ToLower(strings.TrimSpace(authKind)) + providerKey := strings.ToLower(strings.TrimSpace(provider)) + if authKindKey == "apikey" { + return nil + } + return cfg.OAuthExcludedModels[providerKey] +} + +func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo { + if len(models) == 0 || len(excluded) == 0 { + return models + } + + patterns := make([]string, 0, len(excluded)) + for _, item := range excluded { + if trimmed := strings.TrimSpace(item); trimmed != "" { + patterns = append(patterns, strings.ToLower(trimmed)) + } + } + if len(patterns) == 0 { + return models + } + + filtered := make([]*ModelInfo, 0, len(models)) + for _, model := range models { + if model == nil { + continue + } + modelID := strings.ToLower(strings.TrimSpace(model.ID)) + blocked := false + for _, pattern := range patterns { + if matchWildcard(pattern, modelID) { + blocked = true + break + } + } + if !blocked { + filtered = append(filtered, model) + } + } + return filtered +} + +// matchWildcard performs case-insensitive wildcard matching where '*' matches any substring. +func matchWildcard(pattern, value string) bool { + if pattern == "" { + return false + } + + // Fast path for exact match (no wildcard present). + if !strings.Contains(pattern, "*") { + return pattern == value + } + + parts := strings.Split(pattern, "*") + // Handle prefix. + if prefix := parts[0]; prefix != "" { + if !strings.HasPrefix(value, prefix) { + return false + } + value = value[len(prefix):] + } + + // Handle suffix. + if suffix := parts[len(parts)-1]; suffix != "" { + if !strings.HasSuffix(value, suffix) { + return false + } + value = value[:len(value)-len(suffix)] + } + + // Handle middle segments in order. + for i := 1; i < len(parts)-1; i++ { + segment := parts[i] + if segment == "" { + continue + } + idx := strings.Index(value, segment) + if idx < 0 { + return false + } + value = value[idx+len(segment):] + } + + return true +} + func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo { if entry == nil || len(entry.Models) == 0 { return nil diff --git a/sdk/cliproxy/types.go b/sdk/cliproxy/types.go index 1d577153..b44185d1 100644 --- a/sdk/cliproxy/types.go +++ b/sdk/cliproxy/types.go @@ -83,9 +83,10 @@ type WatcherWrapper struct { start func(ctx context.Context) error stop func() error - setConfig func(cfg *config.Config) - snapshotAuths func() []*coreauth.Auth - setUpdateQueue func(queue chan<- watcher.AuthUpdate) + setConfig func(cfg *config.Config) + snapshotAuths func() []*coreauth.Auth + setUpdateQueue func(queue chan<- watcher.AuthUpdate) + dispatchRuntimeUpdate func(update watcher.AuthUpdate) bool } // Start proxies to the underlying watcher Start implementation. @@ -112,6 +113,16 @@ func (w *WatcherWrapper) SetConfig(cfg *config.Config) { w.setConfig(cfg) } +// DispatchRuntimeAuthUpdate forwards runtime auth updates (e.g., websocket providers) +// into the watcher-managed auth update queue when available. +// Returns true if the update was enqueued successfully. +func (w *WatcherWrapper) DispatchRuntimeAuthUpdate(update watcher.AuthUpdate) bool { + if w == nil || w.dispatchRuntimeUpdate == nil { + return false + } + return w.dispatchRuntimeUpdate(update) +} + // SetClients updates the watcher file-backed clients registry. // SetClients and SetAPIKeyClients removed; watcher manages its own caches diff --git a/sdk/cliproxy/watcher.go b/sdk/cliproxy/watcher.go index 81e4c18a..921e2068 100644 --- a/sdk/cliproxy/watcher.go +++ b/sdk/cliproxy/watcher.go @@ -28,5 +28,8 @@ func defaultWatcherFactory(configPath, authDir string, reload func(*config.Confi setUpdateQueue: func(queue chan<- watcher.AuthUpdate) { w.SetAuthUpdateQueue(queue) }, + dispatchRuntimeUpdate: func(update watcher.AuthUpdate) bool { + return w.DispatchRuntimeAuthUpdate(update) + }, }, nil }