diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index 0384eaaf..a1a95b2c 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -214,19 +214,46 @@ func (h *Handler) PatchGeminiKey(c *gin.Context) { func (h *Handler) DeleteGeminiKey(c *gin.Context) { if val := strings.TrimSpace(c.Query("api-key")); val != "" { - out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey)) - for _, v := range h.cfg.GeminiKey { - if v.APIKey != val { + if baseRaw, okBase := c.GetQuery("base-url"); okBase { + base := strings.TrimSpace(baseRaw) + out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey)) + for _, v := range h.cfg.GeminiKey { + if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base { + continue + } out = append(out, v) } + if len(out) != len(h.cfg.GeminiKey) { + h.cfg.GeminiKey = out + h.cfg.SanitizeGeminiKeys() + h.persist(c) + } else { + c.JSON(404, gin.H{"error": "item not found"}) + } + return } - if len(out) != len(h.cfg.GeminiKey) { - h.cfg.GeminiKey = out - h.cfg.SanitizeGeminiKeys() - h.persist(c) - } else { + + matchIndex := -1 + matchCount := 0 + for i := range h.cfg.GeminiKey { + if strings.TrimSpace(h.cfg.GeminiKey[i].APIKey) == val { + matchCount++ + if matchIndex == -1 { + matchIndex = i + } + } + } + if matchCount == 0 { c.JSON(404, gin.H{"error": "item not found"}) + return } + if matchCount > 1 { + c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"}) + return + } + h.cfg.GeminiKey = append(h.cfg.GeminiKey[:matchIndex], h.cfg.GeminiKey[matchIndex+1:]...) + h.cfg.SanitizeGeminiKeys() + h.persist(c) return } if idxStr := c.Query("index"); idxStr != "" { @@ -335,14 +362,39 @@ func (h *Handler) PatchClaudeKey(c *gin.Context) { } func (h *Handler) DeleteClaudeKey(c *gin.Context) { - if val := c.Query("api-key"); val != "" { - out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey)) - for _, v := range h.cfg.ClaudeKey { - if v.APIKey != val { + if val := strings.TrimSpace(c.Query("api-key")); val != "" { + if baseRaw, okBase := c.GetQuery("base-url"); okBase { + base := strings.TrimSpace(baseRaw) + out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey)) + for _, v := range h.cfg.ClaudeKey { + if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base { + continue + } out = append(out, v) } + h.cfg.ClaudeKey = out + h.cfg.SanitizeClaudeKeys() + h.persist(c) + return + } + + matchIndex := -1 + matchCount := 0 + for i := range h.cfg.ClaudeKey { + if strings.TrimSpace(h.cfg.ClaudeKey[i].APIKey) == val { + matchCount++ + if matchIndex == -1 { + matchIndex = i + } + } + } + if matchCount > 1 { + c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"}) + return + } + if matchIndex != -1 { + h.cfg.ClaudeKey = append(h.cfg.ClaudeKey[:matchIndex], h.cfg.ClaudeKey[matchIndex+1:]...) } - h.cfg.ClaudeKey = out h.cfg.SanitizeClaudeKeys() h.persist(c) return @@ -601,13 +653,38 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) { func (h *Handler) DeleteVertexCompatKey(c *gin.Context) { if val := strings.TrimSpace(c.Query("api-key")); val != "" { - out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey)) - for _, v := range h.cfg.VertexCompatAPIKey { - if v.APIKey != val { + if baseRaw, okBase := c.GetQuery("base-url"); okBase { + base := strings.TrimSpace(baseRaw) + out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey)) + for _, v := range h.cfg.VertexCompatAPIKey { + if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base { + continue + } out = append(out, v) } + h.cfg.VertexCompatAPIKey = out + h.cfg.SanitizeVertexCompatKeys() + h.persist(c) + return + } + + matchIndex := -1 + matchCount := 0 + for i := range h.cfg.VertexCompatAPIKey { + if strings.TrimSpace(h.cfg.VertexCompatAPIKey[i].APIKey) == val { + matchCount++ + if matchIndex == -1 { + matchIndex = i + } + } + } + if matchCount > 1 { + c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"}) + return + } + if matchIndex != -1 { + h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:matchIndex], h.cfg.VertexCompatAPIKey[matchIndex+1:]...) } - h.cfg.VertexCompatAPIKey = out h.cfg.SanitizeVertexCompatKeys() h.persist(c) return @@ -919,14 +996,39 @@ func (h *Handler) PatchCodexKey(c *gin.Context) { } func (h *Handler) DeleteCodexKey(c *gin.Context) { - if val := c.Query("api-key"); val != "" { - out := make([]config.CodexKey, 0, len(h.cfg.CodexKey)) - for _, v := range h.cfg.CodexKey { - if v.APIKey != val { + if val := strings.TrimSpace(c.Query("api-key")); val != "" { + if baseRaw, okBase := c.GetQuery("base-url"); okBase { + base := strings.TrimSpace(baseRaw) + out := make([]config.CodexKey, 0, len(h.cfg.CodexKey)) + for _, v := range h.cfg.CodexKey { + if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base { + continue + } out = append(out, v) } + h.cfg.CodexKey = out + h.cfg.SanitizeCodexKeys() + h.persist(c) + return + } + + matchIndex := -1 + matchCount := 0 + for i := range h.cfg.CodexKey { + if strings.TrimSpace(h.cfg.CodexKey[i].APIKey) == val { + matchCount++ + if matchIndex == -1 { + matchIndex = i + } + } + } + if matchCount > 1 { + c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"}) + return + } + if matchIndex != -1 { + h.cfg.CodexKey = append(h.cfg.CodexKey[:matchIndex], h.cfg.CodexKey[matchIndex+1:]...) } - h.cfg.CodexKey = out h.cfg.SanitizeCodexKeys() h.persist(c) return diff --git a/internal/api/handlers/management/config_lists_delete_keys_test.go b/internal/api/handlers/management/config_lists_delete_keys_test.go new file mode 100644 index 00000000..aaa43910 --- /dev/null +++ b/internal/api/handlers/management/config_lists_delete_keys_test.go @@ -0,0 +1,172 @@ +package management + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func writeTestConfigFile(t *testing.T) string { + t.Helper() + + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + if errWrite := os.WriteFile(path, []byte("{}\n"), 0o600); errWrite != nil { + t.Fatalf("failed to write test config: %v", errWrite) + } + return path +} + +func TestDeleteGeminiKey_RequiresBaseURLWhenAPIKeyDuplicated(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + h := &Handler{ + cfg: &config.Config{ + GeminiKey: []config.GeminiKey{ + {APIKey: "shared-key", BaseURL: "https://a.example.com"}, + {APIKey: "shared-key", BaseURL: "https://b.example.com"}, + }, + }, + configFilePath: writeTestConfigFile(t), + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/gemini-api-key?api-key=shared-key", nil) + + h.DeleteGeminiKey(c) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } + if got := len(h.cfg.GeminiKey); got != 2 { + t.Fatalf("gemini keys len = %d, want 2", got) + } +} + +func TestDeleteGeminiKey_DeletesOnlyMatchingBaseURL(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + h := &Handler{ + cfg: &config.Config{ + GeminiKey: []config.GeminiKey{ + {APIKey: "shared-key", BaseURL: "https://a.example.com"}, + {APIKey: "shared-key", BaseURL: "https://b.example.com"}, + }, + }, + configFilePath: writeTestConfigFile(t), + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/gemini-api-key?api-key=shared-key&base-url=https://a.example.com", nil) + + h.DeleteGeminiKey(c) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if got := len(h.cfg.GeminiKey); got != 1 { + t.Fatalf("gemini keys len = %d, want 1", got) + } + if got := h.cfg.GeminiKey[0].BaseURL; got != "https://b.example.com" { + t.Fatalf("remaining base-url = %q, want %q", got, "https://b.example.com") + } +} + +func TestDeleteClaudeKey_DeletesEmptyBaseURLWhenExplicitlyProvided(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + h := &Handler{ + cfg: &config.Config{ + ClaudeKey: []config.ClaudeKey{ + {APIKey: "shared-key", BaseURL: ""}, + {APIKey: "shared-key", BaseURL: "https://claude.example.com"}, + }, + }, + configFilePath: writeTestConfigFile(t), + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/claude-api-key?api-key=shared-key&base-url=", nil) + + h.DeleteClaudeKey(c) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if got := len(h.cfg.ClaudeKey); got != 1 { + t.Fatalf("claude keys len = %d, want 1", got) + } + if got := h.cfg.ClaudeKey[0].BaseURL; got != "https://claude.example.com" { + t.Fatalf("remaining base-url = %q, want %q", got, "https://claude.example.com") + } +} + +func TestDeleteVertexCompatKey_DeletesOnlyMatchingBaseURL(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + h := &Handler{ + cfg: &config.Config{ + VertexCompatAPIKey: []config.VertexCompatKey{ + {APIKey: "shared-key", BaseURL: "https://a.example.com"}, + {APIKey: "shared-key", BaseURL: "https://b.example.com"}, + }, + }, + configFilePath: writeTestConfigFile(t), + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/vertex-api-key?api-key=shared-key&base-url=https://b.example.com", nil) + + h.DeleteVertexCompatKey(c) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if got := len(h.cfg.VertexCompatAPIKey); got != 1 { + t.Fatalf("vertex keys len = %d, want 1", got) + } + if got := h.cfg.VertexCompatAPIKey[0].BaseURL; got != "https://a.example.com" { + t.Fatalf("remaining base-url = %q, want %q", got, "https://a.example.com") + } +} + +func TestDeleteCodexKey_RequiresBaseURLWhenAPIKeyDuplicated(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + h := &Handler{ + cfg: &config.Config{ + CodexKey: []config.CodexKey{ + {APIKey: "shared-key", BaseURL: "https://a.example.com"}, + {APIKey: "shared-key", BaseURL: "https://b.example.com"}, + }, + }, + configFilePath: writeTestConfigFile(t), + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/codex-api-key?api-key=shared-key", nil) + + h.DeleteCodexKey(c) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } + if got := len(h.cfg.CodexKey); got != 2 { + t.Fatalf("codex keys len = %d, want 2", got) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 7a1383b5..fedf77da 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -981,6 +981,7 @@ func (cfg *Config) SanitizeKiroKeys() { } // SanitizeGeminiKeys deduplicates and normalizes Gemini credentials. +// It uses API key + base URL as the uniqueness key. func (cfg *Config) SanitizeGeminiKeys() { if cfg == nil { return @@ -999,10 +1000,11 @@ func (cfg *Config) SanitizeGeminiKeys() { entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) entry.Headers = NormalizeHeaders(entry.Headers) entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels) - if _, exists := seen[entry.APIKey]; exists { + uniqueKey := entry.APIKey + "|" + entry.BaseURL + if _, exists := seen[uniqueKey]; exists { continue } - seen[entry.APIKey] = struct{}{} + seen[uniqueKey] = struct{}{} out = append(out, entry) } cfg.GeminiKey = out diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index cf4a9975..5c8ff039 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -632,6 +632,26 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) { r.Header.Set("Accept", "application/json") } +func normaliseQwenBaseURL(resourceURL string) string { + raw := strings.TrimSpace(resourceURL) + if raw == "" { + return "" + } + + normalized := raw + lower := strings.ToLower(normalized) + if !strings.HasPrefix(lower, "http://") && !strings.HasPrefix(lower, "https://") { + normalized = "https://" + normalized + } + + normalized = strings.TrimRight(normalized, "/") + if !strings.HasSuffix(strings.ToLower(normalized), "/v1") { + normalized += "/v1" + } + + return normalized +} + func qwenCreds(a *cliproxyauth.Auth) (token, baseURL string) { if a == nil { return "", "" @@ -649,7 +669,7 @@ func qwenCreds(a *cliproxyauth.Auth) (token, baseURL string) { token = v } if v, ok := a.Metadata["resource_url"].(string); ok { - baseURL = fmt.Sprintf("https://%s/v1", v) + baseURL = normaliseQwenBaseURL(v) } } return diff --git a/internal/runtime/executor/qwen_executor_test.go b/internal/runtime/executor/qwen_executor_test.go index d12c0a0b..cf9ed21f 100644 --- a/internal/runtime/executor/qwen_executor_test.go +++ b/internal/runtime/executor/qwen_executor_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" "github.com/tidwall/gjson" ) @@ -176,3 +177,35 @@ func TestWrapQwenError_Maps403QuotaTo429WithoutRetryAfter(t *testing.T) { t.Fatalf("wrapQwenError retryAfter = %v, want nil", *retryAfter) } } + +func TestQwenCreds_NormalizesResourceURL(t *testing.T) { + tests := []struct { + name string + resourceURL string + wantBaseURL string + }{ + {"host only", "portal.qwen.ai", "https://portal.qwen.ai/v1"}, + {"scheme no v1", "https://portal.qwen.ai", "https://portal.qwen.ai/v1"}, + {"scheme with v1", "https://portal.qwen.ai/v1", "https://portal.qwen.ai/v1"}, + {"scheme with v1 slash", "https://portal.qwen.ai/v1/", "https://portal.qwen.ai/v1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := &cliproxyauth.Auth{ + Metadata: map[string]any{ + "access_token": "test-token", + "resource_url": tt.resourceURL, + }, + } + + token, baseURL := qwenCreds(auth) + if token != "test-token" { + t.Fatalf("qwenCreds token = %q, want %q", token, "test-token") + } + if baseURL != tt.wantBaseURL { + t.Fatalf("qwenCreds baseURL = %q, want %q", baseURL, tt.wantBaseURL) + } + }) + } +} diff --git a/sdk/auth/qwen.go b/sdk/auth/qwen.go index 310d4987..d891021a 100644 --- a/sdk/auth/qwen.go +++ b/sdk/auth/qwen.go @@ -27,7 +27,7 @@ func (a *QwenAuthenticator) Provider() string { } func (a *QwenAuthenticator) RefreshLead() *time.Duration { - return new(3 * time.Hour) + return new(20 * time.Minute) } func (a *QwenAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { diff --git a/sdk/auth/qwen_refresh_lead_test.go b/sdk/auth/qwen_refresh_lead_test.go new file mode 100644 index 00000000..56f41fc0 --- /dev/null +++ b/sdk/auth/qwen_refresh_lead_test.go @@ -0,0 +1,19 @@ +package auth + +import ( + "testing" + "time" +) + +func TestQwenAuthenticator_RefreshLeadIsSane(t *testing.T) { + lead := NewQwenAuthenticator().RefreshLead() + if lead == nil { + t.Fatal("RefreshLead() = nil, want non-nil") + } + if *lead <= 0 { + t.Fatalf("RefreshLead() = %s, want > 0", *lead) + } + if *lead > 30*time.Minute { + t.Fatalf("RefreshLead() = %s, want <= %s", *lead, 30*time.Minute) + } +} diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 1f53242e..71587e89 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -1850,6 +1850,50 @@ func (m *Manager) closestCooldownWait(providers []string, model string, attempt return minWait, found } +func (m *Manager) retryAllowed(attempt int, providers []string) bool { + if m == nil || attempt < 0 || len(providers) == 0 { + return false + } + defaultRetry := int(m.requestRetry.Load()) + if defaultRetry < 0 { + defaultRetry = 0 + } + providerSet := make(map[string]struct{}, len(providers)) + for i := range providers { + key := strings.TrimSpace(strings.ToLower(providers[i])) + if key == "" { + continue + } + providerSet[key] = struct{}{} + } + if len(providerSet) == 0 { + return false + } + + m.mu.RLock() + defer m.mu.RUnlock() + for _, auth := range m.auths { + if auth == nil { + continue + } + providerKey := strings.TrimSpace(strings.ToLower(auth.Provider)) + if _, ok := providerSet[providerKey]; !ok { + continue + } + effectiveRetry := defaultRetry + if override, ok := auth.RequestRetryOverride(); ok { + effectiveRetry = override + } + if effectiveRetry < 0 { + effectiveRetry = 0 + } + if attempt < effectiveRetry { + return true + } + } + return false +} + func (m *Manager) shouldRetryAfterError(err error, attempt int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) { if err == nil { return 0, false @@ -1857,17 +1901,31 @@ func (m *Manager) shouldRetryAfterError(err error, attempt int, providers []stri if maxWait <= 0 { return 0, false } - if status := statusCodeFromError(err); status == http.StatusOK { + status := statusCodeFromError(err) + if status == http.StatusOK { return 0, false } if isRequestInvalidError(err) { return 0, false } wait, found := m.closestCooldownWait(providers, model, attempt) - if !found || wait > maxWait { + if found { + if wait > maxWait { + return 0, false + } + return wait, true + } + if status != http.StatusTooManyRequests { return 0, false } - return wait, true + if !m.retryAllowed(attempt, providers) { + return 0, false + } + retryAfter := retryAfterFromError(err) + if retryAfter == nil || *retryAfter <= 0 || *retryAfter > maxWait { + return 0, false + } + return *retryAfter, true } func waitForCooldown(ctx context.Context, wait time.Duration) error { diff --git a/sdk/cliproxy/auth/conductor_overrides_test.go b/sdk/cliproxy/auth/conductor_overrides_test.go index e8dc1393..1b74aab1 100644 --- a/sdk/cliproxy/auth/conductor_overrides_test.go +++ b/sdk/cliproxy/auth/conductor_overrides_test.go @@ -690,6 +690,57 @@ func TestManager_Execute_DisableCooling_DoesNotBlackoutAfter429RetryAfter(t *tes } } +func TestManager_Execute_DisableCooling_RetriesAfter429RetryAfter(t *testing.T) { + prev := quotaCooldownDisabled.Load() + quotaCooldownDisabled.Store(false) + t.Cleanup(func() { quotaCooldownDisabled.Store(prev) }) + + m := NewManager(nil, nil, nil) + m.SetRetryConfig(3, 100*time.Millisecond, 0) + + executor := &authFallbackExecutor{ + id: "claude", + executeErrors: map[string]error{ + "auth-429-retryafter-exec": &retryAfterStatusError{ + status: http.StatusTooManyRequests, + message: "quota exhausted", + retryAfter: 5 * time.Millisecond, + }, + }, + } + m.RegisterExecutor(executor) + + auth := &Auth{ + ID: "auth-429-retryafter-exec", + Provider: "claude", + Metadata: map[string]any{ + "disable_cooling": true, + }, + } + if _, errRegister := m.Register(context.Background(), auth); errRegister != nil { + t.Fatalf("register auth: %v", errRegister) + } + + model := "test-model-429-retryafter-exec" + reg := registry.GetGlobalRegistry() + reg.RegisterClient(auth.ID, "claude", []*registry.ModelInfo{{ID: model}}) + t.Cleanup(func() { reg.UnregisterClient(auth.ID) }) + + req := cliproxyexecutor.Request{Model: model} + _, errExecute := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{}) + if errExecute == nil { + t.Fatal("expected execute error") + } + if statusCodeFromError(errExecute) != http.StatusTooManyRequests { + t.Fatalf("execute status = %d, want %d", statusCodeFromError(errExecute), http.StatusTooManyRequests) + } + + calls := executor.ExecuteCalls() + if len(calls) != 4 { + t.Fatalf("execute calls = %d, want 4 (initial + 3 retries)", len(calls)) + } +} + func TestManager_MarkResult_RequestScopedNotFoundDoesNotCooldownAuth(t *testing.T) { m := NewManager(nil, nil, nil) diff --git a/sdk/cliproxy/auth/scheduler.go b/sdk/cliproxy/auth/scheduler.go index bb6021ab..f391aa38 100644 --- a/sdk/cliproxy/auth/scheduler.go +++ b/sdk/cliproxy/auth/scheduler.go @@ -97,6 +97,72 @@ type childBucket struct { // cooldownQueue is the blocked auth collection ordered by next retry time during rebuilds. type cooldownQueue []*scheduledAuth +type readyViewCursorState struct { + cursor int + parentCursor int + childCursors map[string]int +} + +type readyBucketCursorState struct { + all readyViewCursorState + ws readyViewCursorState +} + +func snapshotReadyViewCursors(view readyView) readyViewCursorState { + state := readyViewCursorState{ + cursor: view.cursor, + parentCursor: view.parentCursor, + } + if len(view.children) == 0 { + return state + } + state.childCursors = make(map[string]int, len(view.children)) + for parent, child := range view.children { + if child == nil { + continue + } + state.childCursors[parent] = child.cursor + } + return state +} + +func restoreReadyViewCursors(view *readyView, state readyViewCursorState) { + if view == nil { + return + } + if len(view.flat) > 0 { + view.cursor = normalizeCursor(state.cursor, len(view.flat)) + } + if len(view.parentOrder) == 0 || len(view.children) == 0 { + return + } + view.parentCursor = normalizeCursor(state.parentCursor, len(view.parentOrder)) + if len(state.childCursors) == 0 { + return + } + for parent, child := range view.children { + if child == nil || len(child.items) == 0 { + continue + } + cursor, ok := state.childCursors[parent] + if !ok { + continue + } + child.cursor = normalizeCursor(cursor, len(child.items)) + } +} + +func normalizeCursor(cursor, size int) int { + if size <= 0 || cursor <= 0 { + return 0 + } + cursor = cursor % size + if cursor < 0 { + cursor += size + } + return cursor +} + // newAuthScheduler constructs an empty scheduler configured for the supplied selector strategy. func newAuthScheduler(selector Selector) *authScheduler { return &authScheduler{ @@ -829,6 +895,17 @@ func (m *modelScheduler) availabilitySummaryLocked(predicate func(*scheduledAuth // rebuildIndexesLocked reconstructs ready and blocked views from the current entry map. func (m *modelScheduler) rebuildIndexesLocked() { + cursorStates := make(map[int]readyBucketCursorState, len(m.readyByPriority)) + for priority, bucket := range m.readyByPriority { + if bucket == nil { + continue + } + cursorStates[priority] = readyBucketCursorState{ + all: snapshotReadyViewCursors(bucket.all), + ws: snapshotReadyViewCursors(bucket.ws), + } + } + m.readyByPriority = make(map[int]*readyBucket) m.priorityOrder = m.priorityOrder[:0] m.blocked = m.blocked[:0] @@ -849,7 +926,12 @@ func (m *modelScheduler) rebuildIndexesLocked() { sort.Slice(entries, func(i, j int) bool { return entries[i].auth.ID < entries[j].auth.ID }) - m.readyByPriority[priority] = buildReadyBucket(entries) + bucket := buildReadyBucket(entries) + if cursorState, ok := cursorStates[priority]; ok && bucket != nil { + restoreReadyViewCursors(&bucket.all, cursorState.all) + restoreReadyViewCursors(&bucket.ws, cursorState.ws) + } + m.readyByPriority[priority] = bucket m.priorityOrder = append(m.priorityOrder, priority) } sort.Slice(m.priorityOrder, func(i, j int) bool {