diff --git a/config.example.yaml b/config.example.yaml index 8457e103..71863d49 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -55,6 +55,28 @@ quota-exceeded: # When true, enable authentication for the WebSocket API (/v1/ws). ws-auth: false +# Amp CLI Integration +# Configure upstream URL for Amp CLI OAuth and management features +#amp-upstream-url: "https://ampcode.com" + +# Optional: Override API key for Amp upstream (otherwise uses env or file) +#amp-upstream-api-key: "" + +# Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended) +#amp-restrict-management-to-localhost: true + +# Amp Model Mappings +# Route unavailable Amp models to alternative models available in your local proxy. +# Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5) +# but you have a similar model available (e.g., Claude Sonnet 4). +#amp-model-mappings: +# - from: "claude-opus-4.5" # Model requested by Amp CLI +# to: "claude-sonnet-4" # Route to this available model instead +# - from: "gpt-5" +# to: "gemini-2.5-pro" +# - from: "claude-3-opus-20240229" +# to: "claude-3-5-sonnet-20241022" + # Gemini API keys (preferred) #gemini-api-key: # - api-key: "AIzaSy...01" diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index 0086d179..b5a139f6 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -23,11 +23,13 @@ type Option func(*AmpModule) // - Reverse proxy to Amp control plane for OAuth/management // - Provider-specific route aliases (/api/provider/{provider}/...) // - Automatic gzip decompression for misconfigured upstreams +// - Model mapping for routing unavailable models to alternatives type AmpModule struct { secretSource SecretSource proxy *httputil.ReverseProxy accessManager *sdkaccess.Manager authMiddleware_ gin.HandlerFunc + modelMapper *DefaultModelMapper enabled bool registerOnce sync.Once } @@ -101,6 +103,9 @@ func (m *AmpModule) Register(ctx modules.Context) error { // Use registerOnce to ensure routes are only registered once var regErr error m.registerOnce.Do(func() { + // Initialize model mapper from config (for routing unavailable models to alternatives) + m.modelMapper = NewModelMapper(ctx.Config.AmpModelMappings) + // Always register provider aliases - these work without an upstream m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth) @@ -159,8 +164,13 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc { // OnConfigUpdated handles configuration updates. // Currently requires restart for URL changes (could be enhanced for dynamic updates). func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { + // Update model mappings (hot-reload supported) + if m.modelMapper != nil { + m.modelMapper.UpdateMappings(cfg.AmpModelMappings) + } + if !m.enabled { - log.Debug("Amp routing not enabled, skipping config update") + log.Debug("Amp routing not enabled, skipping other config updates") return nil } @@ -181,3 +191,8 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { log.Debug("Amp config updated (restart required for URL changes)") return nil } + +// GetModelMapper returns the model mapper instance (for testing/debugging). +func (m *AmpModule) GetModelMapper() *DefaultModelMapper { + return m.modelMapper +} diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index e7b28986..17c60708 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -6,16 +6,75 @@ import ( "io" "net/http/httputil" "strings" + "time" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" ) +// AmpRouteType represents the type of routing decision made for an Amp request +type AmpRouteType string + +const ( + // RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free) + RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER" + // RouteTypeModelMapping indicates the request was remapped to another available model (free) + RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING" + // RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits) + RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS" + // RouteTypeNoProvider indicates no provider or fallback available + RouteTypeNoProvider AmpRouteType = "NO_PROVIDER" +) + +// logAmpRouting logs the routing decision for an Amp request with structured fields +func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) { + fields := log.Fields{ + "component": "amp-routing", + "route_type": string(routeType), + "requested_model": requestedModel, + "path": path, + "timestamp": time.Now().Format(time.RFC3339), + } + + if resolvedModel != "" && resolvedModel != requestedModel { + fields["resolved_model"] = resolvedModel + } + if provider != "" { + fields["provider"] = provider + } + + switch routeType { + case RouteTypeLocalProvider: + fields["cost"] = "free" + fields["source"] = "local_oauth" + log.WithFields(fields).Infof("[AMP] Using local provider for model: %s", requestedModel) + + case RouteTypeModelMapping: + fields["cost"] = "free" + fields["source"] = "local_oauth" + fields["mapping"] = requestedModel + " -> " + resolvedModel + log.WithFields(fields).Infof("[AMP] Model mapped: %s -> %s", requestedModel, resolvedModel) + + case RouteTypeAmpCredits: + fields["cost"] = "amp_credits" + fields["source"] = "ampcode.com" + fields["model_id"] = requestedModel // Explicit model_id for easy config reference + log.WithFields(fields).Warnf("[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"\"}]", requestedModel, requestedModel) + + case RouteTypeNoProvider: + fields["cost"] = "none" + fields["source"] = "error" + fields["model_id"] = requestedModel // Explicit model_id for easy config reference + log.WithFields(fields).Warnf("[AMP] No provider available for model_id: %s", requestedModel) + } +} + // FallbackHandler wraps a standard handler with fallback logic to ampcode.com // when the model's provider is not available in CLIProxyAPI type FallbackHandler struct { - getProxy func() *httputil.ReverseProxy + getProxy func() *httputil.ReverseProxy + modelMapper ModelMapper } // NewFallbackHandler creates a new fallback handler wrapper @@ -26,10 +85,25 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler } } +// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support +func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler { + return &FallbackHandler{ + getProxy: getProxy, + modelMapper: mapper, + } +} + +// SetModelMapper sets the model mapper for this handler (allows late binding) +func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) { + fh.modelMapper = mapper +} + // WrapHandler wraps a gin.HandlerFunc with fallback logic // If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc { return func(c *gin.Context) { + requestPath := c.Request.URL.Path + // Read the request body to extract the model name bodyBytes, err := io.ReadAll(c.Request.Body) if err != nil { @@ -55,12 +129,33 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // Check if we have providers for this model providers := util.GetProviderName(normalizedModel) + // Track resolved model for logging (may change if mapping is applied) + resolvedModel := normalizedModel + usedMapping := false + if len(providers) == 0 { - // No providers configured - check if we have a proxy for fallback + // No providers configured - check if we have a model mapping + if fh.modelMapper != nil { + if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { + // Mapping found - rewrite the model in request body + bodyBytes = rewriteModelInBody(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + resolvedModel = mappedModel + usedMapping = true + + // Get providers for the mapped model + providers = util.GetProviderName(mappedModel) + + // Continue to handler with remapped model + goto handleRequest + } + } + + // No mapping found - check if we have a proxy for fallback proxy := fh.getProxy() if proxy != nil { - // Fallback to ampcode.com - log.Infof("amp fallback: model %s has no configured provider, forwarding to ampcode.com", modelName) + // Log: Forwarding to ampcode.com (uses Amp credits) + logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath) // Restore body again for the proxy c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) @@ -71,7 +166,23 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc } // No proxy available, let the normal handler return the error - log.Debugf("amp fallback: model %s has no configured provider and no proxy available", modelName) + logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath) + } + + handleRequest: + + // Log the routing decision + providerName := "" + if len(providers) > 0 { + providerName = providers[0] + } + + if usedMapping { + // Log: Model was mapped to another model + logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) + } else if len(providers) > 0 { + // Log: Using local provider (free) + logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) } // Providers available or no proxy for fallback, restore body and use normal handler @@ -91,6 +202,27 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc } } +// rewriteModelInBody replaces the model name in a JSON request body +func rewriteModelInBody(body []byte, newModel string) []byte { + var payload map[string]interface{} + if err := json.Unmarshal(body, &payload); err != nil { + log.Warnf("amp model mapping: failed to parse body for rewrite: %v", err) + return body + } + + if _, exists := payload["model"]; exists { + payload["model"] = newModel + newBody, err := json.Marshal(payload) + if err != nil { + log.Warnf("amp model mapping: failed to marshal rewritten body: %v", err) + return body + } + return newBody + } + + return body +} + // extractModelFromRequest attempts to extract the model name from various request formats func extractModelFromRequest(body []byte, c *gin.Context) string { // First try to parse from JSON body (OpenAI, Claude, etc.) diff --git a/internal/api/modules/amp/model_mapping.go b/internal/api/modules/amp/model_mapping.go new file mode 100644 index 00000000..c07f41c4 --- /dev/null +++ b/internal/api/modules/amp/model_mapping.go @@ -0,0 +1,113 @@ +// Package amp provides model mapping functionality for routing Amp CLI requests +// to alternative models when the requested model is not available locally. +package amp + +import ( + "strings" + "sync" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +// ModelMapper provides model name mapping/aliasing for Amp CLI requests. +// When an Amp request comes in for a model that isn't available locally, +// this mapper can redirect it to an alternative model that IS available. +type ModelMapper interface { + // MapModel returns the target model name if a mapping exists and the target + // model has available providers. Returns empty string if no mapping applies. + MapModel(requestedModel string) string + + // UpdateMappings refreshes the mapping configuration (for hot-reload). + UpdateMappings(mappings []config.AmpModelMapping) +} + +// DefaultModelMapper implements ModelMapper with thread-safe mapping storage. +type DefaultModelMapper struct { + mu sync.RWMutex + mappings map[string]string // from -> to (normalized lowercase keys) +} + +// NewModelMapper creates a new model mapper with the given initial mappings. +func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper { + m := &DefaultModelMapper{ + mappings: make(map[string]string), + } + m.UpdateMappings(mappings) + return m +} + +// MapModel checks if a mapping exists for the requested model and if the +// target model has available local providers. Returns the mapped model name +// or empty string if no valid mapping exists. +func (m *DefaultModelMapper) MapModel(requestedModel string) string { + if requestedModel == "" { + return "" + } + + m.mu.RLock() + defer m.mu.RUnlock() + + // Normalize the requested model for lookup + normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel)) + + // Check for direct mapping + targetModel, exists := m.mappings[normalizedRequest] + if !exists { + return "" + } + + // Verify target model has available providers + providers := util.GetProviderName(targetModel) + if len(providers) == 0 { + log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel) + return "" + } + + // Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go + log.Debugf("amp model mapping: resolved %s -> %s", requestedModel, targetModel) + return targetModel +} + +// UpdateMappings refreshes the mapping configuration from config. +// This is called during initialization and on config hot-reload. +func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) { + m.mu.Lock() + defer m.mu.Unlock() + + // Clear and rebuild mappings + m.mappings = make(map[string]string, len(mappings)) + + for _, mapping := range mappings { + from := strings.TrimSpace(mapping.From) + to := strings.TrimSpace(mapping.To) + + if from == "" || to == "" { + log.Warnf("amp model mapping: skipping invalid mapping (from=%q, to=%q)", from, to) + continue + } + + // Store with normalized lowercase key for case-insensitive lookup + normalizedFrom := strings.ToLower(from) + m.mappings[normalizedFrom] = to + + log.Debugf("amp model mapping registered: %s -> %s", from, to) + } + + if len(m.mappings) > 0 { + log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings)) + } +} + +// GetMappings returns a copy of current mappings (for debugging/status). +func (m *DefaultModelMapper) GetMappings() map[string]string { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make(map[string]string, len(m.mappings)) + for k, v := range m.mappings { + result[k] = v + } + return result +} diff --git a/internal/api/modules/amp/model_mapping_test.go b/internal/api/modules/amp/model_mapping_test.go new file mode 100644 index 00000000..c11d61bd --- /dev/null +++ b/internal/api/modules/amp/model_mapping_test.go @@ -0,0 +1,186 @@ +package amp + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" +) + +func TestNewModelMapper(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + {From: "gpt-5", To: "gemini-2.5-pro"}, + } + + mapper := NewModelMapper(mappings) + if mapper == nil { + t.Fatal("Expected non-nil mapper") + } + + result := mapper.GetMappings() + if len(result) != 2 { + t.Errorf("Expected 2 mappings, got %d", len(result)) + } +} + +func TestNewModelMapper_Empty(t *testing.T) { + mapper := NewModelMapper(nil) + if mapper == nil { + t.Fatal("Expected non-nil mapper") + } + + result := mapper.GetMappings() + if len(result) != 0 { + t.Errorf("Expected 0 mappings, got %d", len(result)) + } +} + +func TestModelMapper_MapModel_NoProvider(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + // Without a registered provider for the target, mapping should return empty + result := mapper.MapModel("claude-opus-4.5") + if result != "" { + t.Errorf("Expected empty result when target has no provider, got %s", result) + } +} + +func TestModelMapper_MapModel_WithProvider(t *testing.T) { + // Register a mock provider for the target model + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client", "claude", []*registry.ModelInfo{ + {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, + }) + defer reg.UnregisterClient("test-client") + + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + // With a registered provider, mapping should work + result := mapper.MapModel("claude-opus-4.5") + if result != "claude-sonnet-4" { + t.Errorf("Expected claude-sonnet-4, got %s", result) + } +} + +func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) { + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{ + {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, + }) + defer reg.UnregisterClient("test-client2") + + mappings := []config.AmpModelMapping{ + {From: "Claude-Opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + // Should match case-insensitively + result := mapper.MapModel("claude-opus-4.5") + if result != "claude-sonnet-4" { + t.Errorf("Expected claude-sonnet-4, got %s", result) + } +} + +func TestModelMapper_MapModel_NotFound(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + // Unknown model should return empty + result := mapper.MapModel("unknown-model") + if result != "" { + t.Errorf("Expected empty for unknown model, got %s", result) + } +} + +func TestModelMapper_MapModel_EmptyInput(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + result := mapper.MapModel("") + if result != "" { + t.Errorf("Expected empty for empty input, got %s", result) + } +} + +func TestModelMapper_UpdateMappings(t *testing.T) { + mapper := NewModelMapper(nil) + + // Initially empty + if len(mapper.GetMappings()) != 0 { + t.Error("Expected 0 initial mappings") + } + + // Update with new mappings + mapper.UpdateMappings([]config.AmpModelMapping{ + {From: "model-a", To: "model-b"}, + {From: "model-c", To: "model-d"}, + }) + + result := mapper.GetMappings() + if len(result) != 2 { + t.Errorf("Expected 2 mappings after update, got %d", len(result)) + } + + // Update again should replace, not append + mapper.UpdateMappings([]config.AmpModelMapping{ + {From: "model-x", To: "model-y"}, + }) + + result = mapper.GetMappings() + if len(result) != 1 { + t.Errorf("Expected 1 mapping after second update, got %d", len(result)) + } +} + +func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) { + mapper := NewModelMapper(nil) + + mapper.UpdateMappings([]config.AmpModelMapping{ + {From: "", To: "model-b"}, // Invalid: empty from + {From: "model-a", To: ""}, // Invalid: empty to + {From: " ", To: "model-b"}, // Invalid: whitespace from + {From: "model-c", To: "model-d"}, // Valid + }) + + result := mapper.GetMappings() + if len(result) != 1 { + t.Errorf("Expected 1 valid mapping, got %d", len(result)) + } +} + +func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "model-a", To: "model-b"}, + } + + mapper := NewModelMapper(mappings) + + // Get mappings and modify the returned map + result := mapper.GetMappings() + result["new-key"] = "new-value" + + // Original should be unchanged + original := mapper.GetMappings() + if len(original) != 1 { + t.Errorf("Expected original to have 1 mapping, got %d", len(original)) + } + if _, exists := original["new-key"]; exists { + t.Error("Original map was modified") + } +} diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 8e5189ad..8bd739bb 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -162,9 +162,10 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han // Create fallback handler wrapper that forwards to ampcode.com when provider not found // Uses lazy evaluation to access proxy (which is created after routes are registered) - fallbackHandler := NewFallbackHandler(func() *httputil.ReverseProxy { + // Also includes model mapping support for routing unavailable models to alternatives + fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return m.proxy - }) + }, m.modelMapper) // Provider-specific routes under /api/provider/:provider ampProviders := engine.Group("/api/provider") diff --git a/internal/config/config.go b/internal/config/config.go index 31920075..8612b3e5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -37,6 +37,12 @@ type Config struct { // browser attacks and remote access to management endpoints. Default: true (recommended). AmpRestrictManagementToLocalhost bool `yaml:"amp-restrict-management-to-localhost" json:"amp-restrict-management-to-localhost"` + // AmpModelMappings defines model name mappings for Amp CLI requests. + // When Amp requests a model that isn't available locally, these mappings + // allow routing to an alternative model that IS available. + // Example: Map "claude-opus-4.5" -> "claude-sonnet-4" when opus isn't available. + AmpModelMappings []AmpModelMapping `yaml:"amp-model-mappings" json:"amp-model-mappings"` + // AuthDir is the directory where authentication token files are stored. AuthDir string `yaml:"auth-dir" json:"-"` @@ -115,6 +121,18 @@ type QuotaExceeded struct { SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"` } +// AmpModelMapping defines a model name mapping for Amp CLI requests. +// When Amp requests a model that isn't available locally, this mapping +// allows routing to an alternative model that IS available. +type AmpModelMapping struct { + // From is the model name that Amp CLI requests (e.g., "claude-opus-4.5"). + From string `yaml:"from" json:"from"` + + // To is the target model name to route to (e.g., "claude-sonnet-4"). + // The target model must have available providers in the registry. + To string `yaml:"to" json:"to"` +} + // PayloadConfig defines default and override parameter rules applied to provider payloads. type PayloadConfig struct { // Default defines rules that only set parameters when they are missing in the payload.