mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-26 22:02:23 +00:00
Merge pull request #2106 from router-for-me/model
feat(model_registry): enhance model registration and refresh mechanisms
This commit is contained in:
@@ -187,6 +187,7 @@ func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
|
||||
}
|
||||
|
||||
const defaultModelRegistryHookTimeout = 5 * time.Second
|
||||
const modelQuotaExceededWindow = 5 * time.Minute
|
||||
|
||||
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
|
||||
hook := r.hook
|
||||
@@ -388,6 +389,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
reg.InfoByProvider[provider] = cloneModelInfo(model)
|
||||
}
|
||||
reg.LastUpdated = now
|
||||
// Re-registering an existing client/model binding starts a fresh registry
|
||||
// snapshot for that binding. Cooldown and suspension are transient
|
||||
// scheduling state and must not survive this reconciliation step.
|
||||
if reg.QuotaExceededClients != nil {
|
||||
delete(reg.QuotaExceededClients, clientID)
|
||||
}
|
||||
@@ -781,7 +785,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
||||
|
||||
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
|
||||
models := make([]map[string]any, 0, len(r.models))
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
var expiresAt time.Time
|
||||
|
||||
for _, registration := range r.models {
|
||||
@@ -792,7 +795,7 @@ func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.
|
||||
if quotaTime == nil {
|
||||
continue
|
||||
}
|
||||
recoveryAt := quotaTime.Add(quotaExpiredDuration)
|
||||
recoveryAt := quotaTime.Add(modelQuotaExceededWindow)
|
||||
if now.Before(recoveryAt) {
|
||||
expiredClients++
|
||||
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
|
||||
@@ -927,7 +930,6 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
||||
return nil
|
||||
}
|
||||
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
now := time.Now()
|
||||
result := make([]*ModelInfo, 0, len(providerModels))
|
||||
|
||||
@@ -949,7 +951,7 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
||||
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
|
||||
continue
|
||||
}
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
|
||||
expiredClients++
|
||||
}
|
||||
}
|
||||
@@ -1003,12 +1005,11 @@ func (r *ModelRegistry) GetModelCount(modelID string) int {
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
now := time.Now()
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
|
||||
// Count clients that have exceeded quota but haven't recovered yet
|
||||
expiredClients := 0
|
||||
for _, quotaTime := range registration.QuotaExceededClients {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
|
||||
expiredClients++
|
||||
}
|
||||
}
|
||||
@@ -1217,12 +1218,11 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
invalidated := false
|
||||
|
||||
for modelID, registration := range r.models {
|
||||
for clientID, quotaTime := range registration.QuotaExceededClients {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) >= modelQuotaExceededWindow {
|
||||
delete(registration.QuotaExceededClients, clientID)
|
||||
invalidated = true
|
||||
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
||||
|
||||
@@ -15,7 +15,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
modelsFetchTimeout = 30 * time.Second
|
||||
modelsFetchTimeout = 30 * time.Second
|
||||
modelsRefreshInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
var modelsURLs = []string{
|
||||
@@ -35,6 +36,34 @@ var modelsCatalogStore = &modelStore{}
|
||||
|
||||
var updaterOnce sync.Once
|
||||
|
||||
// ModelRefreshCallback is invoked when startup or periodic model refresh detects changes.
|
||||
// changedProviders contains the provider names whose model definitions changed.
|
||||
type ModelRefreshCallback func(changedProviders []string)
|
||||
|
||||
var (
|
||||
refreshCallbackMu sync.Mutex
|
||||
refreshCallback ModelRefreshCallback
|
||||
pendingRefreshChanges []string
|
||||
)
|
||||
|
||||
// SetModelRefreshCallback registers a callback that is invoked when startup or
|
||||
// periodic model refresh detects changes. Only one callback is supported;
|
||||
// subsequent calls replace the previous callback.
|
||||
func SetModelRefreshCallback(cb ModelRefreshCallback) {
|
||||
refreshCallbackMu.Lock()
|
||||
refreshCallback = cb
|
||||
var pending []string
|
||||
if cb != nil && len(pendingRefreshChanges) > 0 {
|
||||
pending = append([]string(nil), pendingRefreshChanges...)
|
||||
pendingRefreshChanges = nil
|
||||
}
|
||||
refreshCallbackMu.Unlock()
|
||||
|
||||
if cb != nil && len(pending) > 0 {
|
||||
cb(pending)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Load embedded data as fallback on startup.
|
||||
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
|
||||
@@ -42,23 +71,76 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// StartModelsUpdater runs a one-time models refresh on startup.
|
||||
// It blocks until the startup fetch attempt finishes so service initialization
|
||||
// can wait for the refreshed catalog before registering auth-backed models.
|
||||
// Safe to call multiple times; only one refresh will run.
|
||||
// StartModelsUpdater starts a background updater that fetches models
|
||||
// immediately on startup and then refreshes the model catalog every 3 hours.
|
||||
// Safe to call multiple times; only one updater will run.
|
||||
func StartModelsUpdater(ctx context.Context) {
|
||||
updaterOnce.Do(func() {
|
||||
runModelsUpdater(ctx)
|
||||
go runModelsUpdater(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func runModelsUpdater(ctx context.Context) {
|
||||
// Try network fetch once on startup, then stop.
|
||||
// Periodic refresh is disabled - models are only refreshed at startup.
|
||||
tryRefreshModels(ctx)
|
||||
tryStartupRefresh(ctx)
|
||||
periodicRefresh(ctx)
|
||||
}
|
||||
|
||||
func tryRefreshModels(ctx context.Context) {
|
||||
func periodicRefresh(ctx context.Context) {
|
||||
ticker := time.NewTicker(modelsRefreshInterval)
|
||||
defer ticker.Stop()
|
||||
log.Infof("periodic model refresh started (interval=%s)", modelsRefreshInterval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
tryPeriodicRefresh(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryPeriodicRefresh fetches models from remote, compares with the current
|
||||
// catalog, and notifies the registered callback if any provider changed.
|
||||
func tryPeriodicRefresh(ctx context.Context) {
|
||||
tryRefreshModels(ctx, "periodic model refresh")
|
||||
}
|
||||
|
||||
// tryStartupRefresh fetches models from remote in the background during
|
||||
// process startup. It uses the same change detection as periodic refresh so
|
||||
// existing auth registrations can be updated after the callback is registered.
|
||||
func tryStartupRefresh(ctx context.Context) {
|
||||
tryRefreshModels(ctx, "startup model refresh")
|
||||
}
|
||||
|
||||
func tryRefreshModels(ctx context.Context, label string) {
|
||||
oldData := getModels()
|
||||
|
||||
parsed, url := fetchModelsFromRemote(ctx)
|
||||
if parsed == nil {
|
||||
log.Warnf("%s: fetch failed from all URLs, keeping current data", label)
|
||||
return
|
||||
}
|
||||
|
||||
// Detect changes before updating store.
|
||||
changed := detectChangedProviders(oldData, parsed)
|
||||
|
||||
// Update store with new data regardless.
|
||||
modelsCatalogStore.mu.Lock()
|
||||
modelsCatalogStore.data = parsed
|
||||
modelsCatalogStore.mu.Unlock()
|
||||
|
||||
if len(changed) == 0 {
|
||||
log.Infof("%s completed from %s, no changes detected", label, url)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("%s completed from %s, changes detected for providers: %v", label, url, changed)
|
||||
notifyModelRefresh(changed)
|
||||
}
|
||||
|
||||
// fetchModelsFromRemote tries all remote URLs and returns the parsed model catalog
|
||||
// along with the URL it was fetched from. Returns (nil, "") if all fetches fail.
|
||||
func fetchModelsFromRemote(ctx context.Context) (*staticModelsJSON, string) {
|
||||
client := &http.Client{Timeout: modelsFetchTimeout}
|
||||
for _, url := range modelsURLs {
|
||||
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
|
||||
@@ -92,15 +174,126 @@ func tryRefreshModels(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := loadModelsFromBytes(data, url); err != nil {
|
||||
var parsed staticModelsJSON
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
log.Warnf("models parse failed from %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
if err := validateModelsCatalog(&parsed); err != nil {
|
||||
log.Warnf("models validate failed from %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("models updated from %s", url)
|
||||
return &parsed, url
|
||||
}
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
// detectChangedProviders compares two model catalogs and returns provider names
|
||||
// whose model definitions differ. Codex tiers (free/team/plus/pro) are grouped
|
||||
// under a single "codex" provider.
|
||||
func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
|
||||
if oldData == nil || newData == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type section struct {
|
||||
provider string
|
||||
oldList []*ModelInfo
|
||||
newList []*ModelInfo
|
||||
}
|
||||
|
||||
sections := []section{
|
||||
{"claude", oldData.Claude, newData.Claude},
|
||||
{"gemini", oldData.Gemini, newData.Gemini},
|
||||
{"vertex", oldData.Vertex, newData.Vertex},
|
||||
{"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI},
|
||||
{"aistudio", oldData.AIStudio, newData.AIStudio},
|
||||
{"codex", oldData.CodexFree, newData.CodexFree},
|
||||
{"codex", oldData.CodexTeam, newData.CodexTeam},
|
||||
{"codex", oldData.CodexPlus, newData.CodexPlus},
|
||||
{"codex", oldData.CodexPro, newData.CodexPro},
|
||||
{"qwen", oldData.Qwen, newData.Qwen},
|
||||
{"iflow", oldData.IFlow, newData.IFlow},
|
||||
{"kimi", oldData.Kimi, newData.Kimi},
|
||||
{"antigravity", oldData.Antigravity, newData.Antigravity},
|
||||
}
|
||||
|
||||
seen := make(map[string]bool, len(sections))
|
||||
var changed []string
|
||||
for _, s := range sections {
|
||||
if seen[s.provider] {
|
||||
continue
|
||||
}
|
||||
if modelSectionChanged(s.oldList, s.newList) {
|
||||
changed = append(changed, s.provider)
|
||||
seen[s.provider] = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// modelSectionChanged reports whether two model slices differ.
|
||||
func modelSectionChanged(a, b []*ModelInfo) bool {
|
||||
if len(a) != len(b) {
|
||||
return true
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return false
|
||||
}
|
||||
aj, err1 := json.Marshal(a)
|
||||
bj, err2 := json.Marshal(b)
|
||||
if err1 != nil || err2 != nil {
|
||||
return true
|
||||
}
|
||||
return string(aj) != string(bj)
|
||||
}
|
||||
|
||||
func notifyModelRefresh(changedProviders []string) {
|
||||
if len(changedProviders) == 0 {
|
||||
return
|
||||
}
|
||||
log.Warn("models refresh failed from all URLs, using local data")
|
||||
|
||||
refreshCallbackMu.Lock()
|
||||
cb := refreshCallback
|
||||
if cb == nil {
|
||||
pendingRefreshChanges = mergeProviderNames(pendingRefreshChanges, changedProviders)
|
||||
refreshCallbackMu.Unlock()
|
||||
return
|
||||
}
|
||||
refreshCallbackMu.Unlock()
|
||||
cb(changedProviders)
|
||||
}
|
||||
|
||||
func mergeProviderNames(existing, incoming []string) []string {
|
||||
if len(incoming) == 0 {
|
||||
return existing
|
||||
}
|
||||
seen := make(map[string]struct{}, len(existing)+len(incoming))
|
||||
merged := make([]string, 0, len(existing)+len(incoming))
|
||||
for _, provider := range existing {
|
||||
name := strings.ToLower(strings.TrimSpace(provider))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
merged = append(merged, name)
|
||||
}
|
||||
for _, provider := range incoming {
|
||||
name := strings.ToLower(strings.TrimSpace(provider))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
merged = append(merged, name)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func loadModelsFromBytes(data []byte, source string) error {
|
||||
|
||||
@@ -434,6 +434,17 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) registerResolvedModelsForAuth(a *coreauth.Auth, providerKey string, models []*ModelInfo) {
|
||||
if a == nil || a.ID == "" {
|
||||
return
|
||||
}
|
||||
if len(models) == 0 {
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
return
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, providerKey, models)
|
||||
}
|
||||
|
||||
// rebindExecutors refreshes provider executors so they observe the latest configuration.
|
||||
func (s *Service) rebindExecutors() {
|
||||
if s == nil || s.coreManager == nil {
|
||||
@@ -541,6 +552,44 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
s.hooks.OnBeforeStart(s.cfg)
|
||||
}
|
||||
|
||||
// Register callback for startup and periodic model catalog refresh.
|
||||
// When remote model definitions change, re-register models for affected providers.
|
||||
// This intentionally rebuilds per-auth model availability from the latest catalog
|
||||
// snapshot instead of preserving prior registry suppression state.
|
||||
registry.SetModelRefreshCallback(func(changedProviders []string) {
|
||||
if s == nil || s.coreManager == nil || len(changedProviders) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
providerSet := make(map[string]bool, len(changedProviders))
|
||||
for _, p := range changedProviders {
|
||||
providerSet[strings.ToLower(strings.TrimSpace(p))] = true
|
||||
}
|
||||
|
||||
auths := s.coreManager.List()
|
||||
refreshed := 0
|
||||
for _, item := range auths {
|
||||
if item == nil || item.ID == "" {
|
||||
continue
|
||||
}
|
||||
auth, ok := s.coreManager.GetByID(item.ID)
|
||||
if !ok || auth == nil || auth.Disabled {
|
||||
continue
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if !providerSet[provider] {
|
||||
continue
|
||||
}
|
||||
if s.refreshModelRegistrationForAuth(auth) {
|
||||
refreshed++
|
||||
}
|
||||
}
|
||||
|
||||
if refreshed > 0 {
|
||||
log.Infof("re-registered models for %d auth(s) due to model catalog changes: %v", refreshed, changedProviders)
|
||||
}
|
||||
})
|
||||
|
||||
s.serverErr = make(chan error, 1)
|
||||
go func() {
|
||||
if errStart := s.server.Start(); errStart != nil {
|
||||
@@ -926,7 +975,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if providerKey == "" {
|
||||
providerKey = "openai-compatibility"
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
|
||||
s.registerResolvedModelsForAuth(a, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
|
||||
} else {
|
||||
// Ensure stale registrations are cleared when model list becomes empty.
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
@@ -947,13 +996,60 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if key == "" {
|
||||
key = strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||
s.registerResolvedModelsForAuth(a, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||
return
|
||||
}
|
||||
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
}
|
||||
|
||||
// refreshModelRegistrationForAuth re-applies the latest model registration for
|
||||
// one auth and reconciles any concurrent auth changes that race with the
|
||||
// refresh. Callers are expected to pre-filter provider membership.
|
||||
//
|
||||
// Re-registration is deliberate: registry cooldown/suspension state is treated
|
||||
// as part of the previous registration snapshot and is cleared when the auth is
|
||||
// rebound to the refreshed model catalog.
|
||||
func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool {
|
||||
if s == nil || s.coreManager == nil || current == nil || current.ID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if !current.Disabled {
|
||||
s.ensureExecutorsForAuth(current)
|
||||
}
|
||||
s.registerModelsForAuth(current)
|
||||
|
||||
latest, ok := s.latestAuthForModelRegistration(current.ID)
|
||||
if !ok || latest.Disabled {
|
||||
GlobalModelRegistry().UnregisterClient(current.ID)
|
||||
s.coreManager.RefreshSchedulerEntry(current.ID)
|
||||
return false
|
||||
}
|
||||
|
||||
// Re-apply the latest auth snapshot so concurrent auth updates cannot leave
|
||||
// stale model registrations behind. This may duplicate registration work when
|
||||
// no auth fields changed, but keeps the refresh path simple and correct.
|
||||
s.ensureExecutorsForAuth(latest)
|
||||
s.registerModelsForAuth(latest)
|
||||
s.coreManager.RefreshSchedulerEntry(current.ID)
|
||||
return true
|
||||
}
|
||||
|
||||
// latestAuthForModelRegistration returns the latest auth snapshot regardless of
|
||||
// provider membership. Callers use this after a registration attempt to restore
|
||||
// whichever state currently owns the client ID in the global registry.
|
||||
func (s *Service) latestAuthForModelRegistration(authID string) (*coreauth.Auth, bool) {
|
||||
if s == nil || s.coreManager == nil || authID == "" {
|
||||
return nil, false
|
||||
}
|
||||
auth, ok := s.coreManager.GetByID(authID)
|
||||
if !ok || auth == nil || auth.ID == "" {
|
||||
return nil, false
|
||||
}
|
||||
return auth, true
|
||||
}
|
||||
|
||||
func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey {
|
||||
if auth == nil || s.cfg == nil {
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user