diff --git a/internal/config/config.go b/internal/config/config.go index 4e4571da..7febe548 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -255,6 +255,10 @@ type ClaudeKey struct { // APIKey is the authentication key for accessing Claude API services. APIKey string `yaml:"api-key" json:"api-key"` + // Priority controls selection preference when multiple credentials match. + // Higher values are preferred; defaults to 0. + Priority int `yaml:"priority,omitempty" json:"priority,omitempty"` + // Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4"). Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` @@ -293,6 +297,10 @@ type CodexKey struct { // APIKey is the authentication key for accessing Codex API services. APIKey string `yaml:"api-key" json:"api-key"` + // Priority controls selection preference when multiple credentials match. + // Higher values are preferred; defaults to 0. + Priority int `yaml:"priority,omitempty" json:"priority,omitempty"` + // Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex"). Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` @@ -331,6 +339,10 @@ type GeminiKey struct { // APIKey is the authentication key for accessing Gemini API services. APIKey string `yaml:"api-key" json:"api-key"` + // Priority controls selection preference when multiple credentials match. + // Higher values are preferred; defaults to 0. + Priority int `yaml:"priority,omitempty" json:"priority,omitempty"` + // Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview"). Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` @@ -397,6 +409,10 @@ type OpenAICompatibility struct { // Name is the identifier for this OpenAI compatibility configuration. Name string `yaml:"name" json:"name"` + // Priority controls selection preference when multiple providers or credentials match. + // Higher values are preferred; defaults to 0. + Priority int `yaml:"priority,omitempty" json:"priority,omitempty"` + // Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2"). Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` diff --git a/internal/config/vertex_compat.go b/internal/config/vertex_compat.go index 94e162b7..632bf7cc 100644 --- a/internal/config/vertex_compat.go +++ b/internal/config/vertex_compat.go @@ -13,6 +13,10 @@ type VertexCompatKey struct { // Maps to the x-goog-api-key header. APIKey string `yaml:"api-key" json:"api-key"` + // Priority controls selection preference when multiple credentials match. + // Higher values are preferred; defaults to 0. + Priority int `yaml:"priority,omitempty" json:"priority,omitempty"` + // Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro"). Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` diff --git a/internal/watcher/synthesizer/config.go b/internal/watcher/synthesizer/config.go index e976af4e..9ef04800 100644 --- a/internal/watcher/synthesizer/config.go +++ b/internal/watcher/synthesizer/config.go @@ -2,6 +2,7 @@ package synthesizer import ( "fmt" + "strconv" "strings" kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" @@ -63,6 +64,9 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea "source": fmt.Sprintf("config:gemini[%s]", token), "api_key": key, } + if entry.Priority != 0 { + attrs["priority"] = strconv.Itoa(entry.Priority) + } if base != "" { attrs["base_url"] = base } @@ -107,6 +111,9 @@ func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*corea "source": fmt.Sprintf("config:claude[%s]", token), "api_key": key, } + if ck.Priority != 0 { + attrs["priority"] = strconv.Itoa(ck.Priority) + } if base != "" { attrs["base_url"] = base } @@ -151,6 +158,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau "source": fmt.Sprintf("config:codex[%s]", token), "api_key": key, } + if ck.Priority != 0 { + attrs["priority"] = strconv.Itoa(ck.Priority) + } if ck.BaseURL != "" { attrs["base_url"] = ck.BaseURL } @@ -206,6 +216,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor "compat_name": compat.Name, "provider_key": providerName, } + if compat.Priority != 0 { + attrs["priority"] = strconv.Itoa(compat.Priority) + } if key != "" { attrs["api_key"] = key } @@ -237,6 +250,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor "compat_name": compat.Name, "provider_key": providerName, } + if compat.Priority != 0 { + attrs["priority"] = strconv.Itoa(compat.Priority) + } if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" { attrs["models_hash"] = hash } @@ -279,6 +295,9 @@ func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*cor "base_url": base, "provider_key": providerName, } + if compat.Priority != 0 { + attrs["priority"] = strconv.Itoa(compat.Priority) + } if key != "" { attrs["api_key"] = key } diff --git a/sdk/auth/antigravity.go b/sdk/auth/antigravity.go index b59acacf..210da57f 100644 --- a/sdk/auth/antigravity.go +++ b/sdk/auth/antigravity.go @@ -382,7 +382,7 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie // Call loadCodeAssist to get the project loadReqBody := map[string]any{ "metadata": map[string]string{ - "ideType": "IDE_UNSPECIFIED", + "ideType": "ANTIGRAVITY", "platform": "PLATFORM_UNSPECIFIED", "pluginType": "GEMINI", }, @@ -442,8 +442,134 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie } if projectID == "" { - return "", fmt.Errorf("no cloudaicompanionProject in response") + tierID := "legacy-tier" + if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers { + for _, rawTier := range tiers { + tier, okTier := rawTier.(map[string]any) + if !okTier { + continue + } + if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault { + if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" { + tierID = strings.TrimSpace(id) + break + } + } + } + } + + projectID, err = antigravityOnboardUser(ctx, accessToken, tierID, httpClient) + if err != nil { + return "", err + } + return projectID, nil } return projectID, nil } + +// antigravityOnboardUser attempts to fetch the project ID via onboardUser by polling for completion. +// It returns an empty string when the operation times out or completes without a project ID. +func antigravityOnboardUser(ctx context.Context, accessToken, tierID string, httpClient *http.Client) (string, error) { + if httpClient == nil { + httpClient = http.DefaultClient + } + fmt.Println("Antigravity: onboarding user...", tierID) + requestBody := map[string]any{ + "tierId": tierID, + "metadata": map[string]string{ + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + }, + } + + rawBody, errMarshal := json.Marshal(requestBody) + if errMarshal != nil { + return "", fmt.Errorf("marshal request body: %w", errMarshal) + } + + maxAttempts := 5 + for attempt := 1; attempt <= maxAttempts; attempt++ { + log.Debugf("Polling attempt %d/%d", attempt, maxAttempts) + + reqCtx := ctx + var cancel context.CancelFunc + if reqCtx == nil { + reqCtx = context.Background() + } + reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second) + + endpointURL := fmt.Sprintf("%s/%s:onboardUser", antigravityAPIEndpoint, antigravityAPIVersion) + req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody))) + if errRequest != nil { + cancel() + return "", fmt.Errorf("create request: %w", errRequest) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", antigravityAPIUserAgent) + req.Header.Set("X-Goog-Api-Client", antigravityAPIClient) + req.Header.Set("Client-Metadata", antigravityClientMetadata) + + resp, errDo := httpClient.Do(req) + if errDo != nil { + cancel() + return "", fmt.Errorf("execute request: %w", errDo) + } + + bodyBytes, errRead := io.ReadAll(resp.Body) + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("close body error: %v", errClose) + } + cancel() + + if errRead != nil { + return "", fmt.Errorf("read response: %w", errRead) + } + + if resp.StatusCode == http.StatusOK { + var data map[string]any + if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil { + return "", fmt.Errorf("decode response: %w", errDecode) + } + + if done, okDone := data["done"].(bool); okDone && done { + projectID := "" + if responseData, okResp := data["response"].(map[string]any); okResp { + switch projectValue := responseData["cloudaicompanionProject"].(type) { + case map[string]any: + if id, okID := projectValue["id"].(string); okID { + projectID = strings.TrimSpace(id) + } + case string: + projectID = strings.TrimSpace(projectValue) + } + } + + if projectID != "" { + log.Infof("Successfully fetched project_id: %s", projectID) + return projectID, nil + } + + return "", fmt.Errorf("no project_id in response") + } + + time.Sleep(2 * time.Second) + continue + } + + responsePreview := strings.TrimSpace(string(bodyBytes)) + if len(responsePreview) > 500 { + responsePreview = responsePreview[:500] + } + + responseErr := responsePreview + if len(responseErr) > 200 { + responseErr = responseErr[:200] + } + return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr) + } + + return "", nil +} diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index f38ccb96..12a23eec 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -271,7 +271,6 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye if len(normalized) == 0 { return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} } - rotated := m.rotateProviders(req.Model, normalized) retryTimes, maxWait := m.retrySettings() attempts := retryTimes + 1 @@ -281,14 +280,12 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye var lastErr error for attempt := 0; attempt < attempts; attempt++ { - resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) { - return m.executeWithProvider(execCtx, provider, req, opts) - }) + resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts) if errExec == nil { return resp, nil } lastErr = errExec - wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait) + wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait) if !shouldRetry { break } @@ -309,7 +306,6 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip if len(normalized) == 0 { return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} } - rotated := m.rotateProviders(req.Model, normalized) retryTimes, maxWait := m.retrySettings() attempts := retryTimes + 1 @@ -319,14 +315,12 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip var lastErr error for attempt := 0; attempt < attempts; attempt++ { - resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) { - return m.executeCountWithProvider(execCtx, provider, req, opts) - }) + resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts) if errExec == nil { return resp, nil } lastErr = errExec - wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait) + wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait) if !shouldRetry { break } @@ -347,7 +341,6 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli if len(normalized) == 0 { return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} } - rotated := m.rotateProviders(req.Model, normalized) retryTimes, maxWait := m.retrySettings() attempts := retryTimes + 1 @@ -357,14 +350,12 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli var lastErr error for attempt := 0; attempt < attempts; attempt++ { - chunks, errStream := m.executeStreamProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (<-chan cliproxyexecutor.StreamChunk, error) { - return m.executeStreamWithProvider(execCtx, provider, req, opts) - }) + chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts) if errStream == nil { return chunks, nil } lastErr = errStream - wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, rotated, req.Model, maxWait) + wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, normalized, req.Model, maxWait) if !shouldRetry { break } @@ -378,6 +369,167 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli return nil, &Error{Code: "auth_not_found", Message: "no auth available"} } +func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + if len(providers) == 0 { + return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} + } + routeModel := req.Model + tried := make(map[string]struct{}) + var lastErr error + for { + auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried) + if errPick != nil { + if lastErr != nil { + return cliproxyexecutor.Response{}, lastErr + } + return cliproxyexecutor.Response{}, errPick + } + + entry := logEntryWithRequestID(ctx) + debugLogAuthSelection(entry, auth, provider, req.Model) + + tried[auth.ID] = struct{}{} + execCtx := ctx + if rt := m.roundTripperFor(auth); rt != nil { + execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) + execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) + } + execReq := req + execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) + execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata) + resp, errExec := executor.Execute(execCtx, auth, execReq, opts) + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} + if errExec != nil { + result.Error = &Error{Message: errExec.Error()} + var se cliproxyexecutor.StatusError + if errors.As(errExec, &se) && se != nil { + result.Error.HTTPStatus = se.StatusCode() + } + if ra := retryAfterFromError(errExec); ra != nil { + result.RetryAfter = ra + } + m.MarkResult(execCtx, result) + lastErr = errExec + continue + } + m.MarkResult(execCtx, result) + return resp, nil + } +} + +func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + if len(providers) == 0 { + return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} + } + routeModel := req.Model + tried := make(map[string]struct{}) + var lastErr error + for { + auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried) + if errPick != nil { + if lastErr != nil { + return cliproxyexecutor.Response{}, lastErr + } + return cliproxyexecutor.Response{}, errPick + } + + entry := logEntryWithRequestID(ctx) + debugLogAuthSelection(entry, auth, provider, req.Model) + + tried[auth.ID] = struct{}{} + execCtx := ctx + if rt := m.roundTripperFor(auth); rt != nil { + execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) + execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) + } + execReq := req + execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) + execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata) + resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} + if errExec != nil { + result.Error = &Error{Message: errExec.Error()} + var se cliproxyexecutor.StatusError + if errors.As(errExec, &se) && se != nil { + result.Error.HTTPStatus = se.StatusCode() + } + if ra := retryAfterFromError(errExec); ra != nil { + result.RetryAfter = ra + } + m.MarkResult(execCtx, result) + lastErr = errExec + continue + } + m.MarkResult(execCtx, result) + return resp, nil + } +} + +func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { + if len(providers) == 0 { + return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} + } + routeModel := req.Model + tried := make(map[string]struct{}) + var lastErr error + for { + auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried) + if errPick != nil { + if lastErr != nil { + return nil, lastErr + } + return nil, errPick + } + + entry := logEntryWithRequestID(ctx) + debugLogAuthSelection(entry, auth, provider, req.Model) + + tried[auth.ID] = struct{}{} + execCtx := ctx + if rt := m.roundTripperFor(auth); rt != nil { + execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) + execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) + } + execReq := req + execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) + execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata) + chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) + if errStream != nil { + rerr := &Error{Message: errStream.Error()} + var se cliproxyexecutor.StatusError + if errors.As(errStream, &se) && se != nil { + rerr.HTTPStatus = se.StatusCode() + } + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} + result.RetryAfter = retryAfterFromError(errStream) + m.MarkResult(execCtx, result) + lastErr = errStream + continue + } + out := make(chan cliproxyexecutor.StreamChunk) + go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) { + defer close(out) + var failed bool + for chunk := range streamChunks { + if chunk.Err != nil && !failed { + failed = true + rerr := &Error{Message: chunk.Err.Error()} + var se cliproxyexecutor.StatusError + if errors.As(chunk.Err, &se) && se != nil { + rerr.HTTPStatus = se.StatusCode() + } + m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr}) + } + out <- chunk + } + if !failed { + m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true}) + } + }(execCtx, auth.Clone(), provider, chunks) + return out, nil + } +} + func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { if provider == "" { return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"} @@ -1191,6 +1343,77 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli return authCopy, executor, nil } +func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) { + providerSet := make(map[string]struct{}, len(providers)) + for _, provider := range providers { + p := strings.TrimSpace(strings.ToLower(provider)) + if p == "" { + continue + } + providerSet[p] = struct{}{} + } + if len(providerSet) == 0 { + return nil, nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"} + } + + m.mu.RLock() + candidates := make([]*Auth, 0, len(m.auths)) + modelKey := strings.TrimSpace(model) + registryRef := registry.GetGlobalRegistry() + for _, candidate := range m.auths { + if candidate == nil || candidate.Disabled { + continue + } + providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider)) + if providerKey == "" { + continue + } + if _, ok := providerSet[providerKey]; !ok { + continue + } + if _, used := tried[candidate.ID]; used { + continue + } + if _, ok := m.executors[providerKey]; !ok { + continue + } + if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) { + continue + } + candidates = append(candidates, candidate) + } + if len(candidates) == 0 { + m.mu.RUnlock() + return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"} + } + selected, errPick := m.selector.Pick(ctx, "mixed", model, opts, candidates) + if errPick != nil { + m.mu.RUnlock() + return nil, nil, "", errPick + } + if selected == nil { + m.mu.RUnlock() + return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"} + } + providerKey := strings.TrimSpace(strings.ToLower(selected.Provider)) + executor, okExecutor := m.executors[providerKey] + if !okExecutor { + m.mu.RUnlock() + return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"} + } + authCopy := selected.Clone() + m.mu.RUnlock() + if !selected.indexAssigned { + m.mu.Lock() + if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned { + current.EnsureIndex() + authCopy = current.Clone() + } + m.mu.Unlock() + } + return authCopy, executor, providerKey, nil +} + func (m *Manager) persist(ctx context.Context, auth *Auth) error { if m.store == nil || auth == nil { return nil diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index d7e120c5..7febf219 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -8,6 +8,7 @@ import ( "net/http" "sort" "strconv" + "strings" "sync" "time" @@ -103,13 +104,29 @@ func (e *modelCooldownError) Headers() http.Header { return headers } -func collectAvailable(auths []*Auth, model string, now time.Time) (available []*Auth, cooldownCount int, earliest time.Time) { - available = make([]*Auth, 0, len(auths)) +func authPriority(auth *Auth) int { + if auth == nil || auth.Attributes == nil { + return 0 + } + raw := strings.TrimSpace(auth.Attributes["priority"]) + if raw == "" { + return 0 + } + parsed, err := strconv.Atoi(raw) + if err != nil { + return 0 + } + return parsed +} + +func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) { + available = make(map[int][]*Auth) for i := 0; i < len(auths); i++ { candidate := auths[i] blocked, reason, next := isAuthBlockedForModel(candidate, model, now) if !blocked { - available = append(available, candidate) + priority := authPriority(candidate) + available[priority] = append(available[priority], candidate) continue } if reason == blockReasonCooldown { @@ -119,9 +136,6 @@ func collectAvailable(auths []*Auth, model string, now time.Time) (available []* } } } - if len(available) > 1 { - sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID }) - } return available, cooldownCount, earliest } @@ -130,18 +144,35 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([] return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"} } - available, cooldownCount, earliest := collectAvailable(auths, model, now) - if len(available) == 0 { + availableByPriority, cooldownCount, earliest := collectAvailableByPriority(auths, model, now) + if len(availableByPriority) == 0 { if cooldownCount == len(auths) && !earliest.IsZero() { + providerForError := provider + if providerForError == "mixed" { + providerForError = "" + } resetIn := earliest.Sub(now) if resetIn < 0 { resetIn = 0 } - return nil, newModelCooldownError(model, provider, resetIn) + return nil, newModelCooldownError(model, providerForError, resetIn) } return nil, &Error{Code: "auth_unavailable", Message: "no auth available"} } + bestPriority := 0 + found := false + for priority := range availableByPriority { + if !found || priority > bestPriority { + bestPriority = priority + found = true + } + } + + available := availableByPriority[bestPriority] + if len(available) > 1 { + sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID }) + } return available, nil } diff --git a/sdk/cliproxy/auth/selector_test.go b/sdk/cliproxy/auth/selector_test.go index f4beed03..91a7ed14 100644 --- a/sdk/cliproxy/auth/selector_test.go +++ b/sdk/cliproxy/auth/selector_test.go @@ -5,6 +5,7 @@ import ( "errors" "sync" "testing" + "time" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" ) @@ -56,6 +57,69 @@ func TestRoundRobinSelectorPick_CyclesDeterministic(t *testing.T) { } } +func TestRoundRobinSelectorPick_PriorityBuckets(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{} + auths := []*Auth{ + {ID: "c", Attributes: map[string]string{"priority": "0"}}, + {ID: "a", Attributes: map[string]string{"priority": "10"}}, + {ID: "b", Attributes: map[string]string{"priority": "10"}}, + } + + want := []string{"a", "b", "a", "b"} + for i, id := range want { + got, err := selector.Pick(context.Background(), "mixed", "", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got == nil { + t.Fatalf("Pick() #%d auth = nil", i) + } + if got.ID != id { + t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id) + } + if got.ID == "c" { + t.Fatalf("Pick() #%d unexpectedly selected lower priority auth", i) + } + } +} + +func TestFillFirstSelectorPick_PriorityFallbackCooldown(t *testing.T) { + t.Parallel() + + selector := &FillFirstSelector{} + now := time.Now() + model := "test-model" + + high := &Auth{ + ID: "high", + Attributes: map[string]string{"priority": "10"}, + ModelStates: map[string]*ModelState{ + model: { + Status: StatusActive, + Unavailable: true, + NextRetryAfter: now.Add(30 * time.Minute), + Quota: QuotaState{ + Exceeded: true, + }, + }, + }, + } + low := &Auth{ID: "low", Attributes: map[string]string{"priority": "0"}} + + got, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, []*Auth{high, low}) + if err != nil { + t.Fatalf("Pick() error = %v", err) + } + if got == nil { + t.Fatalf("Pick() auth = nil") + } + if got.ID != "low" { + t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low") + } +} + func TestRoundRobinSelectorPick_Concurrent(t *testing.T) { selector := &RoundRobinSelector{} auths := []*Auth{