diff --git a/config.example.yaml b/config.example.yaml index 40fa8e95..057acf0d 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -71,6 +71,10 @@ quota-exceeded: switch-project: true # Whether to automatically switch to another project when a quota is exceeded switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded +# Routing strategy for selecting credentials when multiple match. +routing: + strategy: "round-robin" # round-robin (default), fill-first + # When true, enable authentication for the WebSocket API (/v1/ws). ws-auth: false diff --git a/internal/config/config.go b/internal/config/config.go index 43be4910..be68dcb9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -60,6 +60,9 @@ type Config struct { // QuotaExceeded defines the behavior when a quota is exceeded. QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"` + // Routing controls credential selection behavior. + Routing RoutingConfig `yaml:"routing" json:"routing"` + // WebsocketAuth enables or disables authentication for the WebSocket API. WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"` @@ -136,6 +139,13 @@ type QuotaExceeded struct { SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"` } +// RoutingConfig configures how credentials are selected for requests. +type RoutingConfig struct { + // Strategy selects the credential selection strategy. + // Supported values: "round-robin" (default), "fill-first". + Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"` +} + // AmpModelMapping defines a model name mapping for Amp CLI requests. // When Amp requests a model that isn't available locally, this mapping // allows routing to an alternative model that IS available. diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index 338ad566..b3876de6 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -135,6 +135,18 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { } } +func (m *Manager) SetSelector(selector Selector) { + if m == nil { + return + } + if selector == nil { + selector = &RoundRobinSelector{} + } + m.mu.Lock() + m.selector = selector + m.mu.Unlock() +} + // SetStore swaps the underlying persistence store. func (m *Manager) SetStore(store Store) { m.mu.Lock() diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index d4edc8bd..d7e120c5 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -20,6 +20,11 @@ type RoundRobinSelector struct { cursors map[string]int } +// FillFirstSelector selects the first available credential (deterministic ordering). +// This "burns" one account before moving to the next, which can help stagger +// rolling-window subscription caps (e.g. chat message limits). +type FillFirstSelector struct{} + type blockReason int const ( @@ -98,20 +103,8 @@ func (e *modelCooldownError) Headers() http.Header { return headers } -// Pick selects the next available auth for the provider in a round-robin manner. -func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { - _ = ctx - _ = opts - if len(auths) == 0 { - return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"} - } - if s.cursors == nil { - s.cursors = make(map[string]int) - } - available := make([]*Auth, 0, len(auths)) - now := time.Now() - cooldownCount := 0 - var earliest time.Time +func collectAvailable(auths []*Auth, model string, now time.Time) (available []*Auth, cooldownCount int, earliest time.Time) { + available = make([]*Auth, 0, len(auths)) for i := 0; i < len(auths); i++ { candidate := auths[i] blocked, reason, next := isAuthBlockedForModel(candidate, model, now) @@ -126,6 +119,18 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o } } } + if len(available) > 1 { + sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID }) + } + return available, cooldownCount, earliest +} + +func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]*Auth, error) { + if len(auths) == 0 { + return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"} + } + + available, cooldownCount, earliest := collectAvailable(auths, model, now) if len(available) == 0 { if cooldownCount == len(auths) && !earliest.IsZero() { resetIn := earliest.Sub(now) @@ -136,12 +141,24 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o } return nil, &Error{Code: "auth_unavailable", Message: "no auth available"} } - // Make round-robin deterministic even if caller's candidate order is unstable. - if len(available) > 1 { - sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID }) + + return available, nil +} + +// Pick selects the next available auth for the provider in a round-robin manner. +func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { + _ = ctx + _ = opts + now := time.Now() + available, err := getAvailableAuths(auths, provider, model, now) + if err != nil { + return nil, err } key := provider + ":" + model s.mu.Lock() + if s.cursors == nil { + s.cursors = make(map[string]int) + } index := s.cursors[key] if index >= 2_147_483_640 { @@ -154,6 +171,18 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o return available[index%len(available)], nil } +// Pick selects the first available auth for the provider in a deterministic manner. +func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { + _ = ctx + _ = opts + now := time.Now() + available, err := getAvailableAuths(auths, provider, model, now) + if err != nil { + return nil, err + } + return available[0], nil +} + func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, blockReason, time.Time) { if auth == nil { return true, blockReasonOther, time.Time{} diff --git a/sdk/cliproxy/auth/selector_test.go b/sdk/cliproxy/auth/selector_test.go new file mode 100644 index 00000000..f4beed03 --- /dev/null +++ b/sdk/cliproxy/auth/selector_test.go @@ -0,0 +1,113 @@ +package auth + +import ( + "context" + "errors" + "sync" + "testing" + + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +func TestFillFirstSelectorPick_Deterministic(t *testing.T) { + t.Parallel() + + selector := &FillFirstSelector{} + auths := []*Auth{ + {ID: "b"}, + {ID: "a"}, + {ID: "c"}, + } + + got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() error = %v", err) + } + if got == nil { + t.Fatalf("Pick() auth = nil") + } + if got.ID != "a" { + t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "a") + } +} + +func TestRoundRobinSelectorPick_CyclesDeterministic(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{} + auths := []*Auth{ + {ID: "b"}, + {ID: "a"}, + {ID: "c"}, + } + + want := []string{"a", "b", "c", "a", "b"} + for i, id := range want { + got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got == nil { + t.Fatalf("Pick() #%d auth = nil", i) + } + if got.ID != id { + t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id) + } + } +} + +func TestRoundRobinSelectorPick_Concurrent(t *testing.T) { + selector := &RoundRobinSelector{} + auths := []*Auth{ + {ID: "b"}, + {ID: "a"}, + {ID: "c"}, + } + + start := make(chan struct{}) + var wg sync.WaitGroup + errCh := make(chan error, 1) + + goroutines := 32 + iterations := 100 + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + for j := 0; j < iterations; j++ { + got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths) + if err != nil { + select { + case errCh <- err: + default: + } + return + } + if got == nil { + select { + case errCh <- errors.New("Pick() returned nil auth"): + default: + } + return + } + if got.ID == "" { + select { + case errCh <- errors.New("Pick() returned auth with empty ID"): + default: + } + return + } + } + }() + } + + close(start) + wg.Wait() + + select { + case err := <-errCh: + t.Fatalf("concurrent Pick() error = %v", err) + default: + } +} diff --git a/sdk/cliproxy/builder.go b/sdk/cliproxy/builder.go index a85e91d9..381a0926 100644 --- a/sdk/cliproxy/builder.go +++ b/sdk/cliproxy/builder.go @@ -5,6 +5,7 @@ package cliproxy import ( "fmt" + "strings" "github.com/router-for-me/CLIProxyAPI/v6/internal/api" sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access" @@ -197,7 +198,20 @@ func (b *Builder) Build() (*Service, error) { if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok && b.cfg != nil { dirSetter.SetBaseDir(b.cfg.AuthDir) } - coreManager = coreauth.NewManager(tokenStore, nil, nil) + + strategy := "" + if b.cfg != nil { + strategy = strings.ToLower(strings.TrimSpace(b.cfg.Routing.Strategy)) + } + var selector coreauth.Selector + switch strategy { + case "fill-first", "fillfirst", "ff": + selector = &coreauth.FillFirstSelector{} + default: + selector = &coreauth.RoundRobinSelector{} + } + + coreManager = coreauth.NewManager(tokenStore, selector, nil) } // Attach a default RoundTripper provider so providers can opt-in per-auth transports. coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider()) diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 6afacb00..1f568010 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -510,6 +510,13 @@ func (s *Service) Run(ctx context.Context) error { var watcherWrapper *WatcherWrapper reloadCallback := func(newCfg *config.Config) { + previousStrategy := "" + s.cfgMu.RLock() + if s.cfg != nil { + previousStrategy = strings.ToLower(strings.TrimSpace(s.cfg.Routing.Strategy)) + } + s.cfgMu.RUnlock() + if newCfg == nil { s.cfgMu.RLock() newCfg = s.cfg @@ -518,6 +525,30 @@ func (s *Service) Run(ctx context.Context) error { if newCfg == nil { return } + + nextStrategy := strings.ToLower(strings.TrimSpace(newCfg.Routing.Strategy)) + normalizeStrategy := func(strategy string) string { + switch strategy { + case "fill-first", "fillfirst", "ff": + return "fill-first" + default: + return "round-robin" + } + } + previousStrategy = normalizeStrategy(previousStrategy) + nextStrategy = normalizeStrategy(nextStrategy) + if s.coreManager != nil && previousStrategy != nextStrategy { + var selector coreauth.Selector + switch nextStrategy { + case "fill-first": + selector = &coreauth.FillFirstSelector{} + default: + selector = &coreauth.RoundRobinSelector{} + } + s.coreManager.SetSelector(selector) + log.Infof("routing strategy updated to %s", nextStrategy) + } + s.applyRetryConfig(newCfg) if s.server != nil { s.server.UpdateClients(newCfg)