diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index e7132b81..940bd5e8 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -134,7 +134,43 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc } // Normalize model (handles dynamic thinking suffixes) - normalizedModel, _ := util.NormalizeThinkingModel(modelName) + normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName) + thinkingSuffix := "" + if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) { + thinkingSuffix = modelName[len(normalizedModel):] + } + + resolveMappedModel := func() (string, []string) { + if fh.modelMapper == nil { + return "", nil + } + + mappedModel := fh.modelMapper.MapModel(modelName) + if mappedModel == "" { + mappedModel = fh.modelMapper.MapModel(normalizedModel) + } + mappedModel = strings.TrimSpace(mappedModel) + if mappedModel == "" { + return "", nil + } + + // Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target + // already specifies its own thinking suffix. + if thinkingSuffix != "" { + _, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel) + if mappedThinkingMetadata == nil { + mappedModel += thinkingSuffix + } + } + + mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel) + mappedProviders := util.GetProviderName(mappedBaseModel) + if len(mappedProviders) == 0 { + return "", nil + } + + return mappedModel, mappedProviders + } // Track resolved model for logging (may change if mapping is applied) resolvedModel := normalizedModel @@ -147,21 +183,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc if forceMappings { // FORCE MODE: Check model mappings FIRST (takes precedence over local API keys) // This allows users to route Amp requests to their preferred OAuth providers - if fh.modelMapper != nil { - if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { - // Mapping found - check if we have a provider for the mapped model - mappedProviders := util.GetProviderName(mappedModel) - if len(mappedProviders) > 0 { - // Mapping found and provider available - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model in context for handlers that check it (like gemini bridge) - c.Set(MappedModelContextKey, mappedModel) - resolvedModel = mappedModel - usedMapping = true - providers = mappedProviders - } - } + if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { + // Mapping found and provider available - rewrite the model in request body + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) + resolvedModel = mappedModel + usedMapping = true + providers = mappedProviders } // If no mapping applied, check for local providers @@ -174,21 +204,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc if len(providers) == 0 { // No providers configured - check if we have a model mapping - if fh.modelMapper != nil { - if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { - // Mapping found - check if we have a provider for the mapped model - mappedProviders := util.GetProviderName(mappedModel) - if len(mappedProviders) > 0 { - // Mapping found and provider available - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model in context for handlers that check it (like gemini bridge) - c.Set(MappedModelContextKey, mappedModel) - resolvedModel = mappedModel - usedMapping = true - providers = mappedProviders - } - } + if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { + // Mapping found and provider available - rewrite the model in request body + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) + resolvedModel = mappedModel + usedMapping = true + providers = mappedProviders } } } @@ -222,14 +246,14 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // Log: Model was mapped to another model log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel) logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) - rewriter := NewResponseRewriter(c.Writer, normalizedModel) + rewriter := NewResponseRewriter(c.Writer, modelName) c.Writer = rewriter // Filter Anthropic-Beta header only for local handling paths filterAntropicBetaHeader(c) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) handler(c) rewriter.Flush() - log.Debugf("amp model mapping: response %s -> %s", resolvedModel, normalizedModel) + log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName) } else if len(providers) > 0 { // Log: Using local provider (free) logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) diff --git a/internal/api/modules/amp/fallback_handlers_test.go b/internal/api/modules/amp/fallback_handlers_test.go new file mode 100644 index 00000000..a687fd11 --- /dev/null +++ b/internal/api/modules/amp/fallback_handlers_test.go @@ -0,0 +1,73 @@ +package amp + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "net/http/httputil" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" +) + +func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{ + {ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"}, + }) + defer reg.UnregisterClient("test-client-amp-fallback") + + mapper := NewModelMapper([]config.AmpModelMapping{ + {From: "gpt-5.2", To: "test/gpt-5.2"}, + }) + + fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil) + + handler := func(c *gin.Context) { + var req struct { + Model string `json:"model"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "model": req.Model, + "seen_model": req.Model, + }) + } + + r := gin.New() + r.POST("/chat/completions", fallback.WrapHandler(handler)) + + reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`) + req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d", w.Code) + } + + var resp struct { + Model string `json:"model"` + SeenModel string `json:"seen_model"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to parse response JSON: %v", err) + } + + if resp.Model != "gpt-5.2(xhigh)" { + t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model) + } + if resp.SeenModel != "test/gpt-5.2(xhigh)" { + t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel) + } +} diff --git a/internal/api/modules/amp/model_mapping.go b/internal/api/modules/amp/model_mapping.go index 87384a80..bc31c4e5 100644 --- a/internal/api/modules/amp/model_mapping.go +++ b/internal/api/modules/amp/model_mapping.go @@ -59,7 +59,8 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string { } // Verify target model has available providers - providers := util.GetProviderName(targetModel) + normalizedTarget, _ := util.NormalizeThinkingModel(targetModel) + providers := util.GetProviderName(normalizedTarget) if len(providers) == 0 { log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel) return "" diff --git a/internal/api/modules/amp/model_mapping_test.go b/internal/api/modules/amp/model_mapping_test.go index 4f4e5a8e..664a17c5 100644 --- a/internal/api/modules/amp/model_mapping_test.go +++ b/internal/api/modules/amp/model_mapping_test.go @@ -71,6 +71,25 @@ func TestModelMapper_MapModel_WithProvider(t *testing.T) { } } +func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) { + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{ + {ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"}, + }) + defer reg.UnregisterClient("test-client-thinking") + + mappings := []config.AmpModelMapping{ + {From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"}, + } + + mapper := NewModelMapper(mappings) + + result := mapper.MapModel("gpt-5.2-alias") + if result != "gpt-5.2(xhigh)" { + t.Errorf("Expected gpt-5.2(xhigh), got %s", result) + } +} + func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) { reg := registry.GetGlobalRegistry() reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{ diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 2dcf2e3a..f058287b 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -162,6 +162,21 @@ func GetGeminiModels() []*ModelInfo { SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, }, + { + ID: "gemini-3-flash-preview", + Object: "model", + Created: 1765929600, + OwnedBy: "google", + Type: "gemini", + Name: "models/gemini-3-flash-preview", + Version: "3.0", + DisplayName: "Gemini 3 Flash Preview", + Description: "Gemini 3 Flash Preview", + InputTokenLimit: 1048576, + OutputTokenLimit: 65536, + SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, + Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, + }, { ID: "gemini-3-pro-image-preview", Object: "model", diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 976ec404..f73f7233 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -97,6 +97,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated) + translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -191,6 +192,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated) + translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -524,6 +526,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated) + translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) diff --git a/internal/util/gemini_schema.go b/internal/util/gemini_schema.go index b25d14e4..bbd3e339 100644 --- a/internal/util/gemini_schema.go +++ b/internal/util/gemini_schema.go @@ -296,6 +296,7 @@ func flattenTypeArrays(jsonStr string) string { func removeUnsupportedKeywords(jsonStr string) string { keywords := append(unsupportedConstraints, "$schema", "$defs", "definitions", "const", "$ref", "additionalProperties", + "propertyNames", // Gemini doesn't support property name validation ) for _, key := range keywords { for _, p := range findPaths(jsonStr, key) { diff --git a/internal/util/gemini_schema_test.go b/internal/util/gemini_schema_test.go index 655511d9..55a3c5fd 100644 --- a/internal/util/gemini_schema_test.go +++ b/internal/util/gemini_schema_test.go @@ -596,6 +596,71 @@ func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) { } } +func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) { + // propertyNames is used to validate object property names (e.g., must match a pattern) + // Gemini doesn't support this keyword and will reject requests containing it + input := `{ + "type": "object", + "properties": { + "metadata": { + "type": "object", + "propertyNames": { + "pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$" + }, + "additionalProperties": { + "type": "string" + } + } + } + }` + + expected := `{ + "type": "object", + "properties": { + "metadata": { + "type": "object" + } + } + }` + + result := CleanJSONSchemaForGemini(input) + compareJSON(t, expected, result) + + // Verify propertyNames is completely removed + if strings.Contains(result, "propertyNames") { + t.Errorf("propertyNames keyword should be removed, got: %s", result) + } +} + +func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) { + // Test deeply nested propertyNames (as seen in real Claude tool schemas) + input := `{ + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { + "config": { + "type": "object", + "propertyNames": { + "type": "string" + } + } + } + } + } + } + }` + + result := CleanJSONSchemaForGemini(input) + + if strings.Contains(result, "propertyNames") { + t.Errorf("Nested propertyNames should be removed, got: %s", result) + } +} + func compareJSON(t *testing.T, expectedJSON, actualJSON string) { var expMap, actMap map[string]interface{} errExp := json.Unmarshal([]byte(expectedJSON), &expMap) diff --git a/internal/util/gemini_thinking.go b/internal/util/gemini_thinking.go index ba9e13ef..290d5f92 100644 --- a/internal/util/gemini_thinking.go +++ b/internal/util/gemini_thinking.go @@ -136,6 +136,12 @@ func ApplyGeminiThinkingLevel(body []byte, level string, includeThoughts *bool) updated = rewritten } } + if it := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); it.Exists() { + updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.include_thoughts") + } + if tb := gjson.GetBytes(body, "generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() { + updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.thinkingBudget") + } return updated } @@ -167,6 +173,12 @@ func ApplyGeminiCLIThinkingLevel(body []byte, level string, includeThoughts *boo updated = rewritten } } + if it := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); it.Exists() { + updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts") + } + if tb := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() { + updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.thinkingBudget") + } return updated }