From 2b134fc37839d965e0b0dabcae29f1e9aa1dc546 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 8 Mar 2026 05:52:55 +0800 Subject: [PATCH] test(auth-scheduler): add unit tests and scheduler implementation - Added comprehensive unit tests for `authScheduler` and related components. - Implemented `authScheduler` with support for Round Robin, Fill First, and custom selector strategies. - Improved tracking of auth states, cooldowns, and recovery logic in scheduler. --- sdk/cliproxy/auth/conductor.go | 159 +++- sdk/cliproxy/auth/scheduler.go | 851 ++++++++++++++++++ sdk/cliproxy/auth/scheduler_benchmark_test.go | 197 ++++ sdk/cliproxy/auth/scheduler_test.go | 468 ++++++++++ 4 files changed, 1670 insertions(+), 5 deletions(-) create mode 100644 sdk/cliproxy/auth/scheduler.go create mode 100644 sdk/cliproxy/auth/scheduler_benchmark_test.go create mode 100644 sdk/cliproxy/auth/scheduler_test.go diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index e31f3300..aacf9322 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -134,6 +134,7 @@ type Manager struct { hook Hook mu sync.RWMutex auths map[string]*Auth + scheduler *authScheduler // providerOffsets tracks per-model provider rotation state for multi-provider routing. providerOffsets map[string]int @@ -185,9 +186,33 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { // atomic.Value requires non-nil initial value. manager.runtimeConfig.Store(&internalconfig.Config{}) manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil)) + manager.scheduler = newAuthScheduler(selector) return manager } +func isBuiltInSelector(selector Selector) bool { + switch selector.(type) { + case *RoundRobinSelector, *FillFirstSelector: + return true + default: + return false + } +} + +func (m *Manager) syncSchedulerFromSnapshot(auths []*Auth) { + if m == nil || m.scheduler == nil { + return + } + m.scheduler.rebuild(auths) +} + +func (m *Manager) syncScheduler() { + if m == nil || m.scheduler == nil { + return + } + m.syncSchedulerFromSnapshot(m.snapshotAuths()) +} + func (m *Manager) SetSelector(selector Selector) { if m == nil { return @@ -198,6 +223,10 @@ func (m *Manager) SetSelector(selector Selector) { m.mu.Lock() m.selector = selector m.mu.Unlock() + if m.scheduler != nil { + m.scheduler.setSelector(selector) + m.syncScheduler() + } } // SetStore swaps the underlying persistence store. @@ -759,10 +788,14 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) { auth.ID = uuid.NewString() } auth.EnsureIndex() + authClone := auth.Clone() m.mu.Lock() - m.auths[auth.ID] = auth.Clone() + m.auths[auth.ID] = authClone m.mu.Unlock() m.rebuildAPIKeyModelAliasFromRuntimeConfig() + if m.scheduler != nil { + m.scheduler.upsertAuth(authClone) + } _ = m.persist(ctx, auth) m.hook.OnAuthRegistered(ctx, auth.Clone()) return auth.Clone(), nil @@ -784,9 +817,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) { } } auth.EnsureIndex() - m.auths[auth.ID] = auth.Clone() + authClone := auth.Clone() + m.auths[auth.ID] = authClone m.mu.Unlock() m.rebuildAPIKeyModelAliasFromRuntimeConfig() + if m.scheduler != nil { + m.scheduler.upsertAuth(authClone) + } _ = m.persist(ctx, auth) m.hook.OnAuthUpdated(ctx, auth.Clone()) return auth.Clone(), nil @@ -795,12 +832,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) { // Load resets manager state from the backing store. func (m *Manager) Load(ctx context.Context) error { m.mu.Lock() - defer m.mu.Unlock() if m.store == nil { + m.mu.Unlock() return nil } items, err := m.store.List(ctx) if err != nil { + m.mu.Unlock() return err } m.auths = make(map[string]*Auth, len(items)) @@ -816,6 +854,8 @@ func (m *Manager) Load(ctx context.Context) error { cfg = &internalconfig.Config{} } m.rebuildAPIKeyModelAliasLocked(cfg) + m.mu.Unlock() + m.syncScheduler() return nil } @@ -1531,6 +1571,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { suspendReason := "" clearModelQuota := false setModelQuota := false + var authSnapshot *Auth m.mu.Lock() if auth, ok := m.auths[result.AuthID]; ok && auth != nil { @@ -1624,8 +1665,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { } _ = m.persist(ctx, auth) + authSnapshot = auth.Clone() } m.mu.Unlock() + if m.scheduler != nil && authSnapshot != nil { + m.scheduler.upsertAuth(authSnapshot) + } if clearModelQuota && result.Model != "" { registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model) @@ -1982,7 +2027,25 @@ func (m *Manager) CloseExecutionSession(sessionID string) { } } -func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) { +func (m *Manager) useSchedulerFastPath() bool { + if m == nil || m.scheduler == nil { + return false + } + return isBuiltInSelector(m.selector) +} + +func shouldRetrySchedulerPick(err error) bool { + if err == nil { + return false + } + var authErr *Error + if !errors.As(err, &authErr) || authErr == nil { + return false + } + return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable" +} + +func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) { pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) m.mu.RLock() @@ -2042,7 +2105,38 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli return authCopy, executor, nil } -func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) { +func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) { + if !m.useSchedulerFastPath() { + return m.pickNextLegacy(ctx, provider, model, opts, tried) + } + executor, okExecutor := m.Executor(provider) + if !okExecutor { + return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"} + } + selected, errPick := m.scheduler.pickSingle(ctx, provider, model, opts, tried) + if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) { + m.syncScheduler() + selected, errPick = m.scheduler.pickSingle(ctx, provider, model, opts, tried) + } + if errPick != nil { + return nil, nil, errPick + } + if selected == nil { + return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"} + } + authCopy := selected.Clone() + if !selected.indexAssigned { + m.mu.Lock() + if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned { + current.EnsureIndex() + authCopy = current.Clone() + } + m.mu.Unlock() + } + return authCopy, executor, nil +} + +func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) { pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) providerSet := make(map[string]struct{}, len(providers)) @@ -2125,6 +2219,58 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s return authCopy, executor, providerKey, nil } +func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) { + if !m.useSchedulerFastPath() { + return m.pickNextMixedLegacy(ctx, providers, model, opts, tried) + } + + eligibleProviders := make([]string, 0, len(providers)) + seenProviders := make(map[string]struct{}, len(providers)) + for _, provider := range providers { + providerKey := strings.TrimSpace(strings.ToLower(provider)) + if providerKey == "" { + continue + } + if _, seen := seenProviders[providerKey]; seen { + continue + } + if _, okExecutor := m.Executor(providerKey); !okExecutor { + continue + } + seenProviders[providerKey] = struct{}{} + eligibleProviders = append(eligibleProviders, providerKey) + } + if len(eligibleProviders) == 0 { + return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"} + } + + selected, providerKey, errPick := m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried) + if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) { + m.syncScheduler() + selected, providerKey, errPick = m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried) + } + if errPick != nil { + return nil, nil, "", errPick + } + if selected == nil { + return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"} + } + executor, okExecutor := m.Executor(providerKey) + if !okExecutor { + return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"} + } + authCopy := selected.Clone() + if !selected.indexAssigned { + m.mu.Lock() + if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned { + current.EnsureIndex() + authCopy = current.Clone() + } + m.mu.Unlock() + } + return authCopy, executor, providerKey, nil +} + func (m *Manager) persist(ctx context.Context, auth *Auth) error { if m.store == nil || auth == nil { return nil @@ -2476,6 +2622,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) { current.NextRefreshAfter = now.Add(refreshFailureBackoff) current.LastError = &Error{Message: err.Error()} m.auths[id] = current + if m.scheduler != nil { + m.scheduler.upsertAuth(current.Clone()) + } } m.mu.Unlock() return diff --git a/sdk/cliproxy/auth/scheduler.go b/sdk/cliproxy/auth/scheduler.go new file mode 100644 index 00000000..1ede8934 --- /dev/null +++ b/sdk/cliproxy/auth/scheduler.go @@ -0,0 +1,851 @@ +package auth + +import ( + "context" + "sort" + "strings" + "sync" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +// schedulerStrategy identifies which built-in routing semantics the scheduler should apply. +type schedulerStrategy int + +const ( + schedulerStrategyCustom schedulerStrategy = iota + schedulerStrategyRoundRobin + schedulerStrategyFillFirst +) + +// scheduledState describes how an auth currently participates in a model shard. +type scheduledState int + +const ( + scheduledStateReady scheduledState = iota + scheduledStateCooldown + scheduledStateBlocked + scheduledStateDisabled +) + +// authScheduler keeps the incremental provider/model scheduling state used by Manager. +type authScheduler struct { + mu sync.Mutex + strategy schedulerStrategy + providers map[string]*providerScheduler + authProviders map[string]string + mixedCursors map[string]int +} + +// providerScheduler stores auth metadata and model shards for a single provider. +type providerScheduler struct { + providerKey string + auths map[string]*scheduledAuthMeta + modelShards map[string]*modelScheduler +} + +// scheduledAuthMeta stores the immutable scheduling fields derived from an auth snapshot. +type scheduledAuthMeta struct { + auth *Auth + providerKey string + priority int + virtualParent string + websocketEnabled bool + supportedModelSet map[string]struct{} +} + +// modelScheduler tracks ready and blocked auths for one provider/model combination. +type modelScheduler struct { + modelKey string + entries map[string]*scheduledAuth + priorityOrder []int + readyByPriority map[int]*readyBucket + blocked cooldownQueue +} + +// scheduledAuth stores the runtime scheduling state for a single auth inside a model shard. +type scheduledAuth struct { + meta *scheduledAuthMeta + auth *Auth + state scheduledState + nextRetryAt time.Time +} + +// readyBucket keeps the ready views for one priority level. +type readyBucket struct { + all readyView + ws readyView +} + +// readyView holds the selection order for flat or grouped round-robin traversal. +type readyView struct { + flat []*scheduledAuth + cursor int + parentOrder []string + parentCursor int + children map[string]*childBucket +} + +// childBucket keeps the per-parent rotation state for grouped Gemini virtual auths. +type childBucket struct { + items []*scheduledAuth + cursor int +} + +// cooldownQueue is the blocked auth collection ordered by next retry time during rebuilds. +type cooldownQueue []*scheduledAuth + +// newAuthScheduler constructs an empty scheduler configured for the supplied selector strategy. +func newAuthScheduler(selector Selector) *authScheduler { + return &authScheduler{ + strategy: selectorStrategy(selector), + providers: make(map[string]*providerScheduler), + authProviders: make(map[string]string), + mixedCursors: make(map[string]int), + } +} + +// selectorStrategy maps a selector implementation to the scheduler semantics it should emulate. +func selectorStrategy(selector Selector) schedulerStrategy { + switch selector.(type) { + case *FillFirstSelector: + return schedulerStrategyFillFirst + case nil, *RoundRobinSelector: + return schedulerStrategyRoundRobin + default: + return schedulerStrategyCustom + } +} + +// setSelector updates the active built-in strategy and resets mixed-provider cursors. +func (s *authScheduler) setSelector(selector Selector) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.strategy = selectorStrategy(selector) + clear(s.mixedCursors) +} + +// rebuild recreates the complete scheduler state from an auth snapshot. +func (s *authScheduler) rebuild(auths []*Auth) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.providers = make(map[string]*providerScheduler) + s.authProviders = make(map[string]string) + s.mixedCursors = make(map[string]int) + now := time.Now() + for _, auth := range auths { + s.upsertAuthLocked(auth, now) + } +} + +// upsertAuth incrementally synchronizes one auth into the scheduler. +func (s *authScheduler) upsertAuth(auth *Auth) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.upsertAuthLocked(auth, time.Now()) +} + +// removeAuth deletes one auth from every scheduler shard that references it. +func (s *authScheduler) removeAuth(authID string) { + if s == nil { + return + } + authID = strings.TrimSpace(authID) + if authID == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.removeAuthLocked(authID) +} + +// pickSingle returns the next auth for a single provider/model request using scheduler state. +func (s *authScheduler) pickSingle(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, error) { + if s == nil { + return nil, &Error{Code: "auth_not_found", Message: "no auth available"} + } + providerKey := strings.ToLower(strings.TrimSpace(provider)) + modelKey := canonicalModelKey(model) + pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) + preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == "" + + s.mu.Lock() + defer s.mu.Unlock() + providerState := s.providers[providerKey] + if providerState == nil { + return nil, &Error{Code: "auth_not_found", Message: "no auth available"} + } + shard := providerState.ensureModelLocked(modelKey, time.Now()) + if shard == nil { + return nil, &Error{Code: "auth_not_found", Message: "no auth available"} + } + predicate := func(entry *scheduledAuth) bool { + if entry == nil || entry.auth == nil { + return false + } + if pinnedAuthID != "" && entry.auth.ID != pinnedAuthID { + return false + } + if len(tried) > 0 { + if _, ok := tried[entry.auth.ID]; ok { + return false + } + } + return true + } + if picked := shard.pickReadyLocked(preferWebsocket, s.strategy, predicate); picked != nil { + return picked, nil + } + return nil, shard.unavailableErrorLocked(provider, model, predicate) +} + +// pickMixed returns the next auth and provider for a mixed-provider request. +func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) { + if s == nil { + return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"} + } + normalized := normalizeProviderKeys(providers) + if len(normalized) == 0 { + return nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"} + } + pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) + modelKey := canonicalModelKey(model) + + s.mu.Lock() + defer s.mu.Unlock() + if pinnedAuthID != "" { + providerKey := s.authProviders[pinnedAuthID] + if providerKey == "" || !containsProvider(normalized, providerKey) { + return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"} + } + providerState := s.providers[providerKey] + if providerState == nil { + return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"} + } + shard := providerState.ensureModelLocked(modelKey, time.Now()) + predicate := func(entry *scheduledAuth) bool { + if entry == nil || entry.auth == nil || entry.auth.ID != pinnedAuthID { + return false + } + if len(tried) == 0 { + return true + } + _, ok := tried[pinnedAuthID] + return !ok + } + if picked := shard.pickReadyLocked(false, s.strategy, predicate); picked != nil { + return picked, providerKey, nil + } + return nil, "", shard.unavailableErrorLocked("mixed", model, predicate) + } + + if s.strategy == schedulerStrategyFillFirst { + for _, providerKey := range normalized { + providerState := s.providers[providerKey] + if providerState == nil { + continue + } + shard := providerState.ensureModelLocked(modelKey, time.Now()) + if shard == nil { + continue + } + picked := shard.pickReadyLocked(false, s.strategy, triedPredicate(tried)) + if picked != nil { + return picked, providerKey, nil + } + } + return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried) + } + + cursorKey := strings.Join(normalized, ",") + ":" + modelKey + start := 0 + if len(normalized) > 0 { + start = s.mixedCursors[cursorKey] % len(normalized) + } + for offset := 0; offset < len(normalized); offset++ { + providerIndex := (start + offset) % len(normalized) + providerKey := normalized[providerIndex] + providerState := s.providers[providerKey] + if providerState == nil { + continue + } + shard := providerState.ensureModelLocked(modelKey, time.Now()) + if shard == nil { + continue + } + picked := shard.pickReadyLocked(false, schedulerStrategyRoundRobin, triedPredicate(tried)) + if picked == nil { + continue + } + s.mixedCursors[cursorKey] = providerIndex + 1 + return picked, providerKey, nil + } + return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried) +} + +// mixedUnavailableErrorLocked synthesizes the mixed-provider cooldown or unavailable error. +func (s *authScheduler) mixedUnavailableErrorLocked(providers []string, model string, tried map[string]struct{}) error { + now := time.Now() + total := 0 + cooldownCount := 0 + earliest := time.Time{} + for _, providerKey := range providers { + providerState := s.providers[providerKey] + if providerState == nil { + continue + } + shard := providerState.ensureModelLocked(canonicalModelKey(model), now) + if shard == nil { + continue + } + localTotal, localCooldownCount, localEarliest := shard.availabilitySummaryLocked(triedPredicate(tried)) + total += localTotal + cooldownCount += localCooldownCount + if !localEarliest.IsZero() && (earliest.IsZero() || localEarliest.Before(earliest)) { + earliest = localEarliest + } + } + if total == 0 { + return &Error{Code: "auth_not_found", Message: "no auth available"} + } + if cooldownCount == total && !earliest.IsZero() { + resetIn := earliest.Sub(now) + if resetIn < 0 { + resetIn = 0 + } + return newModelCooldownError(model, "", resetIn) + } + return &Error{Code: "auth_unavailable", Message: "no auth available"} +} + +// triedPredicate builds a filter that excludes auths already attempted for the current request. +func triedPredicate(tried map[string]struct{}) func(*scheduledAuth) bool { + if len(tried) == 0 { + return func(entry *scheduledAuth) bool { return entry != nil && entry.auth != nil } + } + return func(entry *scheduledAuth) bool { + if entry == nil || entry.auth == nil { + return false + } + _, ok := tried[entry.auth.ID] + return !ok + } +} + +// normalizeProviderKeys lowercases, trims, and de-duplicates provider keys while preserving order. +func normalizeProviderKeys(providers []string) []string { + seen := make(map[string]struct{}, len(providers)) + out := make([]string, 0, len(providers)) + for _, provider := range providers { + providerKey := strings.ToLower(strings.TrimSpace(provider)) + if providerKey == "" { + continue + } + if _, ok := seen[providerKey]; ok { + continue + } + seen[providerKey] = struct{}{} + out = append(out, providerKey) + } + return out +} + +// containsProvider reports whether provider is present in the normalized provider list. +func containsProvider(providers []string, provider string) bool { + for _, candidate := range providers { + if candidate == provider { + return true + } + } + return false +} + +// upsertAuthLocked updates one auth in-place while the scheduler mutex is held. +func (s *authScheduler) upsertAuthLocked(auth *Auth, now time.Time) { + if auth == nil { + return + } + authID := strings.TrimSpace(auth.ID) + providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) + if authID == "" || providerKey == "" || auth.Disabled { + s.removeAuthLocked(authID) + return + } + if previousProvider := s.authProviders[authID]; previousProvider != "" && previousProvider != providerKey { + if previousState := s.providers[previousProvider]; previousState != nil { + previousState.removeAuthLocked(authID) + } + } + meta := buildScheduledAuthMeta(auth) + s.authProviders[authID] = providerKey + s.ensureProviderLocked(providerKey).upsertAuthLocked(meta, now) +} + +// removeAuthLocked removes one auth from the scheduler while the scheduler mutex is held. +func (s *authScheduler) removeAuthLocked(authID string) { + if authID == "" { + return + } + if providerKey := s.authProviders[authID]; providerKey != "" { + if providerState := s.providers[providerKey]; providerState != nil { + providerState.removeAuthLocked(authID) + } + delete(s.authProviders, authID) + } +} + +// ensureProviderLocked returns the provider scheduler for providerKey, creating it when needed. +func (s *authScheduler) ensureProviderLocked(providerKey string) *providerScheduler { + if s.providers == nil { + s.providers = make(map[string]*providerScheduler) + } + providerState := s.providers[providerKey] + if providerState == nil { + providerState = &providerScheduler{ + providerKey: providerKey, + auths: make(map[string]*scheduledAuthMeta), + modelShards: make(map[string]*modelScheduler), + } + s.providers[providerKey] = providerState + } + return providerState +} + +// buildScheduledAuthMeta extracts the scheduling metadata needed for shard bookkeeping. +func buildScheduledAuthMeta(auth *Auth) *scheduledAuthMeta { + providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) + virtualParent := "" + if auth.Attributes != nil { + virtualParent = strings.TrimSpace(auth.Attributes["gemini_virtual_parent"]) + } + return &scheduledAuthMeta{ + auth: auth, + providerKey: providerKey, + priority: authPriority(auth), + virtualParent: virtualParent, + websocketEnabled: authWebsocketsEnabled(auth), + supportedModelSet: supportedModelSetForAuth(auth.ID), + } +} + +// supportedModelSetForAuth snapshots the registry models currently registered for an auth. +func supportedModelSetForAuth(authID string) map[string]struct{} { + authID = strings.TrimSpace(authID) + if authID == "" { + return nil + } + models := registry.GetGlobalRegistry().GetModelsForClient(authID) + if len(models) == 0 { + return nil + } + set := make(map[string]struct{}, len(models)) + for _, model := range models { + if model == nil { + continue + } + modelKey := canonicalModelKey(model.ID) + if modelKey == "" { + continue + } + set[modelKey] = struct{}{} + } + return set +} + +// upsertAuthLocked updates every existing model shard that can reference the auth metadata. +func (p *providerScheduler) upsertAuthLocked(meta *scheduledAuthMeta, now time.Time) { + if p == nil || meta == nil || meta.auth == nil { + return + } + p.auths[meta.auth.ID] = meta + for modelKey, shard := range p.modelShards { + if shard == nil { + continue + } + if !meta.supportsModel(modelKey) { + shard.removeEntryLocked(meta.auth.ID) + continue + } + shard.upsertEntryLocked(meta, now) + } +} + +// removeAuthLocked removes an auth from all model shards owned by the provider scheduler. +func (p *providerScheduler) removeAuthLocked(authID string) { + if p == nil || authID == "" { + return + } + delete(p.auths, authID) + for _, shard := range p.modelShards { + if shard != nil { + shard.removeEntryLocked(authID) + } + } +} + +// ensureModelLocked returns the shard for modelKey, building it lazily from provider auths. +func (p *providerScheduler) ensureModelLocked(modelKey string, now time.Time) *modelScheduler { + if p == nil { + return nil + } + modelKey = canonicalModelKey(modelKey) + if shard, ok := p.modelShards[modelKey]; ok && shard != nil { + shard.promoteExpiredLocked(now) + return shard + } + shard := &modelScheduler{ + modelKey: modelKey, + entries: make(map[string]*scheduledAuth), + readyByPriority: make(map[int]*readyBucket), + } + for _, meta := range p.auths { + if meta == nil || !meta.supportsModel(modelKey) { + continue + } + shard.upsertEntryLocked(meta, now) + } + p.modelShards[modelKey] = shard + return shard +} + +// supportsModel reports whether the auth metadata currently supports modelKey. +func (m *scheduledAuthMeta) supportsModel(modelKey string) bool { + modelKey = canonicalModelKey(modelKey) + if modelKey == "" { + return true + } + if len(m.supportedModelSet) == 0 { + return false + } + _, ok := m.supportedModelSet[modelKey] + return ok +} + +// upsertEntryLocked updates or inserts one auth entry and rebuilds indexes when ordering changes. +func (m *modelScheduler) upsertEntryLocked(meta *scheduledAuthMeta, now time.Time) { + if m == nil || meta == nil || meta.auth == nil { + return + } + entry, ok := m.entries[meta.auth.ID] + if !ok || entry == nil { + entry = &scheduledAuth{} + m.entries[meta.auth.ID] = entry + } + previousState := entry.state + previousNextRetryAt := entry.nextRetryAt + previousPriority := 0 + previousParent := "" + previousWebsocketEnabled := false + if entry.meta != nil { + previousPriority = entry.meta.priority + previousParent = entry.meta.virtualParent + previousWebsocketEnabled = entry.meta.websocketEnabled + } + + entry.meta = meta + entry.auth = meta.auth + entry.nextRetryAt = time.Time{} + blocked, reason, next := isAuthBlockedForModel(meta.auth, m.modelKey, now) + switch { + case !blocked: + entry.state = scheduledStateReady + case reason == blockReasonCooldown: + entry.state = scheduledStateCooldown + entry.nextRetryAt = next + case reason == blockReasonDisabled: + entry.state = scheduledStateDisabled + default: + entry.state = scheduledStateBlocked + entry.nextRetryAt = next + } + + if ok && previousState == entry.state && previousNextRetryAt.Equal(entry.nextRetryAt) && previousPriority == meta.priority && previousParent == meta.virtualParent && previousWebsocketEnabled == meta.websocketEnabled { + return + } + m.rebuildIndexesLocked() +} + +// removeEntryLocked deletes one auth entry and rebuilds the shard indexes if needed. +func (m *modelScheduler) removeEntryLocked(authID string) { + if m == nil || authID == "" { + return + } + if _, ok := m.entries[authID]; !ok { + return + } + delete(m.entries, authID) + m.rebuildIndexesLocked() +} + +// promoteExpiredLocked reevaluates blocked auths whose retry time has elapsed. +func (m *modelScheduler) promoteExpiredLocked(now time.Time) { + if m == nil || len(m.blocked) == 0 { + return + } + changed := false + for _, entry := range m.blocked { + if entry == nil || entry.auth == nil { + continue + } + if entry.nextRetryAt.IsZero() || entry.nextRetryAt.After(now) { + continue + } + blocked, reason, next := isAuthBlockedForModel(entry.auth, m.modelKey, now) + switch { + case !blocked: + entry.state = scheduledStateReady + entry.nextRetryAt = time.Time{} + case reason == blockReasonCooldown: + entry.state = scheduledStateCooldown + entry.nextRetryAt = next + case reason == blockReasonDisabled: + entry.state = scheduledStateDisabled + entry.nextRetryAt = time.Time{} + default: + entry.state = scheduledStateBlocked + entry.nextRetryAt = next + } + changed = true + } + if changed { + m.rebuildIndexesLocked() + } +} + +// pickReadyLocked selects the next ready auth from the highest available priority bucket. +func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth { + if m == nil { + return nil + } + m.promoteExpiredLocked(time.Now()) + for _, priority := range m.priorityOrder { + bucket := m.readyByPriority[priority] + if bucket == nil { + continue + } + view := &bucket.all + if preferWebsocket && len(bucket.ws.flat) > 0 { + view = &bucket.ws + } + var picked *scheduledAuth + if strategy == schedulerStrategyFillFirst { + picked = view.pickFirst(predicate) + } else { + picked = view.pickRoundRobin(predicate) + } + if picked != nil && picked.auth != nil { + return picked.auth + } + } + return nil +} + +// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard. +func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error { + now := time.Now() + total, cooldownCount, earliest := m.availabilitySummaryLocked(predicate) + if total == 0 { + return &Error{Code: "auth_not_found", Message: "no auth available"} + } + if cooldownCount == total && !earliest.IsZero() { + providerForError := provider + if providerForError == "mixed" { + providerForError = "" + } + resetIn := earliest.Sub(now) + if resetIn < 0 { + resetIn = 0 + } + return newModelCooldownError(model, providerForError, resetIn) + } + return &Error{Code: "auth_unavailable", Message: "no auth available"} +} + +// availabilitySummaryLocked summarizes total candidates, cooldown count, and earliest retry time. +func (m *modelScheduler) availabilitySummaryLocked(predicate func(*scheduledAuth) bool) (int, int, time.Time) { + if m == nil { + return 0, 0, time.Time{} + } + total := 0 + cooldownCount := 0 + earliest := time.Time{} + for _, entry := range m.entries { + if predicate != nil && !predicate(entry) { + continue + } + total++ + if entry == nil || entry.auth == nil { + continue + } + if entry.state != scheduledStateCooldown { + continue + } + cooldownCount++ + if !entry.nextRetryAt.IsZero() && (earliest.IsZero() || entry.nextRetryAt.Before(earliest)) { + earliest = entry.nextRetryAt + } + } + return total, cooldownCount, earliest +} + +// rebuildIndexesLocked reconstructs ready and blocked views from the current entry map. +func (m *modelScheduler) rebuildIndexesLocked() { + m.readyByPriority = make(map[int]*readyBucket) + m.priorityOrder = m.priorityOrder[:0] + m.blocked = m.blocked[:0] + priorityBuckets := make(map[int][]*scheduledAuth) + for _, entry := range m.entries { + if entry == nil || entry.auth == nil { + continue + } + switch entry.state { + case scheduledStateReady: + priority := entry.meta.priority + priorityBuckets[priority] = append(priorityBuckets[priority], entry) + case scheduledStateCooldown, scheduledStateBlocked: + m.blocked = append(m.blocked, entry) + } + } + for priority, entries := range priorityBuckets { + sort.Slice(entries, func(i, j int) bool { + return entries[i].auth.ID < entries[j].auth.ID + }) + m.readyByPriority[priority] = buildReadyBucket(entries) + m.priorityOrder = append(m.priorityOrder, priority) + } + sort.Slice(m.priorityOrder, func(i, j int) bool { + return m.priorityOrder[i] > m.priorityOrder[j] + }) + sort.Slice(m.blocked, func(i, j int) bool { + left := m.blocked[i] + right := m.blocked[j] + if left == nil || right == nil { + return left != nil + } + if left.nextRetryAt.Equal(right.nextRetryAt) { + return left.auth.ID < right.auth.ID + } + if left.nextRetryAt.IsZero() { + return false + } + if right.nextRetryAt.IsZero() { + return true + } + return left.nextRetryAt.Before(right.nextRetryAt) + }) +} + +// buildReadyBucket prepares the general and websocket-only ready views for one priority bucket. +func buildReadyBucket(entries []*scheduledAuth) *readyBucket { + bucket := &readyBucket{} + bucket.all = buildReadyView(entries) + wsEntries := make([]*scheduledAuth, 0, len(entries)) + for _, entry := range entries { + if entry != nil && entry.meta != nil && entry.meta.websocketEnabled { + wsEntries = append(wsEntries, entry) + } + } + bucket.ws = buildReadyView(wsEntries) + return bucket +} + +// buildReadyView creates either a flat view or a grouped parent/child view for rotation. +func buildReadyView(entries []*scheduledAuth) readyView { + view := readyView{flat: append([]*scheduledAuth(nil), entries...)} + if len(entries) == 0 { + return view + } + groups := make(map[string][]*scheduledAuth) + for _, entry := range entries { + if entry == nil || entry.meta == nil || entry.meta.virtualParent == "" { + return view + } + groups[entry.meta.virtualParent] = append(groups[entry.meta.virtualParent], entry) + } + if len(groups) <= 1 { + return view + } + view.children = make(map[string]*childBucket, len(groups)) + view.parentOrder = make([]string, 0, len(groups)) + for parent := range groups { + view.parentOrder = append(view.parentOrder, parent) + } + sort.Strings(view.parentOrder) + for _, parent := range view.parentOrder { + view.children[parent] = &childBucket{items: append([]*scheduledAuth(nil), groups[parent]...)} + } + return view +} + +// pickFirst returns the first ready entry that satisfies predicate without advancing cursors. +func (v *readyView) pickFirst(predicate func(*scheduledAuth) bool) *scheduledAuth { + for _, entry := range v.flat { + if predicate == nil || predicate(entry) { + return entry + } + } + return nil +} + +// pickRoundRobin returns the next ready entry using flat or grouped round-robin traversal. +func (v *readyView) pickRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth { + if len(v.parentOrder) > 1 && len(v.children) > 0 { + return v.pickGroupedRoundRobin(predicate) + } + if len(v.flat) == 0 { + return nil + } + start := 0 + if len(v.flat) > 0 { + start = v.cursor % len(v.flat) + } + for offset := 0; offset < len(v.flat); offset++ { + index := (start + offset) % len(v.flat) + entry := v.flat[index] + if predicate != nil && !predicate(entry) { + continue + } + v.cursor = index + 1 + return entry + } + return nil +} + +// pickGroupedRoundRobin rotates across parents first and then within the selected parent. +func (v *readyView) pickGroupedRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth { + start := 0 + if len(v.parentOrder) > 0 { + start = v.parentCursor % len(v.parentOrder) + } + for offset := 0; offset < len(v.parentOrder); offset++ { + parentIndex := (start + offset) % len(v.parentOrder) + parent := v.parentOrder[parentIndex] + child := v.children[parent] + if child == nil || len(child.items) == 0 { + continue + } + itemStart := child.cursor % len(child.items) + for itemOffset := 0; itemOffset < len(child.items); itemOffset++ { + itemIndex := (itemStart + itemOffset) % len(child.items) + entry := child.items[itemIndex] + if predicate != nil && !predicate(entry) { + continue + } + child.cursor = itemIndex + 1 + v.parentCursor = parentIndex + 1 + return entry + } + } + return nil +} diff --git a/sdk/cliproxy/auth/scheduler_benchmark_test.go b/sdk/cliproxy/auth/scheduler_benchmark_test.go new file mode 100644 index 00000000..33fec2d5 --- /dev/null +++ b/sdk/cliproxy/auth/scheduler_benchmark_test.go @@ -0,0 +1,197 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +type schedulerBenchmarkExecutor struct { + id string +} + +func (e schedulerBenchmarkExecutor) Identifier() string { return e.id } + +func (e schedulerBenchmarkExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, nil +} + +func (e schedulerBenchmarkExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + return nil, nil +} + +func (e schedulerBenchmarkExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) { + return auth, nil +} + +func (e schedulerBenchmarkExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, nil +} + +func (e schedulerBenchmarkExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) { + return nil, nil +} + +func benchmarkManagerSetup(b *testing.B, total int, mixed bool, withPriority bool) (*Manager, []string, string) { + b.Helper() + manager := NewManager(nil, &RoundRobinSelector{}, nil) + providers := []string{"gemini"} + manager.executors["gemini"] = schedulerBenchmarkExecutor{id: "gemini"} + if mixed { + providers = []string{"gemini", "claude"} + manager.executors["claude"] = schedulerBenchmarkExecutor{id: "claude"} + } + + reg := registry.GetGlobalRegistry() + model := "bench-model" + for index := 0; index < total; index++ { + provider := providers[0] + if mixed && index%2 == 1 { + provider = providers[1] + } + auth := &Auth{ID: fmt.Sprintf("bench-%s-%04d", provider, index), Provider: provider} + if withPriority { + priority := "0" + if index%2 == 0 { + priority = "10" + } + auth.Attributes = map[string]string{"priority": priority} + } + _, errRegister := manager.Register(context.Background(), auth) + if errRegister != nil { + b.Fatalf("Register(%s) error = %v", auth.ID, errRegister) + } + reg.RegisterClient(auth.ID, provider, []*registry.ModelInfo{{ID: model}}) + } + manager.syncScheduler() + b.Cleanup(func() { + for index := 0; index < total; index++ { + provider := providers[0] + if mixed && index%2 == 1 { + provider = providers[1] + } + reg.UnregisterClient(fmt.Sprintf("bench-%s-%04d", provider, index)) + } + }) + + return manager, providers, model +} + +func BenchmarkManagerPickNext500(b *testing.B) { + manager, _, model := benchmarkManagerSetup(b, 500, false, false) + ctx := context.Background() + opts := cliproxyexecutor.Options{} + tried := map[string]struct{}{} + if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil { + b.Fatalf("warmup pickNext error = %v", errWarm) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried) + if errPick != nil || auth == nil || exec == nil { + b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick) + } + } +} + +func BenchmarkManagerPickNext1000(b *testing.B) { + manager, _, model := benchmarkManagerSetup(b, 1000, false, false) + ctx := context.Background() + opts := cliproxyexecutor.Options{} + tried := map[string]struct{}{} + if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil { + b.Fatalf("warmup pickNext error = %v", errWarm) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried) + if errPick != nil || auth == nil || exec == nil { + b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick) + } + } +} + +func BenchmarkManagerPickNextPriority500(b *testing.B) { + manager, _, model := benchmarkManagerSetup(b, 500, false, true) + ctx := context.Background() + opts := cliproxyexecutor.Options{} + tried := map[string]struct{}{} + if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil { + b.Fatalf("warmup pickNext error = %v", errWarm) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried) + if errPick != nil || auth == nil || exec == nil { + b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick) + } + } +} + +func BenchmarkManagerPickNextPriority1000(b *testing.B) { + manager, _, model := benchmarkManagerSetup(b, 1000, false, true) + ctx := context.Background() + opts := cliproxyexecutor.Options{} + tried := map[string]struct{}{} + if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil { + b.Fatalf("warmup pickNext error = %v", errWarm) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried) + if errPick != nil || auth == nil || exec == nil { + b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick) + } + } +} + +func BenchmarkManagerPickNextMixed500(b *testing.B) { + manager, providers, model := benchmarkManagerSetup(b, 500, true, false) + ctx := context.Background() + opts := cliproxyexecutor.Options{} + tried := map[string]struct{}{} + if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil { + b.Fatalf("warmup pickNextMixed error = %v", errWarm) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried) + if errPick != nil || auth == nil || exec == nil || provider == "" { + b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick) + } + } +} + +func BenchmarkManagerPickNextAndMarkResult1000(b *testing.B) { + manager, _, model := benchmarkManagerSetup(b, 1000, false, false) + ctx := context.Background() + opts := cliproxyexecutor.Options{} + tried := map[string]struct{}{} + if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil { + b.Fatalf("warmup pickNext error = %v", errWarm) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + auth, _, errPick := manager.pickNext(ctx, "gemini", model, opts, tried) + if errPick != nil || auth == nil { + b.Fatalf("pickNext failed: auth=%v err=%v", auth, errPick) + } + manager.MarkResult(ctx, Result{AuthID: auth.ID, Provider: "gemini", Model: model, Success: true}) + } +} diff --git a/sdk/cliproxy/auth/scheduler_test.go b/sdk/cliproxy/auth/scheduler_test.go new file mode 100644 index 00000000..031071af --- /dev/null +++ b/sdk/cliproxy/auth/scheduler_test.go @@ -0,0 +1,468 @@ +package auth + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +type schedulerTestExecutor struct{} + +func (schedulerTestExecutor) Identifier() string { return "test" } + +func (schedulerTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, nil +} + +func (schedulerTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + return nil, nil +} + +func (schedulerTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) { + return auth, nil +} + +func (schedulerTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, nil +} + +func (schedulerTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) { + return nil, nil +} + +type trackingSelector struct { + calls int + lastAuthID []string +} + +func (s *trackingSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { + s.calls++ + s.lastAuthID = s.lastAuthID[:0] + for _, auth := range auths { + s.lastAuthID = append(s.lastAuthID, auth.ID) + } + if len(auths) == 0 { + return nil, nil + } + return auths[len(auths)-1], nil +} + +func newSchedulerForTest(selector Selector, auths ...*Auth) *authScheduler { + scheduler := newAuthScheduler(selector) + scheduler.rebuild(auths) + return scheduler +} + +func registerSchedulerModels(t *testing.T, provider string, model string, authIDs ...string) { + t.Helper() + reg := registry.GetGlobalRegistry() + for _, authID := range authIDs { + reg.RegisterClient(authID, provider, []*registry.ModelInfo{{ID: model}}) + } + t.Cleanup(func() { + for _, authID := range authIDs { + reg.UnregisterClient(authID) + } + }) +} + +func TestSchedulerPick_RoundRobinHighestPriority(t *testing.T) { + t.Parallel() + + scheduler := newSchedulerForTest( + &RoundRobinSelector{}, + &Auth{ID: "low", Provider: "gemini", Attributes: map[string]string{"priority": "0"}}, + &Auth{ID: "high-b", Provider: "gemini", Attributes: map[string]string{"priority": "10"}}, + &Auth{ID: "high-a", Provider: "gemini", Attributes: map[string]string{"priority": "10"}}, + ) + + want := []string{"high-a", "high-b", "high-a"} + for index, wantID := range want { + got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickSingle() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickSingle() #%d auth = nil", index) + } + if got.ID != wantID { + t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID) + } + } +} + +func TestSchedulerPick_FillFirstSticksToFirstReady(t *testing.T) { + t.Parallel() + + scheduler := newSchedulerForTest( + &FillFirstSelector{}, + &Auth{ID: "b", Provider: "gemini"}, + &Auth{ID: "a", Provider: "gemini"}, + &Auth{ID: "c", Provider: "gemini"}, + ) + + for index := 0; index < 3; index++ { + got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickSingle() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickSingle() #%d auth = nil", index) + } + if got.ID != "a" { + t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, "a") + } + } +} + +func TestSchedulerPick_PromotesExpiredCooldownBeforePick(t *testing.T) { + t.Parallel() + + model := "gemini-2.5-pro" + registerSchedulerModels(t, "gemini", model, "cooldown-expired") + scheduler := newSchedulerForTest( + &RoundRobinSelector{}, + &Auth{ + ID: "cooldown-expired", + Provider: "gemini", + ModelStates: map[string]*ModelState{ + model: { + Status: StatusError, + Unavailable: true, + NextRetryAfter: time.Now().Add(-1 * time.Second), + }, + }, + }, + ) + + got, errPick := scheduler.pickSingle(context.Background(), "gemini", model, cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickSingle() error = %v", errPick) + } + if got == nil { + t.Fatalf("pickSingle() auth = nil") + } + if got.ID != "cooldown-expired" { + t.Fatalf("pickSingle() auth.ID = %q, want %q", got.ID, "cooldown-expired") + } +} + +func TestSchedulerPick_GeminiVirtualParentUsesTwoLevelRotation(t *testing.T) { + t.Parallel() + + registerSchedulerModels(t, "gemini-cli", "gemini-2.5-pro", "cred-a::proj-1", "cred-a::proj-2", "cred-b::proj-1", "cred-b::proj-2") + scheduler := newSchedulerForTest( + &RoundRobinSelector{}, + &Auth{ID: "cred-a::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}}, + &Auth{ID: "cred-a::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}}, + &Auth{ID: "cred-b::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}}, + &Auth{ID: "cred-b::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}}, + ) + + wantParents := []string{"cred-a", "cred-b", "cred-a", "cred-b"} + wantIDs := []string{"cred-a::proj-1", "cred-b::proj-1", "cred-a::proj-2", "cred-b::proj-2"} + for index := range wantIDs { + got, errPick := scheduler.pickSingle(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickSingle() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickSingle() #%d auth = nil", index) + } + if got.ID != wantIDs[index] { + t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index]) + } + if got.Attributes["gemini_virtual_parent"] != wantParents[index] { + t.Fatalf("pickSingle() #%d parent = %q, want %q", index, got.Attributes["gemini_virtual_parent"], wantParents[index]) + } + } +} + +func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T) { + t.Parallel() + + scheduler := newSchedulerForTest( + &RoundRobinSelector{}, + &Auth{ID: "codex-http", Provider: "codex"}, + &Auth{ID: "codex-ws-a", Provider: "codex", Attributes: map[string]string{"websockets": "true"}}, + &Auth{ID: "codex-ws-b", Provider: "codex", Attributes: map[string]string{"websockets": "true"}}, + ) + + ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background()) + want := []string{"codex-ws-a", "codex-ws-b", "codex-ws-a"} + for index, wantID := range want { + got, errPick := scheduler.pickSingle(ctx, "codex", "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickSingle() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickSingle() #%d auth = nil", index) + } + if got.ID != wantID { + t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID) + } + } +} + +func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) { + t.Parallel() + + scheduler := newSchedulerForTest( + &RoundRobinSelector{}, + &Auth{ID: "gemini-a", Provider: "gemini"}, + &Auth{ID: "gemini-b", Provider: "gemini"}, + &Auth{ID: "claude-a", Provider: "claude"}, + ) + + wantProviders := []string{"gemini", "claude", "gemini", "claude"} + wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"} + for index := range wantProviders { + got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickMixed() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickMixed() #%d auth = nil", index) + } + if provider != wantProviders[index] { + t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index]) + } + if got.ID != wantIDs[index] { + t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index]) + } + } +} + +func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, &RoundRobinSelector{}, nil) + manager.executors["gemini"] = schedulerTestExecutor{} + manager.executors["claude"] = schedulerTestExecutor{} + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(gemini-a) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(gemini-b) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil { + t.Fatalf("Register(claude-a) error = %v", errRegister) + } + + wantProviders := []string{"gemini", "claude", "gemini", "claude"} + wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"} + for index := range wantProviders { + got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{}) + if errPick != nil { + t.Fatalf("pickNextMixed() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickNextMixed() #%d auth = nil", index) + } + if provider != wantProviders[index] { + t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index]) + } + if got.ID != wantIDs[index] { + t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index]) + } + } +} + +func TestManagerCustomSelector_FallsBackToLegacyPath(t *testing.T) { + t.Parallel() + + selector := &trackingSelector{} + manager := NewManager(nil, selector, nil) + manager.executors["gemini"] = schedulerTestExecutor{} + manager.auths["auth-a"] = &Auth{ID: "auth-a", Provider: "gemini"} + manager.auths["auth-b"] = &Auth{ID: "auth-b", Provider: "gemini"} + + got, _, errPick := manager.pickNext(context.Background(), "gemini", "", cliproxyexecutor.Options{}, map[string]struct{}{}) + if errPick != nil { + t.Fatalf("pickNext() error = %v", errPick) + } + if got == nil { + t.Fatalf("pickNext() auth = nil") + } + if selector.calls != 1 { + t.Fatalf("selector.calls = %d, want %d", selector.calls, 1) + } + if len(selector.lastAuthID) != 2 { + t.Fatalf("len(selector.lastAuthID) = %d, want %d", len(selector.lastAuthID), 2) + } + if got.ID != selector.lastAuthID[len(selector.lastAuthID)-1] { + t.Fatalf("pickNext() auth.ID = %q, want selector-picked %q", got.ID, selector.lastAuthID[len(selector.lastAuthID)-1]) + } +} + +func TestManager_InitializesSchedulerForBuiltInSelector(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, &RoundRobinSelector{}, nil) + if manager.scheduler == nil { + t.Fatalf("manager.scheduler = nil") + } + if manager.scheduler.strategy != schedulerStrategyRoundRobin { + t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyRoundRobin) + } + + manager.SetSelector(&FillFirstSelector{}) + if manager.scheduler.strategy != schedulerStrategyFillFirst { + t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyFillFirst) + } +} + +func TestManager_SchedulerTracksRegisterAndUpdate(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, &RoundRobinSelector{}, nil) + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(auth-b) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(auth-a) error = %v", errRegister) + } + + got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("scheduler.pickSingle() error = %v", errPick) + } + if got == nil || got.ID != "auth-a" { + t.Fatalf("scheduler.pickSingle() auth = %v, want auth-a", got) + } + + if _, errUpdate := manager.Update(context.Background(), &Auth{ID: "auth-a", Provider: "gemini", Disabled: true}); errUpdate != nil { + t.Fatalf("Update(auth-a) error = %v", errUpdate) + } + + got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("scheduler.pickSingle() after update error = %v", errPick) + } + if got == nil || got.ID != "auth-b" { + t.Fatalf("scheduler.pickSingle() after update auth = %v, want auth-b", got) + } +} + +func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, &RoundRobinSelector{}, nil) + manager.executors["gemini"] = schedulerTestExecutor{} + manager.executors["claude"] = schedulerTestExecutor{} + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(gemini-a) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(gemini-b) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil { + t.Fatalf("Register(claude-a) error = %v", errRegister) + } + + wantProviders := []string{"gemini", "claude", "gemini", "claude"} + wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"} + for index := range wantProviders { + got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickNextMixed() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickNextMixed() #%d auth = nil", index) + } + if provider != wantProviders[index] { + t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index]) + } + if got.ID != wantIDs[index] { + t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index]) + } + } +} + +func TestManager_PickNextMixed_SkipsProvidersWithoutExecutors(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, &RoundRobinSelector{}, nil) + manager.executors["claude"] = schedulerTestExecutor{} + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(gemini-a) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil { + t.Fatalf("Register(claude-a) error = %v", errRegister) + } + + got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickNextMixed() error = %v", errPick) + } + if got == nil { + t.Fatalf("pickNextMixed() auth = nil") + } + if provider != "claude" { + t.Fatalf("pickNextMixed() provider = %q, want %q", provider, "claude") + } + if got.ID != "claude-a" { + t.Fatalf("pickNextMixed() auth.ID = %q, want %q", got.ID, "claude-a") + } +} + +func TestManager_SchedulerTracksMarkResultCooldownAndRecovery(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, &RoundRobinSelector{}, nil) + reg := registry.GetGlobalRegistry() + reg.RegisterClient("auth-a", "gemini", []*registry.ModelInfo{{ID: "test-model"}}) + reg.RegisterClient("auth-b", "gemini", []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + reg.UnregisterClient("auth-a") + reg.UnregisterClient("auth-b") + }) + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(auth-a) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(auth-b) error = %v", errRegister) + } + + manager.MarkResult(context.Background(), Result{ + AuthID: "auth-a", + Provider: "gemini", + Model: "test-model", + Success: false, + Error: &Error{HTTPStatus: 429, Message: "quota"}, + }) + + got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("scheduler.pickSingle() after cooldown error = %v", errPick) + } + if got == nil || got.ID != "auth-b" { + t.Fatalf("scheduler.pickSingle() after cooldown auth = %v, want auth-b", got) + } + + manager.MarkResult(context.Background(), Result{ + AuthID: "auth-a", + Provider: "gemini", + Model: "test-model", + Success: true, + }) + + seen := make(map[string]struct{}, 2) + for index := 0; index < 2; index++ { + got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("scheduler.pickSingle() after recovery #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("scheduler.pickSingle() after recovery #%d auth = nil", index) + } + seen[got.ID] = struct{}{} + } + if len(seen) != 2 { + t.Fatalf("len(seen) = %d, want %d", len(seen), 2) + } +}