From 70897247b25b75658193445d1485276c0145af14 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 26 Jan 2026 21:59:08 +0800 Subject: [PATCH] feat(auth): add support for request_retry and disable_cooling overrides Implement `request_retry` and `disable_cooling` metadata overrides for authentication management. Update retry and cooling logic accordingly across `Manager`, Antigravity executor, and file synthesizer. Add tests to validate new behaviors. --- .../runtime/executor/antigravity_executor.go | 19 ++-- internal/watcher/synthesizer/file.go | 10 ++ internal/watcher/synthesizer/file_test.go | 30 ++++-- sdk/cliproxy/auth/conductor.go | 73 +++++++------ sdk/cliproxy/auth/conductor_overrides_test.go | 97 +++++++++++++++++ sdk/cliproxy/auth/types.go | 102 ++++++++++++++++++ 6 files changed, 286 insertions(+), 45 deletions(-) create mode 100644 sdk/cliproxy/auth/conductor_overrides_test.go diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index a4156302..64d19951 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -148,7 +148,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) - attempts := antigravityRetryAttempts(e.cfg) + attempts := antigravityRetryAttempts(auth, e.cfg) attemptLoop: for attempt := 0; attempt < attempts; attempt++ { @@ -289,7 +289,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) - attempts := antigravityRetryAttempts(e.cfg) + attempts := antigravityRetryAttempts(auth, e.cfg) attemptLoop: for attempt := 0; attempt < attempts; attempt++ { @@ -677,7 +677,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) - attempts := antigravityRetryAttempts(e.cfg) + attempts := antigravityRetryAttempts(auth, e.cfg) attemptLoop: for attempt := 0; attempt < attempts; attempt++ { @@ -1447,11 +1447,16 @@ func resolveUserAgent(auth *cliproxyauth.Auth) string { return defaultAntigravityAgent } -func antigravityRetryAttempts(cfg *config.Config) int { - if cfg == nil { - return 1 +func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int { + retry := 0 + if cfg != nil { + retry = cfg.RequestRetry + } + if auth != nil { + if override, ok := auth.RequestRetryOverride(); ok { + retry = override + } } - retry := cfg.RequestRetry if retry < 0 { retry = 0 } diff --git a/internal/watcher/synthesizer/file.go b/internal/watcher/synthesizer/file.go index 190d310a..ef0eb8c9 100644 --- a/internal/watcher/synthesizer/file.go +++ b/internal/watcher/synthesizer/file.go @@ -167,6 +167,16 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an "virtual_parent_id": primary.ID, "type": metadata["type"], } + if v, ok := metadata["disable_cooling"]; ok { + metadataCopy["disable_cooling"] = v + } else if v, ok := metadata["disable-cooling"]; ok { + metadataCopy["disable_cooling"] = v + } + if v, ok := metadata["request_retry"]; ok { + metadataCopy["request_retry"] = v + } else if v, ok := metadata["request-retry"]; ok { + metadataCopy["request_retry"] = v + } proxy := strings.TrimSpace(primary.ProxyURL) if proxy != "" { metadataCopy["proxy_url"] = proxy diff --git a/internal/watcher/synthesizer/file_test.go b/internal/watcher/synthesizer/file_test.go index 2e9d5f07..93025fba 100644 --- a/internal/watcher/synthesizer/file_test.go +++ b/internal/watcher/synthesizer/file_test.go @@ -69,10 +69,12 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) { // Create a valid auth file authData := map[string]any{ - "type": "claude", - "email": "test@example.com", - "proxy_url": "http://proxy.local", - "prefix": "test-prefix", + "type": "claude", + "email": "test@example.com", + "proxy_url": "http://proxy.local", + "prefix": "test-prefix", + "disable_cooling": true, + "request_retry": 2, } data, _ := json.Marshal(authData) err := os.WriteFile(filepath.Join(tempDir, "claude-auth.json"), data, 0644) @@ -108,6 +110,12 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) { if auths[0].ProxyURL != "http://proxy.local" { t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL) } + if v, ok := auths[0].Metadata["disable_cooling"].(bool); !ok || !v { + t.Errorf("expected disable_cooling true, got %v", auths[0].Metadata["disable_cooling"]) + } + if v, ok := auths[0].Metadata["request_retry"].(float64); !ok || int(v) != 2 { + t.Errorf("expected request_retry 2, got %v", auths[0].Metadata["request_retry"]) + } if auths[0].Status != coreauth.StatusActive { t.Errorf("expected status active, got %s", auths[0].Status) } @@ -336,9 +344,11 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) { }, } metadata := map[string]any{ - "project_id": "project-a, project-b, project-c", - "email": "test@example.com", - "type": "gemini", + "project_id": "project-a, project-b, project-c", + "email": "test@example.com", + "type": "gemini", + "request_retry": 2, + "disable_cooling": true, } virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now) @@ -376,6 +386,12 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) { if v.ProxyURL != "http://proxy.local" { t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL) } + if vv, ok := v.Metadata["disable_cooling"].(bool); !ok || !vv { + t.Errorf("expected disable_cooling true, got %v", v.Metadata["disable_cooling"]) + } + if vv, ok := v.Metadata["request_retry"].(int); !ok || vv != 2 { + t.Errorf("expected request_retry 2, got %v", v.Metadata["request_retry"]) + } if v.Attributes["runtime_only"] != "true" { t.Error("expected runtime_only=true") } diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 2154dc1f..fd7543b4 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -61,6 +61,15 @@ func SetQuotaCooldownDisabled(disable bool) { quotaCooldownDisabled.Store(disable) } +func quotaCooldownDisabledForAuth(auth *Auth) bool { + if auth != nil { + if override, ok := auth.DisableCoolingOverride(); ok { + return override + } + } + return quotaCooldownDisabled.Load() +} + // Result captures execution outcome used to adjust auth state. type Result struct { // AuthID references the auth that produced this result. @@ -468,20 +477,16 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} } - retryTimes, maxWait := m.retrySettings() - attempts := retryTimes + 1 - if attempts < 1 { - attempts = 1 - } + _, maxWait := m.retrySettings() var lastErr error - for attempt := 0; attempt < attempts; attempt++ { + for attempt := 0; ; attempt++ { resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts) if errExec == nil { return resp, nil } lastErr = errExec - wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait) + wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait) if !shouldRetry { break } @@ -503,20 +508,16 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} } - retryTimes, maxWait := m.retrySettings() - attempts := retryTimes + 1 - if attempts < 1 { - attempts = 1 - } + _, maxWait := m.retrySettings() var lastErr error - for attempt := 0; attempt < attempts; attempt++ { + for attempt := 0; ; attempt++ { resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts) if errExec == nil { return resp, nil } lastErr = errExec - wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait) + wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait) if !shouldRetry { break } @@ -538,20 +539,16 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} } - retryTimes, maxWait := m.retrySettings() - attempts := retryTimes + 1 - if attempts < 1 { - attempts = 1 - } + _, maxWait := m.retrySettings() var lastErr error - for attempt := 0; attempt < attempts; attempt++ { + for attempt := 0; ; attempt++ { chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts) if errStream == nil { return chunks, nil } lastErr = errStream - wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, normalized, req.Model, maxWait) + wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait) if !shouldRetry { break } @@ -1034,11 +1031,15 @@ func (m *Manager) retrySettings() (int, time.Duration) { return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load()) } -func (m *Manager) closestCooldownWait(providers []string, model string) (time.Duration, bool) { +func (m *Manager) closestCooldownWait(providers []string, model string, attempt int) (time.Duration, bool) { if m == nil || len(providers) == 0 { return 0, false } now := time.Now() + defaultRetry := int(m.requestRetry.Load()) + if defaultRetry < 0 { + defaultRetry = 0 + } providerSet := make(map[string]struct{}, len(providers)) for i := range providers { key := strings.TrimSpace(strings.ToLower(providers[i])) @@ -1061,6 +1062,16 @@ func (m *Manager) closestCooldownWait(providers []string, model string) (time.Du if _, ok := providerSet[providerKey]; !ok { continue } + effectiveRetry := defaultRetry + if override, ok := auth.RequestRetryOverride(); ok { + effectiveRetry = override + } + if effectiveRetry < 0 { + effectiveRetry = 0 + } + if attempt >= effectiveRetry { + continue + } blocked, reason, next := isAuthBlockedForModel(auth, model, now) if !blocked || next.IsZero() || reason == blockReasonDisabled { continue @@ -1077,8 +1088,8 @@ func (m *Manager) closestCooldownWait(providers []string, model string) (time.Du return minWait, found } -func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) { - if err == nil || attempt >= maxAttempts-1 { +func (m *Manager) shouldRetryAfterError(err error, attempt int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) { + if err == nil { return 0, false } if maxWait <= 0 { @@ -1087,7 +1098,7 @@ func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, pro if status := statusCodeFromError(err); status == http.StatusOK { return 0, false } - wait, found := m.closestCooldownWait(providers, model) + wait, found := m.closestCooldownWait(providers, model, attempt) if !found || wait > maxWait { return 0, false } @@ -1176,7 +1187,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { if result.RetryAfter != nil { next = now.Add(*result.RetryAfter) } else { - cooldown, nextLevel := nextQuotaCooldown(backoffLevel) + cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth)) if cooldown > 0 { next = now.Add(cooldown) } @@ -1193,7 +1204,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { shouldSuspendModel = true setModelQuota = true case 408, 500, 502, 503, 504: - if quotaCooldownDisabled.Load() { + if quotaCooldownDisabledForAuth(auth) { state.NextRetryAfter = time.Time{} } else { next := now.Add(1 * time.Minute) @@ -1439,7 +1450,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati if retryAfter != nil { next = now.Add(*retryAfter) } else { - cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel) + cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, quotaCooldownDisabledForAuth(auth)) if cooldown > 0 { next = now.Add(cooldown) } @@ -1449,7 +1460,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati auth.NextRetryAfter = next case 408, 500, 502, 503, 504: auth.StatusMessage = "transient upstream error" - if quotaCooldownDisabled.Load() { + if quotaCooldownDisabledForAuth(auth) { auth.NextRetryAfter = time.Time{} } else { auth.NextRetryAfter = now.Add(1 * time.Minute) @@ -1462,11 +1473,11 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati } // nextQuotaCooldown returns the next cooldown duration and updated backoff level for repeated quota errors. -func nextQuotaCooldown(prevLevel int) (time.Duration, int) { +func nextQuotaCooldown(prevLevel int, disableCooling bool) (time.Duration, int) { if prevLevel < 0 { prevLevel = 0 } - if quotaCooldownDisabled.Load() { + if disableCooling { return 0, prevLevel } cooldown := quotaBackoffBase * time.Duration(1< 0, got %v", wait) + } + + _, shouldRetry = m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 1, []string{"claude"}, model, maxWait) + if shouldRetry { + t.Fatalf("expected shouldRetry=false on attempt=1 for request_retry=1, got true") + } +} + +func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) { + prev := quotaCooldownDisabled.Load() + quotaCooldownDisabled.Store(false) + t.Cleanup(func() { quotaCooldownDisabled.Store(prev) }) + + m := NewManager(nil, nil, nil) + + auth := &Auth{ + ID: "auth-1", + Provider: "claude", + Metadata: map[string]any{ + "disable_cooling": true, + }, + } + if _, errRegister := m.Register(context.Background(), auth); errRegister != nil { + t.Fatalf("register auth: %v", errRegister) + } + + model := "test-model" + m.MarkResult(context.Background(), Result{ + AuthID: "auth-1", + Provider: "claude", + Model: model, + Success: false, + Error: &Error{HTTPStatus: 500, Message: "boom"}, + }) + + updated, ok := m.GetByID("auth-1") + if !ok || updated == nil { + t.Fatalf("expected auth to be present") + } + state := updated.ModelStates[model] + if state == nil { + t.Fatalf("expected model state to be present") + } + if !state.NextRetryAfter.IsZero() { + t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter) + } +} diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go index 4c69ae90..b2bbe0a2 100644 --- a/sdk/cliproxy/auth/types.go +++ b/sdk/cliproxy/auth/types.go @@ -194,6 +194,108 @@ func (a *Auth) ProxyInfo() string { return "via proxy" } +// DisableCoolingOverride returns the auth-file scoped disable_cooling override when present. +// The value is read from metadata key "disable_cooling" (or legacy "disable-cooling"). +func (a *Auth) DisableCoolingOverride() (bool, bool) { + if a == nil || a.Metadata == nil { + return false, false + } + if val, ok := a.Metadata["disable_cooling"]; ok { + if parsed, okParse := parseBoolAny(val); okParse { + return parsed, true + } + } + if val, ok := a.Metadata["disable-cooling"]; ok { + if parsed, okParse := parseBoolAny(val); okParse { + return parsed, true + } + } + return false, false +} + +// RequestRetryOverride returns the auth-file scoped request_retry override when present. +// The value is read from metadata key "request_retry" (or legacy "request-retry"). +func (a *Auth) RequestRetryOverride() (int, bool) { + if a == nil || a.Metadata == nil { + return 0, false + } + if val, ok := a.Metadata["request_retry"]; ok { + if parsed, okParse := parseIntAny(val); okParse { + if parsed < 0 { + parsed = 0 + } + return parsed, true + } + } + if val, ok := a.Metadata["request-retry"]; ok { + if parsed, okParse := parseIntAny(val); okParse { + if parsed < 0 { + parsed = 0 + } + return parsed, true + } + } + return 0, false +} + +func parseBoolAny(val any) (bool, bool) { + switch typed := val.(type) { + case bool: + return typed, true + case string: + trimmed := strings.TrimSpace(typed) + if trimmed == "" { + return false, false + } + parsed, err := strconv.ParseBool(trimmed) + if err != nil { + return false, false + } + return parsed, true + case float64: + return typed != 0, true + case json.Number: + parsed, err := typed.Int64() + if err != nil { + return false, false + } + return parsed != 0, true + default: + return false, false + } +} + +func parseIntAny(val any) (int, bool) { + switch typed := val.(type) { + case int: + return typed, true + case int32: + return int(typed), true + case int64: + return int(typed), true + case float64: + return int(typed), true + case json.Number: + parsed, err := typed.Int64() + if err != nil { + return 0, false + } + return int(parsed), true + case string: + trimmed := strings.TrimSpace(typed) + if trimmed == "" { + return 0, false + } + parsed, err := strconv.Atoi(trimmed) + if err != nil { + return 0, false + } + return parsed, true + default: + return 0, false + } +} + func (a *Auth) AccountInfo() (string, string) { if a == nil { return "", ""