mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-30 01:06:39 +00:00
373 lines
9.7 KiB
Go
373 lines
9.7 KiB
Go
package registry
|
|
|
|
import (
|
|
"context"
|
|
_ "embed"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
modelsFetchTimeout = 30 * time.Second
|
|
modelsRefreshInterval = 3 * time.Hour
|
|
)
|
|
|
|
var modelsURLs = []string{
|
|
"https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json",
|
|
"https://models.router-for.me/models.json",
|
|
}
|
|
|
|
//go:embed models/models.json
|
|
var embeddedModelsJSON []byte
|
|
|
|
type modelStore struct {
|
|
mu sync.RWMutex
|
|
data *staticModelsJSON
|
|
}
|
|
|
|
var modelsCatalogStore = &modelStore{}
|
|
|
|
var updaterOnce sync.Once
|
|
|
|
// ModelRefreshCallback is invoked when startup or periodic model refresh detects changes.
|
|
// changedProviders contains the provider names whose model definitions changed.
|
|
type ModelRefreshCallback func(changedProviders []string)
|
|
|
|
var (
|
|
refreshCallbackMu sync.Mutex
|
|
refreshCallback ModelRefreshCallback
|
|
pendingRefreshChanges []string
|
|
)
|
|
|
|
// SetModelRefreshCallback registers a callback that is invoked when startup or
|
|
// periodic model refresh detects changes. Only one callback is supported;
|
|
// subsequent calls replace the previous callback.
|
|
func SetModelRefreshCallback(cb ModelRefreshCallback) {
|
|
refreshCallbackMu.Lock()
|
|
refreshCallback = cb
|
|
var pending []string
|
|
if cb != nil && len(pendingRefreshChanges) > 0 {
|
|
pending = append([]string(nil), pendingRefreshChanges...)
|
|
pendingRefreshChanges = nil
|
|
}
|
|
refreshCallbackMu.Unlock()
|
|
|
|
if cb != nil && len(pending) > 0 {
|
|
cb(pending)
|
|
}
|
|
}
|
|
|
|
func init() {
|
|
// Load embedded data as fallback on startup.
|
|
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
|
|
panic(fmt.Sprintf("registry: failed to parse embedded models.json: %v", err))
|
|
}
|
|
}
|
|
|
|
// StartModelsUpdater starts a background updater that fetches models
|
|
// immediately on startup and then refreshes the model catalog every 3 hours.
|
|
// Safe to call multiple times; only one updater will run.
|
|
func StartModelsUpdater(ctx context.Context) {
|
|
updaterOnce.Do(func() {
|
|
go runModelsUpdater(ctx)
|
|
})
|
|
}
|
|
|
|
func runModelsUpdater(ctx context.Context) {
|
|
tryStartupRefresh(ctx)
|
|
periodicRefresh(ctx)
|
|
}
|
|
|
|
func periodicRefresh(ctx context.Context) {
|
|
ticker := time.NewTicker(modelsRefreshInterval)
|
|
defer ticker.Stop()
|
|
log.Infof("periodic model refresh started (interval=%s)", modelsRefreshInterval)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
tryPeriodicRefresh(ctx)
|
|
}
|
|
}
|
|
}
|
|
|
|
// tryPeriodicRefresh fetches models from remote, compares with the current
|
|
// catalog, and notifies the registered callback if any provider changed.
|
|
func tryPeriodicRefresh(ctx context.Context) {
|
|
tryRefreshModels(ctx, "periodic model refresh")
|
|
}
|
|
|
|
// tryStartupRefresh fetches models from remote in the background during
|
|
// process startup. It uses the same change detection as periodic refresh so
|
|
// existing auth registrations can be updated after the callback is registered.
|
|
func tryStartupRefresh(ctx context.Context) {
|
|
tryRefreshModels(ctx, "startup model refresh")
|
|
}
|
|
|
|
func tryRefreshModels(ctx context.Context, label string) {
|
|
oldData := getModels()
|
|
|
|
parsed, url := fetchModelsFromRemote(ctx)
|
|
if parsed == nil {
|
|
log.Warnf("%s: fetch failed from all URLs, keeping current data", label)
|
|
return
|
|
}
|
|
|
|
// Detect changes before updating store.
|
|
changed := detectChangedProviders(oldData, parsed)
|
|
|
|
// Update store with new data regardless.
|
|
modelsCatalogStore.mu.Lock()
|
|
modelsCatalogStore.data = parsed
|
|
modelsCatalogStore.mu.Unlock()
|
|
|
|
if len(changed) == 0 {
|
|
log.Infof("%s completed from %s, no changes detected", label, url)
|
|
return
|
|
}
|
|
|
|
log.Infof("%s completed from %s, changes detected for providers: %v", label, url, changed)
|
|
notifyModelRefresh(changed)
|
|
}
|
|
|
|
// fetchModelsFromRemote tries all remote URLs and returns the parsed model catalog
|
|
// along with the URL it was fetched from. Returns (nil, "") if all fetches fail.
|
|
func fetchModelsFromRemote(ctx context.Context) (*staticModelsJSON, string) {
|
|
client := &http.Client{Timeout: modelsFetchTimeout}
|
|
for _, url := range modelsURLs {
|
|
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
|
|
req, err := http.NewRequestWithContext(reqCtx, "GET", url, nil)
|
|
if err != nil {
|
|
cancel()
|
|
log.Debugf("models fetch request creation failed for %s: %v", url, err)
|
|
continue
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
cancel()
|
|
log.Debugf("models fetch failed from %s: %v", url, err)
|
|
continue
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
resp.Body.Close()
|
|
cancel()
|
|
log.Debugf("models fetch returned %d from %s", resp.StatusCode, url)
|
|
continue
|
|
}
|
|
|
|
data, err := io.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
cancel()
|
|
|
|
if err != nil {
|
|
log.Debugf("models fetch read error from %s: %v", url, err)
|
|
continue
|
|
}
|
|
|
|
var parsed staticModelsJSON
|
|
if err := json.Unmarshal(data, &parsed); err != nil {
|
|
log.Warnf("models parse failed from %s: %v", url, err)
|
|
continue
|
|
}
|
|
if err := validateModelsCatalog(&parsed); err != nil {
|
|
log.Warnf("models validate failed from %s: %v", url, err)
|
|
continue
|
|
}
|
|
|
|
return &parsed, url
|
|
}
|
|
return nil, ""
|
|
}
|
|
|
|
// detectChangedProviders compares two model catalogs and returns provider names
|
|
// whose model definitions differ. Codex tiers (free/team/plus/pro) are grouped
|
|
// under a single "codex" provider.
|
|
func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
|
|
if oldData == nil || newData == nil {
|
|
return nil
|
|
}
|
|
|
|
type section struct {
|
|
provider string
|
|
oldList []*ModelInfo
|
|
newList []*ModelInfo
|
|
}
|
|
|
|
sections := []section{
|
|
{"claude", oldData.Claude, newData.Claude},
|
|
{"gemini", oldData.Gemini, newData.Gemini},
|
|
{"vertex", oldData.Vertex, newData.Vertex},
|
|
{"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI},
|
|
{"aistudio", oldData.AIStudio, newData.AIStudio},
|
|
{"codex", oldData.CodexFree, newData.CodexFree},
|
|
{"codex", oldData.CodexTeam, newData.CodexTeam},
|
|
{"codex", oldData.CodexPlus, newData.CodexPlus},
|
|
{"codex", oldData.CodexPro, newData.CodexPro},
|
|
{"qwen", oldData.Qwen, newData.Qwen},
|
|
{"iflow", oldData.IFlow, newData.IFlow},
|
|
{"kimi", oldData.Kimi, newData.Kimi},
|
|
{"antigravity", oldData.Antigravity, newData.Antigravity},
|
|
}
|
|
|
|
seen := make(map[string]bool, len(sections))
|
|
var changed []string
|
|
for _, s := range sections {
|
|
if seen[s.provider] {
|
|
continue
|
|
}
|
|
if modelSectionChanged(s.oldList, s.newList) {
|
|
changed = append(changed, s.provider)
|
|
seen[s.provider] = true
|
|
}
|
|
}
|
|
return changed
|
|
}
|
|
|
|
// modelSectionChanged reports whether two model slices differ.
|
|
func modelSectionChanged(a, b []*ModelInfo) bool {
|
|
if len(a) != len(b) {
|
|
return true
|
|
}
|
|
if len(a) == 0 {
|
|
return false
|
|
}
|
|
aj, err1 := json.Marshal(a)
|
|
bj, err2 := json.Marshal(b)
|
|
if err1 != nil || err2 != nil {
|
|
return true
|
|
}
|
|
return string(aj) != string(bj)
|
|
}
|
|
|
|
func notifyModelRefresh(changedProviders []string) {
|
|
if len(changedProviders) == 0 {
|
|
return
|
|
}
|
|
|
|
refreshCallbackMu.Lock()
|
|
cb := refreshCallback
|
|
if cb == nil {
|
|
pendingRefreshChanges = mergeProviderNames(pendingRefreshChanges, changedProviders)
|
|
refreshCallbackMu.Unlock()
|
|
return
|
|
}
|
|
refreshCallbackMu.Unlock()
|
|
cb(changedProviders)
|
|
}
|
|
|
|
func mergeProviderNames(existing, incoming []string) []string {
|
|
if len(incoming) == 0 {
|
|
return existing
|
|
}
|
|
seen := make(map[string]struct{}, len(existing)+len(incoming))
|
|
merged := make([]string, 0, len(existing)+len(incoming))
|
|
for _, provider := range existing {
|
|
name := strings.ToLower(strings.TrimSpace(provider))
|
|
if name == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[name]; ok {
|
|
continue
|
|
}
|
|
seen[name] = struct{}{}
|
|
merged = append(merged, name)
|
|
}
|
|
for _, provider := range incoming {
|
|
name := strings.ToLower(strings.TrimSpace(provider))
|
|
if name == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[name]; ok {
|
|
continue
|
|
}
|
|
seen[name] = struct{}{}
|
|
merged = append(merged, name)
|
|
}
|
|
return merged
|
|
}
|
|
|
|
func loadModelsFromBytes(data []byte, source string) error {
|
|
var parsed staticModelsJSON
|
|
if err := json.Unmarshal(data, &parsed); err != nil {
|
|
return fmt.Errorf("%s: decode models catalog: %w", source, err)
|
|
}
|
|
if err := validateModelsCatalog(&parsed); err != nil {
|
|
return fmt.Errorf("%s: validate models catalog: %w", source, err)
|
|
}
|
|
|
|
modelsCatalogStore.mu.Lock()
|
|
modelsCatalogStore.data = &parsed
|
|
modelsCatalogStore.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func getModels() *staticModelsJSON {
|
|
modelsCatalogStore.mu.RLock()
|
|
defer modelsCatalogStore.mu.RUnlock()
|
|
return modelsCatalogStore.data
|
|
}
|
|
|
|
func validateModelsCatalog(data *staticModelsJSON) error {
|
|
if data == nil {
|
|
return fmt.Errorf("catalog is nil")
|
|
}
|
|
|
|
requiredSections := []struct {
|
|
name string
|
|
models []*ModelInfo
|
|
}{
|
|
{name: "claude", models: data.Claude},
|
|
{name: "gemini", models: data.Gemini},
|
|
{name: "vertex", models: data.Vertex},
|
|
{name: "gemini-cli", models: data.GeminiCLI},
|
|
{name: "aistudio", models: data.AIStudio},
|
|
{name: "codex-free", models: data.CodexFree},
|
|
{name: "codex-team", models: data.CodexTeam},
|
|
{name: "codex-plus", models: data.CodexPlus},
|
|
{name: "codex-pro", models: data.CodexPro},
|
|
{name: "qwen", models: data.Qwen},
|
|
{name: "iflow", models: data.IFlow},
|
|
{name: "kimi", models: data.Kimi},
|
|
{name: "antigravity", models: data.Antigravity},
|
|
}
|
|
|
|
for _, section := range requiredSections {
|
|
if err := validateModelSection(section.name, section.models); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validateModelSection(section string, models []*ModelInfo) error {
|
|
if len(models) == 0 {
|
|
return fmt.Errorf("%s section is empty", section)
|
|
}
|
|
|
|
seen := make(map[string]struct{}, len(models))
|
|
for i, model := range models {
|
|
if model == nil {
|
|
return fmt.Errorf("%s[%d] is null", section, i)
|
|
}
|
|
modelID := strings.TrimSpace(model.ID)
|
|
if modelID == "" {
|
|
return fmt.Errorf("%s[%d] has empty id", section, i)
|
|
}
|
|
if _, exists := seen[modelID]; exists {
|
|
return fmt.Errorf("%s contains duplicate model id %q", section, modelID)
|
|
}
|
|
seen[modelID] = struct{}{}
|
|
}
|
|
return nil
|
|
}
|