Merge branch 'router-for-me:main' into main

This commit is contained in:
Luis Pater
2026-01-15 12:09:22 +08:00
committed by GitHub
7 changed files with 509 additions and 26 deletions

View File

@@ -255,6 +255,10 @@ type ClaudeKey struct {
// APIKey is the authentication key for accessing Claude API services.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -293,6 +297,10 @@ type CodexKey struct {
// APIKey is the authentication key for accessing Codex API services.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -331,6 +339,10 @@ type GeminiKey struct {
// APIKey is the authentication key for accessing Gemini API services.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -397,6 +409,10 @@ type OpenAICompatibility struct {
// Name is the identifier for this OpenAI compatibility configuration.
Name string `yaml:"name" json:"name"`
// Priority controls selection preference when multiple providers or credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`

View File

@@ -13,6 +13,10 @@ type VertexCompatKey struct {
// Maps to the x-goog-api-key header.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`

View File

@@ -2,6 +2,7 @@ package synthesizer
import (
"fmt"
"strconv"
"strings"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
@@ -63,6 +64,9 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea
"source": fmt.Sprintf("config:gemini[%s]", token),
"api_key": key,
}
if entry.Priority != 0 {
attrs["priority"] = strconv.Itoa(entry.Priority)
}
if base != "" {
attrs["base_url"] = base
}
@@ -107,6 +111,9 @@ func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*corea
"source": fmt.Sprintf("config:claude[%s]", token),
"api_key": key,
}
if ck.Priority != 0 {
attrs["priority"] = strconv.Itoa(ck.Priority)
}
if base != "" {
attrs["base_url"] = base
}
@@ -151,6 +158,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
"source": fmt.Sprintf("config:codex[%s]", token),
"api_key": key,
}
if ck.Priority != 0 {
attrs["priority"] = strconv.Itoa(ck.Priority)
}
if ck.BaseURL != "" {
attrs["base_url"] = ck.BaseURL
}
@@ -206,6 +216,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor
"compat_name": compat.Name,
"provider_key": providerName,
}
if compat.Priority != 0 {
attrs["priority"] = strconv.Itoa(compat.Priority)
}
if key != "" {
attrs["api_key"] = key
}
@@ -237,6 +250,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor
"compat_name": compat.Name,
"provider_key": providerName,
}
if compat.Priority != 0 {
attrs["priority"] = strconv.Itoa(compat.Priority)
}
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
attrs["models_hash"] = hash
}
@@ -279,6 +295,9 @@ func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*cor
"base_url": base,
"provider_key": providerName,
}
if compat.Priority != 0 {
attrs["priority"] = strconv.Itoa(compat.Priority)
}
if key != "" {
attrs["api_key"] = key
}

View File

@@ -382,7 +382,7 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie
// Call loadCodeAssist to get the project
loadReqBody := map[string]any{
"metadata": map[string]string{
"ideType": "IDE_UNSPECIFIED",
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
@@ -442,8 +442,134 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie
}
if projectID == "" {
return "", fmt.Errorf("no cloudaicompanionProject in response")
tierID := "legacy-tier"
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
for _, rawTier := range tiers {
tier, okTier := rawTier.(map[string]any)
if !okTier {
continue
}
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
tierID = strings.TrimSpace(id)
break
}
}
}
}
projectID, err = antigravityOnboardUser(ctx, accessToken, tierID, httpClient)
if err != nil {
return "", err
}
return projectID, nil
}
return projectID, nil
}
// antigravityOnboardUser attempts to fetch the project ID via onboardUser by polling for completion.
// It returns an empty string when the operation times out or completes without a project ID.
func antigravityOnboardUser(ctx context.Context, accessToken, tierID string, httpClient *http.Client) (string, error) {
if httpClient == nil {
httpClient = http.DefaultClient
}
fmt.Println("Antigravity: onboarding user...", tierID)
requestBody := map[string]any{
"tierId": tierID,
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(requestBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
maxAttempts := 5
for attempt := 1; attempt <= maxAttempts; attempt++ {
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
reqCtx := ctx
var cancel context.CancelFunc
if reqCtx == nil {
reqCtx = context.Background()
}
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
endpointURL := fmt.Sprintf("%s/%s:onboardUser", antigravityAPIEndpoint, antigravityAPIVersion)
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if errRequest != nil {
cancel()
return "", fmt.Errorf("create request: %w", errRequest)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", antigravityAPIUserAgent)
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
req.Header.Set("Client-Metadata", antigravityClientMetadata)
resp, errDo := httpClient.Do(req)
if errDo != nil {
cancel()
return "", fmt.Errorf("execute request: %w", errDo)
}
bodyBytes, errRead := io.ReadAll(resp.Body)
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("close body error: %v", errClose)
}
cancel()
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode == http.StatusOK {
var data map[string]any
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
if done, okDone := data["done"].(bool); okDone && done {
projectID := ""
if responseData, okResp := data["response"].(map[string]any); okResp {
switch projectValue := responseData["cloudaicompanionProject"].(type) {
case map[string]any:
if id, okID := projectValue["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
case string:
projectID = strings.TrimSpace(projectValue)
}
}
if projectID != "" {
log.Infof("Successfully fetched project_id: %s", projectID)
return projectID, nil
}
return "", fmt.Errorf("no project_id in response")
}
time.Sleep(2 * time.Second)
continue
}
responsePreview := strings.TrimSpace(string(bodyBytes))
if len(responsePreview) > 500 {
responsePreview = responsePreview[:500]
}
responseErr := responsePreview
if len(responseErr) > 200 {
responseErr = responseErr[:200]
}
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
}
return "", nil
}

View File

@@ -271,7 +271,6 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
if len(normalized) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
rotated := m.rotateProviders(req.Model, normalized)
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
@@ -281,14 +280,12 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
var lastErr error
for attempt := 0; attempt < attempts; attempt++ {
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
return m.executeWithProvider(execCtx, provider, req, opts)
})
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
if errExec == nil {
return resp, nil
}
lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
if !shouldRetry {
break
}
@@ -309,7 +306,6 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
if len(normalized) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
rotated := m.rotateProviders(req.Model, normalized)
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
@@ -319,14 +315,12 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
var lastErr error
for attempt := 0; attempt < attempts; attempt++ {
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
return m.executeCountWithProvider(execCtx, provider, req, opts)
})
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
if errExec == nil {
return resp, nil
}
lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
if !shouldRetry {
break
}
@@ -347,7 +341,6 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
if len(normalized) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
rotated := m.rotateProviders(req.Model, normalized)
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
@@ -357,14 +350,12 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
var lastErr error
for attempt := 0; attempt < attempts; attempt++ {
chunks, errStream := m.executeStreamProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (<-chan cliproxyexecutor.StreamChunk, error) {
return m.executeStreamWithProvider(execCtx, provider, req, opts)
})
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
if errStream == nil {
return chunks, nil
}
lastErr = errStream
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, rotated, req.Model, maxWait)
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, normalized, req.Model, maxWait)
if !shouldRetry {
break
}
@@ -378,6 +369,167 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
lastErr = errExec
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
}
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
lastErr = errExec
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
}
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
if len(providers) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return nil, lastErr
}
return nil, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
rerr := &Error{Message: errStream.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errStream, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(execCtx, result)
lastErr = errStream
continue
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out)
var failed bool
for chunk := range streamChunks {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
var se cliproxyexecutor.StatusError
if errors.As(chunk.Err, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
}
out <- chunk
}
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
}
}(execCtx, auth.Clone(), provider, chunks)
return out, nil
}
}
func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if provider == "" {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
@@ -1191,6 +1343,77 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
return authCopy, executor, nil
}
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
providerSet := make(map[string]struct{}, len(providers))
for _, provider := range providers {
p := strings.TrimSpace(strings.ToLower(provider))
if p == "" {
continue
}
providerSet[p] = struct{}{}
}
if len(providerSet) == 0 {
return nil, nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
m.mu.RLock()
candidates := make([]*Auth, 0, len(m.auths))
modelKey := strings.TrimSpace(model)
registryRef := registry.GetGlobalRegistry()
for _, candidate := range m.auths {
if candidate == nil || candidate.Disabled {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
if providerKey == "" {
continue
}
if _, ok := providerSet[providerKey]; !ok {
continue
}
if _, used := tried[candidate.ID]; used {
continue
}
if _, ok := m.executors[providerKey]; !ok {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
continue
}
candidates = append(candidates, candidate)
}
if len(candidates) == 0 {
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, errPick := m.selector.Pick(ctx, "mixed", model, opts, candidates)
if errPick != nil {
m.mu.RUnlock()
return nil, nil, "", errPick
}
if selected == nil {
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
providerKey := strings.TrimSpace(strings.ToLower(selected.Provider))
executor, okExecutor := m.executors[providerKey]
if !okExecutor {
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
}
authCopy := selected.Clone()
m.mu.RUnlock()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, providerKey, nil
}
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
if m.store == nil || auth == nil {
return nil

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"sort"
"strconv"
"strings"
"sync"
"time"
@@ -103,13 +104,29 @@ func (e *modelCooldownError) Headers() http.Header {
return headers
}
func collectAvailable(auths []*Auth, model string, now time.Time) (available []*Auth, cooldownCount int, earliest time.Time) {
available = make([]*Auth, 0, len(auths))
func authPriority(auth *Auth) int {
if auth == nil || auth.Attributes == nil {
return 0
}
raw := strings.TrimSpace(auth.Attributes["priority"])
if raw == "" {
return 0
}
parsed, err := strconv.Atoi(raw)
if err != nil {
return 0
}
return parsed
}
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
available = make(map[int][]*Auth)
for i := 0; i < len(auths); i++ {
candidate := auths[i]
blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
if !blocked {
available = append(available, candidate)
priority := authPriority(candidate)
available[priority] = append(available[priority], candidate)
continue
}
if reason == blockReasonCooldown {
@@ -119,9 +136,6 @@ func collectAvailable(auths []*Auth, model string, now time.Time) (available []*
}
}
}
if len(available) > 1 {
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
}
return available, cooldownCount, earliest
}
@@ -130,18 +144,35 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
}
available, cooldownCount, earliest := collectAvailable(auths, model, now)
if len(available) == 0 {
availableByPriority, cooldownCount, earliest := collectAvailableByPriority(auths, model, now)
if len(availableByPriority) == 0 {
if cooldownCount == len(auths) && !earliest.IsZero() {
providerForError := provider
if providerForError == "mixed" {
providerForError = ""
}
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return nil, newModelCooldownError(model, provider, resetIn)
return nil, newModelCooldownError(model, providerForError, resetIn)
}
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
}
bestPriority := 0
found := false
for priority := range availableByPriority {
if !found || priority > bestPriority {
bestPriority = priority
found = true
}
}
available := availableByPriority[bestPriority]
if len(available) > 1 {
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
}
return available, nil
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"sync"
"testing"
"time"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
@@ -56,6 +57,69 @@ func TestRoundRobinSelectorPick_CyclesDeterministic(t *testing.T) {
}
}
func TestRoundRobinSelectorPick_PriorityBuckets(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
auths := []*Auth{
{ID: "c", Attributes: map[string]string{"priority": "0"}},
{ID: "a", Attributes: map[string]string{"priority": "10"}},
{ID: "b", Attributes: map[string]string{"priority": "10"}},
}
want := []string{"a", "b", "a", "b"}
for i, id := range want {
got, err := selector.Pick(context.Background(), "mixed", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != id {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id)
}
if got.ID == "c" {
t.Fatalf("Pick() #%d unexpectedly selected lower priority auth", i)
}
}
}
func TestFillFirstSelectorPick_PriorityFallbackCooldown(t *testing.T) {
t.Parallel()
selector := &FillFirstSelector{}
now := time.Now()
model := "test-model"
high := &Auth{
ID: "high",
Attributes: map[string]string{"priority": "10"},
ModelStates: map[string]*ModelState{
model: {
Status: StatusActive,
Unavailable: true,
NextRetryAfter: now.Add(30 * time.Minute),
Quota: QuotaState{
Exceeded: true,
},
},
},
}
low := &Auth{ID: "low", Attributes: map[string]string{"priority": "0"}}
got, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, []*Auth{high, low})
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if got == nil {
t.Fatalf("Pick() auth = nil")
}
if got.ID != "low" {
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low")
}
}
func TestRoundRobinSelectorPick_Concurrent(t *testing.T) {
selector := &RoundRobinSelector{}
auths := []*Auth{