diff --git a/config.example.yaml b/config.example.yaml index e9bdafd5..9b9745ac 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -38,6 +38,9 @@ proxy-url: "" # Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504. request-retry: 3 +# Maximum wait time in seconds for a cooled-down credential before triggering a retry. +max-retry-interval: 30 + # Quota exceeded behavior quota-exceeded: switch-project: true # Whether to automatically switch to another project when a quota is exceeded diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index e6b18ea7..8f57171e 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -172,6 +172,14 @@ func (h *Handler) PutRequestRetry(c *gin.Context) { h.updateIntField(c, func(v int) { h.cfg.RequestRetry = v }) } +// Max retry interval +func (h *Handler) GetMaxRetryInterval(c *gin.Context) { + c.JSON(200, gin.H{"max-retry-interval": h.cfg.MaxRetryInterval}) +} +func (h *Handler) PutMaxRetryInterval(c *gin.Context) { + h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v }) +} + // Proxy URL func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) } func (h *Handler) PutProxyURL(c *gin.Context) { diff --git a/internal/api/server.go b/internal/api/server.go index 8e0de284..0583eee1 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -247,6 +247,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk // Save initial YAML snapshot s.oldConfigYaml, _ = yaml.Marshal(cfg) s.applyAccessConfig(nil, cfg) + if authManager != nil { + authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second) + } managementasset.SetCurrentConfig(cfg) auth.SetQuotaCooldownDisabled(cfg.DisableCooling) // Initialize management handler @@ -521,6 +524,9 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry) + mgmt.GET("/max-retry-interval", s.mgmt.GetMaxRetryInterval) + mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval) + mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval) mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys) mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys) @@ -816,6 +822,9 @@ func (s *Server) UpdateClients(cfg *config.Config) { log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling) } } + if s.handlers != nil && s.handlers.AuthManager != nil { + s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second) + } // Update log level dynamically when debug flag changes if oldCfg == nil || oldCfg.Debug != cfg.Debug { diff --git a/internal/config/config.go b/internal/config/config.go index ec97064e..726e585f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -63,6 +63,8 @@ type Config struct { // RequestRetry defines the retry times when the request failed. RequestRetry int `yaml:"request-retry" json:"request-retry"` + // MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential. + MaxRetryInterval int `yaml:"max-retry-interval" json:"max-retry-interval"` // ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file. ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"` diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 518ce332..0d955064 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -1419,6 +1419,9 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if oldCfg.RequestRetry != newCfg.RequestRetry { changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry)) } + if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval { + changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval)) + } if oldCfg.ProxyURL != newCfg.ProxyURL { changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL)) } diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index d60bf920..363189ed 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -106,6 +106,10 @@ type Manager struct { // providerOffsets tracks per-model provider rotation state for multi-provider routing. providerOffsets map[string]int + // Retry controls request retry behavior. + requestRetry atomic.Int32 + maxRetryInterval atomic.Int64 + // Optional HTTP RoundTripper provider injected by host. rtProvider RoundTripperProvider @@ -145,6 +149,21 @@ func (m *Manager) SetRoundTripperProvider(p RoundTripperProvider) { m.mu.Unlock() } +// SetRetryConfig updates retry attempts and cooldown wait interval. +func (m *Manager) SetRetryConfig(retry int, maxRetryInterval time.Duration) { + if m == nil { + return + } + if retry < 0 { + retry = 0 + } + if maxRetryInterval < 0 { + maxRetryInterval = 0 + } + m.requestRetry.Store(int32(retry)) + m.maxRetryInterval.Store(maxRetryInterval.Nanoseconds()) +} + // RegisterExecutor registers a provider executor with the manager. func (m *Manager) RegisterExecutor(executor ProviderExecutor) { if executor == nil { @@ -229,13 +248,28 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye rotated := m.rotateProviders(req.Model, normalized) defer m.advanceProviderCursor(req.Model, normalized) + retryTimes, maxWait := m.retrySettings() + attempts := retryTimes + 1 + if attempts < 1 { + attempts = 1 + } + var lastErr error - for _, provider := range rotated { - resp, errExec := m.executeWithProvider(ctx, provider, req, opts) + for attempt := 0; attempt < attempts; attempt++ { + resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) { + return m.executeWithProvider(execCtx, provider, req, opts) + }) if errExec == nil { return resp, nil } lastErr = errExec + wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait) + if !shouldRetry { + break + } + if errWait := waitForCooldown(ctx, wait); errWait != nil { + return cliproxyexecutor.Response{}, errWait + } } if lastErr != nil { return cliproxyexecutor.Response{}, lastErr @@ -253,13 +287,28 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip rotated := m.rotateProviders(req.Model, normalized) defer m.advanceProviderCursor(req.Model, normalized) + retryTimes, maxWait := m.retrySettings() + attempts := retryTimes + 1 + if attempts < 1 { + attempts = 1 + } + var lastErr error - for _, provider := range rotated { - resp, errExec := m.executeCountWithProvider(ctx, provider, req, opts) + for attempt := 0; attempt < attempts; attempt++ { + resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) { + return m.executeCountWithProvider(execCtx, provider, req, opts) + }) if errExec == nil { return resp, nil } lastErr = errExec + wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait) + if !shouldRetry { + break + } + if errWait := waitForCooldown(ctx, wait); errWait != nil { + return cliproxyexecutor.Response{}, errWait + } } if lastErr != nil { return cliproxyexecutor.Response{}, lastErr @@ -277,13 +326,28 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli rotated := m.rotateProviders(req.Model, normalized) defer m.advanceProviderCursor(req.Model, normalized) + retryTimes, maxWait := m.retrySettings() + attempts := retryTimes + 1 + if attempts < 1 { + attempts = 1 + } + var lastErr error - for _, provider := range rotated { - chunks, errStream := m.executeStreamWithProvider(ctx, provider, req, opts) + for attempt := 0; attempt < attempts; attempt++ { + chunks, errStream := m.executeStreamProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (<-chan cliproxyexecutor.StreamChunk, error) { + return m.executeStreamWithProvider(execCtx, provider, req, opts) + }) if errStream == nil { return chunks, nil } lastErr = errStream + wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, rotated, req.Model, maxWait) + if !shouldRetry { + break + } + if errWait := waitForCooldown(ctx, wait); errWait != nil { + return nil, errWait + } } if lastErr != nil { return nil, lastErr @@ -507,6 +571,123 @@ func (m *Manager) advanceProviderCursor(model string, providers []string) { m.mu.Unlock() } +func (m *Manager) retrySettings() (int, time.Duration) { + if m == nil { + return 0, 0 + } + return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load()) +} + +func (m *Manager) closestCooldownWait(providers []string, model string) (time.Duration, bool) { + if m == nil || len(providers) == 0 { + return 0, false + } + now := time.Now() + providerSet := make(map[string]struct{}, len(providers)) + for i := range providers { + key := strings.TrimSpace(strings.ToLower(providers[i])) + if key == "" { + continue + } + providerSet[key] = struct{}{} + } + m.mu.RLock() + defer m.mu.RUnlock() + var ( + found bool + minWait time.Duration + ) + for _, auth := range m.auths { + if auth == nil { + continue + } + providerKey := strings.TrimSpace(strings.ToLower(auth.Provider)) + if _, ok := providerSet[providerKey]; !ok { + continue + } + blocked, reason, next := isAuthBlockedForModel(auth, model, now) + if !blocked || next.IsZero() || reason == blockReasonDisabled { + continue + } + wait := next.Sub(now) + if wait < 0 { + continue + } + if !found || wait < minWait { + minWait = wait + found = true + } + } + return minWait, found +} + +func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) { + if err == nil || attempt >= maxAttempts-1 { + return 0, false + } + if maxWait <= 0 { + return 0, false + } + if status := statusCodeFromError(err); status == http.StatusOK { + return 0, false + } + wait, found := m.closestCooldownWait(providers, model) + if !found || wait > maxWait { + return 0, false + } + return wait, true +} + +func waitForCooldown(ctx context.Context, wait time.Duration) error { + if wait <= 0 { + return nil + } + timer := time.NewTimer(wait) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func (m *Manager) executeProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (cliproxyexecutor.Response, error)) (cliproxyexecutor.Response, error) { + if len(providers) == 0 { + return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} + } + var lastErr error + for _, provider := range providers { + resp, errExec := fn(ctx, provider) + if errExec == nil { + return resp, nil + } + lastErr = errExec + } + if lastErr != nil { + return cliproxyexecutor.Response{}, lastErr + } + return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"} +} + +func (m *Manager) executeStreamProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (<-chan cliproxyexecutor.StreamChunk, error)) (<-chan cliproxyexecutor.StreamChunk, error) { + if len(providers) == 0 { + return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} + } + var lastErr error + for _, provider := range providers { + chunks, errExec := fn(ctx, provider) + if errExec == nil { + return chunks, nil + } + lastErr = errExec + } + if lastErr != nil { + return nil, lastErr + } + return nil, &Error{Code: "auth_not_found", Message: "no auth available"} +} + // MarkResult records an execution result and notifies hooks. func (m *Manager) MarkResult(ctx context.Context, result Result) { if result.AuthID == "" { @@ -762,6 +943,20 @@ func cloneError(err *Error) *Error { } } +func statusCodeFromError(err error) int { + if err == nil { + return 0 + } + type statusCoder interface { + StatusCode() int + } + var sc statusCoder + if errors.As(err, &sc) && sc != nil { + return sc.StatusCode() + } + return 0 +} + func retryAfterFromError(err error) *time.Duration { if err == nil { return nil diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 5be25799..6e303ed2 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -281,6 +281,14 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) { } } +func (s *Service) applyRetryConfig(cfg *config.Config) { + if s == nil || s.coreManager == nil || cfg == nil { + return + } + maxInterval := time.Duration(cfg.MaxRetryInterval) * time.Second + s.coreManager.SetRetryConfig(cfg.RequestRetry, maxInterval) +} + func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName string, ok bool) { if a == nil { return "", "", false @@ -394,6 +402,8 @@ func (s *Service) Run(ctx context.Context) error { return err } + s.applyRetryConfig(s.cfg) + if s.coreManager != nil { if errLoad := s.coreManager.Load(ctx); errLoad != nil { log.Warnf("failed to load auth store: %v", errLoad) @@ -476,6 +486,7 @@ func (s *Service) Run(ctx context.Context) error { if newCfg == nil { return } + s.applyRetryConfig(newCfg) if s.server != nil { s.server.UpdateClients(newCfg) }