mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-09 15:25:17 +00:00
test(auth-scheduler): add unit tests and scheduler implementation
- Added comprehensive unit tests for `authScheduler` and related components. - Implemented `authScheduler` with support for Round Robin, Fill First, and custom selector strategies. - Improved tracking of auth states, cooldowns, and recovery logic in scheduler.
This commit is contained in:
@@ -134,6 +134,7 @@ type Manager struct {
|
||||
hook Hook
|
||||
mu sync.RWMutex
|
||||
auths map[string]*Auth
|
||||
scheduler *authScheduler
|
||||
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
|
||||
providerOffsets map[string]int
|
||||
|
||||
@@ -185,9 +186,33 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
||||
// atomic.Value requires non-nil initial value.
|
||||
manager.runtimeConfig.Store(&internalconfig.Config{})
|
||||
manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil))
|
||||
manager.scheduler = newAuthScheduler(selector)
|
||||
return manager
|
||||
}
|
||||
|
||||
func isBuiltInSelector(selector Selector) bool {
|
||||
switch selector.(type) {
|
||||
case *RoundRobinSelector, *FillFirstSelector:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) syncSchedulerFromSnapshot(auths []*Auth) {
|
||||
if m == nil || m.scheduler == nil {
|
||||
return
|
||||
}
|
||||
m.scheduler.rebuild(auths)
|
||||
}
|
||||
|
||||
func (m *Manager) syncScheduler() {
|
||||
if m == nil || m.scheduler == nil {
|
||||
return
|
||||
}
|
||||
m.syncSchedulerFromSnapshot(m.snapshotAuths())
|
||||
}
|
||||
|
||||
func (m *Manager) SetSelector(selector Selector) {
|
||||
if m == nil {
|
||||
return
|
||||
@@ -198,6 +223,10 @@ func (m *Manager) SetSelector(selector Selector) {
|
||||
m.mu.Lock()
|
||||
m.selector = selector
|
||||
m.mu.Unlock()
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.setSelector(selector)
|
||||
m.syncScheduler()
|
||||
}
|
||||
}
|
||||
|
||||
// SetStore swaps the underlying persistence store.
|
||||
@@ -759,10 +788,14 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
auth.ID = uuid.NewString()
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
authClone := auth.Clone()
|
||||
m.mu.Lock()
|
||||
m.auths[auth.ID] = auth.Clone()
|
||||
m.auths[auth.ID] = authClone
|
||||
m.mu.Unlock()
|
||||
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.upsertAuth(authClone)
|
||||
}
|
||||
_ = m.persist(ctx, auth)
|
||||
m.hook.OnAuthRegistered(ctx, auth.Clone())
|
||||
return auth.Clone(), nil
|
||||
@@ -784,9 +817,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
}
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
m.auths[auth.ID] = auth.Clone()
|
||||
authClone := auth.Clone()
|
||||
m.auths[auth.ID] = authClone
|
||||
m.mu.Unlock()
|
||||
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.upsertAuth(authClone)
|
||||
}
|
||||
_ = m.persist(ctx, auth)
|
||||
m.hook.OnAuthUpdated(ctx, auth.Clone())
|
||||
return auth.Clone(), nil
|
||||
@@ -795,12 +832,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
// Load resets manager state from the backing store.
|
||||
func (m *Manager) Load(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.store == nil {
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
items, err := m.store.List(ctx)
|
||||
if err != nil {
|
||||
m.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
m.auths = make(map[string]*Auth, len(items))
|
||||
@@ -816,6 +854,8 @@ func (m *Manager) Load(ctx context.Context) error {
|
||||
cfg = &internalconfig.Config{}
|
||||
}
|
||||
m.rebuildAPIKeyModelAliasLocked(cfg)
|
||||
m.mu.Unlock()
|
||||
m.syncScheduler()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1531,6 +1571,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
suspendReason := ""
|
||||
clearModelQuota := false
|
||||
setModelQuota := false
|
||||
var authSnapshot *Auth
|
||||
|
||||
m.mu.Lock()
|
||||
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
|
||||
@@ -1624,8 +1665,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
}
|
||||
|
||||
_ = m.persist(ctx, auth)
|
||||
authSnapshot = auth.Clone()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
if m.scheduler != nil && authSnapshot != nil {
|
||||
m.scheduler.upsertAuth(authSnapshot)
|
||||
}
|
||||
|
||||
if clearModelQuota && result.Model != "" {
|
||||
registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model)
|
||||
@@ -1982,7 +2027,25 @@ func (m *Manager) CloseExecutionSession(sessionID string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
||||
func (m *Manager) useSchedulerFastPath() bool {
|
||||
if m == nil || m.scheduler == nil {
|
||||
return false
|
||||
}
|
||||
return isBuiltInSelector(m.selector)
|
||||
}
|
||||
|
||||
func shouldRetrySchedulerPick(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var authErr *Error
|
||||
if !errors.As(err, &authErr) || authErr == nil {
|
||||
return false
|
||||
}
|
||||
return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable"
|
||||
}
|
||||
|
||||
func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
|
||||
m.mu.RLock()
|
||||
@@ -2042,7 +2105,38 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
||||
return authCopy, executor, nil
|
||||
}
|
||||
|
||||
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
||||
if !m.useSchedulerFastPath() {
|
||||
return m.pickNextLegacy(ctx, provider, model, opts, tried)
|
||||
}
|
||||
executor, okExecutor := m.Executor(provider)
|
||||
if !okExecutor {
|
||||
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||
}
|
||||
selected, errPick := m.scheduler.pickSingle(ctx, provider, model, opts, tried)
|
||||
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
|
||||
m.syncScheduler()
|
||||
selected, errPick = m.scheduler.pickSingle(ctx, provider, model, opts, tried)
|
||||
}
|
||||
if errPick != nil {
|
||||
return nil, nil, errPick
|
||||
}
|
||||
if selected == nil {
|
||||
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
|
||||
}
|
||||
authCopy := selected.Clone()
|
||||
if !selected.indexAssigned {
|
||||
m.mu.Lock()
|
||||
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
|
||||
current.EnsureIndex()
|
||||
authCopy = current.Clone()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
return authCopy, executor, nil
|
||||
}
|
||||
|
||||
func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
|
||||
providerSet := make(map[string]struct{}, len(providers))
|
||||
@@ -2125,6 +2219,58 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
|
||||
return authCopy, executor, providerKey, nil
|
||||
}
|
||||
|
||||
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||
if !m.useSchedulerFastPath() {
|
||||
return m.pickNextMixedLegacy(ctx, providers, model, opts, tried)
|
||||
}
|
||||
|
||||
eligibleProviders := make([]string, 0, len(providers))
|
||||
seenProviders := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
providerKey := strings.TrimSpace(strings.ToLower(provider))
|
||||
if providerKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, seen := seenProviders[providerKey]; seen {
|
||||
continue
|
||||
}
|
||||
if _, okExecutor := m.Executor(providerKey); !okExecutor {
|
||||
continue
|
||||
}
|
||||
seenProviders[providerKey] = struct{}{}
|
||||
eligibleProviders = append(eligibleProviders, providerKey)
|
||||
}
|
||||
if len(eligibleProviders) == 0 {
|
||||
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
selected, providerKey, errPick := m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
|
||||
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
|
||||
m.syncScheduler()
|
||||
selected, providerKey, errPick = m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
|
||||
}
|
||||
if errPick != nil {
|
||||
return nil, nil, "", errPick
|
||||
}
|
||||
if selected == nil {
|
||||
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
|
||||
}
|
||||
executor, okExecutor := m.Executor(providerKey)
|
||||
if !okExecutor {
|
||||
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||
}
|
||||
authCopy := selected.Clone()
|
||||
if !selected.indexAssigned {
|
||||
m.mu.Lock()
|
||||
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
|
||||
current.EnsureIndex()
|
||||
authCopy = current.Clone()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
return authCopy, executor, providerKey, nil
|
||||
}
|
||||
|
||||
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
||||
if m.store == nil || auth == nil {
|
||||
return nil
|
||||
@@ -2476,6 +2622,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
|
||||
current.LastError = &Error{Message: err.Error()}
|
||||
m.auths[id] = current
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.upsertAuth(current.Clone())
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
return
|
||||
|
||||
851
sdk/cliproxy/auth/scheduler.go
Normal file
851
sdk/cliproxy/auth/scheduler.go
Normal file
@@ -0,0 +1,851 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
// schedulerStrategy identifies which built-in routing semantics the scheduler should apply.
|
||||
type schedulerStrategy int
|
||||
|
||||
const (
|
||||
schedulerStrategyCustom schedulerStrategy = iota
|
||||
schedulerStrategyRoundRobin
|
||||
schedulerStrategyFillFirst
|
||||
)
|
||||
|
||||
// scheduledState describes how an auth currently participates in a model shard.
|
||||
type scheduledState int
|
||||
|
||||
const (
|
||||
scheduledStateReady scheduledState = iota
|
||||
scheduledStateCooldown
|
||||
scheduledStateBlocked
|
||||
scheduledStateDisabled
|
||||
)
|
||||
|
||||
// authScheduler keeps the incremental provider/model scheduling state used by Manager.
|
||||
type authScheduler struct {
|
||||
mu sync.Mutex
|
||||
strategy schedulerStrategy
|
||||
providers map[string]*providerScheduler
|
||||
authProviders map[string]string
|
||||
mixedCursors map[string]int
|
||||
}
|
||||
|
||||
// providerScheduler stores auth metadata and model shards for a single provider.
|
||||
type providerScheduler struct {
|
||||
providerKey string
|
||||
auths map[string]*scheduledAuthMeta
|
||||
modelShards map[string]*modelScheduler
|
||||
}
|
||||
|
||||
// scheduledAuthMeta stores the immutable scheduling fields derived from an auth snapshot.
|
||||
type scheduledAuthMeta struct {
|
||||
auth *Auth
|
||||
providerKey string
|
||||
priority int
|
||||
virtualParent string
|
||||
websocketEnabled bool
|
||||
supportedModelSet map[string]struct{}
|
||||
}
|
||||
|
||||
// modelScheduler tracks ready and blocked auths for one provider/model combination.
|
||||
type modelScheduler struct {
|
||||
modelKey string
|
||||
entries map[string]*scheduledAuth
|
||||
priorityOrder []int
|
||||
readyByPriority map[int]*readyBucket
|
||||
blocked cooldownQueue
|
||||
}
|
||||
|
||||
// scheduledAuth stores the runtime scheduling state for a single auth inside a model shard.
|
||||
type scheduledAuth struct {
|
||||
meta *scheduledAuthMeta
|
||||
auth *Auth
|
||||
state scheduledState
|
||||
nextRetryAt time.Time
|
||||
}
|
||||
|
||||
// readyBucket keeps the ready views for one priority level.
|
||||
type readyBucket struct {
|
||||
all readyView
|
||||
ws readyView
|
||||
}
|
||||
|
||||
// readyView holds the selection order for flat or grouped round-robin traversal.
|
||||
type readyView struct {
|
||||
flat []*scheduledAuth
|
||||
cursor int
|
||||
parentOrder []string
|
||||
parentCursor int
|
||||
children map[string]*childBucket
|
||||
}
|
||||
|
||||
// childBucket keeps the per-parent rotation state for grouped Gemini virtual auths.
|
||||
type childBucket struct {
|
||||
items []*scheduledAuth
|
||||
cursor int
|
||||
}
|
||||
|
||||
// cooldownQueue is the blocked auth collection ordered by next retry time during rebuilds.
|
||||
type cooldownQueue []*scheduledAuth
|
||||
|
||||
// newAuthScheduler constructs an empty scheduler configured for the supplied selector strategy.
|
||||
func newAuthScheduler(selector Selector) *authScheduler {
|
||||
return &authScheduler{
|
||||
strategy: selectorStrategy(selector),
|
||||
providers: make(map[string]*providerScheduler),
|
||||
authProviders: make(map[string]string),
|
||||
mixedCursors: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// selectorStrategy maps a selector implementation to the scheduler semantics it should emulate.
|
||||
func selectorStrategy(selector Selector) schedulerStrategy {
|
||||
switch selector.(type) {
|
||||
case *FillFirstSelector:
|
||||
return schedulerStrategyFillFirst
|
||||
case nil, *RoundRobinSelector:
|
||||
return schedulerStrategyRoundRobin
|
||||
default:
|
||||
return schedulerStrategyCustom
|
||||
}
|
||||
}
|
||||
|
||||
// setSelector updates the active built-in strategy and resets mixed-provider cursors.
|
||||
func (s *authScheduler) setSelector(selector Selector) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.strategy = selectorStrategy(selector)
|
||||
clear(s.mixedCursors)
|
||||
}
|
||||
|
||||
// rebuild recreates the complete scheduler state from an auth snapshot.
|
||||
func (s *authScheduler) rebuild(auths []*Auth) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.providers = make(map[string]*providerScheduler)
|
||||
s.authProviders = make(map[string]string)
|
||||
s.mixedCursors = make(map[string]int)
|
||||
now := time.Now()
|
||||
for _, auth := range auths {
|
||||
s.upsertAuthLocked(auth, now)
|
||||
}
|
||||
}
|
||||
|
||||
// upsertAuth incrementally synchronizes one auth into the scheduler.
|
||||
func (s *authScheduler) upsertAuth(auth *Auth) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.upsertAuthLocked(auth, time.Now())
|
||||
}
|
||||
|
||||
// removeAuth deletes one auth from every scheduler shard that references it.
|
||||
func (s *authScheduler) removeAuth(authID string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.removeAuthLocked(authID)
|
||||
}
|
||||
|
||||
// pickSingle returns the next auth for a single provider/model request using scheduler state.
|
||||
func (s *authScheduler) pickSingle(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, error) {
|
||||
if s == nil {
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
providerKey := strings.ToLower(strings.TrimSpace(provider))
|
||||
modelKey := canonicalModelKey(model)
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == ""
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||
if shard == nil {
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
predicate := func(entry *scheduledAuth) bool {
|
||||
if entry == nil || entry.auth == nil {
|
||||
return false
|
||||
}
|
||||
if pinnedAuthID != "" && entry.auth.ID != pinnedAuthID {
|
||||
return false
|
||||
}
|
||||
if len(tried) > 0 {
|
||||
if _, ok := tried[entry.auth.ID]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
if picked := shard.pickReadyLocked(preferWebsocket, s.strategy, predicate); picked != nil {
|
||||
return picked, nil
|
||||
}
|
||||
return nil, shard.unavailableErrorLocked(provider, model, predicate)
|
||||
}
|
||||
|
||||
// pickMixed returns the next auth and provider for a mixed-provider request.
|
||||
func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) {
|
||||
if s == nil {
|
||||
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
normalized := normalizeProviderKeys(providers)
|
||||
if len(normalized) == 0 {
|
||||
return nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
modelKey := canonicalModelKey(model)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if pinnedAuthID != "" {
|
||||
providerKey := s.authProviders[pinnedAuthID]
|
||||
if providerKey == "" || !containsProvider(normalized, providerKey) {
|
||||
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||
predicate := func(entry *scheduledAuth) bool {
|
||||
if entry == nil || entry.auth == nil || entry.auth.ID != pinnedAuthID {
|
||||
return false
|
||||
}
|
||||
if len(tried) == 0 {
|
||||
return true
|
||||
}
|
||||
_, ok := tried[pinnedAuthID]
|
||||
return !ok
|
||||
}
|
||||
if picked := shard.pickReadyLocked(false, s.strategy, predicate); picked != nil {
|
||||
return picked, providerKey, nil
|
||||
}
|
||||
return nil, "", shard.unavailableErrorLocked("mixed", model, predicate)
|
||||
}
|
||||
|
||||
if s.strategy == schedulerStrategyFillFirst {
|
||||
for _, providerKey := range normalized {
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
continue
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
picked := shard.pickReadyLocked(false, s.strategy, triedPredicate(tried))
|
||||
if picked != nil {
|
||||
return picked, providerKey, nil
|
||||
}
|
||||
}
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
}
|
||||
|
||||
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
|
||||
start := 0
|
||||
if len(normalized) > 0 {
|
||||
start = s.mixedCursors[cursorKey] % len(normalized)
|
||||
}
|
||||
for offset := 0; offset < len(normalized); offset++ {
|
||||
providerIndex := (start + offset) % len(normalized)
|
||||
providerKey := normalized[providerIndex]
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
continue
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
picked := shard.pickReadyLocked(false, schedulerStrategyRoundRobin, triedPredicate(tried))
|
||||
if picked == nil {
|
||||
continue
|
||||
}
|
||||
s.mixedCursors[cursorKey] = providerIndex + 1
|
||||
return picked, providerKey, nil
|
||||
}
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
}
|
||||
|
||||
// mixedUnavailableErrorLocked synthesizes the mixed-provider cooldown or unavailable error.
|
||||
func (s *authScheduler) mixedUnavailableErrorLocked(providers []string, model string, tried map[string]struct{}) error {
|
||||
now := time.Now()
|
||||
total := 0
|
||||
cooldownCount := 0
|
||||
earliest := time.Time{}
|
||||
for _, providerKey := range providers {
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
continue
|
||||
}
|
||||
shard := providerState.ensureModelLocked(canonicalModelKey(model), now)
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
localTotal, localCooldownCount, localEarliest := shard.availabilitySummaryLocked(triedPredicate(tried))
|
||||
total += localTotal
|
||||
cooldownCount += localCooldownCount
|
||||
if !localEarliest.IsZero() && (earliest.IsZero() || localEarliest.Before(earliest)) {
|
||||
earliest = localEarliest
|
||||
}
|
||||
}
|
||||
if total == 0 {
|
||||
return &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
if cooldownCount == total && !earliest.IsZero() {
|
||||
resetIn := earliest.Sub(now)
|
||||
if resetIn < 0 {
|
||||
resetIn = 0
|
||||
}
|
||||
return newModelCooldownError(model, "", resetIn)
|
||||
}
|
||||
return &Error{Code: "auth_unavailable", Message: "no auth available"}
|
||||
}
|
||||
|
||||
// triedPredicate builds a filter that excludes auths already attempted for the current request.
|
||||
func triedPredicate(tried map[string]struct{}) func(*scheduledAuth) bool {
|
||||
if len(tried) == 0 {
|
||||
return func(entry *scheduledAuth) bool { return entry != nil && entry.auth != nil }
|
||||
}
|
||||
return func(entry *scheduledAuth) bool {
|
||||
if entry == nil || entry.auth == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := tried[entry.auth.ID]
|
||||
return !ok
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeProviderKeys lowercases, trims, and de-duplicates provider keys while preserving order.
|
||||
func normalizeProviderKeys(providers []string) []string {
|
||||
seen := make(map[string]struct{}, len(providers))
|
||||
out := make([]string, 0, len(providers))
|
||||
for _, provider := range providers {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(provider))
|
||||
if providerKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[providerKey]; ok {
|
||||
continue
|
||||
}
|
||||
seen[providerKey] = struct{}{}
|
||||
out = append(out, providerKey)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// containsProvider reports whether provider is present in the normalized provider list.
|
||||
func containsProvider(providers []string, provider string) bool {
|
||||
for _, candidate := range providers {
|
||||
if candidate == provider {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// upsertAuthLocked updates one auth in-place while the scheduler mutex is held.
|
||||
func (s *authScheduler) upsertAuthLocked(auth *Auth, now time.Time) {
|
||||
if auth == nil {
|
||||
return
|
||||
}
|
||||
authID := strings.TrimSpace(auth.ID)
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if authID == "" || providerKey == "" || auth.Disabled {
|
||||
s.removeAuthLocked(authID)
|
||||
return
|
||||
}
|
||||
if previousProvider := s.authProviders[authID]; previousProvider != "" && previousProvider != providerKey {
|
||||
if previousState := s.providers[previousProvider]; previousState != nil {
|
||||
previousState.removeAuthLocked(authID)
|
||||
}
|
||||
}
|
||||
meta := buildScheduledAuthMeta(auth)
|
||||
s.authProviders[authID] = providerKey
|
||||
s.ensureProviderLocked(providerKey).upsertAuthLocked(meta, now)
|
||||
}
|
||||
|
||||
// removeAuthLocked removes one auth from the scheduler while the scheduler mutex is held.
|
||||
func (s *authScheduler) removeAuthLocked(authID string) {
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
if providerKey := s.authProviders[authID]; providerKey != "" {
|
||||
if providerState := s.providers[providerKey]; providerState != nil {
|
||||
providerState.removeAuthLocked(authID)
|
||||
}
|
||||
delete(s.authProviders, authID)
|
||||
}
|
||||
}
|
||||
|
||||
// ensureProviderLocked returns the provider scheduler for providerKey, creating it when needed.
|
||||
func (s *authScheduler) ensureProviderLocked(providerKey string) *providerScheduler {
|
||||
if s.providers == nil {
|
||||
s.providers = make(map[string]*providerScheduler)
|
||||
}
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
providerState = &providerScheduler{
|
||||
providerKey: providerKey,
|
||||
auths: make(map[string]*scheduledAuthMeta),
|
||||
modelShards: make(map[string]*modelScheduler),
|
||||
}
|
||||
s.providers[providerKey] = providerState
|
||||
}
|
||||
return providerState
|
||||
}
|
||||
|
||||
// buildScheduledAuthMeta extracts the scheduling metadata needed for shard bookkeeping.
|
||||
func buildScheduledAuthMeta(auth *Auth) *scheduledAuthMeta {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
virtualParent := ""
|
||||
if auth.Attributes != nil {
|
||||
virtualParent = strings.TrimSpace(auth.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
return &scheduledAuthMeta{
|
||||
auth: auth,
|
||||
providerKey: providerKey,
|
||||
priority: authPriority(auth),
|
||||
virtualParent: virtualParent,
|
||||
websocketEnabled: authWebsocketsEnabled(auth),
|
||||
supportedModelSet: supportedModelSetForAuth(auth.ID),
|
||||
}
|
||||
}
|
||||
|
||||
// supportedModelSetForAuth snapshots the registry models currently registered for an auth.
|
||||
func supportedModelSetForAuth(authID string) map[string]struct{} {
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" {
|
||||
return nil
|
||||
}
|
||||
models := registry.GetGlobalRegistry().GetModelsForClient(authID)
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
set := make(map[string]struct{}, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
modelKey := canonicalModelKey(model.ID)
|
||||
if modelKey == "" {
|
||||
continue
|
||||
}
|
||||
set[modelKey] = struct{}{}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
// upsertAuthLocked updates every existing model shard that can reference the auth metadata.
|
||||
func (p *providerScheduler) upsertAuthLocked(meta *scheduledAuthMeta, now time.Time) {
|
||||
if p == nil || meta == nil || meta.auth == nil {
|
||||
return
|
||||
}
|
||||
p.auths[meta.auth.ID] = meta
|
||||
for modelKey, shard := range p.modelShards {
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
if !meta.supportsModel(modelKey) {
|
||||
shard.removeEntryLocked(meta.auth.ID)
|
||||
continue
|
||||
}
|
||||
shard.upsertEntryLocked(meta, now)
|
||||
}
|
||||
}
|
||||
|
||||
// removeAuthLocked removes an auth from all model shards owned by the provider scheduler.
|
||||
func (p *providerScheduler) removeAuthLocked(authID string) {
|
||||
if p == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
delete(p.auths, authID)
|
||||
for _, shard := range p.modelShards {
|
||||
if shard != nil {
|
||||
shard.removeEntryLocked(authID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensureModelLocked returns the shard for modelKey, building it lazily from provider auths.
|
||||
func (p *providerScheduler) ensureModelLocked(modelKey string, now time.Time) *modelScheduler {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
modelKey = canonicalModelKey(modelKey)
|
||||
if shard, ok := p.modelShards[modelKey]; ok && shard != nil {
|
||||
shard.promoteExpiredLocked(now)
|
||||
return shard
|
||||
}
|
||||
shard := &modelScheduler{
|
||||
modelKey: modelKey,
|
||||
entries: make(map[string]*scheduledAuth),
|
||||
readyByPriority: make(map[int]*readyBucket),
|
||||
}
|
||||
for _, meta := range p.auths {
|
||||
if meta == nil || !meta.supportsModel(modelKey) {
|
||||
continue
|
||||
}
|
||||
shard.upsertEntryLocked(meta, now)
|
||||
}
|
||||
p.modelShards[modelKey] = shard
|
||||
return shard
|
||||
}
|
||||
|
||||
// supportsModel reports whether the auth metadata currently supports modelKey.
|
||||
func (m *scheduledAuthMeta) supportsModel(modelKey string) bool {
|
||||
modelKey = canonicalModelKey(modelKey)
|
||||
if modelKey == "" {
|
||||
return true
|
||||
}
|
||||
if len(m.supportedModelSet) == 0 {
|
||||
return false
|
||||
}
|
||||
_, ok := m.supportedModelSet[modelKey]
|
||||
return ok
|
||||
}
|
||||
|
||||
// upsertEntryLocked updates or inserts one auth entry and rebuilds indexes when ordering changes.
|
||||
func (m *modelScheduler) upsertEntryLocked(meta *scheduledAuthMeta, now time.Time) {
|
||||
if m == nil || meta == nil || meta.auth == nil {
|
||||
return
|
||||
}
|
||||
entry, ok := m.entries[meta.auth.ID]
|
||||
if !ok || entry == nil {
|
||||
entry = &scheduledAuth{}
|
||||
m.entries[meta.auth.ID] = entry
|
||||
}
|
||||
previousState := entry.state
|
||||
previousNextRetryAt := entry.nextRetryAt
|
||||
previousPriority := 0
|
||||
previousParent := ""
|
||||
previousWebsocketEnabled := false
|
||||
if entry.meta != nil {
|
||||
previousPriority = entry.meta.priority
|
||||
previousParent = entry.meta.virtualParent
|
||||
previousWebsocketEnabled = entry.meta.websocketEnabled
|
||||
}
|
||||
|
||||
entry.meta = meta
|
||||
entry.auth = meta.auth
|
||||
entry.nextRetryAt = time.Time{}
|
||||
blocked, reason, next := isAuthBlockedForModel(meta.auth, m.modelKey, now)
|
||||
switch {
|
||||
case !blocked:
|
||||
entry.state = scheduledStateReady
|
||||
case reason == blockReasonCooldown:
|
||||
entry.state = scheduledStateCooldown
|
||||
entry.nextRetryAt = next
|
||||
case reason == blockReasonDisabled:
|
||||
entry.state = scheduledStateDisabled
|
||||
default:
|
||||
entry.state = scheduledStateBlocked
|
||||
entry.nextRetryAt = next
|
||||
}
|
||||
|
||||
if ok && previousState == entry.state && previousNextRetryAt.Equal(entry.nextRetryAt) && previousPriority == meta.priority && previousParent == meta.virtualParent && previousWebsocketEnabled == meta.websocketEnabled {
|
||||
return
|
||||
}
|
||||
m.rebuildIndexesLocked()
|
||||
}
|
||||
|
||||
// removeEntryLocked deletes one auth entry and rebuilds the shard indexes if needed.
|
||||
func (m *modelScheduler) removeEntryLocked(authID string) {
|
||||
if m == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := m.entries[authID]; !ok {
|
||||
return
|
||||
}
|
||||
delete(m.entries, authID)
|
||||
m.rebuildIndexesLocked()
|
||||
}
|
||||
|
||||
// promoteExpiredLocked reevaluates blocked auths whose retry time has elapsed.
|
||||
func (m *modelScheduler) promoteExpiredLocked(now time.Time) {
|
||||
if m == nil || len(m.blocked) == 0 {
|
||||
return
|
||||
}
|
||||
changed := false
|
||||
for _, entry := range m.blocked {
|
||||
if entry == nil || entry.auth == nil {
|
||||
continue
|
||||
}
|
||||
if entry.nextRetryAt.IsZero() || entry.nextRetryAt.After(now) {
|
||||
continue
|
||||
}
|
||||
blocked, reason, next := isAuthBlockedForModel(entry.auth, m.modelKey, now)
|
||||
switch {
|
||||
case !blocked:
|
||||
entry.state = scheduledStateReady
|
||||
entry.nextRetryAt = time.Time{}
|
||||
case reason == blockReasonCooldown:
|
||||
entry.state = scheduledStateCooldown
|
||||
entry.nextRetryAt = next
|
||||
case reason == blockReasonDisabled:
|
||||
entry.state = scheduledStateDisabled
|
||||
entry.nextRetryAt = time.Time{}
|
||||
default:
|
||||
entry.state = scheduledStateBlocked
|
||||
entry.nextRetryAt = next
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
if changed {
|
||||
m.rebuildIndexesLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// pickReadyLocked selects the next ready auth from the highest available priority bucket.
|
||||
func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
m.promoteExpiredLocked(time.Now())
|
||||
for _, priority := range m.priorityOrder {
|
||||
bucket := m.readyByPriority[priority]
|
||||
if bucket == nil {
|
||||
continue
|
||||
}
|
||||
view := &bucket.all
|
||||
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
||||
view = &bucket.ws
|
||||
}
|
||||
var picked *scheduledAuth
|
||||
if strategy == schedulerStrategyFillFirst {
|
||||
picked = view.pickFirst(predicate)
|
||||
} else {
|
||||
picked = view.pickRoundRobin(predicate)
|
||||
}
|
||||
if picked != nil && picked.auth != nil {
|
||||
return picked.auth
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
|
||||
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
|
||||
now := time.Now()
|
||||
total, cooldownCount, earliest := m.availabilitySummaryLocked(predicate)
|
||||
if total == 0 {
|
||||
return &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
if cooldownCount == total && !earliest.IsZero() {
|
||||
providerForError := provider
|
||||
if providerForError == "mixed" {
|
||||
providerForError = ""
|
||||
}
|
||||
resetIn := earliest.Sub(now)
|
||||
if resetIn < 0 {
|
||||
resetIn = 0
|
||||
}
|
||||
return newModelCooldownError(model, providerForError, resetIn)
|
||||
}
|
||||
return &Error{Code: "auth_unavailable", Message: "no auth available"}
|
||||
}
|
||||
|
||||
// availabilitySummaryLocked summarizes total candidates, cooldown count, and earliest retry time.
|
||||
func (m *modelScheduler) availabilitySummaryLocked(predicate func(*scheduledAuth) bool) (int, int, time.Time) {
|
||||
if m == nil {
|
||||
return 0, 0, time.Time{}
|
||||
}
|
||||
total := 0
|
||||
cooldownCount := 0
|
||||
earliest := time.Time{}
|
||||
for _, entry := range m.entries {
|
||||
if predicate != nil && !predicate(entry) {
|
||||
continue
|
||||
}
|
||||
total++
|
||||
if entry == nil || entry.auth == nil {
|
||||
continue
|
||||
}
|
||||
if entry.state != scheduledStateCooldown {
|
||||
continue
|
||||
}
|
||||
cooldownCount++
|
||||
if !entry.nextRetryAt.IsZero() && (earliest.IsZero() || entry.nextRetryAt.Before(earliest)) {
|
||||
earliest = entry.nextRetryAt
|
||||
}
|
||||
}
|
||||
return total, cooldownCount, earliest
|
||||
}
|
||||
|
||||
// rebuildIndexesLocked reconstructs ready and blocked views from the current entry map.
|
||||
func (m *modelScheduler) rebuildIndexesLocked() {
|
||||
m.readyByPriority = make(map[int]*readyBucket)
|
||||
m.priorityOrder = m.priorityOrder[:0]
|
||||
m.blocked = m.blocked[:0]
|
||||
priorityBuckets := make(map[int][]*scheduledAuth)
|
||||
for _, entry := range m.entries {
|
||||
if entry == nil || entry.auth == nil {
|
||||
continue
|
||||
}
|
||||
switch entry.state {
|
||||
case scheduledStateReady:
|
||||
priority := entry.meta.priority
|
||||
priorityBuckets[priority] = append(priorityBuckets[priority], entry)
|
||||
case scheduledStateCooldown, scheduledStateBlocked:
|
||||
m.blocked = append(m.blocked, entry)
|
||||
}
|
||||
}
|
||||
for priority, entries := range priorityBuckets {
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].auth.ID < entries[j].auth.ID
|
||||
})
|
||||
m.readyByPriority[priority] = buildReadyBucket(entries)
|
||||
m.priorityOrder = append(m.priorityOrder, priority)
|
||||
}
|
||||
sort.Slice(m.priorityOrder, func(i, j int) bool {
|
||||
return m.priorityOrder[i] > m.priorityOrder[j]
|
||||
})
|
||||
sort.Slice(m.blocked, func(i, j int) bool {
|
||||
left := m.blocked[i]
|
||||
right := m.blocked[j]
|
||||
if left == nil || right == nil {
|
||||
return left != nil
|
||||
}
|
||||
if left.nextRetryAt.Equal(right.nextRetryAt) {
|
||||
return left.auth.ID < right.auth.ID
|
||||
}
|
||||
if left.nextRetryAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
if right.nextRetryAt.IsZero() {
|
||||
return true
|
||||
}
|
||||
return left.nextRetryAt.Before(right.nextRetryAt)
|
||||
})
|
||||
}
|
||||
|
||||
// buildReadyBucket prepares the general and websocket-only ready views for one priority bucket.
|
||||
func buildReadyBucket(entries []*scheduledAuth) *readyBucket {
|
||||
bucket := &readyBucket{}
|
||||
bucket.all = buildReadyView(entries)
|
||||
wsEntries := make([]*scheduledAuth, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry != nil && entry.meta != nil && entry.meta.websocketEnabled {
|
||||
wsEntries = append(wsEntries, entry)
|
||||
}
|
||||
}
|
||||
bucket.ws = buildReadyView(wsEntries)
|
||||
return bucket
|
||||
}
|
||||
|
||||
// buildReadyView creates either a flat view or a grouped parent/child view for rotation.
|
||||
func buildReadyView(entries []*scheduledAuth) readyView {
|
||||
view := readyView{flat: append([]*scheduledAuth(nil), entries...)}
|
||||
if len(entries) == 0 {
|
||||
return view
|
||||
}
|
||||
groups := make(map[string][]*scheduledAuth)
|
||||
for _, entry := range entries {
|
||||
if entry == nil || entry.meta == nil || entry.meta.virtualParent == "" {
|
||||
return view
|
||||
}
|
||||
groups[entry.meta.virtualParent] = append(groups[entry.meta.virtualParent], entry)
|
||||
}
|
||||
if len(groups) <= 1 {
|
||||
return view
|
||||
}
|
||||
view.children = make(map[string]*childBucket, len(groups))
|
||||
view.parentOrder = make([]string, 0, len(groups))
|
||||
for parent := range groups {
|
||||
view.parentOrder = append(view.parentOrder, parent)
|
||||
}
|
||||
sort.Strings(view.parentOrder)
|
||||
for _, parent := range view.parentOrder {
|
||||
view.children[parent] = &childBucket{items: append([]*scheduledAuth(nil), groups[parent]...)}
|
||||
}
|
||||
return view
|
||||
}
|
||||
|
||||
// pickFirst returns the first ready entry that satisfies predicate without advancing cursors.
|
||||
func (v *readyView) pickFirst(predicate func(*scheduledAuth) bool) *scheduledAuth {
|
||||
for _, entry := range v.flat {
|
||||
if predicate == nil || predicate(entry) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pickRoundRobin returns the next ready entry using flat or grouped round-robin traversal.
|
||||
func (v *readyView) pickRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
|
||||
if len(v.parentOrder) > 1 && len(v.children) > 0 {
|
||||
return v.pickGroupedRoundRobin(predicate)
|
||||
}
|
||||
if len(v.flat) == 0 {
|
||||
return nil
|
||||
}
|
||||
start := 0
|
||||
if len(v.flat) > 0 {
|
||||
start = v.cursor % len(v.flat)
|
||||
}
|
||||
for offset := 0; offset < len(v.flat); offset++ {
|
||||
index := (start + offset) % len(v.flat)
|
||||
entry := v.flat[index]
|
||||
if predicate != nil && !predicate(entry) {
|
||||
continue
|
||||
}
|
||||
v.cursor = index + 1
|
||||
return entry
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pickGroupedRoundRobin rotates across parents first and then within the selected parent.
|
||||
func (v *readyView) pickGroupedRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
|
||||
start := 0
|
||||
if len(v.parentOrder) > 0 {
|
||||
start = v.parentCursor % len(v.parentOrder)
|
||||
}
|
||||
for offset := 0; offset < len(v.parentOrder); offset++ {
|
||||
parentIndex := (start + offset) % len(v.parentOrder)
|
||||
parent := v.parentOrder[parentIndex]
|
||||
child := v.children[parent]
|
||||
if child == nil || len(child.items) == 0 {
|
||||
continue
|
||||
}
|
||||
itemStart := child.cursor % len(child.items)
|
||||
for itemOffset := 0; itemOffset < len(child.items); itemOffset++ {
|
||||
itemIndex := (itemStart + itemOffset) % len(child.items)
|
||||
entry := child.items[itemIndex]
|
||||
if predicate != nil && !predicate(entry) {
|
||||
continue
|
||||
}
|
||||
child.cursor = itemIndex + 1
|
||||
v.parentCursor = parentIndex + 1
|
||||
return entry
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
197
sdk/cliproxy/auth/scheduler_benchmark_test.go
Normal file
197
sdk/cliproxy/auth/scheduler_benchmark_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
type schedulerBenchmarkExecutor struct {
|
||||
id string
|
||||
}
|
||||
|
||||
func (e schedulerBenchmarkExecutor) Identifier() string { return e.id }
|
||||
|
||||
func (e schedulerBenchmarkExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e schedulerBenchmarkExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e schedulerBenchmarkExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e schedulerBenchmarkExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e schedulerBenchmarkExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func benchmarkManagerSetup(b *testing.B, total int, mixed bool, withPriority bool) (*Manager, []string, string) {
|
||||
b.Helper()
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
providers := []string{"gemini"}
|
||||
manager.executors["gemini"] = schedulerBenchmarkExecutor{id: "gemini"}
|
||||
if mixed {
|
||||
providers = []string{"gemini", "claude"}
|
||||
manager.executors["claude"] = schedulerBenchmarkExecutor{id: "claude"}
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
model := "bench-model"
|
||||
for index := 0; index < total; index++ {
|
||||
provider := providers[0]
|
||||
if mixed && index%2 == 1 {
|
||||
provider = providers[1]
|
||||
}
|
||||
auth := &Auth{ID: fmt.Sprintf("bench-%s-%04d", provider, index), Provider: provider}
|
||||
if withPriority {
|
||||
priority := "0"
|
||||
if index%2 == 0 {
|
||||
priority = "10"
|
||||
}
|
||||
auth.Attributes = map[string]string{"priority": priority}
|
||||
}
|
||||
_, errRegister := manager.Register(context.Background(), auth)
|
||||
if errRegister != nil {
|
||||
b.Fatalf("Register(%s) error = %v", auth.ID, errRegister)
|
||||
}
|
||||
reg.RegisterClient(auth.ID, provider, []*registry.ModelInfo{{ID: model}})
|
||||
}
|
||||
manager.syncScheduler()
|
||||
b.Cleanup(func() {
|
||||
for index := 0; index < total; index++ {
|
||||
provider := providers[0]
|
||||
if mixed && index%2 == 1 {
|
||||
provider = providers[1]
|
||||
}
|
||||
reg.UnregisterClient(fmt.Sprintf("bench-%s-%04d", provider, index))
|
||||
}
|
||||
})
|
||||
|
||||
return manager, providers, model
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNext500(b *testing.B) {
|
||||
manager, _, model := benchmarkManagerSetup(b, 500, false, false)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||
if errPick != nil || auth == nil || exec == nil {
|
||||
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNext1000(b *testing.B) {
|
||||
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||
if errPick != nil || auth == nil || exec == nil {
|
||||
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNextPriority500(b *testing.B) {
|
||||
manager, _, model := benchmarkManagerSetup(b, 500, false, true)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||
if errPick != nil || auth == nil || exec == nil {
|
||||
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNextPriority1000(b *testing.B) {
|
||||
manager, _, model := benchmarkManagerSetup(b, 1000, false, true)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||
if errPick != nil || auth == nil || exec == nil {
|
||||
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNextMixed500(b *testing.B) {
|
||||
manager, providers, model := benchmarkManagerSetup(b, 500, true, false)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
|
||||
if errPick != nil || auth == nil || exec == nil || provider == "" {
|
||||
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNextAndMarkResult1000(b *testing.B) {
|
||||
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, _, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||
if errPick != nil || auth == nil {
|
||||
b.Fatalf("pickNext failed: auth=%v err=%v", auth, errPick)
|
||||
}
|
||||
manager.MarkResult(ctx, Result{AuthID: auth.ID, Provider: "gemini", Model: model, Success: true})
|
||||
}
|
||||
}
|
||||
468
sdk/cliproxy/auth/scheduler_test.go
Normal file
468
sdk/cliproxy/auth/scheduler_test.go
Normal file
@@ -0,0 +1,468 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
type schedulerTestExecutor struct{}
|
||||
|
||||
func (schedulerTestExecutor) Identifier() string { return "test" }
|
||||
|
||||
func (schedulerTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (schedulerTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (schedulerTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (schedulerTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (schedulerTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type trackingSelector struct {
|
||||
calls int
|
||||
lastAuthID []string
|
||||
}
|
||||
|
||||
func (s *trackingSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
s.calls++
|
||||
s.lastAuthID = s.lastAuthID[:0]
|
||||
for _, auth := range auths {
|
||||
s.lastAuthID = append(s.lastAuthID, auth.ID)
|
||||
}
|
||||
if len(auths) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return auths[len(auths)-1], nil
|
||||
}
|
||||
|
||||
func newSchedulerForTest(selector Selector, auths ...*Auth) *authScheduler {
|
||||
scheduler := newAuthScheduler(selector)
|
||||
scheduler.rebuild(auths)
|
||||
return scheduler
|
||||
}
|
||||
|
||||
func registerSchedulerModels(t *testing.T, provider string, model string, authIDs ...string) {
|
||||
t.Helper()
|
||||
reg := registry.GetGlobalRegistry()
|
||||
for _, authID := range authIDs {
|
||||
reg.RegisterClient(authID, provider, []*registry.ModelInfo{{ID: model}})
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
for _, authID := range authIDs {
|
||||
reg.UnregisterClient(authID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSchedulerPick_RoundRobinHighestPriority(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scheduler := newSchedulerForTest(
|
||||
&RoundRobinSelector{},
|
||||
&Auth{ID: "low", Provider: "gemini", Attributes: map[string]string{"priority": "0"}},
|
||||
&Auth{ID: "high-b", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
|
||||
&Auth{ID: "high-a", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
|
||||
)
|
||||
|
||||
want := []string{"high-a", "high-b", "high-a"}
|
||||
for index, wantID := range want {
|
||||
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickSingle() #%d auth = nil", index)
|
||||
}
|
||||
if got.ID != wantID {
|
||||
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_FillFirstSticksToFirstReady(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scheduler := newSchedulerForTest(
|
||||
&FillFirstSelector{},
|
||||
&Auth{ID: "b", Provider: "gemini"},
|
||||
&Auth{ID: "a", Provider: "gemini"},
|
||||
&Auth{ID: "c", Provider: "gemini"},
|
||||
)
|
||||
|
||||
for index := 0; index < 3; index++ {
|
||||
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickSingle() #%d auth = nil", index)
|
||||
}
|
||||
if got.ID != "a" {
|
||||
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, "a")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_PromotesExpiredCooldownBeforePick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := "gemini-2.5-pro"
|
||||
registerSchedulerModels(t, "gemini", model, "cooldown-expired")
|
||||
scheduler := newSchedulerForTest(
|
||||
&RoundRobinSelector{},
|
||||
&Auth{
|
||||
ID: "cooldown-expired",
|
||||
Provider: "gemini",
|
||||
ModelStates: map[string]*ModelState{
|
||||
model: {
|
||||
Status: StatusError,
|
||||
Unavailable: true,
|
||||
NextRetryAfter: time.Now().Add(-1 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
got, errPick := scheduler.pickSingle(context.Background(), "gemini", model, cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickSingle() error = %v", errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickSingle() auth = nil")
|
||||
}
|
||||
if got.ID != "cooldown-expired" {
|
||||
t.Fatalf("pickSingle() auth.ID = %q, want %q", got.ID, "cooldown-expired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_GeminiVirtualParentUsesTwoLevelRotation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerSchedulerModels(t, "gemini-cli", "gemini-2.5-pro", "cred-a::proj-1", "cred-a::proj-2", "cred-b::proj-1", "cred-b::proj-2")
|
||||
scheduler := newSchedulerForTest(
|
||||
&RoundRobinSelector{},
|
||||
&Auth{ID: "cred-a::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
|
||||
&Auth{ID: "cred-a::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
|
||||
&Auth{ID: "cred-b::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
|
||||
&Auth{ID: "cred-b::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
|
||||
)
|
||||
|
||||
wantParents := []string{"cred-a", "cred-b", "cred-a", "cred-b"}
|
||||
wantIDs := []string{"cred-a::proj-1", "cred-b::proj-1", "cred-a::proj-2", "cred-b::proj-2"}
|
||||
for index := range wantIDs {
|
||||
got, errPick := scheduler.pickSingle(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickSingle() #%d auth = nil", index)
|
||||
}
|
||||
if got.ID != wantIDs[index] {
|
||||
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||
}
|
||||
if got.Attributes["gemini_virtual_parent"] != wantParents[index] {
|
||||
t.Fatalf("pickSingle() #%d parent = %q, want %q", index, got.Attributes["gemini_virtual_parent"], wantParents[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scheduler := newSchedulerForTest(
|
||||
&RoundRobinSelector{},
|
||||
&Auth{ID: "codex-http", Provider: "codex"},
|
||||
&Auth{ID: "codex-ws-a", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
|
||||
&Auth{ID: "codex-ws-b", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
|
||||
)
|
||||
|
||||
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
|
||||
want := []string{"codex-ws-a", "codex-ws-b", "codex-ws-a"}
|
||||
for index, wantID := range want {
|
||||
got, errPick := scheduler.pickSingle(ctx, "codex", "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickSingle() #%d auth = nil", index)
|
||||
}
|
||||
if got.ID != wantID {
|
||||
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scheduler := newSchedulerForTest(
|
||||
&RoundRobinSelector{},
|
||||
&Auth{ID: "gemini-a", Provider: "gemini"},
|
||||
&Auth{ID: "gemini-b", Provider: "gemini"},
|
||||
&Auth{ID: "claude-a", Provider: "claude"},
|
||||
)
|
||||
|
||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||
for index := range wantProviders {
|
||||
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickMixed() #%d auth = nil", index)
|
||||
}
|
||||
if provider != wantProviders[index] {
|
||||
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
|
||||
}
|
||||
if got.ID != wantIDs[index] {
|
||||
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
manager.executors["gemini"] = schedulerTestExecutor{}
|
||||
manager.executors["claude"] = schedulerTestExecutor{}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(gemini-a) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(gemini-b) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
|
||||
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||
}
|
||||
|
||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||
for index := range wantProviders {
|
||||
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickNextMixed() #%d auth = nil", index)
|
||||
}
|
||||
if provider != wantProviders[index] {
|
||||
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
|
||||
}
|
||||
if got.ID != wantIDs[index] {
|
||||
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCustomSelector_FallsBackToLegacyPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &trackingSelector{}
|
||||
manager := NewManager(nil, selector, nil)
|
||||
manager.executors["gemini"] = schedulerTestExecutor{}
|
||||
manager.auths["auth-a"] = &Auth{ID: "auth-a", Provider: "gemini"}
|
||||
manager.auths["auth-b"] = &Auth{ID: "auth-b", Provider: "gemini"}
|
||||
|
||||
got, _, errPick := manager.pickNext(context.Background(), "gemini", "", cliproxyexecutor.Options{}, map[string]struct{}{})
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickNext() error = %v", errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickNext() auth = nil")
|
||||
}
|
||||
if selector.calls != 1 {
|
||||
t.Fatalf("selector.calls = %d, want %d", selector.calls, 1)
|
||||
}
|
||||
if len(selector.lastAuthID) != 2 {
|
||||
t.Fatalf("len(selector.lastAuthID) = %d, want %d", len(selector.lastAuthID), 2)
|
||||
}
|
||||
if got.ID != selector.lastAuthID[len(selector.lastAuthID)-1] {
|
||||
t.Fatalf("pickNext() auth.ID = %q, want selector-picked %q", got.ID, selector.lastAuthID[len(selector.lastAuthID)-1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_InitializesSchedulerForBuiltInSelector(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
if manager.scheduler == nil {
|
||||
t.Fatalf("manager.scheduler = nil")
|
||||
}
|
||||
if manager.scheduler.strategy != schedulerStrategyRoundRobin {
|
||||
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyRoundRobin)
|
||||
}
|
||||
|
||||
manager.SetSelector(&FillFirstSelector{})
|
||||
if manager.scheduler.strategy != schedulerStrategyFillFirst {
|
||||
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyFillFirst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_SchedulerTracksRegisterAndUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(auth-b) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(auth-a) error = %v", errRegister)
|
||||
}
|
||||
|
||||
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("scheduler.pickSingle() error = %v", errPick)
|
||||
}
|
||||
if got == nil || got.ID != "auth-a" {
|
||||
t.Fatalf("scheduler.pickSingle() auth = %v, want auth-a", got)
|
||||
}
|
||||
|
||||
if _, errUpdate := manager.Update(context.Background(), &Auth{ID: "auth-a", Provider: "gemini", Disabled: true}); errUpdate != nil {
|
||||
t.Fatalf("Update(auth-a) error = %v", errUpdate)
|
||||
}
|
||||
|
||||
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("scheduler.pickSingle() after update error = %v", errPick)
|
||||
}
|
||||
if got == nil || got.ID != "auth-b" {
|
||||
t.Fatalf("scheduler.pickSingle() after update auth = %v, want auth-b", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
manager.executors["gemini"] = schedulerTestExecutor{}
|
||||
manager.executors["claude"] = schedulerTestExecutor{}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(gemini-a) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(gemini-b) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
|
||||
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||
}
|
||||
|
||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||
for index := range wantProviders {
|
||||
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickNextMixed() #%d auth = nil", index)
|
||||
}
|
||||
if provider != wantProviders[index] {
|
||||
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
|
||||
}
|
||||
if got.ID != wantIDs[index] {
|
||||
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PickNextMixed_SkipsProvidersWithoutExecutors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
manager.executors["claude"] = schedulerTestExecutor{}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(gemini-a) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
|
||||
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||
}
|
||||
|
||||
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickNextMixed() error = %v", errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickNextMixed() auth = nil")
|
||||
}
|
||||
if provider != "claude" {
|
||||
t.Fatalf("pickNextMixed() provider = %q, want %q", provider, "claude")
|
||||
}
|
||||
if got.ID != "claude-a" {
|
||||
t.Fatalf("pickNextMixed() auth.ID = %q, want %q", got.ID, "claude-a")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_SchedulerTracksMarkResultCooldownAndRecovery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("auth-a", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
|
||||
reg.RegisterClient("auth-b", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient("auth-a")
|
||||
reg.UnregisterClient("auth-b")
|
||||
})
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(auth-a) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(auth-b) error = %v", errRegister)
|
||||
}
|
||||
|
||||
manager.MarkResult(context.Background(), Result{
|
||||
AuthID: "auth-a",
|
||||
Provider: "gemini",
|
||||
Model: "test-model",
|
||||
Success: false,
|
||||
Error: &Error{HTTPStatus: 429, Message: "quota"},
|
||||
})
|
||||
|
||||
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("scheduler.pickSingle() after cooldown error = %v", errPick)
|
||||
}
|
||||
if got == nil || got.ID != "auth-b" {
|
||||
t.Fatalf("scheduler.pickSingle() after cooldown auth = %v, want auth-b", got)
|
||||
}
|
||||
|
||||
manager.MarkResult(context.Background(), Result{
|
||||
AuthID: "auth-a",
|
||||
Provider: "gemini",
|
||||
Model: "test-model",
|
||||
Success: true,
|
||||
})
|
||||
|
||||
seen := make(map[string]struct{}, 2)
|
||||
for index := 0; index < 2; index++ {
|
||||
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("scheduler.pickSingle() after recovery #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("scheduler.pickSingle() after recovery #%d auth = nil", index)
|
||||
}
|
||||
seen[got.ID] = struct{}{}
|
||||
}
|
||||
if len(seen) != 2 {
|
||||
t.Fatalf("len(seen) = %d, want %d", len(seen), 2)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user