diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index 8f56c43d..74ad6acf 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -187,6 +187,7 @@ func (r *ModelRegistry) SetHook(hook ModelRegistryHook) { } const defaultModelRegistryHookTimeout = 5 * time.Second +const modelQuotaExceededWindow = 5 * time.Minute func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) { hook := r.hook @@ -388,6 +389,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ reg.InfoByProvider[provider] = cloneModelInfo(model) } reg.LastUpdated = now + // Re-registering an existing client/model binding starts a fresh registry + // snapshot for that binding. Cooldown and suspension are transient + // scheduling state and must not survive this reconciliation step. if reg.QuotaExceededClients != nil { delete(reg.QuotaExceededClients, clientID) } @@ -781,7 +785,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) { models := make([]map[string]any, 0, len(r.models)) - quotaExpiredDuration := 5 * time.Minute var expiresAt time.Time for _, registration := range r.models { @@ -792,7 +795,7 @@ func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time. if quotaTime == nil { continue } - recoveryAt := quotaTime.Add(quotaExpiredDuration) + recoveryAt := quotaTime.Add(modelQuotaExceededWindow) if now.Before(recoveryAt) { expiredClients++ if expiresAt.IsZero() || recoveryAt.Before(expiresAt) { @@ -927,7 +930,6 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn return nil } - quotaExpiredDuration := 5 * time.Minute now := time.Now() result := make([]*ModelInfo, 0, len(providerModels)) @@ -949,7 +951,7 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider { continue } - if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration { + if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow { expiredClients++ } } @@ -1003,12 +1005,11 @@ func (r *ModelRegistry) GetModelCount(modelID string) int { if registration, exists := r.models[modelID]; exists { now := time.Now() - quotaExpiredDuration := 5 * time.Minute // Count clients that have exceeded quota but haven't recovered yet expiredClients := 0 for _, quotaTime := range registration.QuotaExceededClients { - if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration { + if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow { expiredClients++ } } @@ -1217,12 +1218,11 @@ func (r *ModelRegistry) CleanupExpiredQuotas() { defer r.mutex.Unlock() now := time.Now() - quotaExpiredDuration := 5 * time.Minute invalidated := false for modelID, registration := range r.models { for clientID, quotaTime := range registration.QuotaExceededClients { - if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration { + if quotaTime != nil && now.Sub(*quotaTime) >= modelQuotaExceededWindow { delete(registration.QuotaExceededClients, clientID) invalidated = true log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID) diff --git a/internal/registry/model_updater.go b/internal/registry/model_updater.go index 36d2dd32..197f6044 100644 --- a/internal/registry/model_updater.go +++ b/internal/registry/model_updater.go @@ -15,7 +15,8 @@ import ( ) const ( - modelsFetchTimeout = 30 * time.Second + modelsFetchTimeout = 30 * time.Second + modelsRefreshInterval = 3 * time.Hour ) var modelsURLs = []string{ @@ -35,6 +36,34 @@ var modelsCatalogStore = &modelStore{} var updaterOnce sync.Once +// ModelRefreshCallback is invoked when startup or periodic model refresh detects changes. +// changedProviders contains the provider names whose model definitions changed. +type ModelRefreshCallback func(changedProviders []string) + +var ( + refreshCallbackMu sync.Mutex + refreshCallback ModelRefreshCallback + pendingRefreshChanges []string +) + +// SetModelRefreshCallback registers a callback that is invoked when startup or +// periodic model refresh detects changes. Only one callback is supported; +// subsequent calls replace the previous callback. +func SetModelRefreshCallback(cb ModelRefreshCallback) { + refreshCallbackMu.Lock() + refreshCallback = cb + var pending []string + if cb != nil && len(pendingRefreshChanges) > 0 { + pending = append([]string(nil), pendingRefreshChanges...) + pendingRefreshChanges = nil + } + refreshCallbackMu.Unlock() + + if cb != nil && len(pending) > 0 { + cb(pending) + } +} + func init() { // Load embedded data as fallback on startup. if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil { @@ -42,23 +71,76 @@ func init() { } } -// StartModelsUpdater runs a one-time models refresh on startup. -// It blocks until the startup fetch attempt finishes so service initialization -// can wait for the refreshed catalog before registering auth-backed models. -// Safe to call multiple times; only one refresh will run. +// StartModelsUpdater starts a background updater that fetches models +// immediately on startup and then refreshes the model catalog every 3 hours. +// Safe to call multiple times; only one updater will run. func StartModelsUpdater(ctx context.Context) { updaterOnce.Do(func() { - runModelsUpdater(ctx) + go runModelsUpdater(ctx) }) } func runModelsUpdater(ctx context.Context) { - // Try network fetch once on startup, then stop. - // Periodic refresh is disabled - models are only refreshed at startup. - tryRefreshModels(ctx) + tryStartupRefresh(ctx) + periodicRefresh(ctx) } -func tryRefreshModels(ctx context.Context) { +func periodicRefresh(ctx context.Context) { + ticker := time.NewTicker(modelsRefreshInterval) + defer ticker.Stop() + log.Infof("periodic model refresh started (interval=%s)", modelsRefreshInterval) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + tryPeriodicRefresh(ctx) + } + } +} + +// tryPeriodicRefresh fetches models from remote, compares with the current +// catalog, and notifies the registered callback if any provider changed. +func tryPeriodicRefresh(ctx context.Context) { + tryRefreshModels(ctx, "periodic model refresh") +} + +// tryStartupRefresh fetches models from remote in the background during +// process startup. It uses the same change detection as periodic refresh so +// existing auth registrations can be updated after the callback is registered. +func tryStartupRefresh(ctx context.Context) { + tryRefreshModels(ctx, "startup model refresh") +} + +func tryRefreshModels(ctx context.Context, label string) { + oldData := getModels() + + parsed, url := fetchModelsFromRemote(ctx) + if parsed == nil { + log.Warnf("%s: fetch failed from all URLs, keeping current data", label) + return + } + + // Detect changes before updating store. + changed := detectChangedProviders(oldData, parsed) + + // Update store with new data regardless. + modelsCatalogStore.mu.Lock() + modelsCatalogStore.data = parsed + modelsCatalogStore.mu.Unlock() + + if len(changed) == 0 { + log.Infof("%s completed from %s, no changes detected", label, url) + return + } + + log.Infof("%s completed from %s, changes detected for providers: %v", label, url, changed) + notifyModelRefresh(changed) +} + +// fetchModelsFromRemote tries all remote URLs and returns the parsed model catalog +// along with the URL it was fetched from. Returns (nil, "") if all fetches fail. +func fetchModelsFromRemote(ctx context.Context) (*staticModelsJSON, string) { client := &http.Client{Timeout: modelsFetchTimeout} for _, url := range modelsURLs { reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout) @@ -92,15 +174,126 @@ func tryRefreshModels(ctx context.Context) { continue } - if err := loadModelsFromBytes(data, url); err != nil { + var parsed staticModelsJSON + if err := json.Unmarshal(data, &parsed); err != nil { log.Warnf("models parse failed from %s: %v", url, err) continue } + if err := validateModelsCatalog(&parsed); err != nil { + log.Warnf("models validate failed from %s: %v", url, err) + continue + } - log.Infof("models updated from %s", url) + return &parsed, url + } + return nil, "" +} + +// detectChangedProviders compares two model catalogs and returns provider names +// whose model definitions differ. Codex tiers (free/team/plus/pro) are grouped +// under a single "codex" provider. +func detectChangedProviders(oldData, newData *staticModelsJSON) []string { + if oldData == nil || newData == nil { + return nil + } + + type section struct { + provider string + oldList []*ModelInfo + newList []*ModelInfo + } + + sections := []section{ + {"claude", oldData.Claude, newData.Claude}, + {"gemini", oldData.Gemini, newData.Gemini}, + {"vertex", oldData.Vertex, newData.Vertex}, + {"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI}, + {"aistudio", oldData.AIStudio, newData.AIStudio}, + {"codex", oldData.CodexFree, newData.CodexFree}, + {"codex", oldData.CodexTeam, newData.CodexTeam}, + {"codex", oldData.CodexPlus, newData.CodexPlus}, + {"codex", oldData.CodexPro, newData.CodexPro}, + {"qwen", oldData.Qwen, newData.Qwen}, + {"iflow", oldData.IFlow, newData.IFlow}, + {"kimi", oldData.Kimi, newData.Kimi}, + {"antigravity", oldData.Antigravity, newData.Antigravity}, + } + + seen := make(map[string]bool, len(sections)) + var changed []string + for _, s := range sections { + if seen[s.provider] { + continue + } + if modelSectionChanged(s.oldList, s.newList) { + changed = append(changed, s.provider) + seen[s.provider] = true + } + } + return changed +} + +// modelSectionChanged reports whether two model slices differ. +func modelSectionChanged(a, b []*ModelInfo) bool { + if len(a) != len(b) { + return true + } + if len(a) == 0 { + return false + } + aj, err1 := json.Marshal(a) + bj, err2 := json.Marshal(b) + if err1 != nil || err2 != nil { + return true + } + return string(aj) != string(bj) +} + +func notifyModelRefresh(changedProviders []string) { + if len(changedProviders) == 0 { return } - log.Warn("models refresh failed from all URLs, using local data") + + refreshCallbackMu.Lock() + cb := refreshCallback + if cb == nil { + pendingRefreshChanges = mergeProviderNames(pendingRefreshChanges, changedProviders) + refreshCallbackMu.Unlock() + return + } + refreshCallbackMu.Unlock() + cb(changedProviders) +} + +func mergeProviderNames(existing, incoming []string) []string { + if len(incoming) == 0 { + return existing + } + seen := make(map[string]struct{}, len(existing)+len(incoming)) + merged := make([]string, 0, len(existing)+len(incoming)) + for _, provider := range existing { + name := strings.ToLower(strings.TrimSpace(provider)) + if name == "" { + continue + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + merged = append(merged, name) + } + for _, provider := range incoming { + name := strings.ToLower(strings.TrimSpace(provider)) + if name == "" { + continue + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + merged = append(merged, name) + } + return merged } func loadModelsFromBytes(data []byte, source string) error { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index af31f86a..abe1deed 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -434,6 +434,17 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace } } +func (s *Service) registerResolvedModelsForAuth(a *coreauth.Auth, providerKey string, models []*ModelInfo) { + if a == nil || a.ID == "" { + return + } + if len(models) == 0 { + GlobalModelRegistry().UnregisterClient(a.ID) + return + } + GlobalModelRegistry().RegisterClient(a.ID, providerKey, models) +} + // rebindExecutors refreshes provider executors so they observe the latest configuration. func (s *Service) rebindExecutors() { if s == nil || s.coreManager == nil { @@ -541,6 +552,44 @@ func (s *Service) Run(ctx context.Context) error { s.hooks.OnBeforeStart(s.cfg) } + // Register callback for startup and periodic model catalog refresh. + // When remote model definitions change, re-register models for affected providers. + // This intentionally rebuilds per-auth model availability from the latest catalog + // snapshot instead of preserving prior registry suppression state. + registry.SetModelRefreshCallback(func(changedProviders []string) { + if s == nil || s.coreManager == nil || len(changedProviders) == 0 { + return + } + + providerSet := make(map[string]bool, len(changedProviders)) + for _, p := range changedProviders { + providerSet[strings.ToLower(strings.TrimSpace(p))] = true + } + + auths := s.coreManager.List() + refreshed := 0 + for _, item := range auths { + if item == nil || item.ID == "" { + continue + } + auth, ok := s.coreManager.GetByID(item.ID) + if !ok || auth == nil || auth.Disabled { + continue + } + provider := strings.ToLower(strings.TrimSpace(auth.Provider)) + if !providerSet[provider] { + continue + } + if s.refreshModelRegistrationForAuth(auth) { + refreshed++ + } + } + + if refreshed > 0 { + log.Infof("re-registered models for %d auth(s) due to model catalog changes: %v", refreshed, changedProviders) + } + }) + s.serverErr = make(chan error, 1) go func() { if errStart := s.server.Start(); errStart != nil { @@ -926,7 +975,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { if providerKey == "" { providerKey = "openai-compatibility" } - GlobalModelRegistry().RegisterClient(a.ID, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix)) + s.registerResolvedModelsForAuth(a, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix)) } else { // Ensure stale registrations are cleared when model list becomes empty. GlobalModelRegistry().UnregisterClient(a.ID) @@ -947,13 +996,60 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { if key == "" { key = strings.ToLower(strings.TrimSpace(a.Provider)) } - GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix)) + s.registerResolvedModelsForAuth(a, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix)) return } GlobalModelRegistry().UnregisterClient(a.ID) } +// refreshModelRegistrationForAuth re-applies the latest model registration for +// one auth and reconciles any concurrent auth changes that race with the +// refresh. Callers are expected to pre-filter provider membership. +// +// Re-registration is deliberate: registry cooldown/suspension state is treated +// as part of the previous registration snapshot and is cleared when the auth is +// rebound to the refreshed model catalog. +func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool { + if s == nil || s.coreManager == nil || current == nil || current.ID == "" { + return false + } + + if !current.Disabled { + s.ensureExecutorsForAuth(current) + } + s.registerModelsForAuth(current) + + latest, ok := s.latestAuthForModelRegistration(current.ID) + if !ok || latest.Disabled { + GlobalModelRegistry().UnregisterClient(current.ID) + s.coreManager.RefreshSchedulerEntry(current.ID) + return false + } + + // Re-apply the latest auth snapshot so concurrent auth updates cannot leave + // stale model registrations behind. This may duplicate registration work when + // no auth fields changed, but keeps the refresh path simple and correct. + s.ensureExecutorsForAuth(latest) + s.registerModelsForAuth(latest) + s.coreManager.RefreshSchedulerEntry(current.ID) + return true +} + +// latestAuthForModelRegistration returns the latest auth snapshot regardless of +// provider membership. Callers use this after a registration attempt to restore +// whichever state currently owns the client ID in the global registry. +func (s *Service) latestAuthForModelRegistration(authID string) (*coreauth.Auth, bool) { + if s == nil || s.coreManager == nil || authID == "" { + return nil, false + } + auth, ok := s.coreManager.GetByID(authID) + if !ok || auth == nil || auth.ID == "" { + return nil, false + } + return auth, true +} + func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey { if auth == nil || s.cfg == nil { return nil