diff --git a/internal/auth/kimi/kimi.go b/internal/auth/kimi/kimi.go index 86052277..8427a057 100644 --- a/internal/auth/kimi/kimi.go +++ b/internal/auth/kimi/kimi.go @@ -30,7 +30,7 @@ const ( // kimiTokenURL is the endpoint for exchanging device codes for tokens. kimiTokenURL = kimiOAuthHost + "/api/oauth/token" // KimiAPIBaseURL is the base URL for Kimi API requests. - KimiAPIBaseURL = "https://api.kimi.com/coding/v1" + KimiAPIBaseURL = "https://api.kimi.com/coding" // defaultPollInterval is the default interval for polling token endpoint. defaultPollInterval = 5 * time.Second // maxPollDuration is the maximum time to wait for user authorization. diff --git a/internal/runtime/executor/kimi_executor.go b/internal/runtime/executor/kimi_executor.go index 1cc66341..94a78331 100644 --- a/internal/runtime/executor/kimi_executor.go +++ b/internal/runtime/executor/kimi_executor.go @@ -25,6 +25,7 @@ import ( // KimiExecutor is a stateless executor for Kimi API using OpenAI-compatible chat completions. type KimiExecutor struct { + ClaudeExecutor cfg *config.Config } @@ -64,6 +65,12 @@ func (e *KimiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, // Execute performs a non-streaming chat completion request to Kimi. func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + from := opts.SourceFormat + if from.String() == "claude" { + auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL + return e.ClaudeExecutor.Execute(ctx, auth, req, opts) + } + baseModel := thinking.ParseSuffix(req.Model).ModelName token := kimiCreds(auth) @@ -71,7 +78,6 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) - from := opts.SourceFormat to := sdktranslator.FromString("openai") originalPayload := bytes.Clone(req.Payload) if len(opts.OriginalRequest) > 0 { @@ -95,7 +101,7 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req requestedModel := payloadRequestedModel(opts, req.Model) body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) - url := kimiauth.KimiAPIBaseURL + "/chat/completions" + url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions" httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { return resp, err @@ -155,14 +161,18 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req // ExecuteStream performs a streaming chat completion request to Kimi. func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { - baseModel := thinking.ParseSuffix(req.Model).ModelName + from := opts.SourceFormat + if from.String() == "claude" { + auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL + return e.ClaudeExecutor.ExecuteStream(ctx, auth, req, opts) + } + baseModel := thinking.ParseSuffix(req.Model).ModelName token := kimiCreds(auth) reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) - from := opts.SourceFormat to := sdktranslator.FromString("openai") originalPayload := bytes.Clone(req.Payload) if len(opts.OriginalRequest) > 0 { @@ -190,7 +200,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut requestedModel := payloadRequestedModel(opts, req.Model) body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) - url := kimiauth.KimiAPIBaseURL + "/chat/completions" + url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions" httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { return nil, err @@ -269,26 +279,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut // CountTokens estimates token count for Kimi requests. func (e *KimiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - baseModel := thinking.ParseSuffix(req.Model).ModelName - - from := opts.SourceFormat - to := sdktranslator.FromString("openai") - body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) - - // Use a generic tokenizer for estimation - enc, err := tokenizerForModel("gpt-4") - if err != nil { - return cliproxyexecutor.Response{}, fmt.Errorf("kimi executor: tokenizer init failed: %w", err) - } - - count, err := countOpenAIChatTokens(enc, body) - if err != nil { - return cliproxyexecutor.Response{}, fmt.Errorf("kimi executor: token counting failed: %w", err) - } - - usageJSON := buildOpenAIUsageJSON(count) - translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON) - return cliproxyexecutor.Response{Payload: []byte(translated)}, nil + auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL + return e.ClaudeExecutor.CountTokens(ctx, auth, req, opts) } // Refresh refreshes the Kimi token using the refresh token. diff --git a/internal/store/gitstore.go b/internal/store/gitstore.go index 3b68e4b0..c8db660c 100644 --- a/internal/store/gitstore.go +++ b/internal/store/gitstore.go @@ -21,6 +21,9 @@ import ( cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" ) +// gcInterval defines minimum time between garbage collection runs. +const gcInterval = 5 * time.Minute + // GitTokenStore persists token records and auth metadata using git as the backing storage. type GitTokenStore struct { mu sync.Mutex @@ -31,6 +34,7 @@ type GitTokenStore struct { remote string username string password string + lastGC time.Time } // NewGitTokenStore creates a token store that saves credentials to disk through the @@ -613,6 +617,7 @@ func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string) } else if errRewrite := s.rewriteHeadAsSingleCommit(repo, headRef.Name(), commitHash, message, signature); errRewrite != nil { return errRewrite } + s.maybeRunGC(repo) if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil { if errors.Is(err, git.NoErrAlreadyUpToDate) { return nil @@ -652,6 +657,23 @@ func (s *GitTokenStore) rewriteHeadAsSingleCommit(repo *git.Repository, branch p return nil } +func (s *GitTokenStore) maybeRunGC(repo *git.Repository) { + now := time.Now() + if now.Sub(s.lastGC) < gcInterval { + return + } + s.lastGC = now + + pruneOpts := git.PruneOptions{ + OnlyObjectsOlderThan: now, + Handler: repo.DeleteObject, + } + if err := repo.Prune(pruneOpts); err != nil && !errors.Is(err, git.ErrLooseObjectsNotSupported) { + return + } + _ = repo.RepackObjects(&git.RepackConfig{}) +} + // PersistConfig commits and pushes configuration changes to git. func (s *GitTokenStore) PersistConfig(_ context.Context) error { if err := s.EnsureRepository(); err != nil { diff --git a/internal/translator/codex/claude/codex_claude_response.go b/internal/translator/codex/claude/codex_claude_response.go index 238d3e24..b39494b7 100644 --- a/internal/translator/codex/claude/codex_claude_response.go +++ b/internal/translator/codex/claude/codex_claude_response.go @@ -113,10 +113,10 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall stopReason := rootResult.Get("response.stop_reason").String() - if stopReason != "" { - template, _ = sjson.Set(template, "delta.stop_reason", stopReason) - } else if p { + if p { template, _ = sjson.Set(template, "delta.stop_reason", "tool_use") + } else if stopReason == "max_tokens" || stopReason == "stop" { + template, _ = sjson.Set(template, "delta.stop_reason", stopReason) } else { template, _ = sjson.Set(template, "delta.stop_reason", "end_turn") } diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go index 97c18c1e..4867085e 100644 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go +++ b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go @@ -78,11 +78,16 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ template, _ = sjson.Set(template, "id", responseIDResult.String()) } - // Extract and set the finish reason. - if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() { - template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String())) - template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String())) + finishReason := "" + if stopReasonResult := gjson.GetBytes(rawJSON, "response.stop_reason"); stopReasonResult.Exists() { + finishReason = stopReasonResult.String() } + if finishReason == "" { + if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() { + finishReason = finishReasonResult.String() + } + } + finishReason = strings.ToLower(finishReason) // Extract and set usage metadata (token counts). if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() { @@ -197,6 +202,12 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ if hasFunctionCall { template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls") template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls") + } else if finishReason != "" && (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex == 0 { + // Only pass through specific finish reasons + if finishReason == "max_tokens" || finishReason == "stop" { + template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason) + template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason) + } } return []string{template} diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go index 9cce35f9..ee581c46 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go @@ -129,11 +129,16 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR candidateIndex := int(candidate.Get("index").Int()) template, _ = sjson.Set(template, "choices.0.index", candidateIndex) - // Extract and set the finish reason. - if finishReasonResult := candidate.Get("finishReason"); finishReasonResult.Exists() { - template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String())) - template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String())) + finishReason := "" + if stopReasonResult := gjson.GetBytes(rawJSON, "stop_reason"); stopReasonResult.Exists() { + finishReason = stopReasonResult.String() } + if finishReason == "" { + if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() { + finishReason = finishReasonResult.String() + } + } + finishReason = strings.ToLower(finishReason) partsResult := candidate.Get("content.parts") hasFunctionCall := false @@ -225,6 +230,12 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR if hasFunctionCall { template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls") template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls") + } else if finishReason != "" { + // Only pass through specific finish reasons + if finishReason == "max_tokens" || finishReason == "stop" { + template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason) + template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason) + } } responseStrings = append(responseStrings, template) diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index f6d0806d..38dadc7c 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -607,6 +607,9 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req result.RetryAfter = ra } m.MarkResult(execCtx, result) + if isRequestInvalidError(errExec) { + return cliproxyexecutor.Response{}, errExec + } lastErr = errExec continue } @@ -660,6 +663,9 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, result.RetryAfter = ra } m.MarkResult(execCtx, result) + if isRequestInvalidError(errExec) { + return cliproxyexecutor.Response{}, errExec + } lastErr = errExec continue } @@ -711,6 +717,9 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} result.RetryAfter = retryAfterFromError(errStream) m.MarkResult(execCtx, result) + if isRequestInvalidError(errStream) { + return nil, errStream + } lastErr = errStream continue } @@ -1110,6 +1119,9 @@ func (m *Manager) shouldRetryAfterError(err error, attempt int, providers []stri if status := statusCodeFromError(err); status == http.StatusOK { return 0, false } + if isRequestInvalidError(err) { + return 0, false + } wait, found := m.closestCooldownWait(providers, model, attempt) if !found || wait > maxWait { return 0, false @@ -1299,7 +1311,7 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) { stateUnavailable = true } else if state.Unavailable { if state.NextRetryAfter.IsZero() { - stateUnavailable = true + stateUnavailable = false } else if state.NextRetryAfter.After(now) { stateUnavailable = true if earliestRetry.IsZero() || state.NextRetryAfter.Before(earliestRetry) { @@ -1430,6 +1442,21 @@ func statusCodeFromResult(err *Error) int { return err.StatusCode() } +// isRequestInvalidError returns true if the error represents a client request +// error that should not be retried. Specifically, it checks for 400 Bad Request +// with "invalid_request_error" in the message, indicating the request itself is +// malformed and switching to a different auth will not help. +func isRequestInvalidError(err error) bool { + if err == nil { + return false + } + status := statusCodeFromError(err) + if status != http.StatusBadRequest { + return false + } + return strings.Contains(err.Error(), "invalid_request_error") +} + func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) { if auth == nil { return diff --git a/sdk/cliproxy/auth/conductor_availability_test.go b/sdk/cliproxy/auth/conductor_availability_test.go new file mode 100644 index 00000000..87caa267 --- /dev/null +++ b/sdk/cliproxy/auth/conductor_availability_test.go @@ -0,0 +1,62 @@ +package auth + +import ( + "testing" + "time" +) + +func TestUpdateAggregatedAvailability_UnavailableWithoutNextRetryDoesNotBlockAuth(t *testing.T) { + t.Parallel() + + now := time.Now() + model := "test-model" + auth := &Auth{ + ID: "a", + ModelStates: map[string]*ModelState{ + model: { + Status: StatusError, + Unavailable: true, + }, + }, + } + + updateAggregatedAvailability(auth, now) + + if auth.Unavailable { + t.Fatalf("auth.Unavailable = true, want false") + } + if !auth.NextRetryAfter.IsZero() { + t.Fatalf("auth.NextRetryAfter = %v, want zero", auth.NextRetryAfter) + } +} + +func TestUpdateAggregatedAvailability_FutureNextRetryBlocksAuth(t *testing.T) { + t.Parallel() + + now := time.Now() + model := "test-model" + next := now.Add(5 * time.Minute) + auth := &Auth{ + ID: "a", + ModelStates: map[string]*ModelState{ + model: { + Status: StatusError, + Unavailable: true, + NextRetryAfter: next, + }, + }, + } + + updateAggregatedAvailability(auth, now) + + if !auth.Unavailable { + t.Fatalf("auth.Unavailable = false, want true") + } + if auth.NextRetryAfter.IsZero() { + t.Fatalf("auth.NextRetryAfter = zero, want %v", next) + } + if auth.NextRetryAfter.Sub(next) > time.Second || next.Sub(auth.NextRetryAfter) > time.Second { + t.Fatalf("auth.NextRetryAfter = %v, want %v", auth.NextRetryAfter, next) + } +} + diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index 7febf219..28500881 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" ) @@ -19,6 +20,7 @@ import ( type RoundRobinSelector struct { mu sync.Mutex cursors map[string]int + maxKeys int } // FillFirstSelector selects the first available credential (deterministic ordering). @@ -119,6 +121,19 @@ func authPriority(auth *Auth) int { return parsed } +func canonicalModelKey(model string) string { + model = strings.TrimSpace(model) + if model == "" { + return "" + } + parsed := thinking.ParseSuffix(model) + modelName := strings.TrimSpace(parsed.ModelName) + if modelName == "" { + return model + } + return modelName +} + func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) { available = make(map[int][]*Auth) for i := 0; i < len(auths); i++ { @@ -185,11 +200,18 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o if err != nil { return nil, err } - key := provider + ":" + model + key := provider + ":" + canonicalModelKey(model) s.mu.Lock() if s.cursors == nil { s.cursors = make(map[string]int) } + limit := s.maxKeys + if limit <= 0 { + limit = 4096 + } + if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit { + s.cursors = make(map[string]int) + } index := s.cursors[key] if index >= 2_147_483_640 { @@ -223,7 +245,14 @@ func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, block } if model != "" { if len(auth.ModelStates) > 0 { - if state, ok := auth.ModelStates[model]; ok && state != nil { + state, ok := auth.ModelStates[model] + if (!ok || state == nil) && model != "" { + baseModel := canonicalModelKey(model) + if baseModel != "" && baseModel != model { + state, ok = auth.ModelStates[baseModel] + } + } + if ok && state != nil { if state.Status == StatusDisabled { return true, blockReasonDisabled, time.Time{} } diff --git a/sdk/cliproxy/auth/selector_test.go b/sdk/cliproxy/auth/selector_test.go index 91a7ed14..fe1cf15e 100644 --- a/sdk/cliproxy/auth/selector_test.go +++ b/sdk/cliproxy/auth/selector_test.go @@ -2,7 +2,9 @@ package auth import ( "context" + "encoding/json" "errors" + "net/http" "sync" "testing" "time" @@ -175,3 +177,228 @@ func TestRoundRobinSelectorPick_Concurrent(t *testing.T) { default: } } + +func TestSelectorPick_AllCooldownReturnsModelCooldownError(t *testing.T) { + t.Parallel() + + model := "test-model" + now := time.Now() + next := now.Add(60 * time.Second) + auths := []*Auth{ + { + ID: "a", + ModelStates: map[string]*ModelState{ + model: { + Status: StatusActive, + Unavailable: true, + NextRetryAfter: next, + Quota: QuotaState{ + Exceeded: true, + NextRecoverAt: next, + }, + }, + }, + }, + { + ID: "b", + ModelStates: map[string]*ModelState{ + model: { + Status: StatusActive, + Unavailable: true, + NextRetryAfter: next, + Quota: QuotaState{ + Exceeded: true, + NextRecoverAt: next, + }, + }, + }, + }, + } + + t.Run("mixed provider redacts provider field", func(t *testing.T) { + t.Parallel() + + selector := &FillFirstSelector{} + _, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, auths) + if err == nil { + t.Fatalf("Pick() error = nil") + } + + var mce *modelCooldownError + if !errors.As(err, &mce) { + t.Fatalf("Pick() error = %T, want *modelCooldownError", err) + } + if mce.StatusCode() != http.StatusTooManyRequests { + t.Fatalf("StatusCode() = %d, want %d", mce.StatusCode(), http.StatusTooManyRequests) + } + + headers := mce.Headers() + if got := headers.Get("Retry-After"); got == "" { + t.Fatalf("Headers().Get(Retry-After) = empty") + } + + var payload map[string]any + if err := json.Unmarshal([]byte(mce.Error()), &payload); err != nil { + t.Fatalf("json.Unmarshal(Error()) error = %v", err) + } + rawErr, ok := payload["error"].(map[string]any) + if !ok { + t.Fatalf("Error() payload missing error object: %v", payload) + } + if got, _ := rawErr["code"].(string); got != "model_cooldown" { + t.Fatalf("Error().error.code = %q, want %q", got, "model_cooldown") + } + if _, ok := rawErr["provider"]; ok { + t.Fatalf("Error().error.provider exists for mixed provider: %v", rawErr["provider"]) + } + }) + + t.Run("non-mixed provider includes provider field", func(t *testing.T) { + t.Parallel() + + selector := &FillFirstSelector{} + _, err := selector.Pick(context.Background(), "gemini", model, cliproxyexecutor.Options{}, auths) + if err == nil { + t.Fatalf("Pick() error = nil") + } + + var mce *modelCooldownError + if !errors.As(err, &mce) { + t.Fatalf("Pick() error = %T, want *modelCooldownError", err) + } + + var payload map[string]any + if err := json.Unmarshal([]byte(mce.Error()), &payload); err != nil { + t.Fatalf("json.Unmarshal(Error()) error = %v", err) + } + rawErr, ok := payload["error"].(map[string]any) + if !ok { + t.Fatalf("Error() payload missing error object: %v", payload) + } + if got, _ := rawErr["provider"].(string); got != "gemini" { + t.Fatalf("Error().error.provider = %q, want %q", got, "gemini") + } + }) +} + +func TestIsAuthBlockedForModel_UnavailableWithoutNextRetryIsNotBlocked(t *testing.T) { + t.Parallel() + + now := time.Now() + model := "test-model" + auth := &Auth{ + ID: "a", + ModelStates: map[string]*ModelState{ + model: { + Status: StatusActive, + Unavailable: true, + Quota: QuotaState{ + Exceeded: true, + }, + }, + }, + } + + blocked, reason, next := isAuthBlockedForModel(auth, model, now) + if blocked { + t.Fatalf("blocked = true, want false") + } + if reason != blockReasonNone { + t.Fatalf("reason = %v, want %v", reason, blockReasonNone) + } + if !next.IsZero() { + t.Fatalf("next = %v, want zero", next) + } +} + +func TestFillFirstSelectorPick_ThinkingSuffixFallsBackToBaseModelState(t *testing.T) { + t.Parallel() + + selector := &FillFirstSelector{} + now := time.Now() + + baseModel := "test-model" + requestedModel := "test-model(high)" + + high := &Auth{ + ID: "high", + Attributes: map[string]string{"priority": "10"}, + ModelStates: map[string]*ModelState{ + baseModel: { + Status: StatusActive, + Unavailable: true, + NextRetryAfter: now.Add(30 * time.Minute), + Quota: QuotaState{ + Exceeded: true, + }, + }, + }, + } + low := &Auth{ + ID: "low", + Attributes: map[string]string{"priority": "0"}, + } + + got, err := selector.Pick(context.Background(), "mixed", requestedModel, cliproxyexecutor.Options{}, []*Auth{high, low}) + if err != nil { + t.Fatalf("Pick() error = %v", err) + } + if got == nil { + t.Fatalf("Pick() auth = nil") + } + if got.ID != "low" { + t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low") + } +} + +func TestRoundRobinSelectorPick_ThinkingSuffixSharesCursor(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{} + auths := []*Auth{ + {ID: "b"}, + {ID: "a"}, + } + + first, err := selector.Pick(context.Background(), "gemini", "test-model(high)", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() first error = %v", err) + } + second, err := selector.Pick(context.Background(), "gemini", "test-model(low)", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() second error = %v", err) + } + if first == nil || second == nil { + t.Fatalf("Pick() returned nil auth") + } + if first.ID != "a" { + t.Fatalf("Pick() first auth.ID = %q, want %q", first.ID, "a") + } + if second.ID != "b" { + t.Fatalf("Pick() second auth.ID = %q, want %q", second.ID, "b") + } +} + +func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{maxKeys: 2} + auths := []*Auth{{ID: "a"}} + + _, _ = selector.Pick(context.Background(), "gemini", "m1", cliproxyexecutor.Options{}, auths) + _, _ = selector.Pick(context.Background(), "gemini", "m2", cliproxyexecutor.Options{}, auths) + _, _ = selector.Pick(context.Background(), "gemini", "m3", cliproxyexecutor.Options{}, auths) + + selector.mu.Lock() + defer selector.mu.Unlock() + + if selector.cursors == nil { + t.Fatalf("selector.cursors = nil") + } + if len(selector.cursors) != 1 { + t.Fatalf("len(selector.cursors) = %d, want %d", len(selector.cursors), 1) + } + if _, ok := selector.cursors["gemini:m3"]; !ok { + t.Fatalf("selector.cursors missing key %q", "gemini:m3") + } +}