diff --git a/config.example.yaml b/config.example.yaml index a6e65faa..a7274126 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -195,6 +195,32 @@ ws-auth: false # - from: "claude-3-opus-20240229" # to: "claude-3-5-sonnet-20241022" +# Global model name mappings (per channel) +# These mappings rename model IDs for both model listing and request routing. +# NOTE: Mappings do not apply to codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode. +# model-name-mappings: +# gemini: +# - from: "gemini-2.5-pro" # original model name under this channel +# to: "gpt-5" # client-visible alias +# apikey-gemini: +# - from: "gemini-2.5-pro" +# to: "gpt-5" +# claude: +# - from: "claude-sonnet-4" +# to: "gpt-4o" +# vertex: +# - from: "gemini-2.5-pro" +# to: "gpt-5" +# qwen: +# - from: "qwen3-coder-plus" +# to: "gpt-4o-mini" +# iflow: +# - from: "glm-4.7" +# to: "gpt-5.1-mini" +# antigravity: +# - from: "gemini-3-pro-preview" +# to: "gpt-5" + # OAuth provider excluded models # oauth-excluded-models: # gemini-cli: diff --git a/internal/config/config.go b/internal/config/config.go index afe08333..0c311b70 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -91,6 +91,13 @@ type Config struct { // OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries. OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"` + // ModelNameMappings defines global per-channel model name mappings. + // These mappings affect both model listing and model routing for supported channels. + // + // NOTE: This does not apply to existing per-credential model alias features under: + // codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode. + ModelNameMappings map[string][]ModelNameMapping `yaml:"model-name-mappings,omitempty" json:"model-name-mappings,omitempty"` + // Payload defines default and override rules for provider payload parameters. Payload PayloadConfig `yaml:"payload" json:"payload"` @@ -137,6 +144,13 @@ type RoutingConfig struct { Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"` } +// ModelNameMapping defines a model ID rename mapping for a specific channel. +// It maps the original model name (From) to the client-visible alias (To). +type ModelNameMapping struct { + From string `yaml:"from" json:"from"` + To string `yaml:"to" json:"to"` +} + // AmpModelMapping defines a model name mapping for Amp CLI requests. // When Amp requests a model that isn't available locally, this mapping // allows routing to an alternative model that IS available. @@ -461,6 +475,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Normalize OAuth provider model exclusion map. cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels) + // Normalize global model name mappings. + cfg.SanitizeModelNameMappings() + if cfg.legacyMigrationPending { fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...") if !optional && configFile != "" { @@ -477,6 +494,50 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { return &cfg, nil } +// SanitizeModelNameMappings normalizes and deduplicates global model name mappings. +// It trims whitespace, normalizes channel keys to lower-case, drops empty entries, +// and ensures (From, To) pairs are unique within each channel. +func (cfg *Config) SanitizeModelNameMappings() { + if cfg == nil || len(cfg.ModelNameMappings) == 0 { + return + } + out := make(map[string][]ModelNameMapping, len(cfg.ModelNameMappings)) + for rawChannel, mappings := range cfg.ModelNameMappings { + channel := strings.ToLower(strings.TrimSpace(rawChannel)) + if channel == "" || len(mappings) == 0 { + continue + } + seenFrom := make(map[string]struct{}, len(mappings)) + seenTo := make(map[string]struct{}, len(mappings)) + clean := make([]ModelNameMapping, 0, len(mappings)) + for _, mapping := range mappings { + from := strings.TrimSpace(mapping.From) + to := strings.TrimSpace(mapping.To) + if from == "" || to == "" { + continue + } + if strings.EqualFold(from, to) { + continue + } + fromKey := strings.ToLower(from) + toKey := strings.ToLower(to) + if _, ok := seenFrom[fromKey]; ok { + continue + } + if _, ok := seenTo[toKey]; ok { + continue + } + seenFrom[fromKey] = struct{}{} + seenTo[toKey] = struct{}{} + clean = append(clean, ModelNameMapping{From: from, To: to}) + } + if len(clean) > 0 { + out[channel] = clean + } + } + cfg.ModelNameMappings = out +} + // SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are // not actionable, specifically those missing a BaseURL. It trims whitespace before // evaluation and preserves the relative order of remaining entries. diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 2b4ec748..9ade4fbb 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -76,7 +76,12 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au // Execute performs a non-streaming request to the Antigravity API. func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { - if strings.Contains(req.Model, "claude") { + upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + if upstreamModel == "" { + upstreamModel = req.Model + } + isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") + if isClaude { return e.executeClaudeNonStream(ctx, auth, req, opts) } @@ -98,7 +103,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) - translated = normalizeAntigravityThinking(req.Model, translated) + translated = normalizeAntigravityThinking(req.Model, translated, isClaude) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) @@ -109,7 +114,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, false, opts.Alt, baseURL) if errReq != nil { err = errReq return resp, err @@ -190,10 +195,15 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * to := sdktranslator.FromString("antigravity") translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + if upstreamModel == "" { + upstreamModel = req.Model + } + translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) - translated = normalizeAntigravityThinking(req.Model, translated) + translated = normalizeAntigravityThinking(req.Model, translated, true) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) @@ -204,7 +214,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) if errReq != nil { err = errReq return resp, err @@ -524,10 +534,16 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya to := sdktranslator.FromString("antigravity") translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + if upstreamModel == "" { + upstreamModel = req.Model + } + isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") + translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) - translated = normalizeAntigravityThinking(req.Model, translated) + translated = normalizeAntigravityThinking(req.Model, translated, isClaude) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) @@ -538,7 +554,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) if errReq != nil { err = errReq return nil, err @@ -676,6 +692,12 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut to := sdktranslator.FromString("antigravity") respCtx := context.WithValue(ctx, "alt", opts.Alt) + upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + if upstreamModel == "" { + upstreamModel = req.Model + } + isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") + baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -694,7 +716,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model) payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload) - payload = normalizeAntigravityThinking(req.Model, payload) + payload = normalizeAntigravityThinking(req.Model, payload, isClaude) payload = deleteJSONField(payload, "project") payload = deleteJSONField(payload, "model") payload = deleteJSONField(payload, "request.safetySettings") @@ -1308,7 +1330,7 @@ func alias2ModelName(modelName string) string { // normalizeAntigravityThinking clamps or removes thinking config based on model support. // For Claude models, it additionally ensures thinking budget < max_tokens. -func normalizeAntigravityThinking(model string, payload []byte) []byte { +func normalizeAntigravityThinking(model string, payload []byte, isClaude bool) []byte { payload = util.StripThinkingConfigIfUnsupported(model, payload) if !util.ModelSupportsThinking(model) { return payload @@ -1320,7 +1342,6 @@ func normalizeAntigravityThinking(model string, payload []byte) []byte { raw := int(budget.Int()) normalized := util.NormalizeThinkingBudget(model, raw) - isClaude := strings.Contains(strings.ToLower(model), "claude") if isClaude { effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload) if effectiveMax > 0 && normalized >= effectiveMax { diff --git a/internal/util/thinking_suffix.go b/internal/util/thinking_suffix.go index ff3b24a6..0a72b4c5 100644 --- a/internal/util/thinking_suffix.go +++ b/internal/util/thinking_suffix.go @@ -7,10 +7,11 @@ import ( ) const ( - ThinkingBudgetMetadataKey = "thinking_budget" - ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts" - ReasoningEffortMetadataKey = "reasoning_effort" - ThinkingOriginalModelMetadataKey = "thinking_original_model" + ThinkingBudgetMetadataKey = "thinking_budget" + ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts" + ReasoningEffortMetadataKey = "reasoning_effort" + ThinkingOriginalModelMetadataKey = "thinking_original_model" + ModelMappingOriginalModelMetadataKey = "model_mapping_original_model" ) // NormalizeThinkingModel parses dynamic thinking suffixes on model names and returns @@ -215,6 +216,13 @@ func ResolveOriginalModel(model string, metadata map[string]any) string { } if metadata != nil { + if v, ok := metadata[ModelMappingOriginalModelMetadataKey]; ok { + if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" { + if base := normalize(s); base != "" { + return base + } + } + } if v, ok := metadata[ThinkingOriginalModelMetadataKey]; ok { if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" { if base := normalize(s); base != "" { diff --git a/internal/watcher/config_reload.go b/internal/watcher/config_reload.go index 244f738e..4db93fc8 100644 --- a/internal/watcher/config_reload.go +++ b/internal/watcher/config_reload.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "encoding/hex" "os" + "reflect" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -126,7 +127,7 @@ func (w *Watcher) reloadConfig() bool { } authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir - forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix + forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.ModelNameMappings, newConfig.ModelNameMappings)) log.Infof("config successfully reloaded, triggering client reload") w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh) diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index d16fc1ae..125966fd 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -111,6 +111,9 @@ type Manager struct { requestRetry atomic.Int32 maxRetryInterval atomic.Int64 + // modelNameMappings stores global model name alias mappings (alias -> upstream name) keyed by channel. + modelNameMappings atomic.Value + // Optional HTTP RoundTripper provider injected by host. rtProvider RoundTripperProvider @@ -410,6 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) + execReq.Metadata = m.applyGlobalModelNameMappingMetadata(auth, execReq.Model, execReq.Metadata) resp, errExec := executor.Execute(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { @@ -471,6 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string, } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) + execReq.Metadata = m.applyGlobalModelNameMappingMetadata(auth, execReq.Model, execReq.Metadata) resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { @@ -532,6 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) + execReq.Metadata = m.applyGlobalModelNameMappingMetadata(auth, execReq.Model, execReq.Metadata) chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) if errStream != nil { rerr := &Error{Message: errStream.Error()} @@ -592,6 +598,7 @@ func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string] keys := []string{ util.ThinkingOriginalModelMetadataKey, util.GeminiOriginalModelMetadataKey, + util.ModelMappingOriginalModelMetadataKey, } var out map[string]any for _, key := range keys { diff --git a/sdk/cliproxy/auth/model_name_mappings.go b/sdk/cliproxy/auth/model_name_mappings.go new file mode 100644 index 00000000..99215fc4 --- /dev/null +++ b/sdk/cliproxy/auth/model_name_mappings.go @@ -0,0 +1,163 @@ +package auth + +import ( + "strings" + + internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" +) + +type modelNameMappingTable struct { + // reverse maps channel -> alias (lower) -> original upstream model name. + reverse map[string]map[string]string +} + +func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelNameMapping) *modelNameMappingTable { + if len(mappings) == 0 { + return &modelNameMappingTable{} + } + out := &modelNameMappingTable{ + reverse: make(map[string]map[string]string, len(mappings)), + } + for rawChannel, entries := range mappings { + channel := strings.ToLower(strings.TrimSpace(rawChannel)) + if channel == "" || len(entries) == 0 { + continue + } + rev := make(map[string]string, len(entries)) + for _, entry := range entries { + from := strings.TrimSpace(entry.From) + to := strings.TrimSpace(entry.To) + if from == "" || to == "" { + continue + } + if strings.EqualFold(from, to) { + continue + } + aliasKey := strings.ToLower(to) + if _, exists := rev[aliasKey]; exists { + continue + } + rev[aliasKey] = from + } + if len(rev) > 0 { + out.reverse[channel] = rev + } + } + if len(out.reverse) == 0 { + out.reverse = nil + } + return out +} + +// SetGlobalModelNameMappings updates the global model name mapping table used during execution. +// The mapping is applied per-auth channel to resolve the upstream model name while keeping the +// client-visible model name unchanged for translation/response formatting. +func (m *Manager) SetGlobalModelNameMappings(mappings map[string][]internalconfig.ModelNameMapping) { + if m == nil { + return + } + table := compileModelNameMappingTable(mappings) + // atomic.Value requires non-nil store values. + if table == nil { + table = &modelNameMappingTable{} + } + m.modelNameMappings.Store(table) +} + +func (m *Manager) applyGlobalModelNameMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any { + original := m.resolveGlobalUpstreamModelForAuth(auth, requestedModel) + if original == "" { + return metadata + } + if metadata != nil { + if v, ok := metadata[util.ModelMappingOriginalModelMetadataKey]; ok { + if s, okStr := v.(string); okStr && strings.EqualFold(s, original) { + return metadata + } + } + } + out := make(map[string]any, 1) + if len(metadata) > 0 { + out = make(map[string]any, len(metadata)+1) + for k, v := range metadata { + out[k] = v + } + } + out[util.ModelMappingOriginalModelMetadataKey] = original + return out +} + +func (m *Manager) resolveGlobalUpstreamModelForAuth(auth *Auth, requestedModel string) string { + if m == nil || auth == nil { + return "" + } + channel := globalModelMappingChannelForAuth(auth) + if channel == "" { + return "" + } + key := strings.ToLower(strings.TrimSpace(requestedModel)) + if key == "" { + return "" + } + raw := m.modelNameMappings.Load() + table, _ := raw.(*modelNameMappingTable) + if table == nil || table.reverse == nil { + return "" + } + rev := table.reverse[channel] + if rev == nil { + return "" + } + original := strings.TrimSpace(rev[key]) + if original == "" || strings.EqualFold(original, requestedModel) { + return "" + } + return original +} + +func globalModelMappingChannelForAuth(auth *Auth) string { + if auth == nil { + return "" + } + provider := strings.ToLower(strings.TrimSpace(auth.Provider)) + authKind := "" + if auth.Attributes != nil { + authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"])) + } + if authKind == "" { + if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") { + authKind = "apikey" + } + } + return globalModelMappingChannel(provider, authKind) +} + +func globalModelMappingChannel(provider, authKind string) string { + switch provider { + case "gemini": + if authKind == "apikey" { + return "apikey-gemini" + } + return "gemini" + case "codex": + if authKind == "apikey" { + return "" + } + return "codex" + case "claude": + if authKind == "apikey" { + return "" + } + return "claude" + case "vertex": + if authKind == "apikey" { + return "" + } + return "vertex" + case "antigravity", "qwen", "iflow": + return provider + default: + return "" + } +} diff --git a/sdk/cliproxy/builder.go b/sdk/cliproxy/builder.go index 381a0926..ce0517b2 100644 --- a/sdk/cliproxy/builder.go +++ b/sdk/cliproxy/builder.go @@ -215,6 +215,7 @@ func (b *Builder) Build() (*Service, error) { } // Attach a default RoundTripper provider so providers can opt-in per-auth transports. coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider()) + coreManager.SetGlobalModelNameMappings(b.cfg.ModelNameMappings) service := &Service{ cfg: b.cfg, diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 6e81e401..a31f3b11 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -552,6 +552,9 @@ func (s *Service) Run(ctx context.Context) error { s.cfgMu.Lock() s.cfg = newCfg s.cfgMu.Unlock() + if s.coreManager != nil { + s.coreManager.SetGlobalModelNameMappings(newCfg.ModelNameMappings) + } s.rebindExecutors() } @@ -677,6 +680,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { return } authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"])) + if authKind == "" { + if kind, _ := a.AccountInfo(); strings.EqualFold(kind, "api_key") { + authKind = "apikey" + } + } if a.Attributes != nil { if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") { GlobalModelRegistry().UnregisterClient(a.ID) @@ -836,6 +844,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { } } } + models = applyGlobalModelNameMappings(s.cfg, provider, authKind, models) if len(models) > 0 { key := provider if key == "" { @@ -1145,6 +1154,124 @@ func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo { return out } +func globalModelMappingChannel(provider, authKind string) string { + provider = strings.ToLower(strings.TrimSpace(provider)) + authKind = strings.ToLower(strings.TrimSpace(authKind)) + switch provider { + case "gemini": + if authKind == "apikey" { + return "apikey-gemini" + } + return "gemini" + case "codex": + if authKind == "apikey" { + return "" + } + return "codex" + case "claude": + if authKind == "apikey" { + return "" + } + return "claude" + case "vertex": + if authKind == "apikey" { + return "" + } + return "vertex" + case "antigravity", "qwen", "iflow": + return provider + default: + return "" + } +} + +func rewriteModelInfoName(name, oldID, newID string) string { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + return name + } + oldID = strings.TrimSpace(oldID) + newID = strings.TrimSpace(newID) + if oldID == "" || newID == "" { + return name + } + if strings.EqualFold(oldID, newID) { + return name + } + if strings.HasSuffix(trimmed, "/"+oldID) { + prefix := strings.TrimSuffix(trimmed, oldID) + return prefix + newID + } + if trimmed == "models/"+oldID { + return "models/" + newID + } + return name +} + +func applyGlobalModelNameMappings(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo { + if cfg == nil || len(models) == 0 { + return models + } + channel := globalModelMappingChannel(provider, authKind) + if channel == "" || len(cfg.ModelNameMappings) == 0 { + return models + } + mappings := cfg.ModelNameMappings[channel] + if len(mappings) == 0 { + return models + } + forward := make(map[string]string, len(mappings)) + for i := range mappings { + from := strings.TrimSpace(mappings[i].From) + to := strings.TrimSpace(mappings[i].To) + if from == "" || to == "" { + continue + } + if strings.EqualFold(from, to) { + continue + } + key := strings.ToLower(from) + if _, exists := forward[key]; exists { + continue + } + forward[key] = to + } + if len(forward) == 0 { + return models + } + out := make([]*ModelInfo, 0, len(models)) + seen := make(map[string]struct{}, len(models)) + for _, model := range models { + if model == nil { + continue + } + id := strings.TrimSpace(model.ID) + if id == "" { + continue + } + mappedID := id + if to, ok := forward[strings.ToLower(id)]; ok && strings.TrimSpace(to) != "" { + mappedID = strings.TrimSpace(to) + } + uniqueKey := strings.ToLower(mappedID) + if _, exists := seen[uniqueKey]; exists { + continue + } + seen[uniqueKey] = struct{}{} + if mappedID == id { + out = append(out, model) + continue + } + clone := *model + clone.ID = mappedID + if clone.Name != "" { + clone.Name = rewriteModelInfoName(clone.Name, id, mappedID) + } + out = append(out, &clone) + } + return out +} + func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo { if entry == nil || len(entry.Models) == 0 { return nil diff --git a/sdk/config/config.go b/sdk/config/config.go index b471e5e0..1ae7ba20 100644 --- a/sdk/config/config.go +++ b/sdk/config/config.go @@ -16,6 +16,7 @@ type StreamingConfig = internalconfig.StreamingConfig type TLSConfig = internalconfig.TLSConfig type RemoteManagement = internalconfig.RemoteManagement type AmpCode = internalconfig.AmpCode +type ModelNameMapping = internalconfig.ModelNameMapping type PayloadConfig = internalconfig.PayloadConfig type PayloadRule = internalconfig.PayloadRule type PayloadModelRule = internalconfig.PayloadModelRule