diff --git a/cmd/server/main.go b/cmd/server/main.go index a0501402..c5182c4a 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -434,7 +434,7 @@ func main() { usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled) coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling) - if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil { + if err = logging.ConfigureLogOutput(cfg); err != nil { log.Errorf("failed to configure log output: %v", err) return } diff --git a/config.example.yaml b/config.example.yaml index 09b2df02..bbde75b6 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -171,9 +171,9 @@ ws-auth: false # headers: # X-Custom-Header: "custom-value" # models: # optional: map aliases to upstream model names -# - name: "gemini-2.0-flash" # upstream model name +# - name: "gemini-2.5-flash" # upstream model name # alias: "vertex-flash" # client-visible alias -# - name: "gemini-1.5-pro" +# - name: "gemini-2.5-pro" # alias: "vertex-pro" # Amp Integration @@ -203,12 +203,42 @@ ws-auth: false # # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5) # # but you have a similar model available (e.g., Claude Sonnet 4). # model-mappings: -# - from: "claude-opus-4.5" # Model requested by Amp CLI -# to: "claude-sonnet-4" # Route to this available model instead -# - from: "gpt-5" -# to: "gemini-2.5-pro" -# - from: "claude-3-opus-20240229" -# to: "claude-3-5-sonnet-20241022" +# - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI +# to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead +# - from: "claude-sonnet-4-5-20250929" +# to: "gemini-claude-sonnet-4-5-thinking" +# - from: "claude-haiku-4-5-20251001" +# to: "gemini-2.5-flash" + +# Global OAuth model name mappings (per channel) +# These mappings rename model IDs for both model listing and request routing. +# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow. +# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode. +# oauth-model-mappings: +# gemini-cli: +# - name: "gemini-2.5-pro" # original model name under this channel +# alias: "g2.5p" # client-visible alias +# vertex: +# - name: "gemini-2.5-pro" +# alias: "g2.5p" +# aistudio: +# - name: "gemini-2.5-pro" +# alias: "g2.5p" +# antigravity: +# - name: "gemini-3-pro-preview" +# alias: "g3p" +# claude: +# - name: "claude-sonnet-4-5-20250929" +# alias: "cs4.5" +# codex: +# - name: "gpt-5" +# alias: "g5" +# qwen: +# - name: "qwen3-coder-plus" +# alias: "qwen-plus" +# iflow: +# - name: "glm-4.7" +# alias: "glm-god" # OAuth provider excluded models # oauth-excluded-models: diff --git a/internal/api/server.go b/internal/api/server.go index bbac2e55..4615894c 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -877,7 +877,7 @@ func (s *Server) UpdateClients(cfg *config.Config) { } if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB { - if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil { + if err := logging.ConfigureLogOutput(cfg); err != nil { log.Errorf("failed to reconfigure log output: %v", err) } else { if oldCfg == nil { diff --git a/internal/config/config.go b/internal/config/config.go index 260871de..6beba5cd 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -98,6 +98,14 @@ type Config struct { // OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries. OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"` + // OAuthModelMappings defines global model name mappings for OAuth/file-backed auth channels. + // These mappings affect both model listing and model routing for supported channels: + // gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow. + // + // NOTE: This does not apply to existing per-credential model alias features under: + // gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode. + OAuthModelMappings map[string][]ModelNameMapping `yaml:"oauth-model-mappings,omitempty" json:"oauth-model-mappings,omitempty"` + // Payload defines default and override rules for provider payload parameters. Payload PayloadConfig `yaml:"payload" json:"payload"` @@ -149,6 +157,13 @@ type RoutingConfig struct { Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"` } +// ModelNameMapping defines a model ID rename mapping for a specific channel. +// It maps the original model name (Name) to the client-visible alias (Alias). +type ModelNameMapping struct { + Name string `yaml:"name" json:"name"` + Alias string `yaml:"alias" json:"alias"` +} + // AmpModelMapping defines a model name mapping for Amp CLI requests. // When Amp requests a model that isn't available locally, this mapping // allows routing to an alternative model that IS available. @@ -506,6 +521,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Normalize OAuth provider model exclusion map. cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels) + // Normalize global OAuth model name mappings. + cfg.SanitizeOAuthModelMappings() + if cfg.legacyMigrationPending { fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...") if !optional && configFile != "" { @@ -522,6 +540,50 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { return &cfg, nil } +// SanitizeOAuthModelMappings normalizes and deduplicates global OAuth model name mappings. +// It trims whitespace, normalizes channel keys to lower-case, drops empty entries, +// and ensures (From, To) pairs are unique within each channel. +func (cfg *Config) SanitizeOAuthModelMappings() { + if cfg == nil || len(cfg.OAuthModelMappings) == 0 { + return + } + out := make(map[string][]ModelNameMapping, len(cfg.OAuthModelMappings)) + for rawChannel, mappings := range cfg.OAuthModelMappings { + channel := strings.ToLower(strings.TrimSpace(rawChannel)) + if channel == "" || len(mappings) == 0 { + continue + } + seenName := make(map[string]struct{}, len(mappings)) + seenAlias := make(map[string]struct{}, len(mappings)) + clean := make([]ModelNameMapping, 0, len(mappings)) + for _, mapping := range mappings { + name := strings.TrimSpace(mapping.Name) + alias := strings.TrimSpace(mapping.Alias) + if name == "" || alias == "" { + continue + } + if strings.EqualFold(name, alias) { + continue + } + nameKey := strings.ToLower(name) + aliasKey := strings.ToLower(alias) + if _, ok := seenName[nameKey]; ok { + continue + } + if _, ok := seenAlias[aliasKey]; ok { + continue + } + seenName[nameKey] = struct{}{} + seenAlias[aliasKey] = struct{}{} + clean = append(clean, ModelNameMapping{Name: name, Alias: alias}) + } + if len(clean) > 0 { + out[channel] = clean + } + } + cfg.OAuthModelMappings = out +} + // SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are // not actionable, specifically those missing a BaseURL. It trims whitespace before // evaluation and preserves the relative order of remaining entries. diff --git a/internal/logging/global_logger.go b/internal/logging/global_logger.go index a588bea4..e305ec70 100644 --- a/internal/logging/global_logger.go +++ b/internal/logging/global_logger.go @@ -10,6 +10,7 @@ import ( "sync" "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" "gopkg.in/natefinch/lumberjack.v2" @@ -84,10 +85,30 @@ func SetupBaseLogger() { }) } +// isDirWritable checks if the specified directory exists and is writable by attempting to create and remove a test file. +func isDirWritable(dir string) bool { + info, err := os.Stat(dir) + if err != nil || !info.IsDir() { + return false + } + + testFile := filepath.Join(dir, ".perm_test") + f, err := os.Create(testFile) + if err != nil { + return false + } + + defer func() { + _ = f.Close() + _ = os.Remove(testFile) + }() + return true +} + // ConfigureLogOutput switches the global log destination between rotating files and stdout. // When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory // until the total size is within the limit. -func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error { +func ConfigureLogOutput(cfg *config.Config) error { SetupBaseLogger() writerMu.Lock() @@ -96,10 +117,12 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error { logDir := "logs" if base := util.WritablePath(); base != "" { logDir = filepath.Join(base, "logs") + } else if !isDirWritable(logDir) { + logDir = filepath.Join(cfg.AuthDir, "logs") } protectedPath := "" - if loggingToFile { + if cfg.LoggingToFile { if err := os.MkdirAll(logDir, 0o755); err != nil { return fmt.Errorf("logging: failed to create log directory: %w", err) } @@ -123,7 +146,7 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error { log.SetOutput(os.Stdout) } - configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath) + configureLogDirCleanerLocked(logDir, cfg.LogsMaxTotalSizeMB, protectedPath) return nil } diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 2b4ec748..9ade4fbb 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -76,7 +76,12 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au // Execute performs a non-streaming request to the Antigravity API. func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { - if strings.Contains(req.Model, "claude") { + upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + if upstreamModel == "" { + upstreamModel = req.Model + } + isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") + if isClaude { return e.executeClaudeNonStream(ctx, auth, req, opts) } @@ -98,7 +103,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) - translated = normalizeAntigravityThinking(req.Model, translated) + translated = normalizeAntigravityThinking(req.Model, translated, isClaude) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) @@ -109,7 +114,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, false, opts.Alt, baseURL) if errReq != nil { err = errReq return resp, err @@ -190,10 +195,15 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * to := sdktranslator.FromString("antigravity") translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + if upstreamModel == "" { + upstreamModel = req.Model + } + translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) - translated = normalizeAntigravityThinking(req.Model, translated) + translated = normalizeAntigravityThinking(req.Model, translated, true) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) @@ -204,7 +214,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) if errReq != nil { err = errReq return resp, err @@ -524,10 +534,16 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya to := sdktranslator.FromString("antigravity") translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + if upstreamModel == "" { + upstreamModel = req.Model + } + isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") + translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) - translated = normalizeAntigravityThinking(req.Model, translated) + translated = normalizeAntigravityThinking(req.Model, translated, isClaude) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) @@ -538,7 +554,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) if errReq != nil { err = errReq return nil, err @@ -676,6 +692,12 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut to := sdktranslator.FromString("antigravity") respCtx := context.WithValue(ctx, "alt", opts.Alt) + upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + if upstreamModel == "" { + upstreamModel = req.Model + } + isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") + baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -694,7 +716,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model) payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload) - payload = normalizeAntigravityThinking(req.Model, payload) + payload = normalizeAntigravityThinking(req.Model, payload, isClaude) payload = deleteJSONField(payload, "project") payload = deleteJSONField(payload, "model") payload = deleteJSONField(payload, "request.safetySettings") @@ -1308,7 +1330,7 @@ func alias2ModelName(modelName string) string { // normalizeAntigravityThinking clamps or removes thinking config based on model support. // For Claude models, it additionally ensures thinking budget < max_tokens. -func normalizeAntigravityThinking(model string, payload []byte) []byte { +func normalizeAntigravityThinking(model string, payload []byte, isClaude bool) []byte { payload = util.StripThinkingConfigIfUnsupported(model, payload) if !util.ModelSupportsThinking(model) { return payload @@ -1320,7 +1342,6 @@ func normalizeAntigravityThinking(model string, payload []byte) []byte { raw := int(budget.Int()) normalized := util.NormalizeThinkingBudget(model, raw) - isClaude := strings.Contains(strings.ToLower(model), "claude") if isClaude { effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload) if effectiveMax > 0 && normalized >= effectiveMax { diff --git a/internal/util/thinking_suffix.go b/internal/util/thinking_suffix.go index ff3b24a6..0a72b4c5 100644 --- a/internal/util/thinking_suffix.go +++ b/internal/util/thinking_suffix.go @@ -7,10 +7,11 @@ import ( ) const ( - ThinkingBudgetMetadataKey = "thinking_budget" - ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts" - ReasoningEffortMetadataKey = "reasoning_effort" - ThinkingOriginalModelMetadataKey = "thinking_original_model" + ThinkingBudgetMetadataKey = "thinking_budget" + ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts" + ReasoningEffortMetadataKey = "reasoning_effort" + ThinkingOriginalModelMetadataKey = "thinking_original_model" + ModelMappingOriginalModelMetadataKey = "model_mapping_original_model" ) // NormalizeThinkingModel parses dynamic thinking suffixes on model names and returns @@ -215,6 +216,13 @@ func ResolveOriginalModel(model string, metadata map[string]any) string { } if metadata != nil { + if v, ok := metadata[ModelMappingOriginalModelMetadataKey]; ok { + if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" { + if base := normalize(s); base != "" { + return base + } + } + } if v, ok := metadata[ThinkingOriginalModelMetadataKey]; ok { if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" { if base := normalize(s); base != "" { diff --git a/internal/watcher/config_reload.go b/internal/watcher/config_reload.go index 244f738e..370ee4e1 100644 --- a/internal/watcher/config_reload.go +++ b/internal/watcher/config_reload.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "encoding/hex" "os" + "reflect" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -126,7 +127,7 @@ func (w *Watcher) reloadConfig() bool { } authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir - forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix + forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelMappings, newConfig.OAuthModelMappings)) log.Infof("config successfully reloaded, triggering client reload") w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh) diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index cf4e39ad..27e940e8 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -111,6 +111,9 @@ type Manager struct { requestRetry atomic.Int32 maxRetryInterval atomic.Int64 + // modelNameMappings stores global model name alias mappings (alias -> upstream name) keyed by channel. + modelNameMappings atomic.Value + // Optional HTTP RoundTripper provider injected by host. rtProvider RoundTripperProvider @@ -410,6 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) + execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) resp, errExec := executor.Execute(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { @@ -471,6 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string, } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) + execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { @@ -532,6 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) + execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) if errStream != nil { rerr := &Error{Message: errStream.Error()} @@ -592,6 +598,7 @@ func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string] keys := []string{ util.ThinkingOriginalModelMetadataKey, util.GeminiOriginalModelMetadataKey, + util.ModelMappingOriginalModelMetadataKey, } var out map[string]any for _, key := range keys { diff --git a/sdk/cliproxy/auth/model_name_mappings.go b/sdk/cliproxy/auth/model_name_mappings.go new file mode 100644 index 00000000..483cb9c9 --- /dev/null +++ b/sdk/cliproxy/auth/model_name_mappings.go @@ -0,0 +1,172 @@ +package auth + +import ( + "strings" + + internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" +) + +type modelNameMappingTable struct { + // reverse maps channel -> alias (lower) -> original upstream model name. + reverse map[string]map[string]string +} + +func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelNameMapping) *modelNameMappingTable { + if len(mappings) == 0 { + return &modelNameMappingTable{} + } + out := &modelNameMappingTable{ + reverse: make(map[string]map[string]string, len(mappings)), + } + for rawChannel, entries := range mappings { + channel := strings.ToLower(strings.TrimSpace(rawChannel)) + if channel == "" || len(entries) == 0 { + continue + } + rev := make(map[string]string, len(entries)) + for _, entry := range entries { + name := strings.TrimSpace(entry.Name) + alias := strings.TrimSpace(entry.Alias) + if name == "" || alias == "" { + continue + } + if strings.EqualFold(name, alias) { + continue + } + aliasKey := strings.ToLower(alias) + if _, exists := rev[aliasKey]; exists { + continue + } + rev[aliasKey] = name + } + if len(rev) > 0 { + out.reverse[channel] = rev + } + } + if len(out.reverse) == 0 { + out.reverse = nil + } + return out +} + +// SetOAuthModelMappings updates the OAuth model name mapping table used during execution. +// The mapping is applied per-auth channel to resolve the upstream model name while keeping the +// client-visible model name unchanged for translation/response formatting. +func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.ModelNameMapping) { + if m == nil { + return + } + table := compileModelNameMappingTable(mappings) + // atomic.Value requires non-nil store values. + if table == nil { + table = &modelNameMappingTable{} + } + m.modelNameMappings.Store(table) +} + +func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any { + original := m.resolveOAuthUpstreamModel(auth, requestedModel) + if original == "" { + return metadata + } + if metadata != nil { + if v, ok := metadata[util.ModelMappingOriginalModelMetadataKey]; ok { + if s, okStr := v.(string); okStr && strings.EqualFold(s, original) { + return metadata + } + } + } + out := make(map[string]any, 1) + if len(metadata) > 0 { + out = make(map[string]any, len(metadata)+1) + for k, v := range metadata { + out[k] = v + } + } + out[util.ModelMappingOriginalModelMetadataKey] = original + return out +} + +func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string { + if m == nil || auth == nil { + return "" + } + channel := modelMappingChannel(auth) + if channel == "" { + return "" + } + key := strings.ToLower(strings.TrimSpace(requestedModel)) + if key == "" { + return "" + } + raw := m.modelNameMappings.Load() + table, _ := raw.(*modelNameMappingTable) + if table == nil || table.reverse == nil { + return "" + } + rev := table.reverse[channel] + if rev == nil { + return "" + } + original := strings.TrimSpace(rev[key]) + if original == "" || strings.EqualFold(original, requestedModel) { + return "" + } + return original +} + +// modelMappingChannel extracts the OAuth model mapping channel from an Auth object. +// It determines the provider and auth kind from the Auth's attributes and delegates +// to OAuthModelMappingChannel for the actual channel resolution. +func modelMappingChannel(auth *Auth) string { + if auth == nil { + return "" + } + provider := strings.ToLower(strings.TrimSpace(auth.Provider)) + authKind := "" + if auth.Attributes != nil { + authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"])) + } + if authKind == "" { + if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") { + authKind = "apikey" + } + } + return OAuthModelMappingChannel(provider, authKind) +} + +// OAuthModelMappingChannel returns the OAuth model mapping channel name for a given provider +// and auth kind. Returns empty string if the provider/authKind combination doesn't support +// OAuth model mappings (e.g., API key authentication). +// +// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow. +func OAuthModelMappingChannel(provider, authKind string) string { + provider = strings.ToLower(strings.TrimSpace(provider)) + authKind = strings.ToLower(strings.TrimSpace(authKind)) + switch provider { + case "gemini": + // gemini provider uses gemini-api-key config, not oauth-model-mappings. + // OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer. + return "" + case "vertex": + if authKind == "apikey" { + return "" + } + return "vertex" + case "claude": + if authKind == "apikey" { + return "" + } + return "claude" + case "codex": + if authKind == "apikey" { + return "" + } + return "codex" + case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow": + return provider + default: + return "" + } +} diff --git a/sdk/cliproxy/builder.go b/sdk/cliproxy/builder.go index 381a0926..51d5dbac 100644 --- a/sdk/cliproxy/builder.go +++ b/sdk/cliproxy/builder.go @@ -215,6 +215,7 @@ func (b *Builder) Build() (*Service, error) { } // Attach a default RoundTripper provider so providers can opt-in per-auth transports. coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider()) + coreManager.SetOAuthModelMappings(b.cfg.OAuthModelMappings) service := &Service{ cfg: b.cfg, diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 49941135..0927eaa6 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -556,6 +556,9 @@ func (s *Service) Run(ctx context.Context) error { s.cfgMu.Lock() s.cfg = newCfg s.cfgMu.Unlock() + if s.coreManager != nil { + s.coreManager.SetOAuthModelMappings(newCfg.OAuthModelMappings) + } s.rebindExecutors() } @@ -681,6 +684,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { return } authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"])) + if authKind == "" { + if kind, _ := a.AccountInfo(); strings.EqualFold(kind, "api_key") { + authKind = "apikey" + } + } if a.Attributes != nil { if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") { GlobalModelRegistry().UnregisterClient(a.ID) @@ -845,6 +853,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { } } } + models = applyOAuthModelMappings(s.cfg, provider, authKind, models) if len(models) > 0 { key := provider if key == "" { @@ -1154,6 +1163,93 @@ func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo { return out } +func rewriteModelInfoName(name, oldID, newID string) string { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + return name + } + oldID = strings.TrimSpace(oldID) + newID = strings.TrimSpace(newID) + if oldID == "" || newID == "" { + return name + } + if strings.EqualFold(oldID, newID) { + return name + } + if strings.HasSuffix(trimmed, "/"+oldID) { + prefix := strings.TrimSuffix(trimmed, oldID) + return prefix + newID + } + if trimmed == "models/"+oldID { + return "models/" + newID + } + return name +} + +func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo { + if cfg == nil || len(models) == 0 { + return models + } + channel := coreauth.OAuthModelMappingChannel(provider, authKind) + if channel == "" || len(cfg.OAuthModelMappings) == 0 { + return models + } + mappings := cfg.OAuthModelMappings[channel] + if len(mappings) == 0 { + return models + } + forward := make(map[string]string, len(mappings)) + for i := range mappings { + name := strings.TrimSpace(mappings[i].Name) + alias := strings.TrimSpace(mappings[i].Alias) + if name == "" || alias == "" { + continue + } + if strings.EqualFold(name, alias) { + continue + } + key := strings.ToLower(name) + if _, exists := forward[key]; exists { + continue + } + forward[key] = alias + } + if len(forward) == 0 { + return models + } + out := make([]*ModelInfo, 0, len(models)) + seen := make(map[string]struct{}, len(models)) + for _, model := range models { + if model == nil { + continue + } + id := strings.TrimSpace(model.ID) + if id == "" { + continue + } + mappedID := id + if to, ok := forward[strings.ToLower(id)]; ok && strings.TrimSpace(to) != "" { + mappedID = strings.TrimSpace(to) + } + uniqueKey := strings.ToLower(mappedID) + if _, exists := seen[uniqueKey]; exists { + continue + } + seen[uniqueKey] = struct{}{} + if mappedID == id { + out = append(out, model) + continue + } + clone := *model + clone.ID = mappedID + if clone.Name != "" { + clone.Name = rewriteModelInfoName(clone.Name, id, mappedID) + } + out = append(out, &clone) + } + return out +} + func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo { if entry == nil || len(entry.Models) == 0 { return nil diff --git a/sdk/config/config.go b/sdk/config/config.go index b471e5e0..1ae7ba20 100644 --- a/sdk/config/config.go +++ b/sdk/config/config.go @@ -16,6 +16,7 @@ type StreamingConfig = internalconfig.StreamingConfig type TLSConfig = internalconfig.TLSConfig type RemoteManagement = internalconfig.RemoteManagement type AmpCode = internalconfig.AmpCode +type ModelNameMapping = internalconfig.ModelNameMapping type PayloadConfig = internalconfig.PayloadConfig type PayloadRule = internalconfig.PayloadRule type PayloadModelRule = internalconfig.PayloadModelRule