Merge branch 'router-for-me:main' into main

This commit is contained in:
Luis Pater
2026-01-26 23:35:07 +08:00
committed by GitHub
10 changed files with 378 additions and 85 deletions

View File

@@ -61,6 +61,15 @@ func SetQuotaCooldownDisabled(disable bool) {
quotaCooldownDisabled.Store(disable)
}
func quotaCooldownDisabledForAuth(auth *Auth) bool {
if auth != nil {
if override, ok := auth.DisableCoolingOverride(); ok {
return override
}
}
return quotaCooldownDisabled.Load()
}
// Result captures execution outcome used to adjust auth state.
type Result struct {
// AuthID references the auth that produced this result.
@@ -468,20 +477,16 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
if attempts < 1 {
attempts = 1
}
_, maxWait := m.retrySettings()
var lastErr error
for attempt := 0; attempt < attempts; attempt++ {
for attempt := 0; ; attempt++ {
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
if errExec == nil {
return resp, nil
}
lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait)
if !shouldRetry {
break
}
@@ -503,20 +508,16 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
if attempts < 1 {
attempts = 1
}
_, maxWait := m.retrySettings()
var lastErr error
for attempt := 0; attempt < attempts; attempt++ {
for attempt := 0; ; attempt++ {
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
if errExec == nil {
return resp, nil
}
lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait)
if !shouldRetry {
break
}
@@ -538,20 +539,16 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
if attempts < 1 {
attempts = 1
}
_, maxWait := m.retrySettings()
var lastErr error
for attempt := 0; attempt < attempts; attempt++ {
for attempt := 0; ; attempt++ {
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
if errStream == nil {
return chunks, nil
}
lastErr = errStream
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, normalized, req.Model, maxWait)
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait)
if !shouldRetry {
break
}
@@ -1034,11 +1031,15 @@ func (m *Manager) retrySettings() (int, time.Duration) {
return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load())
}
func (m *Manager) closestCooldownWait(providers []string, model string) (time.Duration, bool) {
func (m *Manager) closestCooldownWait(providers []string, model string, attempt int) (time.Duration, bool) {
if m == nil || len(providers) == 0 {
return 0, false
}
now := time.Now()
defaultRetry := int(m.requestRetry.Load())
if defaultRetry < 0 {
defaultRetry = 0
}
providerSet := make(map[string]struct{}, len(providers))
for i := range providers {
key := strings.TrimSpace(strings.ToLower(providers[i]))
@@ -1061,6 +1062,16 @@ func (m *Manager) closestCooldownWait(providers []string, model string) (time.Du
if _, ok := providerSet[providerKey]; !ok {
continue
}
effectiveRetry := defaultRetry
if override, ok := auth.RequestRetryOverride(); ok {
effectiveRetry = override
}
if effectiveRetry < 0 {
effectiveRetry = 0
}
if attempt >= effectiveRetry {
continue
}
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
if !blocked || next.IsZero() || reason == blockReasonDisabled {
continue
@@ -1077,8 +1088,8 @@ func (m *Manager) closestCooldownWait(providers []string, model string) (time.Du
return minWait, found
}
func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) {
if err == nil || attempt >= maxAttempts-1 {
func (m *Manager) shouldRetryAfterError(err error, attempt int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) {
if err == nil {
return 0, false
}
if maxWait <= 0 {
@@ -1087,7 +1098,7 @@ func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, pro
if status := statusCodeFromError(err); status == http.StatusOK {
return 0, false
}
wait, found := m.closestCooldownWait(providers, model)
wait, found := m.closestCooldownWait(providers, model, attempt)
if !found || wait > maxWait {
return 0, false
}
@@ -1176,7 +1187,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
if result.RetryAfter != nil {
next = now.Add(*result.RetryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(backoffLevel)
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
}
@@ -1193,7 +1204,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
shouldSuspendModel = true
setModelQuota = true
case 408, 500, 502, 503, 504:
if quotaCooldownDisabled.Load() {
if quotaCooldownDisabledForAuth(auth) {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(1 * time.Minute)
@@ -1439,7 +1450,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
if retryAfter != nil {
next = now.Add(*retryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel)
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
}
@@ -1449,7 +1460,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
auth.NextRetryAfter = next
case 408, 500, 502, 503, 504:
auth.StatusMessage = "transient upstream error"
if quotaCooldownDisabled.Load() {
if quotaCooldownDisabledForAuth(auth) {
auth.NextRetryAfter = time.Time{}
} else {
auth.NextRetryAfter = now.Add(1 * time.Minute)
@@ -1462,11 +1473,11 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
}
// nextQuotaCooldown returns the next cooldown duration and updated backoff level for repeated quota errors.
func nextQuotaCooldown(prevLevel int) (time.Duration, int) {
func nextQuotaCooldown(prevLevel int, disableCooling bool) (time.Duration, int) {
if prevLevel < 0 {
prevLevel = 0
}
if quotaCooldownDisabled.Load() {
if disableCooling {
return 0, prevLevel
}
cooldown := quotaBackoffBase * time.Duration(1<<prevLevel)
@@ -1642,6 +1653,9 @@ func (m *Manager) persist(ctx context.Context, auth *Auth) error {
if m.store == nil || auth == nil {
return nil
}
if shouldSkipPersist(ctx) {
return nil
}
if auth.Attributes != nil {
if v := strings.ToLower(strings.TrimSpace(auth.Attributes["runtime_only"])); v == "true" {
return nil

View File

@@ -0,0 +1,97 @@
package auth
import (
"context"
"testing"
"time"
)
func TestManager_ShouldRetryAfterError_RespectsAuthRequestRetryOverride(t *testing.T) {
m := NewManager(nil, nil, nil)
m.SetRetryConfig(3, 30*time.Second)
model := "test-model"
next := time.Now().Add(5 * time.Second)
auth := &Auth{
ID: "auth-1",
Provider: "claude",
Metadata: map[string]any{
"request_retry": float64(0),
},
ModelStates: map[string]*ModelState{
model: {
Unavailable: true,
Status: StatusError,
NextRetryAfter: next,
},
},
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
_, maxWait := m.retrySettings()
wait, shouldRetry := m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 0, []string{"claude"}, model, maxWait)
if shouldRetry {
t.Fatalf("expected shouldRetry=false for request_retry=0, got true (wait=%v)", wait)
}
auth.Metadata["request_retry"] = float64(1)
if _, errUpdate := m.Update(context.Background(), auth); errUpdate != nil {
t.Fatalf("update auth: %v", errUpdate)
}
wait, shouldRetry = m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 0, []string{"claude"}, model, maxWait)
if !shouldRetry {
t.Fatalf("expected shouldRetry=true for request_retry=1, got false")
}
if wait <= 0 {
t.Fatalf("expected wait > 0, got %v", wait)
}
_, shouldRetry = m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 1, []string{"claude"}, model, maxWait)
if shouldRetry {
t.Fatalf("expected shouldRetry=false on attempt=1 for request_retry=1, got true")
}
}
func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
prev := quotaCooldownDisabled.Load()
quotaCooldownDisabled.Store(false)
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
m := NewManager(nil, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "claude",
Metadata: map[string]any{
"disable_cooling": true,
},
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
model := "test-model"
m.MarkResult(context.Background(), Result{
AuthID: "auth-1",
Provider: "claude",
Model: model,
Success: false,
Error: &Error{HTTPStatus: 500, Message: "boom"},
})
updated, ok := m.GetByID("auth-1")
if !ok || updated == nil {
t.Fatalf("expected auth to be present")
}
state := updated.ModelStates[model]
if state == nil {
t.Fatalf("expected model state to be present")
}
if !state.NextRetryAfter.IsZero() {
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
}
}

View File

@@ -0,0 +1,24 @@
package auth
import "context"
type skipPersistContextKey struct{}
// WithSkipPersist returns a derived context that disables persistence for Manager Update/Register calls.
// It is intended for code paths that are reacting to file watcher events, where the file on disk is
// already the source of truth and persisting again would create a write-back loop.
func WithSkipPersist(ctx context.Context) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, skipPersistContextKey{}, true)
}
func shouldSkipPersist(ctx context.Context) bool {
if ctx == nil {
return false
}
v := ctx.Value(skipPersistContextKey{})
enabled, ok := v.(bool)
return ok && enabled
}

View File

@@ -0,0 +1,62 @@
package auth
import (
"context"
"sync/atomic"
"testing"
)
type countingStore struct {
saveCount atomic.Int32
}
func (s *countingStore) List(context.Context) ([]*Auth, error) { return nil, nil }
func (s *countingStore) Save(context.Context, *Auth) (string, error) {
s.saveCount.Add(1)
return "", nil
}
func (s *countingStore) Delete(context.Context, string) error { return nil }
func TestWithSkipPersist_DisablesUpdatePersistence(t *testing.T) {
store := &countingStore{}
mgr := NewManager(store, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "antigravity",
Metadata: map[string]any{"type": "antigravity"},
}
if _, err := mgr.Update(context.Background(), auth); err != nil {
t.Fatalf("Update returned error: %v", err)
}
if got := store.saveCount.Load(); got != 1 {
t.Fatalf("expected 1 Save call, got %d", got)
}
ctxSkip := WithSkipPersist(context.Background())
if _, err := mgr.Update(ctxSkip, auth); err != nil {
t.Fatalf("Update(skipPersist) returned error: %v", err)
}
if got := store.saveCount.Load(); got != 1 {
t.Fatalf("expected Save call count to remain 1, got %d", got)
}
}
func TestWithSkipPersist_DisablesRegisterPersistence(t *testing.T) {
store := &countingStore{}
mgr := NewManager(store, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "antigravity",
Metadata: map[string]any{"type": "antigravity"},
}
if _, err := mgr.Register(WithSkipPersist(context.Background()), auth); err != nil {
t.Fatalf("Register(skipPersist) returned error: %v", err)
}
if got := store.saveCount.Load(); got != 0 {
t.Fatalf("expected 0 Save calls, got %d", got)
}
}

View File

@@ -194,6 +194,108 @@ func (a *Auth) ProxyInfo() string {
return "via proxy"
}
// DisableCoolingOverride returns the auth-file scoped disable_cooling override when present.
// The value is read from metadata key "disable_cooling" (or legacy "disable-cooling").
func (a *Auth) DisableCoolingOverride() (bool, bool) {
if a == nil || a.Metadata == nil {
return false, false
}
if val, ok := a.Metadata["disable_cooling"]; ok {
if parsed, okParse := parseBoolAny(val); okParse {
return parsed, true
}
}
if val, ok := a.Metadata["disable-cooling"]; ok {
if parsed, okParse := parseBoolAny(val); okParse {
return parsed, true
}
}
return false, false
}
// RequestRetryOverride returns the auth-file scoped request_retry override when present.
// The value is read from metadata key "request_retry" (or legacy "request-retry").
func (a *Auth) RequestRetryOverride() (int, bool) {
if a == nil || a.Metadata == nil {
return 0, false
}
if val, ok := a.Metadata["request_retry"]; ok {
if parsed, okParse := parseIntAny(val); okParse {
if parsed < 0 {
parsed = 0
}
return parsed, true
}
}
if val, ok := a.Metadata["request-retry"]; ok {
if parsed, okParse := parseIntAny(val); okParse {
if parsed < 0 {
parsed = 0
}
return parsed, true
}
}
return 0, false
}
func parseBoolAny(val any) (bool, bool) {
switch typed := val.(type) {
case bool:
return typed, true
case string:
trimmed := strings.TrimSpace(typed)
if trimmed == "" {
return false, false
}
parsed, err := strconv.ParseBool(trimmed)
if err != nil {
return false, false
}
return parsed, true
case float64:
return typed != 0, true
case json.Number:
parsed, err := typed.Int64()
if err != nil {
return false, false
}
return parsed != 0, true
default:
return false, false
}
}
func parseIntAny(val any) (int, bool) {
switch typed := val.(type) {
case int:
return typed, true
case int32:
return int(typed), true
case int64:
return int(typed), true
case float64:
return int(typed), true
case json.Number:
parsed, err := typed.Int64()
if err != nil {
return 0, false
}
return int(parsed), true
case string:
trimmed := strings.TrimSpace(typed)
if trimmed == "" {
return 0, false
}
parsed, err := strconv.Atoi(trimmed)
if err != nil {
return 0, false
}
return parsed, true
default:
return 0, false
}
}
func (a *Auth) AccountInfo() (string, string) {
if a == nil {
return "", ""

View File

@@ -135,6 +135,7 @@ func (s *Service) ensureAuthUpdateQueue(ctx context.Context) {
}
func (s *Service) consumeAuthUpdates(ctx context.Context) {
ctx = coreauth.WithSkipPersist(ctx)
for {
select {
case <-ctx.Done():