diff --git a/config.example.yaml b/config.example.yaml index 40bb8721..348aabd8 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -187,6 +187,17 @@ nonstream-keepalive-interval: 0 # models: # The models supported by the provider. # - name: "moonshotai/kimi-k2:free" # The actual model name. # alias: "kimi-k2" # The alias used in the API. +# # You may repeat the same alias to build an internal model pool. +# # The client still sees only one alias in the model list. +# # Requests to that alias will round-robin across the upstream names below, +# # and if the chosen upstream fails before producing output, the request will +# # continue with the next upstream model in the same alias pool. +# - name: "qwen3.5-plus" +# alias: "claude-opus-4.66" +# - name: "glm-5" +# alias: "claude-opus-4.66" +# - name: "kimi-k2.5" +# alias: "claude-opus-4.66" # Vertex API keys (Vertex-compatible endpoints, use API key + base URL) # vertex-api-key: diff --git a/internal/thinking/apply.go b/internal/thinking/apply.go index b8a0fcae..c79ecd8e 100644 --- a/internal/thinking/apply.go +++ b/internal/thinking/apply.go @@ -257,7 +257,10 @@ func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromForma if suffixResult.HasSuffix { config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID) } else { - config = extractThinkingConfig(body, toFormat) + config = extractThinkingConfig(body, fromFormat) + if !hasThinkingConfig(config) && fromFormat != toFormat { + config = extractThinkingConfig(body, toFormat) + } } if !hasThinkingConfig(config) { @@ -293,6 +296,9 @@ func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat stri if config.Mode != ModeLevel { return config } + if toFormat == "claude" { + return config + } if !isBudgetCapableProvider(toFormat) { return config } diff --git a/internal/thinking/apply_user_defined_test.go b/internal/thinking/apply_user_defined_test.go new file mode 100644 index 00000000..aa24ab8e --- /dev/null +++ b/internal/thinking/apply_user_defined_test.go @@ -0,0 +1,55 @@ +package thinking_test + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + _ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude" + "github.com/tidwall/gjson" +) + +func TestApplyThinking_UserDefinedClaudePreservesAdaptiveLevel(t *testing.T) { + reg := registry.GetGlobalRegistry() + clientID := "test-user-defined-claude-" + t.Name() + modelID := "custom-claude-4-6" + reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ID: modelID, UserDefined: true}}) + t.Cleanup(func() { + reg.UnregisterClient(clientID) + }) + + tests := []struct { + name string + model string + body []byte + }{ + { + name: "claude adaptive effort body", + model: modelID, + body: []byte(`{"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`), + }, + { + name: "suffix level", + model: modelID + "(high)", + body: []byte(`{}`), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := thinking.ApplyThinking(tt.body, tt.model, "openai", "claude", "claude") + if err != nil { + t.Fatalf("ApplyThinking() error = %v", err) + } + if got := gjson.GetBytes(out, "thinking.type").String(); got != "adaptive" { + t.Fatalf("thinking.type = %q, want %q, body=%s", got, "adaptive", string(out)) + } + if got := gjson.GetBytes(out, "output_config.effort").String(); got != "high" { + t.Fatalf("output_config.effort = %q, want %q, body=%s", got, "high", string(out)) + } + if gjson.GetBytes(out, "thinking.budget_tokens").Exists() { + t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out)) + } + }) + } +} diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index ae5b745c..e31f3300 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -149,6 +149,9 @@ type Manager struct { // Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix). apiKeyModelAlias atomic.Value + // modelPoolOffsets tracks per-auth alias pool rotation state. + modelPoolOffsets map[string]int + // runtimeConfig stores the latest application config for request-time decisions. // It is initialized in NewManager; never Load() before first Store(). runtimeConfig atomic.Value @@ -176,6 +179,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { hook: hook, auths: make(map[string]*Auth), providerOffsets: make(map[string]int), + modelPoolOffsets: make(map[string]int), refreshSemaphore: make(chan struct{}, refreshMaxConcurrency), } // atomic.Value requires non-nil initial value. @@ -251,16 +255,323 @@ func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) strin if resolved == "" { return "" } - // Preserve thinking suffix from the client's requested model unless config already has one. - requestResult := thinking.ParseSuffix(requestedModel) - if thinking.ParseSuffix(resolved).HasSuffix { - return resolved - } - if requestResult.HasSuffix && requestResult.RawSuffix != "" { - return resolved + "(" + requestResult.RawSuffix + ")" - } - return resolved + return preserveRequestedModelSuffix(requestedModel, resolved) +} +func isAPIKeyAuth(auth *Auth) bool { + if auth == nil { + return false + } + kind, _ := auth.AccountInfo() + return strings.EqualFold(strings.TrimSpace(kind), "api_key") +} + +func isOpenAICompatAPIKeyAuth(auth *Auth) bool { + if !isAPIKeyAuth(auth) { + return false + } + if strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") { + return true + } + if auth.Attributes == nil { + return false + } + return strings.TrimSpace(auth.Attributes["compat_name"]) != "" +} + +func openAICompatProviderKey(auth *Auth) string { + if auth == nil { + return "" + } + if auth.Attributes != nil { + if providerKey := strings.TrimSpace(auth.Attributes["provider_key"]); providerKey != "" { + return strings.ToLower(providerKey) + } + if compatName := strings.TrimSpace(auth.Attributes["compat_name"]); compatName != "" { + return strings.ToLower(compatName) + } + } + return strings.ToLower(strings.TrimSpace(auth.Provider)) +} + +func openAICompatModelPoolKey(auth *Auth, requestedModel string) string { + base := strings.TrimSpace(thinking.ParseSuffix(requestedModel).ModelName) + if base == "" { + base = strings.TrimSpace(requestedModel) + } + return strings.ToLower(strings.TrimSpace(auth.ID)) + "|" + openAICompatProviderKey(auth) + "|" + strings.ToLower(base) +} + +func (m *Manager) nextModelPoolOffset(key string, size int) int { + if m == nil || size <= 1 { + return 0 + } + key = strings.TrimSpace(key) + if key == "" { + return 0 + } + m.mu.Lock() + defer m.mu.Unlock() + if m.modelPoolOffsets == nil { + m.modelPoolOffsets = make(map[string]int) + } + offset := m.modelPoolOffsets[key] + if offset >= 2_147_483_640 { + offset = 0 + } + m.modelPoolOffsets[key] = offset + 1 + if size <= 0 { + return 0 + } + return offset % size +} + +func rotateStrings(values []string, offset int) []string { + if len(values) <= 1 { + return values + } + if offset <= 0 { + out := make([]string, len(values)) + copy(out, values) + return out + } + offset = offset % len(values) + out := make([]string, 0, len(values)) + out = append(out, values[offset:]...) + out = append(out, values[:offset]...) + return out +} + +func (m *Manager) resolveOpenAICompatUpstreamModelPool(auth *Auth, requestedModel string) []string { + if m == nil || !isOpenAICompatAPIKeyAuth(auth) { + return nil + } + requestedModel = strings.TrimSpace(requestedModel) + if requestedModel == "" { + return nil + } + cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config) + if cfg == nil { + cfg = &internalconfig.Config{} + } + providerKey := "" + compatName := "" + if auth.Attributes != nil { + providerKey = strings.TrimSpace(auth.Attributes["provider_key"]) + compatName = strings.TrimSpace(auth.Attributes["compat_name"]) + } + entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider) + if entry == nil { + return nil + } + return resolveModelAliasPoolFromConfigModels(requestedModel, asModelAliasEntries(entry.Models)) +} + +func preserveRequestedModelSuffix(requestedModel, resolved string) string { + return preserveResolvedModelSuffix(resolved, thinking.ParseSuffix(requestedModel)) +} + +func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string { + return m.prepareExecutionModels(auth, routeModel) +} + +func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string { + requestedModel := rewriteModelForAuth(routeModel, auth) + requestedModel = m.applyOAuthModelAlias(auth, requestedModel) + if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 { + if len(pool) == 1 { + return pool + } + offset := m.nextModelPoolOffset(openAICompatModelPoolKey(auth, requestedModel), len(pool)) + return rotateStrings(pool, offset) + } + resolved := m.applyAPIKeyModelAlias(auth, requestedModel) + if strings.TrimSpace(resolved) == "" { + resolved = requestedModel + } + return []string{resolved} +} + +func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) { + if ch == nil { + return + } + go func() { + for range ch { + } + }() +} + +func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamChunk) ([]cliproxyexecutor.StreamChunk, bool, error) { + if ch == nil { + return nil, true, nil + } + buffered := make([]cliproxyexecutor.StreamChunk, 0, 1) + for { + var ( + chunk cliproxyexecutor.StreamChunk + ok bool + ) + if ctx != nil { + select { + case <-ctx.Done(): + return nil, false, ctx.Err() + case chunk, ok = <-ch: + } + } else { + chunk, ok = <-ch + } + if !ok { + return buffered, true, nil + } + if chunk.Err != nil { + return nil, false, chunk.Err + } + buffered = append(buffered, chunk) + if len(chunk.Payload) > 0 { + return buffered, false, nil + } + } +} + +func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult { + out := make(chan cliproxyexecutor.StreamChunk) + go func() { + defer close(out) + var failed bool + forward := true + emit := func(chunk cliproxyexecutor.StreamChunk) bool { + if chunk.Err != nil && !failed { + failed = true + rerr := &Error{Message: chunk.Err.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil { + rerr.HTTPStatus = se.StatusCode() + } + m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}) + } + if !forward { + return false + } + if ctx == nil { + out <- chunk + return true + } + select { + case <-ctx.Done(): + forward = false + return false + case out <- chunk: + return true + } + } + for _, chunk := range buffered { + if ok := emit(chunk); !ok { + discardStreamChunks(remaining) + return + } + } + for chunk := range remaining { + if ok := emit(chunk); !ok { + discardStreamChunks(remaining) + return + } + } + if !failed { + m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: true}) + } + }() + return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out} +} + +func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) { + if executor == nil { + return nil, &Error{Code: "executor_not_found", Message: "executor not registered"} + } + execModels := m.prepareExecutionModels(auth, routeModel) + var lastErr error + for idx, execModel := range execModels { + execReq := req + execReq.Model = execModel + streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts) + if errStream != nil { + if errCtx := ctx.Err(); errCtx != nil { + return nil, errCtx + } + rerr := &Error{Message: errStream.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil { + rerr.HTTPStatus = se.StatusCode() + } + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} + result.RetryAfter = retryAfterFromError(errStream) + m.MarkResult(ctx, result) + if isRequestInvalidError(errStream) { + return nil, errStream + } + lastErr = errStream + continue + } + + buffered, closed, bootstrapErr := readStreamBootstrap(ctx, streamResult.Chunks) + if bootstrapErr != nil { + if errCtx := ctx.Err(); errCtx != nil { + discardStreamChunks(streamResult.Chunks) + return nil, errCtx + } + if isRequestInvalidError(bootstrapErr) { + rerr := &Error{Message: bootstrapErr.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil { + rerr.HTTPStatus = se.StatusCode() + } + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} + result.RetryAfter = retryAfterFromError(bootstrapErr) + m.MarkResult(ctx, result) + discardStreamChunks(streamResult.Chunks) + return nil, bootstrapErr + } + if idx < len(execModels)-1 { + rerr := &Error{Message: bootstrapErr.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil { + rerr.HTTPStatus = se.StatusCode() + } + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} + result.RetryAfter = retryAfterFromError(bootstrapErr) + m.MarkResult(ctx, result) + discardStreamChunks(streamResult.Chunks) + lastErr = bootstrapErr + continue + } + errCh := make(chan cliproxyexecutor.StreamChunk, 1) + errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr} + close(errCh) + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil + } + + if closed && len(buffered) == 0 { + emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true} + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: emptyErr} + m.MarkResult(ctx, result) + if idx < len(execModels)-1 { + lastErr = emptyErr + continue + } + errCh := make(chan cliproxyexecutor.StreamChunk, 1) + errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr} + close(errCh) + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil + } + + remaining := streamResult.Chunks + if closed { + closedCh := make(chan cliproxyexecutor.StreamChunk) + close(closedCh) + remaining = closedCh + } + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil + } + if lastErr == nil { + lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"} + } + return nil, lastErr } func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() { @@ -634,32 +945,42 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - execReq := req - execReq.Model = rewriteModelForAuth(routeModel, auth) - execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) - execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) - resp, errExec := executor.Execute(execCtx, auth, execReq, opts) - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} - if errExec != nil { - if errCtx := execCtx.Err(); errCtx != nil { - return cliproxyexecutor.Response{}, errCtx - } - result.Error = &Error{Message: errExec.Error()} - if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil { - result.Error.HTTPStatus = se.StatusCode() - } - if ra := retryAfterFromError(errExec); ra != nil { - result.RetryAfter = ra + + models := m.prepareExecutionModels(auth, routeModel) + var authErr error + for _, upstreamModel := range models { + execReq := req + execReq.Model = upstreamModel + resp, errExec := executor.Execute(execCtx, auth, execReq, opts) + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} + if errExec != nil { + if errCtx := execCtx.Err(); errCtx != nil { + return cliproxyexecutor.Response{}, errCtx + } + result.Error = &Error{Message: errExec.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil { + result.Error.HTTPStatus = se.StatusCode() + } + if ra := retryAfterFromError(errExec); ra != nil { + result.RetryAfter = ra + } + m.MarkResult(execCtx, result) + if isRequestInvalidError(errExec) { + return cliproxyexecutor.Response{}, errExec + } + authErr = errExec + continue } m.MarkResult(execCtx, result) - if isRequestInvalidError(errExec) { - return cliproxyexecutor.Response{}, errExec + return resp, nil + } + if authErr != nil { + if isRequestInvalidError(authErr) { + return cliproxyexecutor.Response{}, authErr } - lastErr = errExec + lastErr = authErr continue } - m.MarkResult(execCtx, result) - return resp, nil } } @@ -696,32 +1017,42 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - execReq := req - execReq.Model = rewriteModelForAuth(routeModel, auth) - execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) - execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) - resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} - if errExec != nil { - if errCtx := execCtx.Err(); errCtx != nil { - return cliproxyexecutor.Response{}, errCtx - } - result.Error = &Error{Message: errExec.Error()} - if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil { - result.Error.HTTPStatus = se.StatusCode() - } - if ra := retryAfterFromError(errExec); ra != nil { - result.RetryAfter = ra + + models := m.prepareExecutionModels(auth, routeModel) + var authErr error + for _, upstreamModel := range models { + execReq := req + execReq.Model = upstreamModel + resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} + if errExec != nil { + if errCtx := execCtx.Err(); errCtx != nil { + return cliproxyexecutor.Response{}, errCtx + } + result.Error = &Error{Message: errExec.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil { + result.Error.HTTPStatus = se.StatusCode() + } + if ra := retryAfterFromError(errExec); ra != nil { + result.RetryAfter = ra + } + m.hook.OnResult(execCtx, result) + if isRequestInvalidError(errExec) { + return cliproxyexecutor.Response{}, errExec + } + authErr = errExec + continue } m.hook.OnResult(execCtx, result) - if isRequestInvalidError(errExec) { - return cliproxyexecutor.Response{}, errExec + return resp, nil + } + if authErr != nil { + if isRequestInvalidError(authErr) { + return cliproxyexecutor.Response{}, authErr } - lastErr = errExec + lastErr = authErr continue } - m.hook.OnResult(execCtx, result) - return resp, nil } } @@ -758,63 +1089,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - execReq := req - execReq.Model = rewriteModelForAuth(routeModel, auth) - execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) - execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) - streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) + streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel) if errStream != nil { if errCtx := execCtx.Err(); errCtx != nil { return nil, errCtx } - rerr := &Error{Message: errStream.Error()} - if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil { - rerr.HTTPStatus = se.StatusCode() - } - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} - result.RetryAfter = retryAfterFromError(errStream) - m.MarkResult(execCtx, result) if isRequestInvalidError(errStream) { return nil, errStream } lastErr = errStream continue } - out := make(chan cliproxyexecutor.StreamChunk) - go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) { - defer close(out) - var failed bool - forward := true - for chunk := range streamChunks { - if chunk.Err != nil && !failed { - failed = true - rerr := &Error{Message: chunk.Err.Error()} - if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil { - rerr.HTTPStatus = se.StatusCode() - } - m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr}) - } - if !forward { - continue - } - if streamCtx == nil { - out <- chunk - continue - } - select { - case <-streamCtx.Done(): - forward = false - case out <- chunk: - } - } - if !failed { - m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true}) - } - }(execCtx, auth.Clone(), provider, streamResult.Chunks) - return &cliproxyexecutor.StreamResult{ - Headers: streamResult.Headers, - Chunks: out, - }, nil + return streamResult, nil } } @@ -1533,18 +1819,22 @@ func statusCodeFromResult(err *Error) int { } // isRequestInvalidError returns true if the error represents a client request -// error that should not be retried. Specifically, it checks for 400 Bad Request -// with "invalid_request_error" in the message, indicating the request itself is -// malformed and switching to a different auth will not help. +// error that should not be retried. Specifically, it treats 400 responses with +// "invalid_request_error" and all 422 responses as request-shape failures, +// where switching auths or pooled upstream models will not help. func isRequestInvalidError(err error) bool { if err == nil { return false } status := statusCodeFromError(err) - if status != http.StatusBadRequest { + switch status { + case http.StatusBadRequest: + return strings.Contains(err.Error(), "invalid_request_error") + case http.StatusUnprocessableEntity: + return true + default: return false } - return strings.Contains(err.Error(), "invalid_request_error") } func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) { diff --git a/sdk/cliproxy/auth/oauth_model_alias.go b/sdk/cliproxy/auth/oauth_model_alias.go index d5d2ff8a..77a11c19 100644 --- a/sdk/cliproxy/auth/oauth_model_alias.go +++ b/sdk/cliproxy/auth/oauth_model_alias.go @@ -80,54 +80,98 @@ func (m *Manager) applyOAuthModelAlias(auth *Auth, requestedModel string) string return upstreamModel } -func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string { +func modelAliasLookupCandidates(requestedModel string) (thinking.SuffixResult, []string) { requestedModel = strings.TrimSpace(requestedModel) if requestedModel == "" { - return "" + return thinking.SuffixResult{}, nil } - if len(models) == 0 { - return "" - } - requestResult := thinking.ParseSuffix(requestedModel) base := requestResult.ModelName + if base == "" { + base = requestedModel + } candidates := []string{base} if base != requestedModel { candidates = append(candidates, requestedModel) } + return requestResult, candidates +} - preserveSuffix := func(resolved string) string { - resolved = strings.TrimSpace(resolved) - if resolved == "" { - return "" - } - if thinking.ParseSuffix(resolved).HasSuffix { - return resolved - } - if requestResult.HasSuffix && requestResult.RawSuffix != "" { - return resolved + "(" + requestResult.RawSuffix + ")" - } +func preserveResolvedModelSuffix(resolved string, requestResult thinking.SuffixResult) string { + resolved = strings.TrimSpace(resolved) + if resolved == "" { + return "" + } + if thinking.ParseSuffix(resolved).HasSuffix { return resolved } + if requestResult.HasSuffix && requestResult.RawSuffix != "" { + return resolved + "(" + requestResult.RawSuffix + ")" + } + return resolved +} +func resolveModelAliasPoolFromConfigModels(requestedModel string, models []modelAliasEntry) []string { + requestedModel = strings.TrimSpace(requestedModel) + if requestedModel == "" { + return nil + } + if len(models) == 0 { + return nil + } + + requestResult, candidates := modelAliasLookupCandidates(requestedModel) + if len(candidates) == 0 { + return nil + } + + out := make([]string, 0) + seen := make(map[string]struct{}) for i := range models { name := strings.TrimSpace(models[i].GetName()) alias := strings.TrimSpace(models[i].GetAlias()) for _, candidate := range candidates { - if candidate == "" { + if candidate == "" || alias == "" || !strings.EqualFold(alias, candidate) { continue } - if alias != "" && strings.EqualFold(alias, candidate) { - if name != "" { - return preserveSuffix(name) - } - return preserveSuffix(candidate) + resolved := candidate + if name != "" { + resolved = name } - if name != "" && strings.EqualFold(name, candidate) { - return preserveSuffix(name) + resolved = preserveResolvedModelSuffix(resolved, requestResult) + key := strings.ToLower(strings.TrimSpace(resolved)) + if key == "" { + break } + if _, exists := seen[key]; exists { + break + } + seen[key] = struct{}{} + out = append(out, resolved) + break } } + if len(out) > 0 { + return out + } + + for i := range models { + name := strings.TrimSpace(models[i].GetName()) + for _, candidate := range candidates { + if candidate == "" || name == "" || !strings.EqualFold(name, candidate) { + continue + } + return []string{preserveResolvedModelSuffix(name, requestResult)} + } + } + return nil +} + +func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string { + resolved := resolveModelAliasPoolFromConfigModels(requestedModel, models) + if len(resolved) > 0 { + return resolved[0] + } return "" } diff --git a/sdk/cliproxy/auth/openai_compat_pool_test.go b/sdk/cliproxy/auth/openai_compat_pool_test.go new file mode 100644 index 00000000..5a5ecb4f --- /dev/null +++ b/sdk/cliproxy/auth/openai_compat_pool_test.go @@ -0,0 +1,419 @@ +package auth + +import ( + "context" + "net/http" + "sync" + "testing" + + internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +type openAICompatPoolExecutor struct { + id string + + mu sync.Mutex + executeModels []string + countModels []string + streamModels []string + executeErrors map[string]error + countErrors map[string]error + streamFirstErrors map[string]error + streamPayloads map[string][]cliproxyexecutor.StreamChunk +} + +func (e *openAICompatPoolExecutor) Identifier() string { return e.id } + +func (e *openAICompatPoolExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + _ = ctx + _ = auth + _ = opts + e.mu.Lock() + e.executeModels = append(e.executeModels, req.Model) + err := e.executeErrors[req.Model] + e.mu.Unlock() + if err != nil { + return cliproxyexecutor.Response{}, err + } + return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil +} + +func (e *openAICompatPoolExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + _ = ctx + _ = auth + _ = opts + e.mu.Lock() + e.streamModels = append(e.streamModels, req.Model) + err := e.streamFirstErrors[req.Model] + payloadChunks, hasCustomChunks := e.streamPayloads[req.Model] + chunks := append([]cliproxyexecutor.StreamChunk(nil), payloadChunks...) + e.mu.Unlock() + ch := make(chan cliproxyexecutor.StreamChunk, max(1, len(chunks))) + if err != nil { + ch <- cliproxyexecutor.StreamChunk{Err: err} + close(ch) + return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil + } + if !hasCustomChunks { + ch <- cliproxyexecutor.StreamChunk{Payload: []byte(req.Model)} + } else { + for _, chunk := range chunks { + ch <- chunk + } + } + close(ch) + return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil +} + +func (e *openAICompatPoolExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) { + return auth, nil +} + +func (e *openAICompatPoolExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + _ = ctx + _ = auth + _ = opts + e.mu.Lock() + e.countModels = append(e.countModels, req.Model) + err := e.countErrors[req.Model] + e.mu.Unlock() + if err != nil { + return cliproxyexecutor.Response{}, err + } + return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil +} + +func (e *openAICompatPoolExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) { + _ = ctx + _ = auth + _ = req + return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"} +} + +func (e *openAICompatPoolExecutor) ExecuteModels() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.executeModels)) + copy(out, e.executeModels) + return out +} + +func (e *openAICompatPoolExecutor) CountModels() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.countModels)) + copy(out, e.countModels) + return out +} + +func (e *openAICompatPoolExecutor) StreamModels() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.streamModels)) + copy(out, e.streamModels) + return out +} + +func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []internalconfig.OpenAICompatibilityModel, executor *openAICompatPoolExecutor) *Manager { + t.Helper() + cfg := &internalconfig.Config{ + OpenAICompatibility: []internalconfig.OpenAICompatibility{{ + Name: "pool", + Models: models, + }}, + } + m := NewManager(nil, nil, nil) + m.SetConfig(cfg) + if executor == nil { + executor = &openAICompatPoolExecutor{id: "pool"} + } + m.RegisterExecutor(executor) + + auth := &Auth{ + ID: "pool-auth-" + t.Name(), + Provider: "pool", + Status: StatusActive, + Attributes: map[string]string{ + "api_key": "test-key", + "compat_name": "pool", + "provider_key": "pool", + }, + } + if _, err := m.Register(context.Background(), auth); err != nil { + t.Fatalf("register auth: %v", err) + } + + reg := registry.GetGlobalRegistry() + reg.RegisterClient(auth.ID, "pool", []*registry.ModelInfo{{ID: alias}}) + t.Cleanup(func() { + reg.UnregisterClient(auth.ID) + }) + return m +} + +func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) { + alias := "claude-opus-4.66" + invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"} + executor := &openAICompatPoolExecutor{ + id: "pool", + countErrors: map[string]error{"qwen3.5-plus": invalidErr}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + _, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err == nil || err.Error() != invalidErr.Error() { + t.Fatalf("execute count error = %v, want %v", err, invalidErr) + } + got := executor.CountModels() + if len(got) != 1 || got[0] != "qwen3.5-plus" { + t.Fatalf("count calls = %v, want only first invalid model", got) + } +} +func TestResolveModelAliasPoolFromConfigModels(t *testing.T) { + models := []modelAliasEntry{ + internalconfig.OpenAICompatibilityModel{Name: "qwen3.5-plus", Alias: "claude-opus-4.66"}, + internalconfig.OpenAICompatibilityModel{Name: "glm-5", Alias: "claude-opus-4.66"}, + internalconfig.OpenAICompatibilityModel{Name: "kimi-k2.5", Alias: "claude-opus-4.66"}, + } + got := resolveModelAliasPoolFromConfigModels("claude-opus-4.66(8192)", models) + want := []string{"qwen3.5-plus(8192)", "glm-5(8192)", "kimi-k2.5(8192)"} + if len(got) != len(want) { + t.Fatalf("pool len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("pool[%d] = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) { + alias := "claude-opus-4.66" + executor := &openAICompatPoolExecutor{id: "pool"} + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + for i := 0; i < 3; i++ { + resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute %d: %v", i, err) + } + if len(resp.Payload) == 0 { + t.Fatalf("execute %d returned empty payload", i) + } + } + + got := executor.ExecuteModels() + want := []string{"qwen3.5-plus", "glm-5", "qwen3.5-plus"} + if len(got) != len(want) { + t.Fatalf("execute calls = %v, want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) { + alias := "claude-opus-4.66" + invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"} + executor := &openAICompatPoolExecutor{ + id: "pool", + executeErrors: map[string]error{"qwen3.5-plus": invalidErr}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + _, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err == nil || err.Error() != invalidErr.Error() { + t.Fatalf("execute error = %v, want %v", err, invalidErr) + } + got := executor.ExecuteModels() + if len(got) != 1 || got[0] != "qwen3.5-plus" { + t.Fatalf("execute calls = %v, want only first invalid model", got) + } +} +func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.T) { + alias := "claude-opus-4.66" + executor := &openAICompatPoolExecutor{ + id: "pool", + executeErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute: %v", err) + } + if string(resp.Payload) != "glm-5" { + t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5") + } + got := executor.ExecuteModels() + want := []string{"qwen3.5-plus", "glm-5"} + for i := range want { + if got[i] != want[i] { + t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *testing.T) { + alias := "claude-opus-4.66" + executor := &openAICompatPoolExecutor{ + id: "pool", + streamPayloads: map[string][]cliproxyexecutor.StreamChunk{ + "qwen3.5-plus": {}, + }, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute stream: %v", err) + } + var payload []byte + for chunk := range streamResult.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected stream error: %v", chunk.Err) + } + payload = append(payload, chunk.Payload...) + } + if string(payload) != "glm-5" { + t.Fatalf("payload = %q, want %q", string(payload), "glm-5") + } + got := executor.StreamModels() + want := []string{"qwen3.5-plus", "glm-5"} + for i := range want { + if got[i] != want[i] { + t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *testing.T) { + alias := "claude-opus-4.66" + executor := &openAICompatPoolExecutor{ + id: "pool", + streamFirstErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute stream: %v", err) + } + var payload []byte + for chunk := range streamResult.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected stream error: %v", chunk.Err) + } + payload = append(payload, chunk.Payload...) + } + if string(payload) != "glm-5" { + t.Fatalf("payload = %q, want %q", string(payload), "glm-5") + } + got := executor.StreamModels() + want := []string{"qwen3.5-plus", "glm-5"} + for i := range want { + if got[i] != want[i] { + t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i]) + } + } + if gotHeader := streamResult.Headers.Get("X-Model"); gotHeader != "glm-5" { + t.Fatalf("header X-Model = %q, want %q", gotHeader, "glm-5") + } +} + +func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) { + alias := "claude-opus-4.66" + invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"} + executor := &openAICompatPoolExecutor{ + id: "pool", + streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + _, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err == nil || err.Error() != invalidErr.Error() { + t.Fatalf("execute stream error = %v, want %v", err, invalidErr) + } + got := executor.StreamModels() + if len(got) != 1 || got[0] != "qwen3.5-plus" { + t.Fatalf("stream calls = %v, want only first invalid model", got) + } +} +func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) { + alias := "claude-opus-4.66" + executor := &openAICompatPoolExecutor{id: "pool"} + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + for i := 0; i < 2; i++ { + resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute count %d: %v", i, err) + } + if len(resp.Payload) == 0 { + t.Fatalf("execute count %d returned empty payload", i) + } + } + + got := executor.CountModels() + want := []string{"qwen3.5-plus", "glm-5"} + for i := range want { + if got[i] != want[i] { + t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *testing.T) { + alias := "claude-opus-4.66" + invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"} + executor := &openAICompatPoolExecutor{ + id: "pool", + streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err == nil { + t.Fatal("expected invalid request error") + } + if err != invalidErr { + t.Fatalf("error = %v, want %v", err, invalidErr) + } + if streamResult != nil { + t.Fatalf("streamResult = %#v, want nil on invalid bootstrap", streamResult) + } + if got := executor.StreamModels(); len(got) != 1 || got[0] != "qwen3.5-plus" { + t.Fatalf("stream calls = %v, want only first upstream model", got) + } +}