From f092801b6188a10f05cf8c9a602a34d0a9a224f6 Mon Sep 17 00:00:00 2001 From: "huynguyen03.dev" Date: Sun, 7 Dec 2025 15:39:58 +0700 Subject: [PATCH 01/34] fix: filter whitespace-only text in Claude to OpenAI translation Skip text content blocks that are empty or contain only whitespace when translating Claude messages to OpenAI format. This fixes GLM-4.6 and other strict OpenAI-compatible providers that reject empty text with error 'text cannot be empty'. --- internal/translator/openai/claude/openai_claude_request.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/internal/translator/openai/claude/openai_claude_request.go b/internal/translator/openai/claude/openai_claude_request.go index bff306cc..2510b19c 100644 --- a/internal/translator/openai/claude/openai_claude_request.go +++ b/internal/translator/openai/claude/openai_claude_request.go @@ -8,6 +8,7 @@ package claude import ( "bytes" "encoding/json" + "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -245,8 +246,12 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) { if !part.Get("text").Exists() { return "", false } + text := part.Get("text").String() + if strings.TrimSpace(text) == "" { + return "", false + } textContent := `{"type":"text","text":""}` - textContent, _ = sjson.Set(textContent, "text", part.Get("text").String()) + textContent, _ = sjson.Set(textContent, "text", text) return textContent, true case "image": From 549c0c2c5a7318c9409ffec4e3365b3d055a068f Mon Sep 17 00:00:00 2001 From: "huynguyen03.dev" Date: Sun, 7 Dec 2025 16:08:12 +0700 Subject: [PATCH 02/34] fix: filter whitespace-only text content in Claude to OpenAI translation Remove redundant existence check since TrimSpace handles empty strings --- internal/translator/openai/claude/openai_claude_request.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/internal/translator/openai/claude/openai_claude_request.go b/internal/translator/openai/claude/openai_claude_request.go index 2510b19c..3521b2e5 100644 --- a/internal/translator/openai/claude/openai_claude_request.go +++ b/internal/translator/openai/claude/openai_claude_request.go @@ -243,9 +243,6 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) { switch partType { case "text": - if !part.Get("text").Exists() { - return "", false - } text := part.Get("text").String() if strings.TrimSpace(text) == "" { return "", false From 9c09128e00d2f384bd248c54bec8d62be25c0134 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sun, 7 Dec 2025 19:12:55 +0800 Subject: [PATCH 03/34] feat(registry): add explicit thinking support config for antigravity models --- internal/registry/model_definitions.go | 13 ++++++++- .../runtime/executor/antigravity_executor.go | 28 ++++++++----------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 36aa83bb..64e78199 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -943,8 +943,19 @@ func GetQwenModels() []*ModelInfo { } } -// GetIFlowModels returns supported models for iFlow OAuth accounts. +// GetAntigravityThinkingConfig returns the Thinking configuration for antigravity models. +// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup. +func GetAntigravityThinkingConfig() map[string]*ThinkingSupport { + return map[string]*ThinkingSupport{ + "gemini-2.5-flash": {Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, + "gemini-2.5-flash-lite": {Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, + "gemini-3-pro-preview": {Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, + "gemini-claude-sonnet-4-5-thinking": {Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, + "gemini-claude-opus-4-5-thinking": {Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, + } +} +// GetIFlowModels returns supported models for iFlow OAuth accounts. func GetIFlowModels() []*ModelInfo { entries := []struct { ID string diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 9fc4e722..ed9207f0 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -366,29 +366,25 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c } now := time.Now().Unix() + thinkingConfig := registry.GetAntigravityThinkingConfig() models := make([]*registry.ModelInfo, 0, len(result.Map())) - for id := range result.Map() { - id = modelName2Alias(id) - if id != "" { + for originalName := range result.Map() { + aliasName := modelName2Alias(originalName) + if aliasName != "" { modelInfo := ®istry.ModelInfo{ - ID: id, - Name: id, - Description: id, - DisplayName: id, - Version: id, + ID: aliasName, + Name: aliasName, + Description: aliasName, + DisplayName: aliasName, + Version: aliasName, Object: "model", Created: now, OwnedBy: antigravityAuthType, Type: antigravityAuthType, } - // Add Thinking support for thinking models - if strings.HasSuffix(id, "-thinking") || strings.Contains(id, "-thinking-") { - modelInfo.Thinking = ®istry.ThinkingSupport{ - Min: 1024, - Max: 100000, - ZeroAllowed: false, - DynamicAllowed: true, - } + // Look up Thinking support from static config using alias name + if thinking, ok := thinkingConfig[aliasName]; ok { + modelInfo.Thinking = thinking } models = append(models, modelInfo) } From a174d015f243973a329773a5bf0048578aff132a Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sun, 7 Dec 2025 19:14:05 +0800 Subject: [PATCH 04/34] feat(openai): handle thinking.budget_tokens from Anthropic-style requests --- .../chat-completions/antigravity_openai_request.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index d1914ec8..82e71758 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -88,6 +88,20 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ } } + // Claude/Anthropic API format: thinking.type == "enabled" with budget_tokens + // This allows Claude Code and other Claude API clients to pass thinking configuration + if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && util.ModelSupportsThinking(modelName) { + if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() { + if t.Get("type").String() == "enabled" { + if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number { + budget := util.NormalizeThinkingBudget(modelName, int(b.Int())) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) + } + } + } + } + // For gemini-3-pro-preview, always send default thinkingConfig when none specified. // This matches the official Gemini CLI behavior which always sends: // { thinkingBudget: -1, includeThoughts: true } From afcab5efda90f05b51ffb22380fb71b6888ae5a9 Mon Sep 17 00:00:00 2001 From: huynhgiabuu Date: Sun, 7 Dec 2025 22:47:43 +0700 Subject: [PATCH 05/34] feat: add prioritize-model-mappings config option Add a configuration option to control whether model mappings take precedence over local API keys for Amp CLI requests. - Add PrioritizeModelMappings field to AmpCode config struct - When false (default): Local API keys take precedence (original behavior) - When true: Model mappings take precedence over local API keys - Add management API endpoints GET/PUT /prioritize-model-mappings This allows users who want mapping priority to enable it explicitly while preserving backward compatibility. Config example: ampcode: model-mappings: - from: claude-opus-4-5-20251101 to: gemini-claude-opus-4-5-thinking prioritize-model-mappings: true --- .../api/handlers/management/config_basic.go | 8 ++ internal/api/modules/amp/amp.go | 10 +++ internal/api/modules/amp/fallback_handlers.go | 89 +++++++++++++------ internal/api/modules/amp/routes.go | 4 +- internal/api/server.go | 4 + internal/config/config.go | 4 + 6 files changed, 90 insertions(+), 29 deletions(-) diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index ae292982..e61c695e 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -241,3 +241,11 @@ func (h *Handler) DeleteProxyURL(c *gin.Context) { h.cfg.ProxyURL = "" h.persist(c) } + +// Prioritize Model Mappings (for Amp CLI) +func (h *Handler) GetPrioritizeModelMappings(c *gin.Context) { + c.JSON(200, gin.H{"prioritize-model-mappings": h.cfg.AmpCode.PrioritizeModelMappings}) +} +func (h *Handler) PutPrioritizeModelMappings(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.PrioritizeModelMappings = v }) +} diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index dabb7404..5c7c2708 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -100,6 +100,16 @@ func (m *AmpModule) Name() string { return "amp-routing" } +// getPrioritizeModelMappings returns whether model mappings should take precedence over local API keys +func (m *AmpModule) getPrioritizeModelMappings() bool { + m.configMu.RLock() + defer m.configMu.RUnlock() + if m.lastConfig == nil { + return false + } + return m.lastConfig.PrioritizeModelMappings +} + // Register sets up Amp routes if configured. // This implements the RouteModuleV2 interface with Context. // Routes are registered only once via sync.Once for idempotent behavior. diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index 0cbe0e1a..771e2713 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -77,23 +77,29 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid // FallbackHandler wraps a standard handler with fallback logic to ampcode.com // when the model's provider is not available in CLIProxyAPI type FallbackHandler struct { - getProxy func() *httputil.ReverseProxy - modelMapper ModelMapper + getProxy func() *httputil.ReverseProxy + modelMapper ModelMapper + getPrioritizeModelMappings func() bool } // NewFallbackHandler creates a new fallback handler wrapper // The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes) func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler { return &FallbackHandler{ - getProxy: getProxy, + getProxy: getProxy, + getPrioritizeModelMappings: func() bool { return false }, } } // NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support -func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler { +func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, getPrioritize func() bool) *FallbackHandler { + if getPrioritize == nil { + getPrioritize = func() bool { return false } + } return &FallbackHandler{ - getProxy: getProxy, - modelMapper: mapper, + getProxy: getProxy, + modelMapper: mapper, + getPrioritizeModelMappings: getPrioritize, } } @@ -130,34 +136,65 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // Normalize model (handles Gemini thinking suffixes) normalizedModel, _ := util.NormalizeGeminiThinkingModel(modelName) - // Check if we have providers for this model - providers := util.GetProviderName(normalizedModel) - // Track resolved model for logging (may change if mapping is applied) resolvedModel := normalizedModel usedMapping := false + var providers []string - if len(providers) == 0 { - // No providers configured - check if we have a model mapping + // Check if model mappings should take priority over local API keys + prioritizeMappings := fh.getPrioritizeModelMappings != nil && fh.getPrioritizeModelMappings() + + if prioritizeMappings { + // PRIORITY MODE: Check model mappings FIRST (takes precedence over local API keys) + // This allows users to route Amp requests to their preferred OAuth providers if fh.modelMapper != nil { if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { - // Mapping found - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model in context for handlers that check it (like gemini bridge) - c.Set(MappedModelContextKey, mappedModel) - resolvedModel = mappedModel - usedMapping = true - - // Get providers for the mapped model - providers = util.GetProviderName(mappedModel) - - // Continue to handler with remapped model - goto handleRequest + // Mapping found - check if we have a provider for the mapped model + mappedProviders := util.GetProviderName(mappedModel) + if len(mappedProviders) > 0 { + // Mapping found and provider available - rewrite the model in request body + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) + resolvedModel = mappedModel + usedMapping = true + providers = mappedProviders + } } } - // No mapping found - check if we have a proxy for fallback + // If no mapping applied, check for local providers + if !usedMapping { + providers = util.GetProviderName(normalizedModel) + } + } else { + // DEFAULT MODE: Check local providers first, then mappings as fallback + providers = util.GetProviderName(normalizedModel) + + if len(providers) == 0 { + // No providers configured - check if we have a model mapping + if fh.modelMapper != nil { + if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { + // Mapping found - check if we have a provider for the mapped model + mappedProviders := util.GetProviderName(mappedModel) + if len(mappedProviders) > 0 { + // Mapping found and provider available - rewrite the model in request body + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) + resolvedModel = mappedModel + usedMapping = true + providers = mappedProviders + } + } + } + } + } + + // If no providers available, fallback to ampcode.com + if len(providers) == 0 { proxy := fh.getProxy() if proxy != nil { // Log: Forwarding to ampcode.com (uses Amp credits) @@ -175,8 +212,6 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath) } - handleRequest: - // Log the routing decision providerName := "" if len(providers) > 0 { diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 6826dbbe..dedbd444 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -171,7 +171,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler) geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return m.getProxy() - }, m.modelMapper) + }, m.modelMapper, m.getPrioritizeModelMappings) geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) // Route POST model calls through Gemini bridge with FallbackHandler. @@ -209,7 +209,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han // Also includes model mapping support for routing unavailable models to alternatives fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return m.getProxy() - }, m.modelMapper) + }, m.modelMapper, m.getPrioritizeModelMappings) // Provider-specific routes under /api/provider/:provider ampProviders := engine.Group("/api/provider") diff --git a/internal/api/server.go b/internal/api/server.go index 9e1c5848..93d13557 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -520,6 +520,10 @@ func (s *Server) registerManagementRoutes() { mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth) mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth) + mgmt.GET("/prioritize-model-mappings", s.mgmt.GetPrioritizeModelMappings) + mgmt.PUT("/prioritize-model-mappings", s.mgmt.PutPrioritizeModelMappings) + mgmt.PATCH("/prioritize-model-mappings", s.mgmt.PutPrioritizeModelMappings) + mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry) diff --git a/internal/config/config.go b/internal/config/config.go index 2681d049..d7e455f6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -143,6 +143,10 @@ type AmpCode struct { // When Amp requests a model that isn't available locally, these mappings // allow routing to an alternative model that IS available. ModelMappings []AmpModelMapping `yaml:"model-mappings" json:"model-mappings"` + + // PrioritizeModelMappings when true, model mappings take precedence over local API keys. + // When false (default), local API keys are used first if available. + PrioritizeModelMappings bool `yaml:"prioritize-model-mappings" json:"prioritize-model-mappings"` } // PayloadConfig defines default and override parameter rules applied to provider payloads. From e9eb4db8bb67806441d0825fe2678d21a5459200 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Mon, 8 Dec 2025 09:48:31 +0800 Subject: [PATCH 06/34] feat(auth): refresh API key during cookie authentication --- internal/auth/iflow/iflow_auth.go | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/internal/auth/iflow/iflow_auth.go b/internal/auth/iflow/iflow_auth.go index 4957f519..2978e94c 100644 --- a/internal/auth/iflow/iflow_auth.go +++ b/internal/auth/iflow/iflow_auth.go @@ -309,17 +309,23 @@ func (ia *IFlowAuth) AuthenticateWithCookie(ctx context.Context, cookie string) return nil, fmt.Errorf("iflow cookie authentication: cookie is empty") } - // First, get initial API key information using GET request + // First, get initial API key information using GET request to obtain the name keyInfo, err := ia.fetchAPIKeyInfo(ctx, cookie) if err != nil { return nil, fmt.Errorf("iflow cookie authentication: fetch initial API key info failed: %w", err) } - // Convert to token data format + // Refresh the API key using POST request + refreshedKeyInfo, err := ia.RefreshAPIKey(ctx, cookie, keyInfo.Name) + if err != nil { + return nil, fmt.Errorf("iflow cookie authentication: refresh API key failed: %w", err) + } + + // Convert to token data format using refreshed key data := &IFlowTokenData{ - APIKey: keyInfo.APIKey, - Expire: keyInfo.ExpireTime, - Email: keyInfo.Name, + APIKey: refreshedKeyInfo.APIKey, + Expire: refreshedKeyInfo.ExpireTime, + Email: refreshedKeyInfo.Name, Cookie: cookie, } From 56ed0d8d901847332b18fb6bfc055ba80748e284 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Mon, 8 Dec 2025 10:44:39 +0800 Subject: [PATCH 07/34] refactor(config): rename prioritize-model-mappings to force-model-mappings --- config.example.yaml | 2 ++ .../api/handlers/management/config_basic.go | 10 +++---- internal/api/modules/amp/amp.go | 6 ++-- internal/api/modules/amp/fallback_handlers.go | 30 +++++++++---------- internal/api/modules/amp/routes.go | 4 +-- internal/api/server.go | 6 ++-- internal/config/config.go | 4 +-- 7 files changed, 32 insertions(+), 30 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index 61f51d47..0f8679aa 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -134,6 +134,8 @@ ws-auth: false # upstream-api-key: "" # # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended) # restrict-management-to-localhost: true +# # Force model mappings to run before checking local API keys (default: false) +# force-model-mappings: false # # Amp Model Mappings # # Route unavailable Amp models to alternative models available in your local proxy. # # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5) diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index e61c695e..3702e156 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -242,10 +242,10 @@ func (h *Handler) DeleteProxyURL(c *gin.Context) { h.persist(c) } -// Prioritize Model Mappings (for Amp CLI) -func (h *Handler) GetPrioritizeModelMappings(c *gin.Context) { - c.JSON(200, gin.H{"prioritize-model-mappings": h.cfg.AmpCode.PrioritizeModelMappings}) +// Force Model Mappings (for Amp CLI) +func (h *Handler) GetForceModelMappings(c *gin.Context) { + c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings}) } -func (h *Handler) PutPrioritizeModelMappings(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.PrioritizeModelMappings = v }) +func (h *Handler) PutForceModelMappings(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) } diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index 5c7c2708..88319a78 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -100,14 +100,14 @@ func (m *AmpModule) Name() string { return "amp-routing" } -// getPrioritizeModelMappings returns whether model mappings should take precedence over local API keys -func (m *AmpModule) getPrioritizeModelMappings() bool { +// forceModelMappings returns whether model mappings should take precedence over local API keys +func (m *AmpModule) forceModelMappings() bool { m.configMu.RLock() defer m.configMu.RUnlock() if m.lastConfig == nil { return false } - return m.lastConfig.PrioritizeModelMappings + return m.lastConfig.ForceModelMappings } // Register sets up Amp routes if configured. diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index 771e2713..3ec6c85e 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -77,29 +77,29 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid // FallbackHandler wraps a standard handler with fallback logic to ampcode.com // when the model's provider is not available in CLIProxyAPI type FallbackHandler struct { - getProxy func() *httputil.ReverseProxy - modelMapper ModelMapper - getPrioritizeModelMappings func() bool + getProxy func() *httputil.ReverseProxy + modelMapper ModelMapper + forceModelMappings func() bool } // NewFallbackHandler creates a new fallback handler wrapper // The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes) func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler { return &FallbackHandler{ - getProxy: getProxy, - getPrioritizeModelMappings: func() bool { return false }, + getProxy: getProxy, + forceModelMappings: func() bool { return false }, } } // NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support -func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, getPrioritize func() bool) *FallbackHandler { - if getPrioritize == nil { - getPrioritize = func() bool { return false } +func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler { + if forceModelMappings == nil { + forceModelMappings = func() bool { return false } } return &FallbackHandler{ - getProxy: getProxy, - modelMapper: mapper, - getPrioritizeModelMappings: getPrioritize, + getProxy: getProxy, + modelMapper: mapper, + forceModelMappings: forceModelMappings, } } @@ -141,11 +141,11 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc usedMapping := false var providers []string - // Check if model mappings should take priority over local API keys - prioritizeMappings := fh.getPrioritizeModelMappings != nil && fh.getPrioritizeModelMappings() + // Check if model mappings should be forced ahead of local API keys + forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings() - if prioritizeMappings { - // PRIORITY MODE: Check model mappings FIRST (takes precedence over local API keys) + if forceMappings { + // FORCE MODE: Check model mappings FIRST (takes precedence over local API keys) // This allows users to route Amp requests to their preferred OAuth providers if fh.modelMapper != nil { if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index dedbd444..0c1fcadb 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -171,7 +171,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler) geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return m.getProxy() - }, m.modelMapper, m.getPrioritizeModelMappings) + }, m.modelMapper, m.forceModelMappings) geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) // Route POST model calls through Gemini bridge with FallbackHandler. @@ -209,7 +209,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han // Also includes model mapping support for routing unavailable models to alternatives fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return m.getProxy() - }, m.modelMapper, m.getPrioritizeModelMappings) + }, m.modelMapper, m.forceModelMappings) // Provider-specific routes under /api/provider/:provider ampProviders := engine.Group("/api/provider") diff --git a/internal/api/server.go b/internal/api/server.go index 93d13557..1cc4a4fe 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -520,9 +520,9 @@ func (s *Server) registerManagementRoutes() { mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth) mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth) - mgmt.GET("/prioritize-model-mappings", s.mgmt.GetPrioritizeModelMappings) - mgmt.PUT("/prioritize-model-mappings", s.mgmt.PutPrioritizeModelMappings) - mgmt.PATCH("/prioritize-model-mappings", s.mgmt.PutPrioritizeModelMappings) + mgmt.GET("/force-model-mappings", s.mgmt.GetForceModelMappings) + mgmt.PUT("/force-model-mappings", s.mgmt.PutForceModelMappings) + mgmt.PATCH("/force-model-mappings", s.mgmt.PutForceModelMappings) mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) diff --git a/internal/config/config.go b/internal/config/config.go index d7e455f6..f6d1eb73 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -144,9 +144,9 @@ type AmpCode struct { // allow routing to an alternative model that IS available. ModelMappings []AmpModelMapping `yaml:"model-mappings" json:"model-mappings"` - // PrioritizeModelMappings when true, model mappings take precedence over local API keys. + // ForceModelMappings when true, model mappings take precedence over local API keys. // When false (default), local API keys are used first if available. - PrioritizeModelMappings bool `yaml:"prioritize-model-mappings" json:"prioritize-model-mappings"` + ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"` } // PayloadConfig defines default and override parameter rules applied to provider payloads. From 93a6e2d9206520611c56131a1306473cbbe6ea57 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Mon, 8 Dec 2025 12:03:00 +0800 Subject: [PATCH 08/34] feat(api): add comprehensive ampcode management endpoints Add new REST API endpoints under /v0/management/ampcode for managing ampcode configuration including upstream URL, API key, localhost restriction, model mappings, and force model mappings settings. - Move force-model-mappings from config_basic to config_lists - Add GET/PUT/PATCH/DELETE endpoints for all ampcode settings - Support model mapping CRUD with upsert (PATCH) capability - Add comprehensive test coverage for all ampcode endpoints --- .../api/handlers/management/config_basic.go | 8 - .../api/handlers/management/config_lists.go | 170 ++++ internal/api/server.go | 22 +- test/amp_management_test.go | 779 ++++++++++++++++++ 4 files changed, 968 insertions(+), 11 deletions(-) create mode 100644 test/amp_management_test.go diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index 3702e156..ae292982 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -241,11 +241,3 @@ func (h *Handler) DeleteProxyURL(c *gin.Context) { h.cfg.ProxyURL = "" h.persist(c) } - -// Force Model Mappings (for Amp CLI) -func (h *Handler) GetForceModelMappings(c *gin.Context) { - c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings}) -} -func (h *Handler) PutForceModelMappings(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) -} diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index 8f4c4037..93a02409 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -706,3 +706,173 @@ func normalizeClaudeKey(entry *config.ClaudeKey) { } entry.Models = normalized } + +// GetAmpCode returns the complete ampcode configuration. +func (h *Handler) GetAmpCode(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"ampcode": config.AmpCode{}}) + return + } + c.JSON(200, gin.H{"ampcode": h.cfg.AmpCode}) +} + +// GetAmpUpstreamURL returns the ampcode upstream URL. +func (h *Handler) GetAmpUpstreamURL(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"upstream-url": ""}) + return + } + c.JSON(200, gin.H{"upstream-url": h.cfg.AmpCode.UpstreamURL}) +} + +// PutAmpUpstreamURL updates the ampcode upstream URL. +func (h *Handler) PutAmpUpstreamURL(c *gin.Context) { + h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) }) +} + +// DeleteAmpUpstreamURL clears the ampcode upstream URL. +func (h *Handler) DeleteAmpUpstreamURL(c *gin.Context) { + h.cfg.AmpCode.UpstreamURL = "" + h.persist(c) +} + +// GetAmpUpstreamAPIKey returns the ampcode upstream API key. +func (h *Handler) GetAmpUpstreamAPIKey(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"upstream-api-key": ""}) + return + } + c.JSON(200, gin.H{"upstream-api-key": h.cfg.AmpCode.UpstreamAPIKey}) +} + +// PutAmpUpstreamAPIKey updates the ampcode upstream API key. +func (h *Handler) PutAmpUpstreamAPIKey(c *gin.Context) { + h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) }) +} + +// DeleteAmpUpstreamAPIKey clears the ampcode upstream API key. +func (h *Handler) DeleteAmpUpstreamAPIKey(c *gin.Context) { + h.cfg.AmpCode.UpstreamAPIKey = "" + h.persist(c) +} + +// GetAmpRestrictManagementToLocalhost returns the localhost restriction setting. +func (h *Handler) GetAmpRestrictManagementToLocalhost(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"restrict-management-to-localhost": true}) + return + } + c.JSON(200, gin.H{"restrict-management-to-localhost": h.cfg.AmpCode.RestrictManagementToLocalhost}) +} + +// PutAmpRestrictManagementToLocalhost updates the localhost restriction setting. +func (h *Handler) PutAmpRestrictManagementToLocalhost(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.RestrictManagementToLocalhost = v }) +} + +// GetAmpModelMappings returns the ampcode model mappings. +func (h *Handler) GetAmpModelMappings(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"model-mappings": []config.AmpModelMapping{}}) + return + } + c.JSON(200, gin.H{"model-mappings": h.cfg.AmpCode.ModelMappings}) +} + +// PutAmpModelMappings replaces all ampcode model mappings. +func (h *Handler) PutAmpModelMappings(c *gin.Context) { + var body struct { + Value []config.AmpModelMapping `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + var mappings []config.AmpModelMapping + if err2 := c.ShouldBindJSON(&mappings); err2 != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + body.Value = mappings + } + h.cfg.AmpCode.ModelMappings = body.Value + h.persist(c) +} + +// PatchAmpModelMappings adds or updates model mappings. +func (h *Handler) PatchAmpModelMappings(c *gin.Context) { + var body struct { + Value []config.AmpModelMapping `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + var mappings []config.AmpModelMapping + if err2 := c.ShouldBindJSON(&mappings); err2 != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + body.Value = mappings + } + + existing := make(map[string]int) + for i, m := range h.cfg.AmpCode.ModelMappings { + existing[strings.TrimSpace(m.From)] = i + } + + for _, newMapping := range body.Value { + from := strings.TrimSpace(newMapping.From) + if idx, ok := existing[from]; ok { + h.cfg.AmpCode.ModelMappings[idx] = newMapping + } else { + h.cfg.AmpCode.ModelMappings = append(h.cfg.AmpCode.ModelMappings, newMapping) + existing[from] = len(h.cfg.AmpCode.ModelMappings) - 1 + } + } + h.persist(c) +} + +// DeleteAmpModelMappings removes specified model mappings by "from" field. +func (h *Handler) DeleteAmpModelMappings(c *gin.Context) { + var body struct { + Value []string `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + var fromList []string + if err2 := c.ShouldBindJSON(&fromList); err2 != nil { + h.cfg.AmpCode.ModelMappings = nil + h.persist(c) + return + } + body.Value = fromList + } + + if len(body.Value) == 0 { + h.cfg.AmpCode.ModelMappings = nil + h.persist(c) + return + } + + toRemove := make(map[string]bool) + for _, from := range body.Value { + toRemove[strings.TrimSpace(from)] = true + } + + newMappings := make([]config.AmpModelMapping, 0, len(h.cfg.AmpCode.ModelMappings)) + for _, m := range h.cfg.AmpCode.ModelMappings { + if !toRemove[strings.TrimSpace(m.From)] { + newMappings = append(newMappings, m) + } + } + h.cfg.AmpCode.ModelMappings = newMappings + h.persist(c) +} + +// GetAmpForceModelMappings returns whether model mappings are forced. +func (h *Handler) GetAmpForceModelMappings(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"force-model-mappings": false}) + return + } + c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings}) +} + +// PutAmpForceModelMappings updates the force model mappings setting. +func (h *Handler) PutAmpForceModelMappings(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) +} diff --git a/internal/api/server.go b/internal/api/server.go index 1cc4a4fe..b65185a7 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -520,9 +520,25 @@ func (s *Server) registerManagementRoutes() { mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth) mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth) - mgmt.GET("/force-model-mappings", s.mgmt.GetForceModelMappings) - mgmt.PUT("/force-model-mappings", s.mgmt.PutForceModelMappings) - mgmt.PATCH("/force-model-mappings", s.mgmt.PutForceModelMappings) + mgmt.GET("/ampcode", s.mgmt.GetAmpCode) + mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL) + mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) + mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) + mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL) + mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey) + mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) + mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) + mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey) + mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost) + mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) + mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) + mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings) + mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings) + mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings) + mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings) + mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings) + mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) + mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) diff --git a/test/amp_management_test.go b/test/amp_management_test.go new file mode 100644 index 00000000..3cb8be87 --- /dev/null +++ b/test/amp_management_test.go @@ -0,0 +1,779 @@ +package test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func newAmpTestHandler(t *testing.T) (*management.Handler, string) { + t.Helper() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + UpstreamURL: "https://example.com", + UpstreamAPIKey: "test-api-key-12345", + RestrictManagementToLocalhost: true, + ForceModelMappings: false, + ModelMappings: []config.AmpModelMapping{ + {From: "gpt-4", To: "gemini-pro"}, + }, + }, + } + + if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + h := management.NewHandler(cfg, configPath, nil) + return h, configPath +} + +func setupAmpRouter(h *management.Handler) *gin.Engine { + r := gin.New() + mgmt := r.Group("/v0/management") + { + mgmt.GET("/ampcode", h.GetAmpCode) + mgmt.GET("/ampcode/upstream-url", h.GetAmpUpstreamURL) + mgmt.PUT("/ampcode/upstream-url", h.PutAmpUpstreamURL) + mgmt.DELETE("/ampcode/upstream-url", h.DeleteAmpUpstreamURL) + mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey) + mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey) + mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey) + mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost) + mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost) + mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings) + mgmt.PUT("/ampcode/model-mappings", h.PutAmpModelMappings) + mgmt.PATCH("/ampcode/model-mappings", h.PatchAmpModelMappings) + mgmt.DELETE("/ampcode/model-mappings", h.DeleteAmpModelMappings) + mgmt.GET("/ampcode/force-model-mappings", h.GetAmpForceModelMappings) + mgmt.PUT("/ampcode/force-model-mappings", h.PutAmpForceModelMappings) + } + return r +} + +func TestGetAmpCode(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]config.AmpCode + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + ampcode := resp["ampcode"] + if ampcode.UpstreamURL != "https://example.com" { + t.Errorf("expected upstream-url %q, got %q", "https://example.com", ampcode.UpstreamURL) + } + if len(ampcode.ModelMappings) != 1 { + t.Errorf("expected 1 model mapping, got %d", len(ampcode.ModelMappings)) + } +} + +func TestGetAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["upstream-url"] != "https://example.com" { + t.Errorf("expected %q, got %q", "https://example.com", resp["upstream-url"]) + } +} + +func TestPutAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "https://new-upstream.com"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +func TestDeleteAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +func TestGetAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + key := resp["upstream-api-key"].(string) + if key != "test-api-key-12345" { + t.Errorf("expected key %q, got %q", "test-api-key-12345", key) + } +} + +func TestPutAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "new-secret-key"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +func TestDeleteAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +func TestGetAmpRestrictManagementToLocalhost(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["restrict-management-to-localhost"] != true { + t.Error("expected restrict-management-to-localhost to be true") + } +} + +func TestPutAmpRestrictManagementToLocalhost(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": false}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +func TestGetAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 1 { + t.Fatalf("expected 1 mapping, got %d", len(mappings)) + } + if mappings[0].From != "gpt-4" || mappings[0].To != "gemini-pro" { + t.Errorf("unexpected mapping: %+v", mappings[0]) + } +} + +func TestPutAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "claude-3", "to": "gpt-4o"}, {"from": "gemini", "to": "claude"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +func TestPatchAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "gpt-4", "to": "updated-model"}, {"from": "new-model", "to": "target"}]}` + req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +func TestDeleteAmpModelMappings_Specific(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": ["gpt-4"]}` + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +func TestDeleteAmpModelMappings_All(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +func TestGetAmpForceModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["force-model-mappings"] != false { + t.Error("expected force-model-mappings to be false") + } +} + +func TestPutAmpForceModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": true}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +func TestPutAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "model-a", "to": "model-b"}, {"from": "model-c", "to": "model-d"}, {"from": "model-e", "to": "model-f"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d, body: %s", w.Code, w.Body.String()) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 3 { + t.Fatalf("expected 3 mappings, got %d", len(mappings)) + } + + expected := map[string]string{"model-a": "model-b", "model-c": "model-d", "model-e": "model-f"} + for _, m := range mappings { + if expected[m.From] != m.To { + t.Errorf("mapping %q -> expected %q, got %q", m.From, expected[m.From], m.To) + } + } +} + +func TestPatchAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "gpt-4", "to": "updated-target"}, {"from": "new-model", "to": "new-target"}]}` + req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PATCH failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 2 { + t.Fatalf("expected 2 mappings (1 updated + 1 new), got %d", len(mappings)) + } + + found := make(map[string]string) + for _, m := range mappings { + found[m.From] = m.To + } + + if found["gpt-4"] != "updated-target" { + t.Errorf("gpt-4 should map to updated-target, got %q", found["gpt-4"]) + } + if found["new-model"] != "new-target" { + t.Errorf("new-model should map to new-target, got %q", found["new-model"]) + } +} + +func TestDeleteAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + putBody := `{"value": [{"from": "a", "to": "1"}, {"from": "b", "to": "2"}, {"from": "c", "to": "3"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + delBody := `{"value": ["a", "c"]}` + req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 1 { + t.Fatalf("expected 1 mapping remaining, got %d", len(mappings)) + } + if mappings[0].From != "b" || mappings[0].To != "2" { + t.Errorf("expected b->2, got %s->%s", mappings[0].From, mappings[0].To) + } +} + +func TestDeleteAmpModelMappings_NonExistent(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + delBody := `{"value": ["non-existent-model"]}` + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 1 { + t.Errorf("original mapping should remain, got %d mappings", len(resp["model-mappings"])) + } +} + +func TestPutAmpModelMappings_Empty(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": []}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 0 { + t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"])) + } +} + +func TestPutAmpUpstreamURL_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "https://new-api.example.com"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-url"] != "https://new-api.example.com" { + t.Errorf("expected %q, got %q", "https://new-api.example.com", resp["upstream-url"]) + } +} + +func TestDeleteAmpUpstreamURL_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-url"] != "" { + t.Errorf("expected empty string, got %q", resp["upstream-url"]) + } +} + +func TestPutAmpUpstreamAPIKey_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "new-secret-api-key-xyz"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-api-key"] != "new-secret-api-key-xyz" { + t.Errorf("expected %q, got %q", "new-secret-api-key-xyz", resp["upstream-api-key"]) + } +} + +func TestDeleteAmpUpstreamAPIKey_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-api-key"] != "" { + t.Errorf("expected empty string, got %q", resp["upstream-api-key"]) + } +} + +func TestPutAmpRestrictManagementToLocalhost_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": false}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["restrict-management-to-localhost"] != false { + t.Error("expected false after update") + } +} + +func TestPutAmpForceModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": true}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["force-model-mappings"] != true { + t.Error("expected true after update") + } +} + +func TestComplexMappingsWorkflow(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + putBody := `{"value": [{"from": "m1", "to": "t1"}, {"from": "m2", "to": "t2"}, {"from": "m3", "to": "t3"}, {"from": "m4", "to": "t4"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + patchBody := `{"value": [{"from": "m2", "to": "t2-updated"}, {"from": "m5", "to": "t5"}]}` + req = httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(patchBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + delBody := `{"value": ["m1", "m3"]}` + req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 3 { + t.Fatalf("expected 3 mappings (m2, m4, m5), got %d", len(mappings)) + } + + expected := map[string]string{"m2": "t2-updated", "m4": "t4", "m5": "t5"} + found := make(map[string]string) + for _, m := range mappings { + found[m.From] = m.To + } + + for from, to := range expected { + if found[from] != to { + t.Errorf("mapping %s: expected %q, got %q", from, to, found[from]) + } + } +} + +func TestNilHandlerGetAmpCode(t *testing.T) { + cfg := &config.Config{} + h := management.NewHandler(cfg, "", nil) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +func TestEmptyConfigGetAmpModelMappings(t *testing.T) { + cfg := &config.Config{} + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + h := management.NewHandler(cfg, configPath, nil) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 0 { + t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"])) + } +} From 05cfa16e5fc7b28240d250397849ebd7e28a17e0 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Mon, 8 Dec 2025 14:45:35 +0800 Subject: [PATCH 09/34] refactor(api): simplify request body parsing in ampcode handlers --- .../api/handlers/management/config_lists.go | 28 ++--------- internal/api/handlers/management/handler.go | 10 ---- test/amp_management_test.go | 48 +++++++++++++++++++ 3 files changed, 53 insertions(+), 33 deletions(-) diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index 93a02409..a0d0b169 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -785,12 +785,8 @@ func (h *Handler) PutAmpModelMappings(c *gin.Context) { Value []config.AmpModelMapping `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil { - var mappings []config.AmpModelMapping - if err2 := c.ShouldBindJSON(&mappings); err2 != nil { - c.JSON(400, gin.H{"error": "invalid body"}) - return - } - body.Value = mappings + c.JSON(400, gin.H{"error": "invalid body"}) + return } h.cfg.AmpCode.ModelMappings = body.Value h.persist(c) @@ -802,12 +798,8 @@ func (h *Handler) PatchAmpModelMappings(c *gin.Context) { Value []config.AmpModelMapping `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil { - var mappings []config.AmpModelMapping - if err2 := c.ShouldBindJSON(&mappings); err2 != nil { - c.JSON(400, gin.H{"error": "invalid body"}) - return - } - body.Value = mappings + c.JSON(400, gin.H{"error": "invalid body"}) + return } existing := make(map[string]int) @@ -832,17 +824,7 @@ func (h *Handler) DeleteAmpModelMappings(c *gin.Context) { var body struct { Value []string `json:"value"` } - if err := c.ShouldBindJSON(&body); err != nil { - var fromList []string - if err2 := c.ShouldBindJSON(&fromList); err2 != nil { - h.cfg.AmpCode.ModelMappings = nil - h.persist(c) - return - } - body.Value = fromList - } - - if len(body.Value) == 0 { + if err := c.ShouldBindJSON(&body); err != nil || len(body.Value) == 0 { h.cfg.AmpCode.ModelMappings = nil h.persist(c) return diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index ef6f400a..39e6b7fd 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -240,16 +240,6 @@ func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) { Value *bool `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { - var m map[string]any - if err2 := c.ShouldBindJSON(&m); err2 == nil { - for _, v := range m { - if b, ok := v.(bool); ok { - set(b) - h.persist(c) - return - } - } - } c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) return } diff --git a/test/amp_management_test.go b/test/amp_management_test.go index 3cb8be87..19450dbf 100644 --- a/test/amp_management_test.go +++ b/test/amp_management_test.go @@ -18,6 +18,7 @@ func init() { gin.SetMode(gin.TestMode) } +// newAmpTestHandler creates a test handler with default ampcode configuration. func newAmpTestHandler(t *testing.T) (*management.Handler, string) { t.Helper() tmpDir := t.TempDir() @@ -43,6 +44,7 @@ func newAmpTestHandler(t *testing.T) (*management.Handler, string) { return h, configPath } +// setupAmpRouter creates a test router with all ampcode management endpoints. func setupAmpRouter(h *management.Handler) *gin.Engine { r := gin.New() mgmt := r.Group("/v0/management") @@ -66,6 +68,7 @@ func setupAmpRouter(h *management.Handler) *gin.Engine { return r } +// TestGetAmpCode verifies GET /v0/management/ampcode returns full ampcode config. func TestGetAmpCode(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -92,6 +95,7 @@ func TestGetAmpCode(t *testing.T) { } } +// TestGetAmpUpstreamURL verifies GET /v0/management/ampcode/upstream-url returns the upstream URL. func TestGetAmpUpstreamURL(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -114,6 +118,7 @@ func TestGetAmpUpstreamURL(t *testing.T) { } } +// TestPutAmpUpstreamURL verifies PUT /v0/management/ampcode/upstream-url updates the upstream URL. func TestPutAmpUpstreamURL(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -129,6 +134,7 @@ func TestPutAmpUpstreamURL(t *testing.T) { } } +// TestDeleteAmpUpstreamURL verifies DELETE /v0/management/ampcode/upstream-url clears the upstream URL. func TestDeleteAmpUpstreamURL(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -142,6 +148,7 @@ func TestDeleteAmpUpstreamURL(t *testing.T) { } } +// TestGetAmpUpstreamAPIKey verifies GET /v0/management/ampcode/upstream-api-key returns the API key. func TestGetAmpUpstreamAPIKey(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -165,6 +172,7 @@ func TestGetAmpUpstreamAPIKey(t *testing.T) { } } +// TestPutAmpUpstreamAPIKey verifies PUT /v0/management/ampcode/upstream-api-key updates the API key. func TestPutAmpUpstreamAPIKey(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -180,6 +188,7 @@ func TestPutAmpUpstreamAPIKey(t *testing.T) { } } +// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key. func TestDeleteAmpUpstreamAPIKey(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -193,6 +202,7 @@ func TestDeleteAmpUpstreamAPIKey(t *testing.T) { } } +// TestGetAmpRestrictManagementToLocalhost verifies GET returns the localhost restriction setting. func TestGetAmpRestrictManagementToLocalhost(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -215,6 +225,7 @@ func TestGetAmpRestrictManagementToLocalhost(t *testing.T) { } } +// TestPutAmpRestrictManagementToLocalhost verifies PUT updates the localhost restriction setting. func TestPutAmpRestrictManagementToLocalhost(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -230,6 +241,7 @@ func TestPutAmpRestrictManagementToLocalhost(t *testing.T) { } } +// TestGetAmpModelMappings verifies GET /v0/management/ampcode/model-mappings returns all mappings. func TestGetAmpModelMappings(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -256,6 +268,7 @@ func TestGetAmpModelMappings(t *testing.T) { } } +// TestPutAmpModelMappings verifies PUT /v0/management/ampcode/model-mappings replaces all mappings. func TestPutAmpModelMappings(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -271,6 +284,7 @@ func TestPutAmpModelMappings(t *testing.T) { } } +// TestPatchAmpModelMappings verifies PATCH updates existing mappings and adds new ones. func TestPatchAmpModelMappings(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -286,6 +300,7 @@ func TestPatchAmpModelMappings(t *testing.T) { } } +// TestDeleteAmpModelMappings_Specific verifies DELETE removes specified mappings by "from" field. func TestDeleteAmpModelMappings_Specific(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -301,6 +316,7 @@ func TestDeleteAmpModelMappings_Specific(t *testing.T) { } } +// TestDeleteAmpModelMappings_All verifies DELETE with empty body removes all mappings. func TestDeleteAmpModelMappings_All(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -314,6 +330,7 @@ func TestDeleteAmpModelMappings_All(t *testing.T) { } } +// TestGetAmpForceModelMappings verifies GET returns the force-model-mappings setting. func TestGetAmpForceModelMappings(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -336,6 +353,7 @@ func TestGetAmpForceModelMappings(t *testing.T) { } } +// TestPutAmpForceModelMappings verifies PUT updates the force-model-mappings setting. func TestPutAmpForceModelMappings(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -351,6 +369,7 @@ func TestPutAmpForceModelMappings(t *testing.T) { } } +// TestPutAmpModelMappings_VerifyState verifies PUT replaces mappings and state is persisted. func TestPutAmpModelMappings_VerifyState(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -387,6 +406,7 @@ func TestPutAmpModelMappings_VerifyState(t *testing.T) { } } +// TestPatchAmpModelMappings_VerifyState verifies PATCH merges mappings correctly. func TestPatchAmpModelMappings_VerifyState(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -428,6 +448,7 @@ func TestPatchAmpModelMappings_VerifyState(t *testing.T) { } } +// TestDeleteAmpModelMappings_VerifyState verifies DELETE removes specific mappings and keeps others. func TestDeleteAmpModelMappings_VerifyState(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -466,6 +487,7 @@ func TestDeleteAmpModelMappings_VerifyState(t *testing.T) { } } +// TestDeleteAmpModelMappings_NonExistent verifies DELETE with non-existent mapping doesn't affect existing ones. func TestDeleteAmpModelMappings_NonExistent(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -494,6 +516,7 @@ func TestDeleteAmpModelMappings_NonExistent(t *testing.T) { } } +// TestPutAmpModelMappings_Empty verifies PUT with empty array clears all mappings. func TestPutAmpModelMappings_Empty(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -522,6 +545,7 @@ func TestPutAmpModelMappings_Empty(t *testing.T) { } } +// TestPutAmpUpstreamURL_VerifyState verifies PUT updates upstream URL and persists state. func TestPutAmpUpstreamURL_VerifyState(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -550,6 +574,7 @@ func TestPutAmpUpstreamURL_VerifyState(t *testing.T) { } } +// TestDeleteAmpUpstreamURL_VerifyState verifies DELETE clears upstream URL. func TestDeleteAmpUpstreamURL_VerifyState(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -576,6 +601,7 @@ func TestDeleteAmpUpstreamURL_VerifyState(t *testing.T) { } } +// TestPutAmpUpstreamAPIKey_VerifyState verifies PUT updates API key and persists state. func TestPutAmpUpstreamAPIKey_VerifyState(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -604,6 +630,7 @@ func TestPutAmpUpstreamAPIKey_VerifyState(t *testing.T) { } } +// TestDeleteAmpUpstreamAPIKey_VerifyState verifies DELETE clears API key. func TestDeleteAmpUpstreamAPIKey_VerifyState(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -630,6 +657,7 @@ func TestDeleteAmpUpstreamAPIKey_VerifyState(t *testing.T) { } } +// TestPutAmpRestrictManagementToLocalhost_VerifyState verifies PUT updates localhost restriction. func TestPutAmpRestrictManagementToLocalhost_VerifyState(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -658,6 +686,7 @@ func TestPutAmpRestrictManagementToLocalhost_VerifyState(t *testing.T) { } } +// TestPutAmpForceModelMappings_VerifyState verifies PUT updates force-model-mappings setting. func TestPutAmpForceModelMappings_VerifyState(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -686,6 +715,23 @@ func TestPutAmpForceModelMappings_VerifyState(t *testing.T) { } } +// TestPutBoolField_EmptyObject verifies PUT with empty object returns 400. +func TestPutBoolField_EmptyObject(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status %d for empty object, got %d", http.StatusBadRequest, w.Code) + } +} + +// TestComplexMappingsWorkflow tests a full workflow: PUT, PATCH, DELETE, and GET. func TestComplexMappingsWorkflow(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -735,6 +781,7 @@ func TestComplexMappingsWorkflow(t *testing.T) { } } +// TestNilHandlerGetAmpCode verifies handler works with empty config. func TestNilHandlerGetAmpCode(t *testing.T) { cfg := &config.Config{} h := management.NewHandler(cfg, "", nil) @@ -749,6 +796,7 @@ func TestNilHandlerGetAmpCode(t *testing.T) { } } +// TestEmptyConfigGetAmpModelMappings verifies GET returns empty array for fresh config. func TestEmptyConfigGetAmpModelMappings(t *testing.T) { cfg := &config.Config{} tmpDir := t.TempDir() From 92f13fc31628de7e097046592edaf186c55a3543 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Mon, 8 Dec 2025 17:21:58 +0800 Subject: [PATCH 10/34] feat(logging): add upstream API request/response capture to streaming logs --- internal/api/middleware/response_writer.go | 9 + internal/logging/request_logger.go | 203 ++++++++++++++++++--- 2 files changed, 182 insertions(+), 30 deletions(-) diff --git a/internal/api/middleware/response_writer.go b/internal/api/middleware/response_writer.go index f0d1ad26..b7259bc6 100644 --- a/internal/api/middleware/response_writer.go +++ b/internal/api/middleware/response_writer.go @@ -232,7 +232,16 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { w.streamDone = nil } + // Write API Request and Response to the streaming log before closing if w.streamWriter != nil { + apiRequest := w.extractAPIRequest(c) + if len(apiRequest) > 0 { + _ = w.streamWriter.WriteAPIRequest(apiRequest) + } + apiResponse := w.extractAPIResponse(c) + if len(apiResponse) > 0 { + _ = w.streamWriter.WriteAPIResponse(apiResponse) + } if err := w.streamWriter.Close(); err != nil { w.streamWriter = nil return err diff --git a/internal/logging/request_logger.go b/internal/logging/request_logger.go index c574febb..58667cb9 100644 --- a/internal/logging/request_logger.go +++ b/internal/logging/request_logger.go @@ -84,6 +84,26 @@ type StreamingLogWriter interface { // - error: An error if writing fails, nil otherwise WriteStatus(status int, headers map[string][]string) error + // WriteAPIRequest writes the upstream API request details to the log. + // This should be called before WriteStatus to maintain proper log ordering. + // + // Parameters: + // - apiRequest: The API request data (typically includes URL, headers, body sent upstream) + // + // Returns: + // - error: An error if writing fails, nil otherwise + WriteAPIRequest(apiRequest []byte) error + + // WriteAPIResponse writes the upstream API response details to the log. + // This should be called after the streaming response is complete. + // + // Parameters: + // - apiResponse: The API response data + // + // Returns: + // - error: An error if writing fails, nil otherwise + WriteAPIResponse(apiResponse []byte) error + // Close finalizes the log file and cleans up resources. // // Returns: @@ -248,10 +268,11 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[ // Create streaming writer writer := &FileStreamingLogWriter{ - file: file, - chunkChan: make(chan []byte, 100), // Buffered channel for async writes - closeChan: make(chan struct{}), - errorChan: make(chan error, 1), + file: file, + chunkChan: make(chan []byte, 100), // Buffered channel for async writes + closeChan: make(chan struct{}), + errorChan: make(chan error, 1), + bufferedChunks: &bytes.Buffer{}, } // Start async writer goroutine @@ -628,11 +649,12 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st // FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs. // It handles asynchronous writing of streaming response chunks to a file. +// All data is buffered and written in the correct order when Close is called. type FileStreamingLogWriter struct { // file is the file where log data is written. file *os.File - // chunkChan is a channel for receiving response chunks to write. + // chunkChan is a channel for receiving response chunks to buffer. chunkChan chan []byte // closeChan is a channel for signaling when the writer is closed. @@ -641,8 +663,23 @@ type FileStreamingLogWriter struct { // errorChan is a channel for reporting errors during writing. errorChan chan error - // statusWritten indicates whether the response status has been written. + // bufferedChunks stores the response chunks in order. + bufferedChunks *bytes.Buffer + + // responseStatus stores the HTTP status code. + responseStatus int + + // statusWritten indicates whether a non-zero status was recorded. statusWritten bool + + // responseHeaders stores the response headers. + responseHeaders map[string][]string + + // apiRequest stores the upstream API request data. + apiRequest []byte + + // apiResponse stores the upstream API response data. + apiResponse []byte } // WriteChunkAsync writes a response chunk asynchronously (non-blocking). @@ -666,39 +703,65 @@ func (w *FileStreamingLogWriter) WriteChunkAsync(chunk []byte) { } } -// WriteStatus writes the response status and headers to the log. +// WriteStatus buffers the response status and headers for later writing. // // Parameters: // - status: The response status code // - headers: The response headers // // Returns: -// - error: An error if writing fails, nil otherwise +// - error: Always returns nil (buffering cannot fail) func (w *FileStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error { - if w.file == nil || w.statusWritten { + if status == 0 { return nil } - var content strings.Builder - content.WriteString("========================================\n") - content.WriteString("=== RESPONSE ===\n") - content.WriteString(fmt.Sprintf("Status: %d\n", status)) - - for key, values := range headers { - for _, value := range values { - content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + w.responseStatus = status + if headers != nil { + w.responseHeaders = make(map[string][]string, len(headers)) + for key, values := range headers { + headerValues := make([]string, len(values)) + copy(headerValues, values) + w.responseHeaders[key] = headerValues } } - content.WriteString("\n") + w.statusWritten = true + return nil +} - _, err := w.file.WriteString(content.String()) - if err == nil { - w.statusWritten = true +// WriteAPIRequest buffers the upstream API request details for later writing. +// +// Parameters: +// - apiRequest: The API request data (typically includes URL, headers, body sent upstream) +// +// Returns: +// - error: Always returns nil (buffering cannot fail) +func (w *FileStreamingLogWriter) WriteAPIRequest(apiRequest []byte) error { + if len(apiRequest) == 0 { + return nil } - return err + w.apiRequest = bytes.Clone(apiRequest) + return nil +} + +// WriteAPIResponse buffers the upstream API response details for later writing. +// +// Parameters: +// - apiResponse: The API response data +// +// Returns: +// - error: Always returns nil (buffering cannot fail) +func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error { + if len(apiResponse) == 0 { + return nil + } + w.apiResponse = bytes.Clone(apiResponse) + return nil } // Close finalizes the log file and cleans up resources. +// It writes all buffered data to the file in the correct order: +// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks) // // Returns: // - error: An error if closing fails, nil otherwise @@ -707,27 +770,85 @@ func (w *FileStreamingLogWriter) Close() error { close(w.chunkChan) } - // Wait for async writer to finish + // Wait for async writer to finish buffering chunks if w.closeChan != nil { <-w.closeChan w.chunkChan = nil } - if w.file != nil { - return w.file.Close() + if w.file == nil { + return nil } - return nil + // Write all content in the correct order + var content strings.Builder + + // 1. Write API REQUEST section + if len(w.apiRequest) > 0 { + if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) { + content.Write(w.apiRequest) + if !bytes.HasSuffix(w.apiRequest, []byte("\n")) { + content.WriteString("\n") + } + } else { + content.WriteString("=== API REQUEST ===\n") + content.Write(w.apiRequest) + content.WriteString("\n") + } + content.WriteString("\n") + } + + // 2. Write API RESPONSE section + if len(w.apiResponse) > 0 { + if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) { + content.Write(w.apiResponse) + if !bytes.HasSuffix(w.apiResponse, []byte("\n")) { + content.WriteString("\n") + } + } else { + content.WriteString("=== API RESPONSE ===\n") + content.Write(w.apiResponse) + content.WriteString("\n") + } + content.WriteString("\n") + } + + // 3. Write RESPONSE section (status, headers, buffered chunks) + content.WriteString("========================================\n") + content.WriteString("=== RESPONSE ===\n") + if w.statusWritten { + content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus)) + } + + for key, values := range w.responseHeaders { + for _, value := range values { + content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + } + } + content.WriteString("\n") + + // Write buffered response body chunks + if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 { + content.Write(w.bufferedChunks.Bytes()) + } + + // Write the complete content to file + if _, err := w.file.WriteString(content.String()); err != nil { + _ = w.file.Close() + return err + } + + return w.file.Close() } -// asyncWriter runs in a goroutine to handle async chunk writing. -// It continuously reads chunks from the channel and writes them to the file. +// asyncWriter runs in a goroutine to buffer chunks from the channel. +// It continuously reads chunks from the channel and buffers them for later writing. func (w *FileStreamingLogWriter) asyncWriter() { defer close(w.closeChan) for chunk := range w.chunkChan { - if w.file != nil { - _, _ = w.file.Write(chunk) + if w.bufferedChunks != nil { + w.bufferedChunks.Write(chunk) } } } @@ -754,6 +875,28 @@ func (w *NoOpStreamingLogWriter) WriteStatus(_ int, _ map[string][]string) error return nil } +// WriteAPIRequest is a no-op implementation that does nothing and always returns nil. +// +// Parameters: +// - apiRequest: The API request data (ignored) +// +// Returns: +// - error: Always returns nil +func (w *NoOpStreamingLogWriter) WriteAPIRequest(_ []byte) error { + return nil +} + +// WriteAPIResponse is a no-op implementation that does nothing and always returns nil. +// +// Parameters: +// - apiResponse: The API response data (ignored) +// +// Returns: +// - error: Always returns nil +func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error { + return nil +} + // Close is a no-op implementation that does nothing and always returns nil. // // Returns: From aee659fb66ee851ba3e9fee84b6965a508bc6536 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Mon, 8 Dec 2025 18:18:33 +0800 Subject: [PATCH 11/34] style(logging): remove redundant separator line from response section --- internal/logging/request_logger.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/logging/request_logger.go b/internal/logging/request_logger.go index 58667cb9..f8c068c5 100644 --- a/internal/logging/request_logger.go +++ b/internal/logging/request_logger.go @@ -814,7 +814,6 @@ func (w *FileStreamingLogWriter) Close() error { } // 3. Write RESPONSE section (status, headers, buffered chunks) - content.WriteString("========================================\n") content.WriteString("=== RESPONSE ===\n") if w.statusWritten { content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus)) From a283545b6b1f252f185e8080b69f038c30920b1f Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Mon, 8 Dec 2025 20:36:17 +0800 Subject: [PATCH 12/34] feat(antigravity): enforce thinking budget limits for Claude models --- internal/registry/model_definitions.go | 31 +++++---- .../runtime/executor/antigravity_executor.go | 63 ++++++++++++++++++- .../antigravity_openai_request.go | 5 +- 3 files changed, 83 insertions(+), 16 deletions(-) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 64e78199..c82c2b67 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -943,18 +943,6 @@ func GetQwenModels() []*ModelInfo { } } -// GetAntigravityThinkingConfig returns the Thinking configuration for antigravity models. -// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup. -func GetAntigravityThinkingConfig() map[string]*ThinkingSupport { - return map[string]*ThinkingSupport{ - "gemini-2.5-flash": {Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - "gemini-2.5-flash-lite": {Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - "gemini-3-pro-preview": {Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, - "gemini-claude-sonnet-4-5-thinking": {Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, - "gemini-claude-opus-4-5-thinking": {Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, - } -} - // GetIFlowModels returns supported models for iFlow OAuth accounts. func GetIFlowModels() []*ModelInfo { entries := []struct { @@ -997,3 +985,22 @@ func GetIFlowModels() []*ModelInfo { } return models } + +// AntigravityModelConfig captures static antigravity model overrides, including +// Thinking budget limits and provider max completion tokens. +type AntigravityModelConfig struct { + Thinking *ThinkingSupport + MaxCompletionTokens int +} + +// GetAntigravityModelConfig returns static configuration for antigravity models. +// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup. +func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { + return map[string]*AntigravityModelConfig{ + "gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, + "gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, + "gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}}, + "gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, + "gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, + } +} diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index ed9207f0..b2aa231f 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -77,6 +77,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = normalizeAntigravityThinking(req.Model, translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -170,6 +171,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = normalizeAntigravityThinking(req.Model, translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -366,7 +368,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c } now := time.Now().Unix() - thinkingConfig := registry.GetAntigravityThinkingConfig() + modelConfig := registry.GetAntigravityModelConfig() models := make([]*registry.ModelInfo, 0, len(result.Map())) for originalName := range result.Map() { aliasName := modelName2Alias(originalName) @@ -383,8 +385,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c Type: antigravityAuthType, } // Look up Thinking support from static config using alias name - if thinking, ok := thinkingConfig[aliasName]; ok { - modelInfo.Thinking = thinking + if cfg, ok := modelConfig[aliasName]; ok { + if cfg.Thinking != nil { + modelInfo.Thinking = cfg.Thinking + } + if cfg.MaxCompletionTokens > 0 { + modelInfo.MaxCompletionTokens = cfg.MaxCompletionTokens + } } models = append(models, modelInfo) } @@ -804,3 +811,53 @@ func alias2ModelName(modelName string) string { return modelName } } + +// normalizeAntigravityThinking clamps or removes thinking config based on model support. +// For Claude models, it additionally ensures thinking budget < max_tokens. +func normalizeAntigravityThinking(model string, payload []byte) []byte { + payload = util.StripThinkingConfigIfUnsupported(model, payload) + if !util.ModelSupportsThinking(model) { + return payload + } + budget := gjson.GetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget") + if !budget.Exists() { + return payload + } + raw := int(budget.Int()) + normalized := util.NormalizeThinkingBudget(model, raw) + + isClaude := strings.Contains(strings.ToLower(model), "claude") + if isClaude { + effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload) + if effectiveMax > 0 && normalized >= effectiveMax { + normalized = effectiveMax - 1 + if normalized < 1 { + normalized = 1 + } + } + if setDefaultMax { + if res, errSet := sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax); errSet == nil { + payload = res + } + } + } + + updated, err := sjson.SetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget", normalized) + if err != nil { + return payload + } + return updated +} + +// antigravityEffectiveMaxTokens returns the max tokens to cap thinking: +// prefer request-provided maxOutputTokens; otherwise fall back to model default. +// The boolean indicates whether the value came from the model default (and thus should be written back). +func antigravityEffectiveMaxTokens(model string, payload []byte) (max int, fromModel bool) { + if maxTok := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxTok.Exists() && maxTok.Int() > 0 { + return int(maxTok.Int()), false + } + if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.MaxCompletionTokens > 0 { + return modelInfo.MaxCompletionTokens, true + } + return 0, false +} diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index 82e71758..1c90a803 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -111,7 +111,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) } - // Temperature/top_p/top_k + // Temperature/top_p/top_k/max_tokens if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number { out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num) } @@ -121,6 +121,9 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number { out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num) } + if maxTok := gjson.GetBytes(rawJSON, "max_tokens"); maxTok.Exists() && maxTok.Type == gjson.Number { + out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num) + } // Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities // e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"] From 6ad188921c66f6a685b6d0f5df77c06275256477 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 8 Dec 2025 22:25:58 +0800 Subject: [PATCH 13/34] refactor(logging): remove unused variable in `ensureAttempt` and redundant function call --- internal/runtime/executor/logging_helpers.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/internal/runtime/executor/logging_helpers.go b/internal/runtime/executor/logging_helpers.go index 26931f53..7798b96b 100644 --- a/internal/runtime/executor/logging_helpers.go +++ b/internal/runtime/executor/logging_helpers.go @@ -157,7 +157,7 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt if ginCtx == nil { return } - attempts, attempt := ensureAttempt(ginCtx) + _, attempt := ensureAttempt(ginCtx) ensureResponseIntro(attempt) if !attempt.headersWritten { @@ -175,8 +175,6 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt } attempt.response.WriteString(string(data)) attempt.bodyHasContent = true - - updateAggregatedResponse(ginCtx, attempts) } func ginContextFrom(ctx context.Context) *gin.Context { From ab9e9442ec0a63a29ebf6f9468d330fd76280bb9 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 8 Dec 2025 22:32:29 +0800 Subject: [PATCH 14/34] v6.5.56 (#12) * feat(api): add comprehensive ampcode management endpoints Add new REST API endpoints under /v0/management/ampcode for managing ampcode configuration including upstream URL, API key, localhost restriction, model mappings, and force model mappings settings. - Move force-model-mappings from config_basic to config_lists - Add GET/PUT/PATCH/DELETE endpoints for all ampcode settings - Support model mapping CRUD with upsert (PATCH) capability - Add comprehensive test coverage for all ampcode endpoints * refactor(api): simplify request body parsing in ampcode handlers * feat(logging): add upstream API request/response capture to streaming logs * style(logging): remove redundant separator line from response section * feat(antigravity): enforce thinking budget limits for Claude models * refactor(logging): remove unused variable in `ensureAttempt` and redundant function call --------- Co-authored-by: hkfires <10558748+hkfires@users.noreply.github.com> --- .../api/handlers/management/config_basic.go | 8 - .../api/handlers/management/config_lists.go | 152 ++++ internal/api/handlers/management/handler.go | 10 - internal/api/middleware/response_writer.go | 9 + internal/api/server.go | 22 +- internal/logging/request_logger.go | 202 ++++- internal/registry/model_definitions.go | 31 +- .../runtime/executor/antigravity_executor.go | 63 +- internal/runtime/executor/logging_helpers.go | 4 +- .../antigravity_openai_request.go | 5 +- test/amp_management_test.go | 827 ++++++++++++++++++ 11 files changed, 1263 insertions(+), 70 deletions(-) create mode 100644 test/amp_management_test.go diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index c788aca4..f9069198 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -241,11 +241,3 @@ func (h *Handler) DeleteProxyURL(c *gin.Context) { h.cfg.ProxyURL = "" h.persist(c) } - -// Force Model Mappings (for Amp CLI) -func (h *Handler) GetForceModelMappings(c *gin.Context) { - c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings}) -} -func (h *Handler) PutForceModelMappings(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) -} diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index 8f4c4037..a0d0b169 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -706,3 +706,155 @@ func normalizeClaudeKey(entry *config.ClaudeKey) { } entry.Models = normalized } + +// GetAmpCode returns the complete ampcode configuration. +func (h *Handler) GetAmpCode(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"ampcode": config.AmpCode{}}) + return + } + c.JSON(200, gin.H{"ampcode": h.cfg.AmpCode}) +} + +// GetAmpUpstreamURL returns the ampcode upstream URL. +func (h *Handler) GetAmpUpstreamURL(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"upstream-url": ""}) + return + } + c.JSON(200, gin.H{"upstream-url": h.cfg.AmpCode.UpstreamURL}) +} + +// PutAmpUpstreamURL updates the ampcode upstream URL. +func (h *Handler) PutAmpUpstreamURL(c *gin.Context) { + h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) }) +} + +// DeleteAmpUpstreamURL clears the ampcode upstream URL. +func (h *Handler) DeleteAmpUpstreamURL(c *gin.Context) { + h.cfg.AmpCode.UpstreamURL = "" + h.persist(c) +} + +// GetAmpUpstreamAPIKey returns the ampcode upstream API key. +func (h *Handler) GetAmpUpstreamAPIKey(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"upstream-api-key": ""}) + return + } + c.JSON(200, gin.H{"upstream-api-key": h.cfg.AmpCode.UpstreamAPIKey}) +} + +// PutAmpUpstreamAPIKey updates the ampcode upstream API key. +func (h *Handler) PutAmpUpstreamAPIKey(c *gin.Context) { + h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) }) +} + +// DeleteAmpUpstreamAPIKey clears the ampcode upstream API key. +func (h *Handler) DeleteAmpUpstreamAPIKey(c *gin.Context) { + h.cfg.AmpCode.UpstreamAPIKey = "" + h.persist(c) +} + +// GetAmpRestrictManagementToLocalhost returns the localhost restriction setting. +func (h *Handler) GetAmpRestrictManagementToLocalhost(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"restrict-management-to-localhost": true}) + return + } + c.JSON(200, gin.H{"restrict-management-to-localhost": h.cfg.AmpCode.RestrictManagementToLocalhost}) +} + +// PutAmpRestrictManagementToLocalhost updates the localhost restriction setting. +func (h *Handler) PutAmpRestrictManagementToLocalhost(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.RestrictManagementToLocalhost = v }) +} + +// GetAmpModelMappings returns the ampcode model mappings. +func (h *Handler) GetAmpModelMappings(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"model-mappings": []config.AmpModelMapping{}}) + return + } + c.JSON(200, gin.H{"model-mappings": h.cfg.AmpCode.ModelMappings}) +} + +// PutAmpModelMappings replaces all ampcode model mappings. +func (h *Handler) PutAmpModelMappings(c *gin.Context) { + var body struct { + Value []config.AmpModelMapping `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + h.cfg.AmpCode.ModelMappings = body.Value + h.persist(c) +} + +// PatchAmpModelMappings adds or updates model mappings. +func (h *Handler) PatchAmpModelMappings(c *gin.Context) { + var body struct { + Value []config.AmpModelMapping `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + + existing := make(map[string]int) + for i, m := range h.cfg.AmpCode.ModelMappings { + existing[strings.TrimSpace(m.From)] = i + } + + for _, newMapping := range body.Value { + from := strings.TrimSpace(newMapping.From) + if idx, ok := existing[from]; ok { + h.cfg.AmpCode.ModelMappings[idx] = newMapping + } else { + h.cfg.AmpCode.ModelMappings = append(h.cfg.AmpCode.ModelMappings, newMapping) + existing[from] = len(h.cfg.AmpCode.ModelMappings) - 1 + } + } + h.persist(c) +} + +// DeleteAmpModelMappings removes specified model mappings by "from" field. +func (h *Handler) DeleteAmpModelMappings(c *gin.Context) { + var body struct { + Value []string `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil || len(body.Value) == 0 { + h.cfg.AmpCode.ModelMappings = nil + h.persist(c) + return + } + + toRemove := make(map[string]bool) + for _, from := range body.Value { + toRemove[strings.TrimSpace(from)] = true + } + + newMappings := make([]config.AmpModelMapping, 0, len(h.cfg.AmpCode.ModelMappings)) + for _, m := range h.cfg.AmpCode.ModelMappings { + if !toRemove[strings.TrimSpace(m.From)] { + newMappings = append(newMappings, m) + } + } + h.cfg.AmpCode.ModelMappings = newMappings + h.persist(c) +} + +// GetAmpForceModelMappings returns whether model mappings are forced. +func (h *Handler) GetAmpForceModelMappings(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"force-model-mappings": false}) + return + } + c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings}) +} + +// PutAmpForceModelMappings updates the force model mappings setting. +func (h *Handler) PutAmpForceModelMappings(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) +} diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index ef6f400a..39e6b7fd 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -240,16 +240,6 @@ func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) { Value *bool `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { - var m map[string]any - if err2 := c.ShouldBindJSON(&m); err2 == nil { - for _, v := range m { - if b, ok := v.(bool); ok { - set(b) - h.persist(c) - return - } - } - } c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) return } diff --git a/internal/api/middleware/response_writer.go b/internal/api/middleware/response_writer.go index f0d1ad26..b7259bc6 100644 --- a/internal/api/middleware/response_writer.go +++ b/internal/api/middleware/response_writer.go @@ -232,7 +232,16 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { w.streamDone = nil } + // Write API Request and Response to the streaming log before closing if w.streamWriter != nil { + apiRequest := w.extractAPIRequest(c) + if len(apiRequest) > 0 { + _ = w.streamWriter.WriteAPIRequest(apiRequest) + } + apiResponse := w.extractAPIResponse(c) + if len(apiResponse) > 0 { + _ = w.streamWriter.WriteAPIResponse(apiResponse) + } if err := w.streamWriter.Close(); err != nil { w.streamWriter = nil return err diff --git a/internal/api/server.go b/internal/api/server.go index 1f35429e..2e463de4 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -520,9 +520,25 @@ func (s *Server) registerManagementRoutes() { mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth) mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth) - mgmt.GET("/force-model-mappings", s.mgmt.GetForceModelMappings) - mgmt.PUT("/force-model-mappings", s.mgmt.PutForceModelMappings) - mgmt.PATCH("/force-model-mappings", s.mgmt.PutForceModelMappings) + mgmt.GET("/ampcode", s.mgmt.GetAmpCode) + mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL) + mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) + mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) + mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL) + mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey) + mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) + mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) + mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey) + mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost) + mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) + mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) + mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings) + mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings) + mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings) + mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings) + mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings) + mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) + mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) diff --git a/internal/logging/request_logger.go b/internal/logging/request_logger.go index c574febb..f8c068c5 100644 --- a/internal/logging/request_logger.go +++ b/internal/logging/request_logger.go @@ -84,6 +84,26 @@ type StreamingLogWriter interface { // - error: An error if writing fails, nil otherwise WriteStatus(status int, headers map[string][]string) error + // WriteAPIRequest writes the upstream API request details to the log. + // This should be called before WriteStatus to maintain proper log ordering. + // + // Parameters: + // - apiRequest: The API request data (typically includes URL, headers, body sent upstream) + // + // Returns: + // - error: An error if writing fails, nil otherwise + WriteAPIRequest(apiRequest []byte) error + + // WriteAPIResponse writes the upstream API response details to the log. + // This should be called after the streaming response is complete. + // + // Parameters: + // - apiResponse: The API response data + // + // Returns: + // - error: An error if writing fails, nil otherwise + WriteAPIResponse(apiResponse []byte) error + // Close finalizes the log file and cleans up resources. // // Returns: @@ -248,10 +268,11 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[ // Create streaming writer writer := &FileStreamingLogWriter{ - file: file, - chunkChan: make(chan []byte, 100), // Buffered channel for async writes - closeChan: make(chan struct{}), - errorChan: make(chan error, 1), + file: file, + chunkChan: make(chan []byte, 100), // Buffered channel for async writes + closeChan: make(chan struct{}), + errorChan: make(chan error, 1), + bufferedChunks: &bytes.Buffer{}, } // Start async writer goroutine @@ -628,11 +649,12 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st // FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs. // It handles asynchronous writing of streaming response chunks to a file. +// All data is buffered and written in the correct order when Close is called. type FileStreamingLogWriter struct { // file is the file where log data is written. file *os.File - // chunkChan is a channel for receiving response chunks to write. + // chunkChan is a channel for receiving response chunks to buffer. chunkChan chan []byte // closeChan is a channel for signaling when the writer is closed. @@ -641,8 +663,23 @@ type FileStreamingLogWriter struct { // errorChan is a channel for reporting errors during writing. errorChan chan error - // statusWritten indicates whether the response status has been written. + // bufferedChunks stores the response chunks in order. + bufferedChunks *bytes.Buffer + + // responseStatus stores the HTTP status code. + responseStatus int + + // statusWritten indicates whether a non-zero status was recorded. statusWritten bool + + // responseHeaders stores the response headers. + responseHeaders map[string][]string + + // apiRequest stores the upstream API request data. + apiRequest []byte + + // apiResponse stores the upstream API response data. + apiResponse []byte } // WriteChunkAsync writes a response chunk asynchronously (non-blocking). @@ -666,39 +703,65 @@ func (w *FileStreamingLogWriter) WriteChunkAsync(chunk []byte) { } } -// WriteStatus writes the response status and headers to the log. +// WriteStatus buffers the response status and headers for later writing. // // Parameters: // - status: The response status code // - headers: The response headers // // Returns: -// - error: An error if writing fails, nil otherwise +// - error: Always returns nil (buffering cannot fail) func (w *FileStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error { - if w.file == nil || w.statusWritten { + if status == 0 { return nil } - var content strings.Builder - content.WriteString("========================================\n") - content.WriteString("=== RESPONSE ===\n") - content.WriteString(fmt.Sprintf("Status: %d\n", status)) - - for key, values := range headers { - for _, value := range values { - content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + w.responseStatus = status + if headers != nil { + w.responseHeaders = make(map[string][]string, len(headers)) + for key, values := range headers { + headerValues := make([]string, len(values)) + copy(headerValues, values) + w.responseHeaders[key] = headerValues } } - content.WriteString("\n") + w.statusWritten = true + return nil +} - _, err := w.file.WriteString(content.String()) - if err == nil { - w.statusWritten = true +// WriteAPIRequest buffers the upstream API request details for later writing. +// +// Parameters: +// - apiRequest: The API request data (typically includes URL, headers, body sent upstream) +// +// Returns: +// - error: Always returns nil (buffering cannot fail) +func (w *FileStreamingLogWriter) WriteAPIRequest(apiRequest []byte) error { + if len(apiRequest) == 0 { + return nil } - return err + w.apiRequest = bytes.Clone(apiRequest) + return nil +} + +// WriteAPIResponse buffers the upstream API response details for later writing. +// +// Parameters: +// - apiResponse: The API response data +// +// Returns: +// - error: Always returns nil (buffering cannot fail) +func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error { + if len(apiResponse) == 0 { + return nil + } + w.apiResponse = bytes.Clone(apiResponse) + return nil } // Close finalizes the log file and cleans up resources. +// It writes all buffered data to the file in the correct order: +// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks) // // Returns: // - error: An error if closing fails, nil otherwise @@ -707,27 +770,84 @@ func (w *FileStreamingLogWriter) Close() error { close(w.chunkChan) } - // Wait for async writer to finish + // Wait for async writer to finish buffering chunks if w.closeChan != nil { <-w.closeChan w.chunkChan = nil } - if w.file != nil { - return w.file.Close() + if w.file == nil { + return nil } - return nil + // Write all content in the correct order + var content strings.Builder + + // 1. Write API REQUEST section + if len(w.apiRequest) > 0 { + if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) { + content.Write(w.apiRequest) + if !bytes.HasSuffix(w.apiRequest, []byte("\n")) { + content.WriteString("\n") + } + } else { + content.WriteString("=== API REQUEST ===\n") + content.Write(w.apiRequest) + content.WriteString("\n") + } + content.WriteString("\n") + } + + // 2. Write API RESPONSE section + if len(w.apiResponse) > 0 { + if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) { + content.Write(w.apiResponse) + if !bytes.HasSuffix(w.apiResponse, []byte("\n")) { + content.WriteString("\n") + } + } else { + content.WriteString("=== API RESPONSE ===\n") + content.Write(w.apiResponse) + content.WriteString("\n") + } + content.WriteString("\n") + } + + // 3. Write RESPONSE section (status, headers, buffered chunks) + content.WriteString("=== RESPONSE ===\n") + if w.statusWritten { + content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus)) + } + + for key, values := range w.responseHeaders { + for _, value := range values { + content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + } + } + content.WriteString("\n") + + // Write buffered response body chunks + if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 { + content.Write(w.bufferedChunks.Bytes()) + } + + // Write the complete content to file + if _, err := w.file.WriteString(content.String()); err != nil { + _ = w.file.Close() + return err + } + + return w.file.Close() } -// asyncWriter runs in a goroutine to handle async chunk writing. -// It continuously reads chunks from the channel and writes them to the file. +// asyncWriter runs in a goroutine to buffer chunks from the channel. +// It continuously reads chunks from the channel and buffers them for later writing. func (w *FileStreamingLogWriter) asyncWriter() { defer close(w.closeChan) for chunk := range w.chunkChan { - if w.file != nil { - _, _ = w.file.Write(chunk) + if w.bufferedChunks != nil { + w.bufferedChunks.Write(chunk) } } } @@ -754,6 +874,28 @@ func (w *NoOpStreamingLogWriter) WriteStatus(_ int, _ map[string][]string) error return nil } +// WriteAPIRequest is a no-op implementation that does nothing and always returns nil. +// +// Parameters: +// - apiRequest: The API request data (ignored) +// +// Returns: +// - error: Always returns nil +func (w *NoOpStreamingLogWriter) WriteAPIRequest(_ []byte) error { + return nil +} + +// WriteAPIResponse is a no-op implementation that does nothing and always returns nil. +// +// Parameters: +// - apiResponse: The API response data (ignored) +// +// Returns: +// - error: Always returns nil +func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error { + return nil +} + // Close is a no-op implementation that does nothing and always returns nil. // // Returns: diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index b25d91c2..fc7e75a1 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -943,18 +943,6 @@ func GetQwenModels() []*ModelInfo { } } -// GetAntigravityThinkingConfig returns the Thinking configuration for antigravity models. -// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup. -func GetAntigravityThinkingConfig() map[string]*ThinkingSupport { - return map[string]*ThinkingSupport{ - "gemini-2.5-flash": {Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - "gemini-2.5-flash-lite": {Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - "gemini-3-pro-preview": {Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, - "gemini-claude-sonnet-4-5-thinking": {Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, - "gemini-claude-opus-4-5-thinking": {Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, - } -} - // GetIFlowModels returns supported models for iFlow OAuth accounts. func GetIFlowModels() []*ModelInfo { entries := []struct { @@ -998,6 +986,25 @@ func GetIFlowModels() []*ModelInfo { return models } +// AntigravityModelConfig captures static antigravity model overrides, including +// Thinking budget limits and provider max completion tokens. +type AntigravityModelConfig struct { + Thinking *ThinkingSupport + MaxCompletionTokens int +} + +// GetAntigravityModelConfig returns static configuration for antigravity models. +// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup. +func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { + return map[string]*AntigravityModelConfig{ + "gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, + "gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, + "gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}}, + "gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, + "gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, + } +} + // GetGitHubCopilotModels returns the available models for GitHub Copilot. // These models are available through the GitHub Copilot API at api.githubcopilot.com. func GetGitHubCopilotModels() []*ModelInfo { diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index ce836a77..d83559ab 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -81,6 +81,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = normalizeAntigravityThinking(req.Model, translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -174,6 +175,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = normalizeAntigravityThinking(req.Model, translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -370,7 +372,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c } now := time.Now().Unix() - thinkingConfig := registry.GetAntigravityThinkingConfig() + modelConfig := registry.GetAntigravityModelConfig() models := make([]*registry.ModelInfo, 0, len(result.Map())) for originalName := range result.Map() { aliasName := modelName2Alias(originalName) @@ -387,8 +389,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c Type: antigravityAuthType, } // Look up Thinking support from static config using alias name - if thinking, ok := thinkingConfig[aliasName]; ok { - modelInfo.Thinking = thinking + if cfg, ok := modelConfig[aliasName]; ok { + if cfg.Thinking != nil { + modelInfo.Thinking = cfg.Thinking + } + if cfg.MaxCompletionTokens > 0 { + modelInfo.MaxCompletionTokens = cfg.MaxCompletionTokens + } } models = append(models, modelInfo) } @@ -812,3 +819,53 @@ func alias2ModelName(modelName string) string { return modelName } } + +// normalizeAntigravityThinking clamps or removes thinking config based on model support. +// For Claude models, it additionally ensures thinking budget < max_tokens. +func normalizeAntigravityThinking(model string, payload []byte) []byte { + payload = util.StripThinkingConfigIfUnsupported(model, payload) + if !util.ModelSupportsThinking(model) { + return payload + } + budget := gjson.GetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget") + if !budget.Exists() { + return payload + } + raw := int(budget.Int()) + normalized := util.NormalizeThinkingBudget(model, raw) + + isClaude := strings.Contains(strings.ToLower(model), "claude") + if isClaude { + effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload) + if effectiveMax > 0 && normalized >= effectiveMax { + normalized = effectiveMax - 1 + if normalized < 1 { + normalized = 1 + } + } + if setDefaultMax { + if res, errSet := sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax); errSet == nil { + payload = res + } + } + } + + updated, err := sjson.SetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget", normalized) + if err != nil { + return payload + } + return updated +} + +// antigravityEffectiveMaxTokens returns the max tokens to cap thinking: +// prefer request-provided maxOutputTokens; otherwise fall back to model default. +// The boolean indicates whether the value came from the model default (and thus should be written back). +func antigravityEffectiveMaxTokens(model string, payload []byte) (max int, fromModel bool) { + if maxTok := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxTok.Exists() && maxTok.Int() > 0 { + return int(maxTok.Int()), false + } + if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.MaxCompletionTokens > 0 { + return modelInfo.MaxCompletionTokens, true + } + return 0, false +} diff --git a/internal/runtime/executor/logging_helpers.go b/internal/runtime/executor/logging_helpers.go index 26931f53..7798b96b 100644 --- a/internal/runtime/executor/logging_helpers.go +++ b/internal/runtime/executor/logging_helpers.go @@ -157,7 +157,7 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt if ginCtx == nil { return } - attempts, attempt := ensureAttempt(ginCtx) + _, attempt := ensureAttempt(ginCtx) ensureResponseIntro(attempt) if !attempt.headersWritten { @@ -175,8 +175,6 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt } attempt.response.WriteString(string(data)) attempt.bodyHasContent = true - - updateAggregatedResponse(ginCtx, attempts) } func ginContextFrom(ctx context.Context) *gin.Context { diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index 82e71758..1c90a803 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -111,7 +111,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) } - // Temperature/top_p/top_k + // Temperature/top_p/top_k/max_tokens if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number { out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num) } @@ -121,6 +121,9 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number { out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num) } + if maxTok := gjson.GetBytes(rawJSON, "max_tokens"); maxTok.Exists() && maxTok.Type == gjson.Number { + out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num) + } // Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities // e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"] diff --git a/test/amp_management_test.go b/test/amp_management_test.go new file mode 100644 index 00000000..19450dbf --- /dev/null +++ b/test/amp_management_test.go @@ -0,0 +1,827 @@ +package test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// newAmpTestHandler creates a test handler with default ampcode configuration. +func newAmpTestHandler(t *testing.T) (*management.Handler, string) { + t.Helper() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + UpstreamURL: "https://example.com", + UpstreamAPIKey: "test-api-key-12345", + RestrictManagementToLocalhost: true, + ForceModelMappings: false, + ModelMappings: []config.AmpModelMapping{ + {From: "gpt-4", To: "gemini-pro"}, + }, + }, + } + + if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + h := management.NewHandler(cfg, configPath, nil) + return h, configPath +} + +// setupAmpRouter creates a test router with all ampcode management endpoints. +func setupAmpRouter(h *management.Handler) *gin.Engine { + r := gin.New() + mgmt := r.Group("/v0/management") + { + mgmt.GET("/ampcode", h.GetAmpCode) + mgmt.GET("/ampcode/upstream-url", h.GetAmpUpstreamURL) + mgmt.PUT("/ampcode/upstream-url", h.PutAmpUpstreamURL) + mgmt.DELETE("/ampcode/upstream-url", h.DeleteAmpUpstreamURL) + mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey) + mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey) + mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey) + mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost) + mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost) + mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings) + mgmt.PUT("/ampcode/model-mappings", h.PutAmpModelMappings) + mgmt.PATCH("/ampcode/model-mappings", h.PatchAmpModelMappings) + mgmt.DELETE("/ampcode/model-mappings", h.DeleteAmpModelMappings) + mgmt.GET("/ampcode/force-model-mappings", h.GetAmpForceModelMappings) + mgmt.PUT("/ampcode/force-model-mappings", h.PutAmpForceModelMappings) + } + return r +} + +// TestGetAmpCode verifies GET /v0/management/ampcode returns full ampcode config. +func TestGetAmpCode(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]config.AmpCode + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + ampcode := resp["ampcode"] + if ampcode.UpstreamURL != "https://example.com" { + t.Errorf("expected upstream-url %q, got %q", "https://example.com", ampcode.UpstreamURL) + } + if len(ampcode.ModelMappings) != 1 { + t.Errorf("expected 1 model mapping, got %d", len(ampcode.ModelMappings)) + } +} + +// TestGetAmpUpstreamURL verifies GET /v0/management/ampcode/upstream-url returns the upstream URL. +func TestGetAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["upstream-url"] != "https://example.com" { + t.Errorf("expected %q, got %q", "https://example.com", resp["upstream-url"]) + } +} + +// TestPutAmpUpstreamURL verifies PUT /v0/management/ampcode/upstream-url updates the upstream URL. +func TestPutAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "https://new-upstream.com"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +// TestDeleteAmpUpstreamURL verifies DELETE /v0/management/ampcode/upstream-url clears the upstream URL. +func TestDeleteAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpUpstreamAPIKey verifies GET /v0/management/ampcode/upstream-api-key returns the API key. +func TestGetAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + key := resp["upstream-api-key"].(string) + if key != "test-api-key-12345" { + t.Errorf("expected key %q, got %q", "test-api-key-12345", key) + } +} + +// TestPutAmpUpstreamAPIKey verifies PUT /v0/management/ampcode/upstream-api-key updates the API key. +func TestPutAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "new-secret-key"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key. +func TestDeleteAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpRestrictManagementToLocalhost verifies GET returns the localhost restriction setting. +func TestGetAmpRestrictManagementToLocalhost(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["restrict-management-to-localhost"] != true { + t.Error("expected restrict-management-to-localhost to be true") + } +} + +// TestPutAmpRestrictManagementToLocalhost verifies PUT updates the localhost restriction setting. +func TestPutAmpRestrictManagementToLocalhost(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": false}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpModelMappings verifies GET /v0/management/ampcode/model-mappings returns all mappings. +func TestGetAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 1 { + t.Fatalf("expected 1 mapping, got %d", len(mappings)) + } + if mappings[0].From != "gpt-4" || mappings[0].To != "gemini-pro" { + t.Errorf("unexpected mapping: %+v", mappings[0]) + } +} + +// TestPutAmpModelMappings verifies PUT /v0/management/ampcode/model-mappings replaces all mappings. +func TestPutAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "claude-3", "to": "gpt-4o"}, {"from": "gemini", "to": "claude"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +// TestPatchAmpModelMappings verifies PATCH updates existing mappings and adds new ones. +func TestPatchAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "gpt-4", "to": "updated-model"}, {"from": "new-model", "to": "target"}]}` + req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +// TestDeleteAmpModelMappings_Specific verifies DELETE removes specified mappings by "from" field. +func TestDeleteAmpModelMappings_Specific(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": ["gpt-4"]}` + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestDeleteAmpModelMappings_All verifies DELETE with empty body removes all mappings. +func TestDeleteAmpModelMappings_All(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpForceModelMappings verifies GET returns the force-model-mappings setting. +func TestGetAmpForceModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["force-model-mappings"] != false { + t.Error("expected force-model-mappings to be false") + } +} + +// TestPutAmpForceModelMappings verifies PUT updates the force-model-mappings setting. +func TestPutAmpForceModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": true}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestPutAmpModelMappings_VerifyState verifies PUT replaces mappings and state is persisted. +func TestPutAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "model-a", "to": "model-b"}, {"from": "model-c", "to": "model-d"}, {"from": "model-e", "to": "model-f"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d, body: %s", w.Code, w.Body.String()) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 3 { + t.Fatalf("expected 3 mappings, got %d", len(mappings)) + } + + expected := map[string]string{"model-a": "model-b", "model-c": "model-d", "model-e": "model-f"} + for _, m := range mappings { + if expected[m.From] != m.To { + t.Errorf("mapping %q -> expected %q, got %q", m.From, expected[m.From], m.To) + } + } +} + +// TestPatchAmpModelMappings_VerifyState verifies PATCH merges mappings correctly. +func TestPatchAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "gpt-4", "to": "updated-target"}, {"from": "new-model", "to": "new-target"}]}` + req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PATCH failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 2 { + t.Fatalf("expected 2 mappings (1 updated + 1 new), got %d", len(mappings)) + } + + found := make(map[string]string) + for _, m := range mappings { + found[m.From] = m.To + } + + if found["gpt-4"] != "updated-target" { + t.Errorf("gpt-4 should map to updated-target, got %q", found["gpt-4"]) + } + if found["new-model"] != "new-target" { + t.Errorf("new-model should map to new-target, got %q", found["new-model"]) + } +} + +// TestDeleteAmpModelMappings_VerifyState verifies DELETE removes specific mappings and keeps others. +func TestDeleteAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + putBody := `{"value": [{"from": "a", "to": "1"}, {"from": "b", "to": "2"}, {"from": "c", "to": "3"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + delBody := `{"value": ["a", "c"]}` + req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 1 { + t.Fatalf("expected 1 mapping remaining, got %d", len(mappings)) + } + if mappings[0].From != "b" || mappings[0].To != "2" { + t.Errorf("expected b->2, got %s->%s", mappings[0].From, mappings[0].To) + } +} + +// TestDeleteAmpModelMappings_NonExistent verifies DELETE with non-existent mapping doesn't affect existing ones. +func TestDeleteAmpModelMappings_NonExistent(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + delBody := `{"value": ["non-existent-model"]}` + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 1 { + t.Errorf("original mapping should remain, got %d mappings", len(resp["model-mappings"])) + } +} + +// TestPutAmpModelMappings_Empty verifies PUT with empty array clears all mappings. +func TestPutAmpModelMappings_Empty(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": []}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 0 { + t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"])) + } +} + +// TestPutAmpUpstreamURL_VerifyState verifies PUT updates upstream URL and persists state. +func TestPutAmpUpstreamURL_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "https://new-api.example.com"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-url"] != "https://new-api.example.com" { + t.Errorf("expected %q, got %q", "https://new-api.example.com", resp["upstream-url"]) + } +} + +// TestDeleteAmpUpstreamURL_VerifyState verifies DELETE clears upstream URL. +func TestDeleteAmpUpstreamURL_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-url"] != "" { + t.Errorf("expected empty string, got %q", resp["upstream-url"]) + } +} + +// TestPutAmpUpstreamAPIKey_VerifyState verifies PUT updates API key and persists state. +func TestPutAmpUpstreamAPIKey_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "new-secret-api-key-xyz"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-api-key"] != "new-secret-api-key-xyz" { + t.Errorf("expected %q, got %q", "new-secret-api-key-xyz", resp["upstream-api-key"]) + } +} + +// TestDeleteAmpUpstreamAPIKey_VerifyState verifies DELETE clears API key. +func TestDeleteAmpUpstreamAPIKey_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-api-key"] != "" { + t.Errorf("expected empty string, got %q", resp["upstream-api-key"]) + } +} + +// TestPutAmpRestrictManagementToLocalhost_VerifyState verifies PUT updates localhost restriction. +func TestPutAmpRestrictManagementToLocalhost_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": false}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["restrict-management-to-localhost"] != false { + t.Error("expected false after update") + } +} + +// TestPutAmpForceModelMappings_VerifyState verifies PUT updates force-model-mappings setting. +func TestPutAmpForceModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": true}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["force-model-mappings"] != true { + t.Error("expected true after update") + } +} + +// TestPutBoolField_EmptyObject verifies PUT with empty object returns 400. +func TestPutBoolField_EmptyObject(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status %d for empty object, got %d", http.StatusBadRequest, w.Code) + } +} + +// TestComplexMappingsWorkflow tests a full workflow: PUT, PATCH, DELETE, and GET. +func TestComplexMappingsWorkflow(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + putBody := `{"value": [{"from": "m1", "to": "t1"}, {"from": "m2", "to": "t2"}, {"from": "m3", "to": "t3"}, {"from": "m4", "to": "t4"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + patchBody := `{"value": [{"from": "m2", "to": "t2-updated"}, {"from": "m5", "to": "t5"}]}` + req = httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(patchBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + delBody := `{"value": ["m1", "m3"]}` + req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 3 { + t.Fatalf("expected 3 mappings (m2, m4, m5), got %d", len(mappings)) + } + + expected := map[string]string{"m2": "t2-updated", "m4": "t4", "m5": "t5"} + found := make(map[string]string) + for _, m := range mappings { + found[m.From] = m.To + } + + for from, to := range expected { + if found[from] != to { + t.Errorf("mapping %s: expected %q, got %q", from, to, found[from]) + } + } +} + +// TestNilHandlerGetAmpCode verifies handler works with empty config. +func TestNilHandlerGetAmpCode(t *testing.T) { + cfg := &config.Config{} + h := management.NewHandler(cfg, "", nil) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestEmptyConfigGetAmpModelMappings verifies GET returns empty array for fresh config. +func TestEmptyConfigGetAmpModelMappings(t *testing.T) { + cfg := &config.Config{} + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + h := management.NewHandler(cfg, configPath, nil) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 0 { + t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"])) + } +} From 5c3a013cd1bd07460394a74cee4d2fa140fa7498 Mon Sep 17 00:00:00 2001 From: "vuonglv(Andy)" <46917325+vuonglv1612@users.noreply.github.com> Date: Mon, 8 Dec 2025 22:16:39 +0700 Subject: [PATCH 15/34] feat(config): add configurable host binding for server (#454) * feat(config): add configurable host binding for server --- config.example.yaml | 4 ++++ internal/api/server.go | 2 +- internal/config/config.go | 4 ++++ sdk/cliproxy/service.go | 2 +- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index 0f8679aa..dfd7454b 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,3 +1,7 @@ +# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6). +# Use "127.0.0.1" or "localhost" to restrict access to local machine only. +host: "" + # Server port port: 8317 diff --git a/internal/api/server.go b/internal/api/server.go index b65185a7..79dcf12a 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -300,7 +300,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk // Create HTTP server s.server = &http.Server{ - Addr: fmt.Sprintf(":%d", cfg.Port), + Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), Handler: engine, } diff --git a/internal/config/config.go b/internal/config/config.go index f6d1eb73..5af74b1b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,6 +20,9 @@ import ( // Config represents the application's configuration, loaded from a YAML file. type Config struct { config.SDKConfig `yaml:",inline"` + // Host is the network host/interface on which the API server will bind. + // Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access. + Host string `yaml:"host" json:"-"` // Port is the network port on which the API server will listen. Port int `yaml:"port" json:"-"` @@ -320,6 +323,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Unmarshal the YAML data into the Config struct. var cfg Config // Set defaults before unmarshal so that absent keys keep defaults. + cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6) cfg.LoggingToFile = false cfg.UsageStatisticsEnabled = false cfg.DisableCooling = false diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 8b9a6639..13d647dd 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -498,7 +498,7 @@ func (s *Service) Run(ctx context.Context) error { }() time.Sleep(100 * time.Millisecond) - fmt.Printf("API server started successfully on: %d\n", s.cfg.Port) + fmt.Printf("API server started successfully on: %s:%d\n", s.cfg.Host, s.cfg.Port) if s.hooks.OnAfterStart != nil { s.hooks.OnAfterStart(s) From af00304b0cf09fae96913b796ea97a92fb63e80d Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 8 Dec 2025 23:28:01 +0800 Subject: [PATCH 16/34] fix(antigravity): remove `exclusiveMaximum` from JSON during key deletion --- internal/runtime/executor/antigravity_executor.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index b2aa231f..730a32fb 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -536,6 +536,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau strJSON = util.DeleteKey(strJSON, "minLength") strJSON = util.DeleteKey(strJSON, "maxLength") strJSON = util.DeleteKey(strJSON, "exclusiveMinimum") + strJSON = util.DeleteKey(strJSON, "exclusiveMaximum") paths = make([]string, 0) util.Walk(gjson.Parse(strJSON), "", "anyOf", &paths) From 96b55acff8b673adb0863c58fa4f5fd283d15245 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Tue, 9 Dec 2025 08:27:44 +0800 Subject: [PATCH 17/34] feat(aistudio): normalize thinking budget in request translation --- internal/runtime/executor/aistudio_executor.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 61a06721..898c08c7 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -310,6 +310,10 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) payload = applyThinkingMetadata(payload, req.Metadata, req.Model) payload = util.ConvertThinkingLevelToBudget(payload) + if budget := gjson.GetBytes(payload, "generationConfig.thinkingConfig.thinkingBudget"); budget.Exists() { + normalized := util.NormalizeThinkingBudget(req.Model, int(budget.Int())) + payload, _ = sjson.SetBytes(payload, "generationConfig.thinkingConfig.thinkingBudget", normalized) + } payload = util.StripThinkingConfigIfUnsupported(req.Model, payload) payload = fixGeminiImageAspectRatio(req.Model, payload) payload = applyPayloadConfig(e.cfg, req.Model, payload) From e5312fb5a25f23183305c592602b668a81b992d0 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Tue, 9 Dec 2025 16:54:13 +0800 Subject: [PATCH 18/34] feat(antigravity): support canonical names for antigravity models --- internal/registry/model_definitions.go | 13 ++++++++----- internal/runtime/executor/antigravity_executor.go | 9 +++++++-- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index c82c2b67..77015d14 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -991,16 +991,19 @@ func GetIFlowModels() []*ModelInfo { type AntigravityModelConfig struct { Thinking *ThinkingSupport MaxCompletionTokens int + Name string } // GetAntigravityModelConfig returns static configuration for antigravity models. // Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup. func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { return map[string]*AntigravityModelConfig{ - "gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, - "gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, - "gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}}, - "gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, - "gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, + "gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"}, + "gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"}, + "gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"}, + "gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-preview"}, + "gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-image-preview"}, + "gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, + "gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, } } diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 730a32fb..052d4faf 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -373,9 +373,14 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c for originalName := range result.Map() { aliasName := modelName2Alias(originalName) if aliasName != "" { + cfg := modelConfig[aliasName] + modelName := aliasName + if cfg != nil && cfg.Name != "" { + modelName = cfg.Name + } modelInfo := ®istry.ModelInfo{ ID: aliasName, - Name: aliasName, + Name: modelName, Description: aliasName, DisplayName: aliasName, Version: aliasName, @@ -385,7 +390,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c Type: antigravityAuthType, } // Look up Thinking support from static config using alias name - if cfg, ok := modelConfig[aliasName]; ok { + if cfg != nil { if cfg.Thinking != nil { modelInfo.Thinking = cfg.Thinking } From c600519fa456f60bd7477334f6bfbf931930dc33 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Tue, 9 Dec 2025 17:16:30 +0800 Subject: [PATCH 19/34] refactor(logging): replace log.Fatalf with log.Errorf and add error handling paths --- cmd/server/main.go | 51 ++++++++++++------- .../api/handlers/management/auth_files.go | 39 ++++++++------ internal/auth/gemini/gemini_auth.go | 9 +++- internal/cmd/login.go | 21 ++++---- internal/cmd/run.go | 5 +- internal/cmd/vertex_import.go | 12 ++--- 6 files changed, 85 insertions(+), 52 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index bbf500e7..aec51ab8 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -139,7 +139,8 @@ func main() { wd, err := os.Getwd() if err != nil { - log.Fatalf("failed to get working directory: %v", err) + log.Errorf("failed to get working directory: %v", err) + return } // Load environment variables from .env if present. @@ -233,13 +234,15 @@ func main() { }) cancel() if err != nil { - log.Fatalf("failed to initialize postgres token store: %v", err) + log.Errorf("failed to initialize postgres token store: %v", err) + return } examplePath := filepath.Join(wd, "config.example.yaml") ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second) if errBootstrap := pgStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil { cancel() - log.Fatalf("failed to bootstrap postgres-backed config: %v", errBootstrap) + log.Errorf("failed to bootstrap postgres-backed config: %v", errBootstrap) + return } cancel() configFilePath = pgStoreInst.ConfigPath() @@ -262,7 +265,8 @@ func main() { if strings.Contains(resolvedEndpoint, "://") { parsed, errParse := url.Parse(resolvedEndpoint) if errParse != nil { - log.Fatalf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse) + log.Errorf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse) + return } switch strings.ToLower(parsed.Scheme) { case "http": @@ -270,10 +274,12 @@ func main() { case "https": useSSL = true default: - log.Fatalf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme) + log.Errorf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme) + return } if parsed.Host == "" { - log.Fatalf("object store endpoint %q is missing host information", objectStoreEndpoint) + log.Errorf("object store endpoint %q is missing host information", objectStoreEndpoint) + return } resolvedEndpoint = parsed.Host if parsed.Path != "" && parsed.Path != "/" { @@ -292,13 +298,15 @@ func main() { } objectStoreInst, err = store.NewObjectTokenStore(objCfg) if err != nil { - log.Fatalf("failed to initialize object token store: %v", err) + log.Errorf("failed to initialize object token store: %v", err) + return } examplePath := filepath.Join(wd, "config.example.yaml") ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) if errBootstrap := objectStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil { cancel() - log.Fatalf("failed to bootstrap object-backed config: %v", errBootstrap) + log.Errorf("failed to bootstrap object-backed config: %v", errBootstrap) + return } cancel() configFilePath = objectStoreInst.ConfigPath() @@ -323,7 +331,8 @@ func main() { gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword) gitStoreInst.SetBaseDir(authDir) if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil { - log.Fatalf("failed to prepare git token store: %v", errRepo) + log.Errorf("failed to prepare git token store: %v", errRepo) + return } configFilePath = gitStoreInst.ConfigPath() if configFilePath == "" { @@ -332,17 +341,21 @@ func main() { if _, statErr := os.Stat(configFilePath); errors.Is(statErr, fs.ErrNotExist) { examplePath := filepath.Join(wd, "config.example.yaml") if _, errExample := os.Stat(examplePath); errExample != nil { - log.Fatalf("failed to find template config file: %v", errExample) + log.Errorf("failed to find template config file: %v", errExample) + return } if errCopy := misc.CopyConfigTemplate(examplePath, configFilePath); errCopy != nil { - log.Fatalf("failed to bootstrap git-backed config: %v", errCopy) + log.Errorf("failed to bootstrap git-backed config: %v", errCopy) + return } if errCommit := gitStoreInst.PersistConfig(context.Background()); errCommit != nil { - log.Fatalf("failed to commit initial git-backed config: %v", errCommit) + log.Errorf("failed to commit initial git-backed config: %v", errCommit) + return } log.Infof("git-backed config initialized from template: %s", configFilePath) } else if statErr != nil { - log.Fatalf("failed to inspect git-backed config: %v", statErr) + log.Errorf("failed to inspect git-backed config: %v", statErr) + return } cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy) if err == nil { @@ -355,13 +368,15 @@ func main() { } else { wd, err = os.Getwd() if err != nil { - log.Fatalf("failed to get working directory: %v", err) + log.Errorf("failed to get working directory: %v", err) + return } configFilePath = filepath.Join(wd, "config.yaml") cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy) } if err != nil { - log.Fatalf("failed to load config: %v", err) + log.Errorf("failed to load config: %v", err) + return } if cfg == nil { cfg = &config.Config{} @@ -391,7 +406,8 @@ func main() { coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling) if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil { - log.Fatalf("failed to configure log output: %v", err) + log.Errorf("failed to configure log output: %v", err) + return } log.Infof("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate) @@ -400,7 +416,8 @@ func main() { util.SetLogLevel(cfg) if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil { - log.Fatalf("failed to resolve auth directory: %v", errResolveAuthDir) + log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir) + return } else { cfg.AuthDir = resolvedAuthDir } diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 6f77fda9..e626af47 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -713,14 +713,16 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { // Generate PKCE codes pkceCodes, err := claude.GeneratePKCECodes() if err != nil { - log.Fatalf("Failed to generate PKCE codes: %v", err) + log.Errorf("Failed to generate PKCE codes: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"}) return } // Generate random state parameter state, err := misc.GenerateRandomState() if err != nil { - log.Fatalf("Failed to generate state parameter: %v", err) + log.Errorf("Failed to generate state parameter: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"}) return } @@ -730,7 +732,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { // Generate authorization URL (then override redirect_uri to reuse server port) authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes) if err != nil { - log.Fatalf("Failed to generate authorization URL: %v", err) + log.Errorf("Failed to generate authorization URL: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"}) return } @@ -872,7 +875,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - log.Fatalf("Failed to save authentication tokens: %v", errSave) + log.Errorf("Failed to save authentication tokens: %v", errSave) oauthStatus[state] = "Failed to save authentication tokens" return } @@ -1045,7 +1048,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { gemAuth := geminiAuth.NewGeminiAuth() gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) if errGetClient != nil { - log.Fatalf("failed to get authenticated client: %v", errGetClient) + log.Errorf("failed to get authenticated client: %v", errGetClient) oauthStatus[state] = "Failed to get authenticated client" return } @@ -1110,7 +1113,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - log.Fatalf("Failed to save token to file: %v", errSave) + log.Errorf("Failed to save token to file: %v", errSave) oauthStatus[state] = "Failed to save token to file" return } @@ -1131,14 +1134,16 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { // Generate PKCE codes pkceCodes, err := codex.GeneratePKCECodes() if err != nil { - log.Fatalf("Failed to generate PKCE codes: %v", err) + log.Errorf("Failed to generate PKCE codes: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"}) return } // Generate random state parameter state, err := misc.GenerateRandomState() if err != nil { - log.Fatalf("Failed to generate state parameter: %v", err) + log.Errorf("Failed to generate state parameter: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"}) return } @@ -1148,7 +1153,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { // Generate authorization URL authURL, err := openaiAuth.GenerateAuthURL(state, pkceCodes) if err != nil { - log.Fatalf("Failed to generate authorization URL: %v", err) + log.Errorf("Failed to generate authorization URL: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"}) return } @@ -1283,7 +1289,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { oauthStatus[state] = "Failed to save authentication tokens" - log.Fatalf("Failed to save authentication tokens: %v", errSave) + log.Errorf("Failed to save authentication tokens: %v", errSave) return } fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) @@ -1318,7 +1324,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { state, errState := misc.GenerateRandomState() if errState != nil { - log.Fatalf("Failed to generate state parameter: %v", errState) + log.Errorf("Failed to generate state parameter: %v", errState) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"}) return } @@ -1514,7 +1521,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - log.Fatalf("Failed to save token to file: %v", errSave) + log.Errorf("Failed to save token to file: %v", errSave) oauthStatus[state] = "Failed to save token to file" return } @@ -1543,7 +1550,8 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { // Generate authorization URL deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx) if err != nil { - log.Fatalf("Failed to generate authorization URL: %v", err) + log.Errorf("Failed to generate authorization URL: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"}) return } authURL := deviceFlow.VerificationURIComplete @@ -1570,7 +1578,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - log.Fatalf("Failed to save authentication tokens: %v", errSave) + log.Errorf("Failed to save authentication tokens: %v", errSave) oauthStatus[state] = "Failed to save authentication tokens" return } @@ -1674,7 +1682,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { oauthStatus[state] = "Failed to save authentication tokens" - log.Fatalf("Failed to save authentication tokens: %v", errSave) + log.Errorf("Failed to save authentication tokens: %v", errSave) return } @@ -2103,6 +2111,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec continue } } + _ = resp.Body.Close() return false, fmt.Errorf("project activation required: %s", errMessage) } return true, nil diff --git a/internal/auth/gemini/gemini_auth.go b/internal/auth/gemini/gemini_auth.go index a6ac4507..f173c95f 100644 --- a/internal/auth/gemini/gemini_auth.go +++ b/internal/auth/gemini/gemini_auth.go @@ -76,7 +76,8 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken auth := &proxy.Auth{User: username, Password: password} dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct) if errSOCKS5 != nil { - log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5) + log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) + return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5) } transport = &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -238,7 +239,11 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, // Start the server in a goroutine. go func() { if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("ListenAndServe(): %v", err) + log.Errorf("ListenAndServe(): %v", err) + select { + case errChan <- err: + default: + } } }() diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 5e5159aa..de01cec5 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -65,20 +65,20 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { authenticator := sdkAuth.NewGeminiAuthenticator() record, errLogin := authenticator.Login(ctx, cfg, loginOpts) if errLogin != nil { - log.Fatalf("Gemini authentication failed: %v", errLogin) + log.Errorf("Gemini authentication failed: %v", errLogin) return } storage, okStorage := record.Storage.(*gemini.GeminiTokenStorage) if !okStorage || storage == nil { - log.Fatal("Gemini authentication failed: unsupported token storage") + log.Error("Gemini authentication failed: unsupported token storage") return } geminiAuth := gemini.NewGeminiAuth() httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser) if errClient != nil { - log.Fatalf("Gemini authentication failed: %v", errClient) + log.Errorf("Gemini authentication failed: %v", errClient) return } @@ -86,7 +86,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { projects, errProjects := fetchGCPProjects(ctx, httpClient) if errProjects != nil { - log.Fatalf("Failed to get project list: %v", errProjects) + log.Errorf("Failed to get project list: %v", errProjects) return } @@ -98,11 +98,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn) projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects) if errSelection != nil { - log.Fatalf("Invalid project selection: %v", errSelection) + log.Errorf("Invalid project selection: %v", errSelection) return } if len(projectSelections) == 0 { - log.Fatal("No project selected; aborting login.") + log.Error("No project selected; aborting login.") return } @@ -116,7 +116,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { showProjectSelectionHelp(storage.Email, projects) return } - log.Fatalf("Failed to complete user setup: %v", errSetup) + log.Errorf("Failed to complete user setup: %v", errSetup) return } finalID := strings.TrimSpace(storage.ProjectID) @@ -133,11 +133,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { for _, pid := range activatedProjects { isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, pid) if errCheck != nil { - log.Fatalf("Failed to check if Cloud AI API is enabled for %s: %v", pid, errCheck) + log.Errorf("Failed to check if Cloud AI API is enabled for %s: %v", pid, errCheck) return } if !isChecked { - log.Fatalf("Failed to check if Cloud AI API is enabled for project %s. If you encounter an error message, please create an issue.", pid) + log.Errorf("Failed to check if Cloud AI API is enabled for project %s. If you encounter an error message, please create an issue.", pid) return } } @@ -153,7 +153,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { savedPath, errSave := store.Save(ctx, record) if errSave != nil { - log.Fatalf("Failed to save token to file: %v", errSave) + log.Errorf("Failed to save token to file: %v", errSave) return } @@ -555,6 +555,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec continue } } + _ = resp.Body.Close() return false, fmt.Errorf("project activation required: %s", errMessage) } return true, nil diff --git a/internal/cmd/run.go b/internal/cmd/run.go index e2f6ee80..1e968126 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -45,12 +45,13 @@ func StartService(cfg *config.Config, configPath string, localPassword string) { service, err := builder.Build() if err != nil { - log.Fatalf("failed to build proxy service: %v", err) + log.Errorf("failed to build proxy service: %v", err) + return } err = service.Run(runCtx) if err != nil && !errors.Is(err, context.Canceled) { - log.Fatalf("proxy service exited with error: %v", err) + log.Errorf("proxy service exited with error: %v", err) } } diff --git a/internal/cmd/vertex_import.go b/internal/cmd/vertex_import.go index ebb32d0c..32d782d8 100644 --- a/internal/cmd/vertex_import.go +++ b/internal/cmd/vertex_import.go @@ -29,30 +29,30 @@ func DoVertexImport(cfg *config.Config, keyPath string) { } rawPath := strings.TrimSpace(keyPath) if rawPath == "" { - log.Fatalf("vertex-import: missing service account key path") + log.Errorf("vertex-import: missing service account key path") return } data, errRead := os.ReadFile(rawPath) if errRead != nil { - log.Fatalf("vertex-import: read file failed: %v", errRead) + log.Errorf("vertex-import: read file failed: %v", errRead) return } var sa map[string]any if errUnmarshal := json.Unmarshal(data, &sa); errUnmarshal != nil { - log.Fatalf("vertex-import: invalid service account json: %v", errUnmarshal) + log.Errorf("vertex-import: invalid service account json: %v", errUnmarshal) return } // Validate and normalize private_key before saving normalizedSA, errFix := vertex.NormalizeServiceAccountMap(sa) if errFix != nil { - log.Fatalf("vertex-import: %v", errFix) + log.Errorf("vertex-import: %v", errFix) return } sa = normalizedSA email, _ := sa["client_email"].(string) projectID, _ := sa["project_id"].(string) if strings.TrimSpace(projectID) == "" { - log.Fatalf("vertex-import: project_id missing in service account json") + log.Errorf("vertex-import: project_id missing in service account json") return } if strings.TrimSpace(email) == "" { @@ -92,7 +92,7 @@ func DoVertexImport(cfg *config.Config, keyPath string) { } path, errSave := store.Save(context.Background(), record) if errSave != nil { - log.Fatalf("vertex-import: save credential failed: %v", errSave) + log.Errorf("vertex-import: save credential failed: %v", errSave) return } fmt.Printf("Vertex credentials imported: %s\n", path) From 39b6b3b289fdf9cd0d2d20e4da296ab6f16b8b75 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Tue, 9 Dec 2025 17:32:17 +0800 Subject: [PATCH 20/34] Fixed: #463 fix(antigravity): remove `$ref` and `$defs` from JSON during key deletion --- internal/runtime/executor/antigravity_executor.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 730a32fb..b74f43e1 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -537,6 +537,8 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau strJSON = util.DeleteKey(strJSON, "maxLength") strJSON = util.DeleteKey(strJSON, "exclusiveMinimum") strJSON = util.DeleteKey(strJSON, "exclusiveMaximum") + strJSON = util.DeleteKey(strJSON, "$ref") + strJSON = util.DeleteKey(strJSON, "$defs") paths = make([]string, 0) util.Walk(gjson.Parse(strJSON), "", "anyOf", &paths) From da23ddb061a029bd51e321058183ada19f98cda9 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Tue, 9 Dec 2025 17:34:15 +0800 Subject: [PATCH 21/34] fix(gemini): normalize model listing output --- sdk/api/handlers/gemini/gemini_handlers.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/sdk/api/handlers/gemini/gemini_handlers.go b/sdk/api/handlers/gemini/gemini_handlers.go index 7ba72a93..6cd9ee62 100644 --- a/sdk/api/handlers/gemini/gemini_handlers.go +++ b/sdk/api/handlers/gemini/gemini_handlers.go @@ -48,8 +48,24 @@ func (h *GeminiAPIHandler) Models() []map[string]any { // GeminiModels handles the Gemini models listing endpoint. // It returns a JSON response containing available Gemini models and their specifications. func (h *GeminiAPIHandler) GeminiModels(c *gin.Context) { + rawModels := h.Models() + normalizedModels := make([]map[string]any, 0, len(rawModels)) + defaultMethods := []string{"generateContent"} + for _, model := range rawModels { + normalizedModel := make(map[string]any, len(model)) + for k, v := range model { + normalizedModel[k] = v + } + if name, ok := normalizedModel["name"].(string); ok && name != "" && !strings.HasPrefix(name, "models/") { + normalizedModel["name"] = "models/" + name + } + if _, ok := normalizedModel["supportedGenerationMethods"]; !ok { + normalizedModel["supportedGenerationMethods"] = defaultMethods + } + normalizedModels = append(normalizedModels, normalizedModel) + } c.JSON(http.StatusOK, gin.H{ - "models": h.Models(), + "models": normalizedModels, }) } From 3cfe7008a2e857589b618f7eac5568b0c8d63acf Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Tue, 9 Dec 2025 17:55:21 +0800 Subject: [PATCH 22/34] fix(registry): update gpt 5.1 model names --- internal/registry/model_definitions.go | 60 +++++++++++++------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 77015d14..de547182 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -693,8 +693,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Low", - Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 Nothink", + Description: "Stable version of GPT 5.1, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -719,8 +719,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Medium", - Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 Medium", + Description: "Stable version of GPT 5.1, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -732,8 +732,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 High", - Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 High", + Description: "Stable version of GPT 5.1, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -745,8 +745,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex", - Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 Codex", + Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -758,8 +758,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex Low", - Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 Codex Low", + Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -771,8 +771,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex Medium", - Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 Codex Medium", + Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -784,8 +784,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex High", - Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 Codex High", + Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -797,8 +797,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex Mini", - Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", + DisplayName: "GPT 5.1 Codex Mini", + Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -810,8 +810,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex Mini Medium", - Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", + DisplayName: "GPT 5.1 Codex Mini Medium", + Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -823,8 +823,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex Mini High", - Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", + DisplayName: "GPT 5.1 Codex Mini High", + Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -837,8 +837,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-max", - DisplayName: "GPT 5 Codex Max", - Description: "Stable version of GPT 5 Codex Max", + DisplayName: "GPT 5.1 Codex Max", + Description: "Stable version of GPT 5.1 Codex Max", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -850,8 +850,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-max", - DisplayName: "GPT 5 Codex Max Low", - Description: "Stable version of GPT 5 Codex Max Low", + DisplayName: "GPT 5.1 Codex Max Low", + Description: "Stable version of GPT 5.1 Codex Max Low", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -863,8 +863,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-max", - DisplayName: "GPT 5 Codex Max Medium", - Description: "Stable version of GPT 5 Codex Max Medium", + DisplayName: "GPT 5.1 Codex Max Medium", + Description: "Stable version of GPT 5.1 Codex Max Medium", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -876,8 +876,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-max", - DisplayName: "GPT 5 Codex Max High", - Description: "Stable version of GPT 5 Codex Max High", + DisplayName: "GPT 5.1 Codex Max High", + Description: "Stable version of GPT 5.1 Codex Max High", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -889,8 +889,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-max", - DisplayName: "GPT 5 Codex Max XHigh", - Description: "Stable version of GPT 5 Codex Max XHigh", + DisplayName: "GPT 5.1 Codex Max XHigh", + Description: "Stable version of GPT 5.1 Codex Max XHigh", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, From 347769b3e30ff5a1562675e4f09789906bec6a10 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Tue, 9 Dec 2025 18:09:14 +0800 Subject: [PATCH 23/34] fix(openai-compat): use model id for auth model display --- sdk/cliproxy/service.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 13d647dd..1ef829d1 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -779,7 +779,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { Created: time.Now().Unix(), OwnedBy: compat.Name, Type: "openai-compatibility", - DisplayName: m.Name, + DisplayName: modelID, }) } // Register and return From 1fa5514d56c8ff7a56ecfca91ffb91839e752995 Mon Sep 17 00:00:00 2001 From: Mario Date: Tue, 9 Dec 2025 20:13:16 +0800 Subject: [PATCH 24/34] fix kiro cannot refresh the token --- internal/watcher/watcher.go | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index da152141..36276de9 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -1272,7 +1272,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } for i := range cfg.KiroKey { kk := cfg.KiroKey[i] - var accessToken, profileArn string + var accessToken, profileArn, refreshToken string // Try to load from token file first if kk.TokenFile != "" && kAuth != nil { @@ -1282,6 +1282,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } else { accessToken = tokenData.AccessToken profileArn = tokenData.ProfileArn + refreshToken = tokenData.RefreshToken } } @@ -1292,6 +1293,9 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if kk.ProfileArn != "" { profileArn = kk.ProfileArn } + if kk.RefreshToken != "" { + refreshToken = kk.RefreshToken + } if accessToken == "" { log.Warnf("kiro config[%d] missing access_token, skipping", i) @@ -1313,6 +1317,9 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if kk.AgentTaskType != "" { attrs["agent_task_type"] = kk.AgentTaskType } + if refreshToken != "" { + attrs["refresh_token"] = refreshToken + } proxyURL := strings.TrimSpace(kk.ProxyURL) a := &coreauth.Auth{ ID: id, @@ -1324,6 +1331,14 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { CreatedAt: now, UpdatedAt: now, } + + if refreshToken != "" { + if a.Metadata == nil { + a.Metadata = make(map[string]any) + } + a.Metadata["refresh_token"] = refreshToken + } + out = append(out, a) } for i := range cfg.OpenAICompatibility { From 5ec9b5e5a9c6b095ce861eeedf91f1766303556c Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Tue, 9 Dec 2025 09:25:25 +0800 Subject: [PATCH 25/34] feat(executor): normalize thinking budget across all Gemini executors --- .../runtime/executor/aistudio_executor.go | 5 +--- .../runtime/executor/gemini_cli_executor.go | 2 ++ internal/runtime/executor/gemini_executor.go | 2 ++ .../executor/gemini_vertex_executor.go | 4 +++ internal/util/gemini_thinking.go | 26 +++++++++++++++++++ 5 files changed, 35 insertions(+), 4 deletions(-) diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 898c08c7..94b48de7 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -310,10 +310,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) payload = applyThinkingMetadata(payload, req.Metadata, req.Model) payload = util.ConvertThinkingLevelToBudget(payload) - if budget := gjson.GetBytes(payload, "generationConfig.thinkingConfig.thinkingBudget"); budget.Exists() { - normalized := util.NormalizeThinkingBudget(req.Model, int(budget.Int())) - payload, _ = sjson.SetBytes(payload, "generationConfig.thinkingConfig.thinkingBudget", normalized) - } + payload = util.NormalizeGeminiThinkingBudget(req.Model, payload) payload = util.StripThinkingConfigIfUnsupported(req.Model, payload) payload = fixGeminiImageAspectRatio(req.Model, payload) payload = applyPayloadConfig(e.cfg, req.Model, payload) diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 147a1ea1..520320ec 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -64,6 +64,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth to := sdktranslator.FromString("gemini-cli") basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model) + basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload) basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload) basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload) basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload) @@ -199,6 +200,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut to := sdktranslator.FromString("gemini-cli") basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model) + basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload) basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload) basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload) basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload) diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index fc7b8e19..4184e88b 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -80,6 +80,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r to := sdktranslator.FromString("gemini") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body = applyThinkingMetadata(body, req.Metadata, req.Model) + body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) @@ -169,6 +170,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A to := sdktranslator.FromString("gemini") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body = applyThinkingMetadata(body, req.Metadata, req.Model) + body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index de4ba072..3caf1cd0 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -296,6 +296,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } + body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) @@ -391,6 +392,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } + body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) @@ -487,6 +489,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } + body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) @@ -599,6 +602,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } + body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) diff --git a/internal/util/gemini_thinking.go b/internal/util/gemini_thinking.go index 14077fa0..85f8d74d 100644 --- a/internal/util/gemini_thinking.go +++ b/internal/util/gemini_thinking.go @@ -223,6 +223,32 @@ func StripThinkingConfigIfUnsupported(model string, body []byte) []byte { return updated } +// NormalizeGeminiThinkingBudget normalizes the thinkingBudget value in a standard Gemini +// request body (generationConfig.thinkingConfig.thinkingBudget path). +func NormalizeGeminiThinkingBudget(model string, body []byte) []byte { + const budgetPath = "generationConfig.thinkingConfig.thinkingBudget" + budget := gjson.GetBytes(body, budgetPath) + if !budget.Exists() { + return body + } + normalized := NormalizeThinkingBudget(model, int(budget.Int())) + updated, _ := sjson.SetBytes(body, budgetPath, normalized) + return updated +} + +// NormalizeGeminiCLIThinkingBudget normalizes the thinkingBudget value in a Gemini CLI +// request body (request.generationConfig.thinkingConfig.thinkingBudget path). +func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte { + const budgetPath = "request.generationConfig.thinkingConfig.thinkingBudget" + budget := gjson.GetBytes(body, budgetPath) + if !budget.Exists() { + return body + } + normalized := NormalizeThinkingBudget(model, int(budget.Int())) + updated, _ := sjson.SetBytes(body, budgetPath, normalized) + return updated +} + // ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel" // and converts it to "thinkingBudget". // "high" -> 32768 From 5b6d201408f49be469c619cd83cebe370660470e Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Tue, 9 Dec 2025 09:36:36 +0800 Subject: [PATCH 26/34] refactor(translator): remove thinking budget normalization across all translators --- .../claude/antigravity_claude_request.go | 1 - .../antigravity_openai_request.go | 20 +++++++++---------- .../claude/gemini-cli_claude_request.go | 1 - .../gemini-cli_openai_request.go | 18 ++++++++--------- .../gemini/claude/gemini_claude_request.go | 1 - .../chat-completions/gemini_openai_request.go | 18 ++++++++--------- .../gemini_openai-responses_request.go | 16 +++++++-------- 7 files changed, 36 insertions(+), 39 deletions(-) diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index e1b73da0..a810ba7a 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -180,7 +180,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ if t.Get("type").String() == "enabled" { if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number { budget := int(b.Int()) - budget = util.NormalizeThinkingBudget(modelName, budget) out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.include_thoughts", true) } diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index 1c90a803..b3d8b04d 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -48,13 +48,13 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) case "low": - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024)) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) case "medium": - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192)) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) case "high": - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768)) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) default: out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1) @@ -66,15 +66,15 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ if !hasOfficialThinking && util.ModelSupportsThinking(modelName) { if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() { var setBudget bool - var normalized int + var budget int if v := tc.Get("thinkingBudget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } else if v := tc.Get("thinking_budget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } @@ -82,7 +82,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool()) } else if v := tc.Get("include_thoughts"); v.Exists() { out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool()) - } else if setBudget && normalized != 0 { + } else if setBudget && budget != 0 { out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) } } @@ -94,7 +94,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() { if t.Get("type").String() == "enabled" { if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number { - budget := util.NormalizeThinkingBudget(modelName, int(b.Int())) + budget := int(b.Int()) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) } diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go index 50fd5a25..913727ce 100644 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go +++ b/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go @@ -165,7 +165,6 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) [] if t.Get("type").String() == "enabled" { if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number { budget := int(b.Int()) - budget = util.NormalizeThinkingBudget(modelName, budget) out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.include_thoughts", true) } diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go index d14f1119..0cb3cd76 100644 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go +++ b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go @@ -48,13 +48,13 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) case "low": - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024)) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) case "medium": - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192)) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) case "high": - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768)) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) default: out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1) @@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo if !hasOfficialThinking && util.ModelSupportsThinking(modelName) { if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() { var setBudget bool - var normalized int + var budget int if v := tc.Get("thinkingBudget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } else if v := tc.Get("thinking_budget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } @@ -82,7 +82,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool()) } else if v := tc.Get("include_thoughts"); v.Exists() { out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool()) - } else if setBudget && normalized != 0 { + } else if setBudget && budget != 0 { out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) } } diff --git a/internal/translator/gemini/claude/gemini_claude_request.go b/internal/translator/gemini/claude/gemini_claude_request.go index 05f9be5d..45a5a88f 100644 --- a/internal/translator/gemini/claude/gemini_claude_request.go +++ b/internal/translator/gemini/claude/gemini_claude_request.go @@ -158,7 +158,6 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) if t.Get("type").String() == "enabled" { if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number { budget := int(b.Int()) - budget = util.NormalizeThinkingBudget(modelName, budget) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) } diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go index 0df8987f..8c48a5b3 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go @@ -48,13 +48,13 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1) out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true) case "low": - out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024)) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 1024) out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true) case "medium": - out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192)) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 8192) out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true) case "high": - out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768)) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 32768) out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true) default: out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1) @@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) if !hasOfficialThinking && util.ModelSupportsThinking(modelName) { if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() { var setBudget bool - var normalized int + var budget int if v := tc.Get("thinkingBudget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } else if v := tc.Get("thinking_budget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } @@ -82,7 +82,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool()) } else if v := tc.Get("include_thoughts"); v.Exists() { out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool()) - } else if setBudget && normalized != 0 { + } else if setBudget && budget != 0 { out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true) } } diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go index 4ea75c18..1df1d226 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go @@ -400,16 +400,16 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) case "minimal": - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024)) + out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 1024) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) case "low": - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 4096)) + out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 4096) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) case "medium": - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192)) + out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 8192) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) case "high": - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768)) + out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 32768) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) default: out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1) @@ -421,16 +421,16 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte if !hasOfficialThinking && util.ModelSupportsThinking(modelName) { if tc := root.Get("extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() { var setBudget bool - var normalized int + var budget int if v := tc.Get("thinking_budget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } if v := tc.Get("include_thoughts"); v.Exists() { out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool()) } else if setBudget { - if normalized != 0 { + if budget != 0 { out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) } } From 6a66b6801a8d7ea17acd7f1e5766ddab4335e42e Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Tue, 9 Dec 2025 11:26:42 +0800 Subject: [PATCH 27/34] feat(executor): enforce minimum thinking budget for antigravity models --- .../runtime/executor/antigravity_executor.go | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 155193da..683285d1 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -839,9 +839,12 @@ func normalizeAntigravityThinking(model string, payload []byte) []byte { effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload) if effectiveMax > 0 && normalized >= effectiveMax { normalized = effectiveMax - 1 - if normalized < 1 { - normalized = 1 - } + } + minBudget := antigravityMinThinkingBudget(model) + if minBudget > 0 && normalized >= 0 && normalized < minBudget { + // Budget is below minimum, remove thinking config entirely + payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.thinkingConfig") + return payload } if setDefaultMax { if res, errSet := sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax); errSet == nil { @@ -869,3 +872,12 @@ func antigravityEffectiveMaxTokens(model string, payload []byte) (max int, fromM } return 0, false } + +// antigravityMinThinkingBudget returns the minimum thinking budget for a model. +// Falls back to -1 if no model info is found. +func antigravityMinThinkingBudget(model string) int { + if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.Thinking != nil { + return modelInfo.Thinking.Min + } + return -1 +} From 9b202b6c1c44adac285440390a165b481b06d292 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Tue, 9 Dec 2025 13:23:50 +0800 Subject: [PATCH 28/34] fix(executor): centralize default thinking config --- .../runtime/executor/aistudio_executor.go | 1 + .../runtime/executor/antigravity_executor.go | 2 + .../runtime/executor/gemini_cli_executor.go | 2 + internal/runtime/executor/gemini_executor.go | 2 + .../executor/gemini_vertex_executor.go | 4 ++ .../antigravity_openai_request.go | 9 ---- .../gemini-cli_openai_request.go | 9 ---- .../gemini_openai-responses_request.go | 10 ----- internal/util/gemini_thinking.go | 41 +++++++++++++++++++ 9 files changed, 52 insertions(+), 28 deletions(-) diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 94b48de7..d37cd2c2 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -309,6 +309,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c to := sdktranslator.FromString("gemini") payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) payload = applyThinkingMetadata(payload, req.Metadata, req.Model) + payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload) payload = util.ConvertThinkingLevelToBudget(payload) payload = util.NormalizeGeminiThinkingBudget(req.Model, payload) payload = util.StripThinkingConfigIfUnsupported(req.Model, payload) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 683285d1..52b91450 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -77,6 +77,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated) baseURLs := antigravityBaseURLFallbackOrder(auth) @@ -171,6 +172,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated) baseURLs := antigravityBaseURLFallbackOrder(auth) diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 520320ec..a2e0ecec 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -64,6 +64,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth to := sdktranslator.FromString("gemini-cli") basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model) + basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload) basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload) basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload) basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload) @@ -200,6 +201,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut to := sdktranslator.FromString("gemini-cli") basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model) + basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload) basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload) basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload) basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload) diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index 4184e88b..8879a4f1 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -80,6 +80,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r to := sdktranslator.FromString("gemini") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body = applyThinkingMetadata(body, req.Metadata, req.Model) + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) @@ -170,6 +171,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A to := sdktranslator.FromString("gemini") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body = applyThinkingMetadata(body, req.Metadata, req.Model) + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index 3caf1cd0..c7d10a67 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -296,6 +296,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) @@ -392,6 +393,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) @@ -489,6 +491,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) @@ -602,6 +605,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index b3d8b04d..717f88f7 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -102,15 +102,6 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ } } - // For gemini-3-pro-preview, always send default thinkingConfig when none specified. - // This matches the official Gemini CLI behavior which always sends: - // { thinkingBudget: -1, includeThoughts: true } - // See: ai-gemini-cli/packages/core/src/config/defaultModelConfigs.ts - if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && modelName == "gemini-3-pro-preview" { - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) - } - // Temperature/top_p/top_k/max_tokens if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number { out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num) diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go index 0cb3cd76..b52bf224 100644 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go +++ b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go @@ -88,15 +88,6 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo } } - // For gemini-3-pro-preview, always send default thinkingConfig when none specified. - // This matches the official Gemini CLI behavior which always sends: - // { thinkingBudget: -1, includeThoughts: true } - // See: ai-gemini-cli/packages/core/src/config/defaultModelConfigs.ts - if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && modelName == "gemini-3-pro-preview" { - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) - } - // Temperature/top_p/top_k if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number { out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num) diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go index 1df1d226..bdf59785 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go @@ -437,16 +437,6 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte } } - // For gemini-3-pro-preview, always send default thinkingConfig when none specified. - // This matches the official Gemini CLI behavior which always sends: - // { thinkingBudget: -1, includeThoughts: true } - // See: ai-gemini-cli/packages/core/src/config/defaultModelConfigs.ts - if !gjson.Get(out, "generationConfig.thinkingConfig").Exists() && modelName == "gemini-3-pro-preview" { - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1) - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) - // log.Debugf("Applied default thinkingConfig for gemini-3-pro-preview (matches Gemini CLI): thinkingBudget=-1, include_thoughts=true") - } - result := []byte(out) result = common.AttachDefaultSafetySettings(result, "safetySettings") return result diff --git a/internal/util/gemini_thinking.go b/internal/util/gemini_thinking.go index 85f8d74d..fc389511 100644 --- a/internal/util/gemini_thinking.go +++ b/internal/util/gemini_thinking.go @@ -207,6 +207,47 @@ func GeminiThinkingFromMetadata(metadata map[string]any) (*int, *bool, bool) { return budgetPtr, includePtr, matched } +// modelsWithDefaultThinking lists models that should have thinking enabled by default +// when no explicit thinkingConfig is provided. +var modelsWithDefaultThinking = map[string]bool{ + "gemini-3-pro-preview": true, +} + +// ModelHasDefaultThinking returns true if the model should have thinking enabled by default. +func ModelHasDefaultThinking(model string) bool { + return modelsWithDefaultThinking[model] +} + +// ApplyDefaultThinkingIfNeeded injects default thinkingConfig for models that require it. +// For standard Gemini API format (generationConfig.thinkingConfig path). +// Returns the modified body if thinkingConfig was added, otherwise returns the original. +func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte { + if !ModelHasDefaultThinking(model) { + return body + } + if gjson.GetBytes(body, "generationConfig.thinkingConfig").Exists() { + return body + } + updated, _ := sjson.SetBytes(body, "generationConfig.thinkingConfig.thinkingBudget", -1) + updated, _ = sjson.SetBytes(updated, "generationConfig.thinkingConfig.include_thoughts", true) + return updated +} + +// ApplyDefaultThinkingIfNeededCLI injects default thinkingConfig for models that require it. +// For Gemini CLI API format (request.generationConfig.thinkingConfig path). +// Returns the modified body if thinkingConfig was added, otherwise returns the original. +func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte { + if !ModelHasDefaultThinking(model) { + return body + } + if gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists() { + return body + } + updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget", -1) + updated, _ = sjson.SetBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts", true) + return updated +} + // StripThinkingConfigIfUnsupported removes thinkingConfig from the request body // when the target model does not advertise Thinking capability. It cleans both // standard Gemini and Gemini CLI JSON envelopes. This acts as a final safety net From 70d6b95097ce38eb512484a7bd60667fd014221a Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Tue, 9 Dec 2025 14:05:08 +0800 Subject: [PATCH 29/34] feat(amp): add /news.rss proxy route --- internal/api/modules/amp/routes.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 0c1fcadb..48fbbbb9 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -156,6 +156,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()} engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...) engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...) + engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...) // Root-level auth routes for CLI login flow // Amp uses multiple auth routes: /auth/cli-login, /auth/callback, /auth/sign-in, /auth/logout From f25f419e5aa5503f63d612d1a0d0a9dbee8504fe Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 10 Dec 2025 00:13:20 +0800 Subject: [PATCH 30/34] fix(antigravity): remove references to `autopush` endpoint and update fallback logic --- .../runtime/executor/antigravity_executor.go | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 52b91450..a32e66ec 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -27,18 +27,18 @@ import ( ) const ( - antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com" - antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com" - antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com" - antigravityStreamPath = "/v1internal:streamGenerateContent" - antigravityGeneratePath = "/v1internal:generateContent" - antigravityModelsPath = "/v1internal:fetchAvailableModels" - antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" - antigravityClientSecret = "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" - defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64" - antigravityAuthType = "antigravity" - refreshSkew = 3000 * time.Second - streamScannerBuffer int = 20_971_520 + antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com" + // antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com" + antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com" + antigravityStreamPath = "/v1internal:streamGenerateContent" + antigravityGeneratePath = "/v1internal:generateContent" + antigravityModelsPath = "/v1internal:fetchAvailableModels" + antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" + antigravityClientSecret = "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" + defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64" + antigravityAuthType = "antigravity" + refreshSkew = 3000 * time.Second + streamScannerBuffer int = 20_971_520 ) var randSource = rand.New(rand.NewSource(time.Now().UnixNano())) @@ -661,7 +661,7 @@ func buildBaseURL(auth *cliproxyauth.Auth) string { if baseURLs := antigravityBaseURLFallbackOrder(auth); len(baseURLs) > 0 { return baseURLs[0] } - return antigravityBaseURLAutopush + return antigravityBaseURLDaily } func resolveHost(base string) string { @@ -697,7 +697,7 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string { } return []string{ antigravityBaseURLDaily, - antigravityBaseURLAutopush, + // antigravityBaseURLAutopush, antigravityBaseURLProd, } } From a594338bc57cc6df86b90a5ce10a72f2de880a07 Mon Sep 17 00:00:00 2001 From: fuko2935 Date: Tue, 9 Dec 2025 19:14:40 +0300 Subject: [PATCH 31/34] fix(registry): remove unstable kiro-auto model - Removes kiro-auto from static model registry - Removes kiro-auto mapping from executor - Fixes compatibility issues reported in #7 Fixes #7 --- internal/registry/model_definitions.go | 11 ----------- internal/runtime/executor/kiro_executor.go | 2 -- 2 files changed, 13 deletions(-) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 3c31e61f..31f08f98 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -1195,17 +1195,6 @@ func GetGitHubCopilotModels() []*ModelInfo { // GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions func GetKiroModels() []*ModelInfo { return []*ModelInfo{ - { - ID: "kiro-auto", - Object: "model", - Created: 1732752000, // 2024-11-28 - OwnedBy: "aws", - Type: "kiro", - DisplayName: "Kiro Auto", - Description: "Automatic model selection by AWS CodeWhisperer", - ContextLength: 200000, - MaxCompletionTokens: 64000, - }, { ID: "kiro-claude-opus-4.5", Object: "model", diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 0157d68c..b965c9ca 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -547,8 +547,6 @@ func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { // Agentic variants (-agentic suffix) map to the same backend model IDs. func (e *KiroExecutor) mapModelToKiro(model string) string { modelMap := map[string]string{ - // Proxy format (kiro- prefix) - "kiro-auto": "auto", "kiro-claude-opus-4.5": "claude-opus-4.5", "kiro-claude-sonnet-4.5": "claude-sonnet-4.5", "kiro-claude-sonnet-4": "claude-sonnet-4", From 6b37f33d31515ab49e2c1688d278a662566c9ec2 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 10 Dec 2025 15:27:57 +0800 Subject: [PATCH 32/34] feat(antigravity): add unique identifier for tool use blocks in response --- .../antigravity/claude/antigravity_claude_response.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index 4073f20b..42265e80 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -36,6 +37,9 @@ type Params struct { HasToolUse bool // Indicates if tool use was observed in the stream } +// toolUseIDCounter provides a process-wide unique counter for tool use identifiers. +var toolUseIDCounter uint64 + // ConvertAntigravityResponseToClaude performs sophisticated streaming response format conversion. // This function implements a complex state machine that translates backend client responses // into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types @@ -216,7 +220,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex) - data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))) data, _ = sjson.Set(data, "content_block.name", fcName) output = output + fmt.Sprintf("data: %s\n\n\n", data) From 1249b07eb85551982dae1e1058124ed30cb678ef Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 10 Dec 2025 16:02:54 +0800 Subject: [PATCH 33/34] feat(responses): add unique identifiers for responses, function calls, and tool uses --- .../chat-completions/antigravity_openai_response.go | 6 +++++- .../gemini-cli/claude/gemini-cli_claude_response.go | 6 +++++- .../chat-completions/gemini-cli_openai_response.go | 6 +++++- .../gemini/claude/gemini_claude_response.go | 6 +++++- .../chat-completions/gemini_openai_response.go | 8 ++++++-- .../responses/gemini_openai-responses_response.go | 13 ++++++++++--- .../responses/openai_openai-responses_response.go | 6 +++++- 7 files changed, 41 insertions(+), 10 deletions(-) diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go index e069f7ec..24694e1d 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" . "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions" @@ -24,6 +25,9 @@ type convertCliResponseToOpenAIChatParams struct { FunctionIndex int } +// functionCallIDCounter provides a process-wide unique counter for function call identifiers. +var functionCallIDCounter uint64 + // ConvertAntigravityResponseToOpenAI translates a single chunk of a streaming response from the // Gemini CLI API format to the OpenAI Chat Completions streaming format. // It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses. @@ -146,7 +150,7 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}` fcName := functionCallResult.Get("name").String() - functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))) functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex) functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go index 733668f3..9b37c52b 100644 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go +++ b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -27,6 +28,9 @@ type Params struct { ResponseIndex int // Index counter for content blocks in the streaming response } +// toolUseIDCounter provides a process-wide unique counter for tool use identifiers. +var toolUseIDCounter uint64 + // ConvertGeminiCLIResponseToClaude performs sophisticated streaming response format conversion. // This function implements a complex state machine that translates backend client responses // into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types @@ -197,7 +201,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) - data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))) data, _ = sjson.Set(data, "content_block.name", fcName) output = output + fmt.Sprintf("data: %s\n\n\n", data) diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go index 9c422a07..753870f3 100644 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go +++ b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" . "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions" @@ -24,6 +25,9 @@ type convertCliResponseToOpenAIChatParams struct { FunctionIndex int } +// functionCallIDCounter provides a process-wide unique counter for function call identifiers. +var functionCallIDCounter uint64 + // ConvertCliResponseToOpenAI translates a single chunk of a streaming response from the // Gemini CLI API format to the OpenAI Chat Completions streaming format. // It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses. @@ -146,7 +150,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}` fcName := functionCallResult.Get("name").String() - functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))) functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex) functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { diff --git a/internal/translator/gemini/claude/gemini_claude_response.go b/internal/translator/gemini/claude/gemini_claude_response.go index a80171a9..8fd566df 100644 --- a/internal/translator/gemini/claude/gemini_claude_response.go +++ b/internal/translator/gemini/claude/gemini_claude_response.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -26,6 +27,9 @@ type Params struct { ResponseIndex int } +// toolUseIDCounter provides a process-wide unique counter for tool use identifiers. +var toolUseIDCounter uint64 + // ConvertGeminiResponseToClaude performs sophisticated streaming response format conversion. // This function implements a complex state machine that translates backend client responses // into Claude-compatible Server-Sent Events (SSE) format. It manages different response types @@ -197,7 +201,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) - data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))) data, _ = sjson.Set(data, "content_block.name", fcName) output = output + fmt.Sprintf("data: %s\n\n\n", data) diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go index 12e28cca..a1ebc855 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -23,6 +24,9 @@ type convertGeminiResponseToOpenAIChatParams struct { FunctionIndex int } +// functionCallIDCounter provides a process-wide unique counter for function call identifiers. +var functionCallIDCounter uint64 + // ConvertGeminiResponseToOpenAI translates a single chunk of a streaming response from the // Gemini API format to the OpenAI Chat Completions streaming format. // It processes various Gemini event types and transforms them into OpenAI-compatible JSON responses. @@ -148,7 +152,7 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}` fcName := functionCallResult.Get("name").String() - functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))) functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex) functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { @@ -281,7 +285,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina } functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}` fcName := functionCallResult.Get("name").String() - functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))) functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw) diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go index ce221863..e08b265d 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -37,6 +38,12 @@ type geminiToResponsesState struct { FuncCallIDs map[int]string } +// responseIDCounter provides a process-wide unique counter for synthesized response identifiers. +var responseIDCounter uint64 + +// funcCallIDCounter provides a process-wide unique counter for function call identifiers. +var funcCallIDCounter uint64 + func emitEvent(event string, payload string) string { return fmt.Sprintf("event: %s\ndata: %s", event, payload) } @@ -205,7 +212,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, st.FuncArgsBuf[idx] = &strings.Builder{} } if st.FuncCallIDs[idx] == "" { - st.FuncCallIDs[idx] = fmt.Sprintf("call_%d", time.Now().UnixNano()) + st.FuncCallIDs[idx] = fmt.Sprintf("call_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1)) } st.FuncNames[idx] = name @@ -464,7 +471,7 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string // id: prefer provider responseId, otherwise synthesize id := root.Get("responseId").String() if id == "" { - id = fmt.Sprintf("resp_%x", time.Now().UnixNano()) + id = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1)) } // Normalize to response-style id (prefix resp_ if missing) if !strings.HasPrefix(id, "resp_") { @@ -575,7 +582,7 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string if fc := p.Get("functionCall"); fc.Exists() { name := fc.Get("name").String() args := fc.Get("args") - callID := fmt.Sprintf("call_%x", time.Now().UnixNano()) + callID := fmt.Sprintf("call_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1)) outputs = append(outputs, map[string]interface{}{ "id": fmt.Sprintf("fc_%s", callID), "type": "function_call", diff --git a/internal/translator/openai/openai/responses/openai_openai-responses_response.go b/internal/translator/openai/openai/responses/openai_openai-responses_response.go index 00ec5c7f..c698b93f 100644 --- a/internal/translator/openai/openai/responses/openai_openai-responses_response.go +++ b/internal/translator/openai/openai/responses/openai_openai-responses_response.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -41,6 +42,9 @@ type oaiToResponsesState struct { UsageSeen bool } +// responseIDCounter provides a process-wide unique counter for synthesized response identifiers. +var responseIDCounter uint64 + func emitRespEvent(event string, payload string) string { return fmt.Sprintf("event: %s\ndata: %s", event, payload) } @@ -590,7 +594,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co // id: use provider id if present, otherwise synthesize id := root.Get("id").String() if id == "" { - id = fmt.Sprintf("resp_%x", time.Now().UnixNano()) + id = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1)) } resp, _ = sjson.Set(resp, "id", id) From 94d61c7b2b1d618216f2be92795c4eec50733248 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 10 Dec 2025 16:53:48 +0800 Subject: [PATCH 34/34] fix(logging): update response aggregation logic to include all attempts --- internal/runtime/executor/logging_helpers.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/runtime/executor/logging_helpers.go b/internal/runtime/executor/logging_helpers.go index 7798b96b..26931f53 100644 --- a/internal/runtime/executor/logging_helpers.go +++ b/internal/runtime/executor/logging_helpers.go @@ -157,7 +157,7 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt if ginCtx == nil { return } - _, attempt := ensureAttempt(ginCtx) + attempts, attempt := ensureAttempt(ginCtx) ensureResponseIntro(attempt) if !attempt.headersWritten { @@ -175,6 +175,8 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt } attempt.response.WriteString(string(data)) attempt.bodyHasContent = true + + updateAggregatedResponse(ginCtx, attempts) } func ginContextFrom(ctx context.Context) *gin.Context {