From e68a6037e24b617ae934d423ebc929af1061ed4e Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Tue, 23 Sep 2025 09:24:55 +0800 Subject: [PATCH] feat(auth): enable model suspension and resumption logic in `AuthManager` - Added model suspension with reason tracking for 401 (unauthorized) and 402/403 (payment-related) errors. - Implemented resumption logic upon model quota recovery or auth state changes. - Enhanced registry to manage suspended clients, including counts and observability data. - Updated availability computation to exclude suspended clients, ensuring accurate client model tracking. --- internal/registry/model_registry.go | 97 +++++++++++++++++++++++++++-- sdk/cliproxy/auth/manager.go | 36 +++++++++-- 2 files changed, 124 insertions(+), 9 deletions(-) diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index b0cfd2cf..74e7abf4 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -58,6 +58,8 @@ type ModelRegistration struct { QuotaExceededClients map[string]*time.Time // Providers tracks available clients grouped by provider identifier Providers map[string]int + // SuspendedClients tracks temporarily disabled clients keyed by client ID + SuspendedClients map[string]string } // ModelRegistry manages the global registry of available models @@ -112,6 +114,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ // Model already exists, increment count existing.Count++ existing.LastUpdated = now + if existing.SuspendedClients == nil { + existing.SuspendedClients = make(map[string]string) + } if provider != "" { if existing.Providers == nil { existing.Providers = make(map[string]int) @@ -126,6 +131,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ Count: 1, LastUpdated: now, QuotaExceededClients: make(map[string]*time.Time), + SuspendedClients: make(map[string]string), } if provider != "" { registration.Providers = map[string]int{provider: 1} @@ -172,6 +178,9 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) { // Remove quota tracking for this client delete(registration.QuotaExceededClients, clientID) + if registration.SuspendedClients != nil { + delete(registration.SuspendedClients, clientID) + } if hasProvider && registration.Providers != nil { if count, ok := registration.Providers[provider]; ok { @@ -229,6 +238,60 @@ func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) { } } +// SuspendClientModel marks a client's model as temporarily unavailable until explicitly resumed. +// Parameters: +// - clientID: The client to suspend +// - modelID: The model affected by the suspension +// - reason: Optional description for observability +func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) { + if clientID == "" || modelID == "" { + return + } + r.mutex.Lock() + defer r.mutex.Unlock() + + registration, exists := r.models[modelID] + if !exists || registration == nil { + return + } + if registration.SuspendedClients == nil { + registration.SuspendedClients = make(map[string]string) + } + if _, already := registration.SuspendedClients[clientID]; already { + return + } + registration.SuspendedClients[clientID] = reason + registration.LastUpdated = time.Now() + if reason != "" { + log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason) + } else { + log.Debugf("Suspended client %s for model %s", clientID, modelID) + } +} + +// ResumeClientModel clears a previous suspension so the client counts toward availability again. +// Parameters: +// - clientID: The client to resume +// - modelID: The model being resumed +func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) { + if clientID == "" || modelID == "" { + return + } + r.mutex.Lock() + defer r.mutex.Unlock() + + registration, exists := r.models[modelID] + if !exists || registration == nil || registration.SuspendedClients == nil { + return + } + if _, ok := registration.SuspendedClients[clientID]; !ok { + return + } + delete(registration.SuspendedClients, clientID) + registration.LastUpdated = time.Now() + log.Debugf("Resumed client %s for model %s", clientID, modelID) +} + // GetAvailableModels returns all models that have at least one available client // Parameters: // - handlerType: The handler type to filter models for (e.g., "openai", "claude", "gemini") @@ -255,7 +318,14 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any } } - effectiveClients := availableClients - expiredClients + suspendedClients := 0 + if registration.SuspendedClients != nil { + suspendedClients = len(registration.SuspendedClients) + } + effectiveClients := availableClients - expiredClients - suspendedClients + if effectiveClients < 0 { + effectiveClients = 0 + } // Only include models that have available clients if effectiveClients > 0 { @@ -290,8 +360,15 @@ func (r *ModelRegistry) GetModelCount(modelID string) int { expiredClients++ } } - - return registration.Count - expiredClients + suspendedClients := 0 + if registration.SuspendedClients != nil { + suspendedClients = len(registration.SuspendedClients) + } + result := registration.Count - expiredClients - suspendedClients + if result < 0 { + return 0 + } + return result } return 0 } @@ -316,11 +393,23 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string { count int } providers := make([]providerCount, 0, len(registration.Providers)) + suspendedByProvider := make(map[string]int) + if registration.SuspendedClients != nil { + for clientID := range registration.SuspendedClients { + if provider, ok := r.clientProviders[clientID]; ok && provider != "" { + suspendedByProvider[provider]++ + } + } + } for name, count := range registration.Providers { if count <= 0 { continue } - providers = append(providers, providerCount{name: name, count: count}) + adjusted := count - suspendedByProvider[name] + if adjusted <= 0 { + continue + } + providers = append(providers, providerCount{name: name, count: adjusted}) } if len(providers) == 0 { return nil diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index 0071e86a..4a6491ce 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -486,6 +486,9 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { return } // Update in-memory auth status based on result. + shouldResumeModel := false + shouldSuspendModel := false + suspendReason := "" m.mu.Lock() if auth, ok := m.auths[result.AuthID]; ok && auth != nil { now := time.Now() @@ -501,6 +504,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { auth.UpdatedAt = now if result.Model != "" { registry.GetGlobalRegistry().ClearModelQuotaExceeded(auth.ID, result.Model) + shouldResumeModel = true } } else { // Default transient error state. @@ -511,7 +515,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { auth.LastError = &Error{Code: result.Error.Code, Message: result.Error.Message, Retryable: result.Error.Retryable} } // If the error carries a status code, adjust backoff/quota accordingly. - // 401 -> auth issue; 402/429 -> quota; 5xx -> transient. + // 401 -> auth issue; 402 -> billing; 403 -> forbidden; 429 -> quota; 5xx -> transient. var statusCode int if se, isOk := any(result.Error).(interface{ StatusCode() int }); isOk && se != nil { statusCode = se.StatusCode() @@ -519,19 +523,35 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { switch statusCode { case 401: auth.StatusMessage = "unauthorized" - auth.NextRefreshAfter = now.Add(5 * time.Minute) - case 402, 429: + auth.NextRefreshAfter = now.Add(30 * time.Minute) + if result.Model != "" { + shouldSuspendModel = true + suspendReason = "unauthorized" + } + case 402, 403: + auth.StatusMessage = "payment_required" + auth.NextRefreshAfter = now.Add(30 * time.Minute) + if result.Model != "" { + shouldSuspendModel = true + suspendReason = "payment_required" + } + case 429: auth.StatusMessage = "quota exhausted" auth.Quota.Exceeded = true auth.Quota.Reason = "quota" - auth.Quota.NextRecoverAt = now.Add(10 * time.Minute) + auth.Quota.NextRecoverAt = now.Add(30 * time.Minute) auth.NextRefreshAfter = auth.Quota.NextRecoverAt if result.Model != "" { + shouldSuspendModel = true registry.GetGlobalRegistry().SetModelQuotaExceeded(auth.ID, result.Model) } - case 403, 408, 500, 502, 503, 504: + case 408, 500, 502, 503, 504: auth.StatusMessage = "transient upstream error" auth.NextRefreshAfter = now.Add(1 * time.Minute) + if result.Model != "" { + shouldSuspendModel = false + suspendReason = "forbidden" + } default: // keep generic if auth.StatusMessage == "" { @@ -544,6 +564,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { } m.mu.Unlock() + if shouldResumeModel { + registry.GetGlobalRegistry().ResumeClientModel(result.AuthID, result.Model) + } else if shouldSuspendModel { + registry.GetGlobalRegistry().SuspendClientModel(result.AuthID, result.Model, suspendReason) + } + m.hook.OnResult(ctx, result) }