Merge pull request #2106 from router-for-me/model

feat(model_registry): enhance model registration and refresh mechanisms
This commit is contained in:
Luis Pater
2026-03-13 11:18:51 +08:00
committed by GitHub
3 changed files with 312 additions and 23 deletions

View File

@@ -187,6 +187,7 @@ func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
}
const defaultModelRegistryHookTimeout = 5 * time.Second
const modelQuotaExceededWindow = 5 * time.Minute
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
hook := r.hook
@@ -388,6 +389,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
reg.InfoByProvider[provider] = cloneModelInfo(model)
}
reg.LastUpdated = now
// Re-registering an existing client/model binding starts a fresh registry
// snapshot for that binding. Cooldown and suspension are transient
// scheduling state and must not survive this reconciliation step.
if reg.QuotaExceededClients != nil {
delete(reg.QuotaExceededClients, clientID)
}
@@ -781,7 +785,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
models := make([]map[string]any, 0, len(r.models))
quotaExpiredDuration := 5 * time.Minute
var expiresAt time.Time
for _, registration := range r.models {
@@ -792,7 +795,7 @@ func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.
if quotaTime == nil {
continue
}
recoveryAt := quotaTime.Add(quotaExpiredDuration)
recoveryAt := quotaTime.Add(modelQuotaExceededWindow)
if now.Before(recoveryAt) {
expiredClients++
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
@@ -927,7 +930,6 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
return nil
}
quotaExpiredDuration := 5 * time.Minute
now := time.Now()
result := make([]*ModelInfo, 0, len(providerModels))
@@ -949,7 +951,7 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue
}
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
expiredClients++
}
}
@@ -1003,12 +1005,11 @@ func (r *ModelRegistry) GetModelCount(modelID string) int {
if registration, exists := r.models[modelID]; exists {
now := time.Now()
quotaExpiredDuration := 5 * time.Minute
// Count clients that have exceeded quota but haven't recovered yet
expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
expiredClients++
}
}
@@ -1217,12 +1218,11 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
defer r.mutex.Unlock()
now := time.Now()
quotaExpiredDuration := 5 * time.Minute
invalidated := false
for modelID, registration := range r.models {
for clientID, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
if quotaTime != nil && now.Sub(*quotaTime) >= modelQuotaExceededWindow {
delete(registration.QuotaExceededClients, clientID)
invalidated = true
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)

View File

@@ -15,7 +15,8 @@ import (
)
const (
modelsFetchTimeout = 30 * time.Second
modelsFetchTimeout = 30 * time.Second
modelsRefreshInterval = 3 * time.Hour
)
var modelsURLs = []string{
@@ -35,6 +36,34 @@ var modelsCatalogStore = &modelStore{}
var updaterOnce sync.Once
// ModelRefreshCallback is invoked when startup or periodic model refresh detects changes.
// changedProviders contains the provider names whose model definitions changed.
type ModelRefreshCallback func(changedProviders []string)
var (
refreshCallbackMu sync.Mutex
refreshCallback ModelRefreshCallback
pendingRefreshChanges []string
)
// SetModelRefreshCallback registers a callback that is invoked when startup or
// periodic model refresh detects changes. Only one callback is supported;
// subsequent calls replace the previous callback.
func SetModelRefreshCallback(cb ModelRefreshCallback) {
refreshCallbackMu.Lock()
refreshCallback = cb
var pending []string
if cb != nil && len(pendingRefreshChanges) > 0 {
pending = append([]string(nil), pendingRefreshChanges...)
pendingRefreshChanges = nil
}
refreshCallbackMu.Unlock()
if cb != nil && len(pending) > 0 {
cb(pending)
}
}
func init() {
// Load embedded data as fallback on startup.
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
@@ -42,23 +71,76 @@ func init() {
}
}
// StartModelsUpdater runs a one-time models refresh on startup.
// It blocks until the startup fetch attempt finishes so service initialization
// can wait for the refreshed catalog before registering auth-backed models.
// Safe to call multiple times; only one refresh will run.
// StartModelsUpdater starts a background updater that fetches models
// immediately on startup and then refreshes the model catalog every 3 hours.
// Safe to call multiple times; only one updater will run.
func StartModelsUpdater(ctx context.Context) {
updaterOnce.Do(func() {
runModelsUpdater(ctx)
go runModelsUpdater(ctx)
})
}
func runModelsUpdater(ctx context.Context) {
// Try network fetch once on startup, then stop.
// Periodic refresh is disabled - models are only refreshed at startup.
tryRefreshModels(ctx)
tryStartupRefresh(ctx)
periodicRefresh(ctx)
}
func tryRefreshModels(ctx context.Context) {
func periodicRefresh(ctx context.Context) {
ticker := time.NewTicker(modelsRefreshInterval)
defer ticker.Stop()
log.Infof("periodic model refresh started (interval=%s)", modelsRefreshInterval)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
tryPeriodicRefresh(ctx)
}
}
}
// tryPeriodicRefresh fetches models from remote, compares with the current
// catalog, and notifies the registered callback if any provider changed.
func tryPeriodicRefresh(ctx context.Context) {
tryRefreshModels(ctx, "periodic model refresh")
}
// tryStartupRefresh fetches models from remote in the background during
// process startup. It uses the same change detection as periodic refresh so
// existing auth registrations can be updated after the callback is registered.
func tryStartupRefresh(ctx context.Context) {
tryRefreshModels(ctx, "startup model refresh")
}
func tryRefreshModels(ctx context.Context, label string) {
oldData := getModels()
parsed, url := fetchModelsFromRemote(ctx)
if parsed == nil {
log.Warnf("%s: fetch failed from all URLs, keeping current data", label)
return
}
// Detect changes before updating store.
changed := detectChangedProviders(oldData, parsed)
// Update store with new data regardless.
modelsCatalogStore.mu.Lock()
modelsCatalogStore.data = parsed
modelsCatalogStore.mu.Unlock()
if len(changed) == 0 {
log.Infof("%s completed from %s, no changes detected", label, url)
return
}
log.Infof("%s completed from %s, changes detected for providers: %v", label, url, changed)
notifyModelRefresh(changed)
}
// fetchModelsFromRemote tries all remote URLs and returns the parsed model catalog
// along with the URL it was fetched from. Returns (nil, "") if all fetches fail.
func fetchModelsFromRemote(ctx context.Context) (*staticModelsJSON, string) {
client := &http.Client{Timeout: modelsFetchTimeout}
for _, url := range modelsURLs {
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
@@ -92,15 +174,126 @@ func tryRefreshModels(ctx context.Context) {
continue
}
if err := loadModelsFromBytes(data, url); err != nil {
var parsed staticModelsJSON
if err := json.Unmarshal(data, &parsed); err != nil {
log.Warnf("models parse failed from %s: %v", url, err)
continue
}
if err := validateModelsCatalog(&parsed); err != nil {
log.Warnf("models validate failed from %s: %v", url, err)
continue
}
log.Infof("models updated from %s", url)
return &parsed, url
}
return nil, ""
}
// detectChangedProviders compares two model catalogs and returns provider names
// whose model definitions differ. Codex tiers (free/team/plus/pro) are grouped
// under a single "codex" provider.
func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
if oldData == nil || newData == nil {
return nil
}
type section struct {
provider string
oldList []*ModelInfo
newList []*ModelInfo
}
sections := []section{
{"claude", oldData.Claude, newData.Claude},
{"gemini", oldData.Gemini, newData.Gemini},
{"vertex", oldData.Vertex, newData.Vertex},
{"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI},
{"aistudio", oldData.AIStudio, newData.AIStudio},
{"codex", oldData.CodexFree, newData.CodexFree},
{"codex", oldData.CodexTeam, newData.CodexTeam},
{"codex", oldData.CodexPlus, newData.CodexPlus},
{"codex", oldData.CodexPro, newData.CodexPro},
{"qwen", oldData.Qwen, newData.Qwen},
{"iflow", oldData.IFlow, newData.IFlow},
{"kimi", oldData.Kimi, newData.Kimi},
{"antigravity", oldData.Antigravity, newData.Antigravity},
}
seen := make(map[string]bool, len(sections))
var changed []string
for _, s := range sections {
if seen[s.provider] {
continue
}
if modelSectionChanged(s.oldList, s.newList) {
changed = append(changed, s.provider)
seen[s.provider] = true
}
}
return changed
}
// modelSectionChanged reports whether two model slices differ.
func modelSectionChanged(a, b []*ModelInfo) bool {
if len(a) != len(b) {
return true
}
if len(a) == 0 {
return false
}
aj, err1 := json.Marshal(a)
bj, err2 := json.Marshal(b)
if err1 != nil || err2 != nil {
return true
}
return string(aj) != string(bj)
}
func notifyModelRefresh(changedProviders []string) {
if len(changedProviders) == 0 {
return
}
log.Warn("models refresh failed from all URLs, using local data")
refreshCallbackMu.Lock()
cb := refreshCallback
if cb == nil {
pendingRefreshChanges = mergeProviderNames(pendingRefreshChanges, changedProviders)
refreshCallbackMu.Unlock()
return
}
refreshCallbackMu.Unlock()
cb(changedProviders)
}
func mergeProviderNames(existing, incoming []string) []string {
if len(incoming) == 0 {
return existing
}
seen := make(map[string]struct{}, len(existing)+len(incoming))
merged := make([]string, 0, len(existing)+len(incoming))
for _, provider := range existing {
name := strings.ToLower(strings.TrimSpace(provider))
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
merged = append(merged, name)
}
for _, provider := range incoming {
name := strings.ToLower(strings.TrimSpace(provider))
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
merged = append(merged, name)
}
return merged
}
func loadModelsFromBytes(data []byte, source string) error {

View File

@@ -434,6 +434,17 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
}
}
func (s *Service) registerResolvedModelsForAuth(a *coreauth.Auth, providerKey string, models []*ModelInfo) {
if a == nil || a.ID == "" {
return
}
if len(models) == 0 {
GlobalModelRegistry().UnregisterClient(a.ID)
return
}
GlobalModelRegistry().RegisterClient(a.ID, providerKey, models)
}
// rebindExecutors refreshes provider executors so they observe the latest configuration.
func (s *Service) rebindExecutors() {
if s == nil || s.coreManager == nil {
@@ -541,6 +552,44 @@ func (s *Service) Run(ctx context.Context) error {
s.hooks.OnBeforeStart(s.cfg)
}
// Register callback for startup and periodic model catalog refresh.
// When remote model definitions change, re-register models for affected providers.
// This intentionally rebuilds per-auth model availability from the latest catalog
// snapshot instead of preserving prior registry suppression state.
registry.SetModelRefreshCallback(func(changedProviders []string) {
if s == nil || s.coreManager == nil || len(changedProviders) == 0 {
return
}
providerSet := make(map[string]bool, len(changedProviders))
for _, p := range changedProviders {
providerSet[strings.ToLower(strings.TrimSpace(p))] = true
}
auths := s.coreManager.List()
refreshed := 0
for _, item := range auths {
if item == nil || item.ID == "" {
continue
}
auth, ok := s.coreManager.GetByID(item.ID)
if !ok || auth == nil || auth.Disabled {
continue
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if !providerSet[provider] {
continue
}
if s.refreshModelRegistrationForAuth(auth) {
refreshed++
}
}
if refreshed > 0 {
log.Infof("re-registered models for %d auth(s) due to model catalog changes: %v", refreshed, changedProviders)
}
})
s.serverErr = make(chan error, 1)
go func() {
if errStart := s.server.Start(); errStart != nil {
@@ -926,7 +975,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
if providerKey == "" {
providerKey = "openai-compatibility"
}
GlobalModelRegistry().RegisterClient(a.ID, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
s.registerResolvedModelsForAuth(a, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
} else {
// Ensure stale registrations are cleared when model list becomes empty.
GlobalModelRegistry().UnregisterClient(a.ID)
@@ -947,13 +996,60 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
if key == "" {
key = strings.ToLower(strings.TrimSpace(a.Provider))
}
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
s.registerResolvedModelsForAuth(a, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
return
}
GlobalModelRegistry().UnregisterClient(a.ID)
}
// refreshModelRegistrationForAuth re-applies the latest model registration for
// one auth and reconciles any concurrent auth changes that race with the
// refresh. Callers are expected to pre-filter provider membership.
//
// Re-registration is deliberate: registry cooldown/suspension state is treated
// as part of the previous registration snapshot and is cleared when the auth is
// rebound to the refreshed model catalog.
func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool {
if s == nil || s.coreManager == nil || current == nil || current.ID == "" {
return false
}
if !current.Disabled {
s.ensureExecutorsForAuth(current)
}
s.registerModelsForAuth(current)
latest, ok := s.latestAuthForModelRegistration(current.ID)
if !ok || latest.Disabled {
GlobalModelRegistry().UnregisterClient(current.ID)
s.coreManager.RefreshSchedulerEntry(current.ID)
return false
}
// Re-apply the latest auth snapshot so concurrent auth updates cannot leave
// stale model registrations behind. This may duplicate registration work when
// no auth fields changed, but keeps the refresh path simple and correct.
s.ensureExecutorsForAuth(latest)
s.registerModelsForAuth(latest)
s.coreManager.RefreshSchedulerEntry(current.ID)
return true
}
// latestAuthForModelRegistration returns the latest auth snapshot regardless of
// provider membership. Callers use this after a registration attempt to restore
// whichever state currently owns the client ID in the global registry.
func (s *Service) latestAuthForModelRegistration(authID string) (*coreauth.Auth, bool) {
if s == nil || s.coreManager == nil || authID == "" {
return nil, false
}
auth, ok := s.coreManager.GetByID(authID)
if !ok || auth == nil || auth.ID == "" {
return nil, false
}
return auth, true
}
func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey {
if auth == nil || s.cfg == nil {
return nil