diff --git a/cmd/server/main.go b/cmd/server/main.go index a5294d2e..fe648f6c 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -166,7 +166,8 @@ func main() { wd, err := os.Getwd() if err != nil { - log.Fatalf("failed to get working directory: %v", err) + log.Errorf("failed to get working directory: %v", err) + return } // Load environment variables from .env if present. @@ -260,13 +261,15 @@ func main() { }) cancel() if err != nil { - log.Fatalf("failed to initialize postgres token store: %v", err) + log.Errorf("failed to initialize postgres token store: %v", err) + return } examplePath := filepath.Join(wd, "config.example.yaml") ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second) if errBootstrap := pgStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil { cancel() - log.Fatalf("failed to bootstrap postgres-backed config: %v", errBootstrap) + log.Errorf("failed to bootstrap postgres-backed config: %v", errBootstrap) + return } cancel() configFilePath = pgStoreInst.ConfigPath() @@ -289,7 +292,8 @@ func main() { if strings.Contains(resolvedEndpoint, "://") { parsed, errParse := url.Parse(resolvedEndpoint) if errParse != nil { - log.Fatalf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse) + log.Errorf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse) + return } switch strings.ToLower(parsed.Scheme) { case "http": @@ -297,10 +301,12 @@ func main() { case "https": useSSL = true default: - log.Fatalf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme) + log.Errorf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme) + return } if parsed.Host == "" { - log.Fatalf("object store endpoint %q is missing host information", objectStoreEndpoint) + log.Errorf("object store endpoint %q is missing host information", objectStoreEndpoint) + return } resolvedEndpoint = parsed.Host if parsed.Path != "" && parsed.Path != "/" { @@ -319,13 +325,15 @@ func main() { } objectStoreInst, err = store.NewObjectTokenStore(objCfg) if err != nil { - log.Fatalf("failed to initialize object token store: %v", err) + log.Errorf("failed to initialize object token store: %v", err) + return } examplePath := filepath.Join(wd, "config.example.yaml") ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) if errBootstrap := objectStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil { cancel() - log.Fatalf("failed to bootstrap object-backed config: %v", errBootstrap) + log.Errorf("failed to bootstrap object-backed config: %v", errBootstrap) + return } cancel() configFilePath = objectStoreInst.ConfigPath() @@ -350,7 +358,8 @@ func main() { gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword) gitStoreInst.SetBaseDir(authDir) if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil { - log.Fatalf("failed to prepare git token store: %v", errRepo) + log.Errorf("failed to prepare git token store: %v", errRepo) + return } configFilePath = gitStoreInst.ConfigPath() if configFilePath == "" { @@ -359,17 +368,21 @@ func main() { if _, statErr := os.Stat(configFilePath); errors.Is(statErr, fs.ErrNotExist) { examplePath := filepath.Join(wd, "config.example.yaml") if _, errExample := os.Stat(examplePath); errExample != nil { - log.Fatalf("failed to find template config file: %v", errExample) + log.Errorf("failed to find template config file: %v", errExample) + return } if errCopy := misc.CopyConfigTemplate(examplePath, configFilePath); errCopy != nil { - log.Fatalf("failed to bootstrap git-backed config: %v", errCopy) + log.Errorf("failed to bootstrap git-backed config: %v", errCopy) + return } if errCommit := gitStoreInst.PersistConfig(context.Background()); errCommit != nil { - log.Fatalf("failed to commit initial git-backed config: %v", errCommit) + log.Errorf("failed to commit initial git-backed config: %v", errCommit) + return } log.Infof("git-backed config initialized from template: %s", configFilePath) } else if statErr != nil { - log.Fatalf("failed to inspect git-backed config: %v", statErr) + log.Errorf("failed to inspect git-backed config: %v", statErr) + return } cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy) if err == nil { @@ -382,13 +395,15 @@ func main() { } else { wd, err = os.Getwd() if err != nil { - log.Fatalf("failed to get working directory: %v", err) + log.Errorf("failed to get working directory: %v", err) + return } configFilePath = filepath.Join(wd, "config.yaml") cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy) } if err != nil { - log.Fatalf("failed to load config: %v", err) + log.Errorf("failed to load config: %v", err) + return } if cfg == nil { cfg = &config.Config{} @@ -418,7 +433,8 @@ func main() { coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling) if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil { - log.Fatalf("failed to configure log output: %v", err) + log.Errorf("failed to configure log output: %v", err) + return } log.Infof("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate) @@ -427,7 +443,8 @@ func main() { util.SetLogLevel(cfg) if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil { - log.Fatalf("failed to resolve auth directory: %v", errResolveAuthDir) + log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir) + return } else { cfg.AuthDir = resolvedAuthDir } diff --git a/config.example.yaml b/config.example.yaml index 8fb01c14..f46158c3 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,3 +1,7 @@ +# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6). +# Use "127.0.0.1" or "localhost" to restrict access to local machine only. +host: "" + # Server port port: 8317 @@ -149,6 +153,8 @@ ws-auth: false # upstream-api-key: "" # # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended) # restrict-management-to-localhost: true +# # Force model mappings to run before checking local API keys (default: false) +# force-model-mappings: false # # Amp Model Mappings # # Route unavailable Amp models to alternative models available in your local proxy. # # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 161f7a93..46d880a7 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -736,14 +736,16 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { // Generate PKCE codes pkceCodes, err := claude.GeneratePKCECodes() if err != nil { - log.Fatalf("Failed to generate PKCE codes: %v", err) + log.Errorf("Failed to generate PKCE codes: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"}) return } // Generate random state parameter state, err := misc.GenerateRandomState() if err != nil { - log.Fatalf("Failed to generate state parameter: %v", err) + log.Errorf("Failed to generate state parameter: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"}) return } @@ -753,7 +755,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { // Generate authorization URL (then override redirect_uri to reuse server port) authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes) if err != nil { - log.Fatalf("Failed to generate authorization URL: %v", err) + log.Errorf("Failed to generate authorization URL: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"}) return } @@ -895,7 +898,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - log.Fatalf("Failed to save authentication tokens: %v", errSave) + log.Errorf("Failed to save authentication tokens: %v", errSave) setOAuthStatus(state, "Failed to save authentication tokens") return } @@ -1068,7 +1071,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { gemAuth := geminiAuth.NewGeminiAuth() gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) if errGetClient != nil { - log.Fatalf("failed to get authenticated client: %v", errGetClient) + log.Errorf("failed to get authenticated client: %v", errGetClient) setOAuthStatus(state, "Failed to get authenticated client") return } @@ -1133,7 +1136,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - log.Fatalf("Failed to save token to file: %v", errSave) + log.Errorf("Failed to save token to file: %v", errSave) setOAuthStatus(state, "Failed to save token to file") return } @@ -1154,14 +1157,16 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { // Generate PKCE codes pkceCodes, err := codex.GeneratePKCECodes() if err != nil { - log.Fatalf("Failed to generate PKCE codes: %v", err) + log.Errorf("Failed to generate PKCE codes: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"}) return } // Generate random state parameter state, err := misc.GenerateRandomState() if err != nil { - log.Fatalf("Failed to generate state parameter: %v", err) + log.Errorf("Failed to generate state parameter: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"}) return } @@ -1171,7 +1176,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { // Generate authorization URL authURL, err := openaiAuth.GenerateAuthURL(state, pkceCodes) if err != nil { - log.Fatalf("Failed to generate authorization URL: %v", err) + log.Errorf("Failed to generate authorization URL: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"}) return } @@ -1305,8 +1311,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { + log.Errorf("Failed to save authentication tokens: %v", errSave) setOAuthStatus(state, "Failed to save authentication tokens") - log.Fatalf("Failed to save authentication tokens: %v", errSave) return } fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) @@ -1341,7 +1347,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { state, errState := misc.GenerateRandomState() if errState != nil { - log.Fatalf("Failed to generate state parameter: %v", errState) + log.Errorf("Failed to generate state parameter: %v", errState) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"}) return } @@ -1537,7 +1544,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - log.Fatalf("Failed to save token to file: %v", errSave) + log.Errorf("Failed to save token to file: %v", errSave) setOAuthStatus(state, "Failed to save token to file") return } @@ -1566,7 +1573,8 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { // Generate authorization URL deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx) if err != nil { - log.Fatalf("Failed to generate authorization URL: %v", err) + log.Errorf("Failed to generate authorization URL: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"}) return } authURL := deviceFlow.VerificationURIComplete @@ -1593,7 +1601,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - log.Fatalf("Failed to save authentication tokens: %v", errSave) + log.Errorf("Failed to save authentication tokens: %v", errSave) setOAuthStatus(state, "Failed to save authentication tokens") return } @@ -1696,8 +1704,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { + log.Errorf("Failed to save authentication tokens: %v", errSave) setOAuthStatus(state, "Failed to save authentication tokens") - log.Fatalf("Failed to save authentication tokens: %v", errSave) return } @@ -2126,6 +2134,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec continue } } + _ = resp.Body.Close() return false, fmt.Errorf("project activation required: %s", errMessage) } return true, nil diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index 8f4c4037..a0d0b169 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -706,3 +706,155 @@ func normalizeClaudeKey(entry *config.ClaudeKey) { } entry.Models = normalized } + +// GetAmpCode returns the complete ampcode configuration. +func (h *Handler) GetAmpCode(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"ampcode": config.AmpCode{}}) + return + } + c.JSON(200, gin.H{"ampcode": h.cfg.AmpCode}) +} + +// GetAmpUpstreamURL returns the ampcode upstream URL. +func (h *Handler) GetAmpUpstreamURL(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"upstream-url": ""}) + return + } + c.JSON(200, gin.H{"upstream-url": h.cfg.AmpCode.UpstreamURL}) +} + +// PutAmpUpstreamURL updates the ampcode upstream URL. +func (h *Handler) PutAmpUpstreamURL(c *gin.Context) { + h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) }) +} + +// DeleteAmpUpstreamURL clears the ampcode upstream URL. +func (h *Handler) DeleteAmpUpstreamURL(c *gin.Context) { + h.cfg.AmpCode.UpstreamURL = "" + h.persist(c) +} + +// GetAmpUpstreamAPIKey returns the ampcode upstream API key. +func (h *Handler) GetAmpUpstreamAPIKey(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"upstream-api-key": ""}) + return + } + c.JSON(200, gin.H{"upstream-api-key": h.cfg.AmpCode.UpstreamAPIKey}) +} + +// PutAmpUpstreamAPIKey updates the ampcode upstream API key. +func (h *Handler) PutAmpUpstreamAPIKey(c *gin.Context) { + h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) }) +} + +// DeleteAmpUpstreamAPIKey clears the ampcode upstream API key. +func (h *Handler) DeleteAmpUpstreamAPIKey(c *gin.Context) { + h.cfg.AmpCode.UpstreamAPIKey = "" + h.persist(c) +} + +// GetAmpRestrictManagementToLocalhost returns the localhost restriction setting. +func (h *Handler) GetAmpRestrictManagementToLocalhost(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"restrict-management-to-localhost": true}) + return + } + c.JSON(200, gin.H{"restrict-management-to-localhost": h.cfg.AmpCode.RestrictManagementToLocalhost}) +} + +// PutAmpRestrictManagementToLocalhost updates the localhost restriction setting. +func (h *Handler) PutAmpRestrictManagementToLocalhost(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.RestrictManagementToLocalhost = v }) +} + +// GetAmpModelMappings returns the ampcode model mappings. +func (h *Handler) GetAmpModelMappings(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"model-mappings": []config.AmpModelMapping{}}) + return + } + c.JSON(200, gin.H{"model-mappings": h.cfg.AmpCode.ModelMappings}) +} + +// PutAmpModelMappings replaces all ampcode model mappings. +func (h *Handler) PutAmpModelMappings(c *gin.Context) { + var body struct { + Value []config.AmpModelMapping `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + h.cfg.AmpCode.ModelMappings = body.Value + h.persist(c) +} + +// PatchAmpModelMappings adds or updates model mappings. +func (h *Handler) PatchAmpModelMappings(c *gin.Context) { + var body struct { + Value []config.AmpModelMapping `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + + existing := make(map[string]int) + for i, m := range h.cfg.AmpCode.ModelMappings { + existing[strings.TrimSpace(m.From)] = i + } + + for _, newMapping := range body.Value { + from := strings.TrimSpace(newMapping.From) + if idx, ok := existing[from]; ok { + h.cfg.AmpCode.ModelMappings[idx] = newMapping + } else { + h.cfg.AmpCode.ModelMappings = append(h.cfg.AmpCode.ModelMappings, newMapping) + existing[from] = len(h.cfg.AmpCode.ModelMappings) - 1 + } + } + h.persist(c) +} + +// DeleteAmpModelMappings removes specified model mappings by "from" field. +func (h *Handler) DeleteAmpModelMappings(c *gin.Context) { + var body struct { + Value []string `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil || len(body.Value) == 0 { + h.cfg.AmpCode.ModelMappings = nil + h.persist(c) + return + } + + toRemove := make(map[string]bool) + for _, from := range body.Value { + toRemove[strings.TrimSpace(from)] = true + } + + newMappings := make([]config.AmpModelMapping, 0, len(h.cfg.AmpCode.ModelMappings)) + for _, m := range h.cfg.AmpCode.ModelMappings { + if !toRemove[strings.TrimSpace(m.From)] { + newMappings = append(newMappings, m) + } + } + h.cfg.AmpCode.ModelMappings = newMappings + h.persist(c) +} + +// GetAmpForceModelMappings returns whether model mappings are forced. +func (h *Handler) GetAmpForceModelMappings(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"force-model-mappings": false}) + return + } + c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings}) +} + +// PutAmpForceModelMappings updates the force model mappings setting. +func (h *Handler) PutAmpForceModelMappings(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) +} diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index ef6f400a..39e6b7fd 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -240,16 +240,6 @@ func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) { Value *bool `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { - var m map[string]any - if err2 := c.ShouldBindJSON(&m); err2 == nil { - for _, v := range m { - if b, ok := v.(bool); ok { - set(b) - h.persist(c) - return - } - } - } c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) return } diff --git a/internal/api/middleware/response_writer.go b/internal/api/middleware/response_writer.go index f0d1ad26..b7259bc6 100644 --- a/internal/api/middleware/response_writer.go +++ b/internal/api/middleware/response_writer.go @@ -232,7 +232,16 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { w.streamDone = nil } + // Write API Request and Response to the streaming log before closing if w.streamWriter != nil { + apiRequest := w.extractAPIRequest(c) + if len(apiRequest) > 0 { + _ = w.streamWriter.WriteAPIRequest(apiRequest) + } + apiResponse := w.extractAPIResponse(c) + if len(apiResponse) > 0 { + _ = w.streamWriter.WriteAPIResponse(apiResponse) + } if err := w.streamWriter.Close(); err != nil { w.streamWriter = nil return err diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index dabb7404..88319a78 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -100,6 +100,16 @@ func (m *AmpModule) Name() string { return "amp-routing" } +// forceModelMappings returns whether model mappings should take precedence over local API keys +func (m *AmpModule) forceModelMappings() bool { + m.configMu.RLock() + defer m.configMu.RUnlock() + if m.lastConfig == nil { + return false + } + return m.lastConfig.ForceModelMappings +} + // Register sets up Amp routes if configured. // This implements the RouteModuleV2 interface with Context. // Routes are registered only once via sync.Once for idempotent behavior. diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index 0cbe0e1a..3ec6c85e 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -77,23 +77,29 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid // FallbackHandler wraps a standard handler with fallback logic to ampcode.com // when the model's provider is not available in CLIProxyAPI type FallbackHandler struct { - getProxy func() *httputil.ReverseProxy - modelMapper ModelMapper + getProxy func() *httputil.ReverseProxy + modelMapper ModelMapper + forceModelMappings func() bool } // NewFallbackHandler creates a new fallback handler wrapper // The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes) func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler { return &FallbackHandler{ - getProxy: getProxy, + getProxy: getProxy, + forceModelMappings: func() bool { return false }, } } // NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support -func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler { +func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler { + if forceModelMappings == nil { + forceModelMappings = func() bool { return false } + } return &FallbackHandler{ - getProxy: getProxy, - modelMapper: mapper, + getProxy: getProxy, + modelMapper: mapper, + forceModelMappings: forceModelMappings, } } @@ -130,34 +136,65 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // Normalize model (handles Gemini thinking suffixes) normalizedModel, _ := util.NormalizeGeminiThinkingModel(modelName) - // Check if we have providers for this model - providers := util.GetProviderName(normalizedModel) - // Track resolved model for logging (may change if mapping is applied) resolvedModel := normalizedModel usedMapping := false + var providers []string - if len(providers) == 0 { - // No providers configured - check if we have a model mapping + // Check if model mappings should be forced ahead of local API keys + forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings() + + if forceMappings { + // FORCE MODE: Check model mappings FIRST (takes precedence over local API keys) + // This allows users to route Amp requests to their preferred OAuth providers if fh.modelMapper != nil { if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { - // Mapping found - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model in context for handlers that check it (like gemini bridge) - c.Set(MappedModelContextKey, mappedModel) - resolvedModel = mappedModel - usedMapping = true - - // Get providers for the mapped model - providers = util.GetProviderName(mappedModel) - - // Continue to handler with remapped model - goto handleRequest + // Mapping found - check if we have a provider for the mapped model + mappedProviders := util.GetProviderName(mappedModel) + if len(mappedProviders) > 0 { + // Mapping found and provider available - rewrite the model in request body + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) + resolvedModel = mappedModel + usedMapping = true + providers = mappedProviders + } } } - // No mapping found - check if we have a proxy for fallback + // If no mapping applied, check for local providers + if !usedMapping { + providers = util.GetProviderName(normalizedModel) + } + } else { + // DEFAULT MODE: Check local providers first, then mappings as fallback + providers = util.GetProviderName(normalizedModel) + + if len(providers) == 0 { + // No providers configured - check if we have a model mapping + if fh.modelMapper != nil { + if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { + // Mapping found - check if we have a provider for the mapped model + mappedProviders := util.GetProviderName(mappedModel) + if len(mappedProviders) > 0 { + // Mapping found and provider available - rewrite the model in request body + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) + resolvedModel = mappedModel + usedMapping = true + providers = mappedProviders + } + } + } + } + } + + // If no providers available, fallback to ampcode.com + if len(providers) == 0 { proxy := fh.getProxy() if proxy != nil { // Log: Forwarding to ampcode.com (uses Amp credits) @@ -175,8 +212,6 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath) } - handleRequest: - // Log the routing decision providerName := "" if len(providers) > 0 { diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 6826dbbe..48fbbbb9 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -156,6 +156,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()} engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...) engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...) + engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...) // Root-level auth routes for CLI login flow // Amp uses multiple auth routes: /auth/cli-login, /auth/callback, /auth/sign-in, /auth/logout @@ -171,7 +172,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler) geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return m.getProxy() - }, m.modelMapper) + }, m.modelMapper, m.forceModelMappings) geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) // Route POST model calls through Gemini bridge with FallbackHandler. @@ -209,7 +210,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han // Also includes model mapping support for routing unavailable models to alternatives fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return m.getProxy() - }, m.modelMapper) + }, m.modelMapper, m.forceModelMappings) // Provider-specific routes under /api/provider/:provider ampProviders := engine.Group("/api/provider") diff --git a/internal/api/server.go b/internal/api/server.go index 72cb0313..e1cea9e9 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -300,7 +300,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk // Create HTTP server s.server = &http.Server{ - Addr: fmt.Sprintf(":%d", cfg.Port), + Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), Handler: engine, } @@ -520,6 +520,26 @@ func (s *Server) registerManagementRoutes() { mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth) mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth) + mgmt.GET("/ampcode", s.mgmt.GetAmpCode) + mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL) + mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) + mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) + mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL) + mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey) + mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) + mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) + mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey) + mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost) + mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) + mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) + mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings) + mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings) + mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings) + mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings) + mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings) + mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) + mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) + mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry) diff --git a/internal/auth/gemini/gemini_auth.go b/internal/auth/gemini/gemini_auth.go index a6ac4507..f173c95f 100644 --- a/internal/auth/gemini/gemini_auth.go +++ b/internal/auth/gemini/gemini_auth.go @@ -76,7 +76,8 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken auth := &proxy.Auth{User: username, Password: password} dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct) if errSOCKS5 != nil { - log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5) + log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) + return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5) } transport = &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -238,7 +239,11 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, // Start the server in a goroutine. go func() { if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("ListenAndServe(): %v", err) + log.Errorf("ListenAndServe(): %v", err) + select { + case errChan <- err: + default: + } } }() diff --git a/internal/auth/iflow/iflow_auth.go b/internal/auth/iflow/iflow_auth.go index b3431f84..1d69a821 100644 --- a/internal/auth/iflow/iflow_auth.go +++ b/internal/auth/iflow/iflow_auth.go @@ -321,17 +321,23 @@ func (ia *IFlowAuth) AuthenticateWithCookie(ctx context.Context, cookie string) return nil, fmt.Errorf("iflow cookie authentication: cookie is empty") } - // First, get initial API key information using GET request + // First, get initial API key information using GET request to obtain the name keyInfo, err := ia.fetchAPIKeyInfo(ctx, cookie) if err != nil { return nil, fmt.Errorf("iflow cookie authentication: fetch initial API key info failed: %w", err) } - // Convert to token data format + // Refresh the API key using POST request + refreshedKeyInfo, err := ia.RefreshAPIKey(ctx, cookie, keyInfo.Name) + if err != nil { + return nil, fmt.Errorf("iflow cookie authentication: refresh API key failed: %w", err) + } + + // Convert to token data format using refreshed key data := &IFlowTokenData{ - APIKey: keyInfo.APIKey, - Expire: keyInfo.ExpireTime, - Email: keyInfo.Name, + APIKey: refreshedKeyInfo.APIKey, + Expire: refreshedKeyInfo.ExpireTime, + Email: refreshedKeyInfo.Name, Cookie: cookie, } diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 5e5159aa..de01cec5 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -65,20 +65,20 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { authenticator := sdkAuth.NewGeminiAuthenticator() record, errLogin := authenticator.Login(ctx, cfg, loginOpts) if errLogin != nil { - log.Fatalf("Gemini authentication failed: %v", errLogin) + log.Errorf("Gemini authentication failed: %v", errLogin) return } storage, okStorage := record.Storage.(*gemini.GeminiTokenStorage) if !okStorage || storage == nil { - log.Fatal("Gemini authentication failed: unsupported token storage") + log.Error("Gemini authentication failed: unsupported token storage") return } geminiAuth := gemini.NewGeminiAuth() httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser) if errClient != nil { - log.Fatalf("Gemini authentication failed: %v", errClient) + log.Errorf("Gemini authentication failed: %v", errClient) return } @@ -86,7 +86,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { projects, errProjects := fetchGCPProjects(ctx, httpClient) if errProjects != nil { - log.Fatalf("Failed to get project list: %v", errProjects) + log.Errorf("Failed to get project list: %v", errProjects) return } @@ -98,11 +98,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn) projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects) if errSelection != nil { - log.Fatalf("Invalid project selection: %v", errSelection) + log.Errorf("Invalid project selection: %v", errSelection) return } if len(projectSelections) == 0 { - log.Fatal("No project selected; aborting login.") + log.Error("No project selected; aborting login.") return } @@ -116,7 +116,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { showProjectSelectionHelp(storage.Email, projects) return } - log.Fatalf("Failed to complete user setup: %v", errSetup) + log.Errorf("Failed to complete user setup: %v", errSetup) return } finalID := strings.TrimSpace(storage.ProjectID) @@ -133,11 +133,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { for _, pid := range activatedProjects { isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, pid) if errCheck != nil { - log.Fatalf("Failed to check if Cloud AI API is enabled for %s: %v", pid, errCheck) + log.Errorf("Failed to check if Cloud AI API is enabled for %s: %v", pid, errCheck) return } if !isChecked { - log.Fatalf("Failed to check if Cloud AI API is enabled for project %s. If you encounter an error message, please create an issue.", pid) + log.Errorf("Failed to check if Cloud AI API is enabled for project %s. If you encounter an error message, please create an issue.", pid) return } } @@ -153,7 +153,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { savedPath, errSave := store.Save(ctx, record) if errSave != nil { - log.Fatalf("Failed to save token to file: %v", errSave) + log.Errorf("Failed to save token to file: %v", errSave) return } @@ -555,6 +555,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec continue } } + _ = resp.Body.Close() return false, fmt.Errorf("project activation required: %s", errMessage) } return true, nil diff --git a/internal/cmd/run.go b/internal/cmd/run.go index e2f6ee80..1e968126 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -45,12 +45,13 @@ func StartService(cfg *config.Config, configPath string, localPassword string) { service, err := builder.Build() if err != nil { - log.Fatalf("failed to build proxy service: %v", err) + log.Errorf("failed to build proxy service: %v", err) + return } err = service.Run(runCtx) if err != nil && !errors.Is(err, context.Canceled) { - log.Fatalf("proxy service exited with error: %v", err) + log.Errorf("proxy service exited with error: %v", err) } } diff --git a/internal/cmd/vertex_import.go b/internal/cmd/vertex_import.go index ebb32d0c..32d782d8 100644 --- a/internal/cmd/vertex_import.go +++ b/internal/cmd/vertex_import.go @@ -29,30 +29,30 @@ func DoVertexImport(cfg *config.Config, keyPath string) { } rawPath := strings.TrimSpace(keyPath) if rawPath == "" { - log.Fatalf("vertex-import: missing service account key path") + log.Errorf("vertex-import: missing service account key path") return } data, errRead := os.ReadFile(rawPath) if errRead != nil { - log.Fatalf("vertex-import: read file failed: %v", errRead) + log.Errorf("vertex-import: read file failed: %v", errRead) return } var sa map[string]any if errUnmarshal := json.Unmarshal(data, &sa); errUnmarshal != nil { - log.Fatalf("vertex-import: invalid service account json: %v", errUnmarshal) + log.Errorf("vertex-import: invalid service account json: %v", errUnmarshal) return } // Validate and normalize private_key before saving normalizedSA, errFix := vertex.NormalizeServiceAccountMap(sa) if errFix != nil { - log.Fatalf("vertex-import: %v", errFix) + log.Errorf("vertex-import: %v", errFix) return } sa = normalizedSA email, _ := sa["client_email"].(string) projectID, _ := sa["project_id"].(string) if strings.TrimSpace(projectID) == "" { - log.Fatalf("vertex-import: project_id missing in service account json") + log.Errorf("vertex-import: project_id missing in service account json") return } if strings.TrimSpace(email) == "" { @@ -92,7 +92,7 @@ func DoVertexImport(cfg *config.Config, keyPath string) { } path, errSave := store.Save(context.Background(), record) if errSave != nil { - log.Fatalf("vertex-import: save credential failed: %v", errSave) + log.Errorf("vertex-import: save credential failed: %v", errSave) return } fmt.Printf("Vertex credentials imported: %s\n", path) diff --git a/internal/config/config.go b/internal/config/config.go index 1c72ece4..f9da2c29 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,6 +20,9 @@ import ( // Config represents the application's configuration, loaded from a YAML file. type Config struct { config.SDKConfig `yaml:",inline"` + // Host is the network host/interface on which the API server will bind. + // Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access. + Host string `yaml:"host" json:"-"` // Port is the network port on which the API server will listen. Port int `yaml:"port" json:"-"` @@ -151,6 +154,10 @@ type AmpCode struct { // When Amp requests a model that isn't available locally, these mappings // allow routing to an alternative model that IS available. ModelMappings []AmpModelMapping `yaml:"model-mappings" json:"model-mappings"` + + // ForceModelMappings when true, model mappings take precedence over local API keys. + // When false (default), local API keys are used first if available. + ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"` } // PayloadConfig defines default and override parameter rules applied to provider payloads. @@ -349,6 +356,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Unmarshal the YAML data into the Config struct. var cfg Config // Set defaults before unmarshal so that absent keys keep defaults. + cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6) cfg.LoggingToFile = false cfg.UsageStatisticsEnabled = false cfg.DisableCooling = false diff --git a/internal/logging/request_logger.go b/internal/logging/request_logger.go index c574febb..f8c068c5 100644 --- a/internal/logging/request_logger.go +++ b/internal/logging/request_logger.go @@ -84,6 +84,26 @@ type StreamingLogWriter interface { // - error: An error if writing fails, nil otherwise WriteStatus(status int, headers map[string][]string) error + // WriteAPIRequest writes the upstream API request details to the log. + // This should be called before WriteStatus to maintain proper log ordering. + // + // Parameters: + // - apiRequest: The API request data (typically includes URL, headers, body sent upstream) + // + // Returns: + // - error: An error if writing fails, nil otherwise + WriteAPIRequest(apiRequest []byte) error + + // WriteAPIResponse writes the upstream API response details to the log. + // This should be called after the streaming response is complete. + // + // Parameters: + // - apiResponse: The API response data + // + // Returns: + // - error: An error if writing fails, nil otherwise + WriteAPIResponse(apiResponse []byte) error + // Close finalizes the log file and cleans up resources. // // Returns: @@ -248,10 +268,11 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[ // Create streaming writer writer := &FileStreamingLogWriter{ - file: file, - chunkChan: make(chan []byte, 100), // Buffered channel for async writes - closeChan: make(chan struct{}), - errorChan: make(chan error, 1), + file: file, + chunkChan: make(chan []byte, 100), // Buffered channel for async writes + closeChan: make(chan struct{}), + errorChan: make(chan error, 1), + bufferedChunks: &bytes.Buffer{}, } // Start async writer goroutine @@ -628,11 +649,12 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st // FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs. // It handles asynchronous writing of streaming response chunks to a file. +// All data is buffered and written in the correct order when Close is called. type FileStreamingLogWriter struct { // file is the file where log data is written. file *os.File - // chunkChan is a channel for receiving response chunks to write. + // chunkChan is a channel for receiving response chunks to buffer. chunkChan chan []byte // closeChan is a channel for signaling when the writer is closed. @@ -641,8 +663,23 @@ type FileStreamingLogWriter struct { // errorChan is a channel for reporting errors during writing. errorChan chan error - // statusWritten indicates whether the response status has been written. + // bufferedChunks stores the response chunks in order. + bufferedChunks *bytes.Buffer + + // responseStatus stores the HTTP status code. + responseStatus int + + // statusWritten indicates whether a non-zero status was recorded. statusWritten bool + + // responseHeaders stores the response headers. + responseHeaders map[string][]string + + // apiRequest stores the upstream API request data. + apiRequest []byte + + // apiResponse stores the upstream API response data. + apiResponse []byte } // WriteChunkAsync writes a response chunk asynchronously (non-blocking). @@ -666,39 +703,65 @@ func (w *FileStreamingLogWriter) WriteChunkAsync(chunk []byte) { } } -// WriteStatus writes the response status and headers to the log. +// WriteStatus buffers the response status and headers for later writing. // // Parameters: // - status: The response status code // - headers: The response headers // // Returns: -// - error: An error if writing fails, nil otherwise +// - error: Always returns nil (buffering cannot fail) func (w *FileStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error { - if w.file == nil || w.statusWritten { + if status == 0 { return nil } - var content strings.Builder - content.WriteString("========================================\n") - content.WriteString("=== RESPONSE ===\n") - content.WriteString(fmt.Sprintf("Status: %d\n", status)) - - for key, values := range headers { - for _, value := range values { - content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + w.responseStatus = status + if headers != nil { + w.responseHeaders = make(map[string][]string, len(headers)) + for key, values := range headers { + headerValues := make([]string, len(values)) + copy(headerValues, values) + w.responseHeaders[key] = headerValues } } - content.WriteString("\n") + w.statusWritten = true + return nil +} - _, err := w.file.WriteString(content.String()) - if err == nil { - w.statusWritten = true +// WriteAPIRequest buffers the upstream API request details for later writing. +// +// Parameters: +// - apiRequest: The API request data (typically includes URL, headers, body sent upstream) +// +// Returns: +// - error: Always returns nil (buffering cannot fail) +func (w *FileStreamingLogWriter) WriteAPIRequest(apiRequest []byte) error { + if len(apiRequest) == 0 { + return nil } - return err + w.apiRequest = bytes.Clone(apiRequest) + return nil +} + +// WriteAPIResponse buffers the upstream API response details for later writing. +// +// Parameters: +// - apiResponse: The API response data +// +// Returns: +// - error: Always returns nil (buffering cannot fail) +func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error { + if len(apiResponse) == 0 { + return nil + } + w.apiResponse = bytes.Clone(apiResponse) + return nil } // Close finalizes the log file and cleans up resources. +// It writes all buffered data to the file in the correct order: +// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks) // // Returns: // - error: An error if closing fails, nil otherwise @@ -707,27 +770,84 @@ func (w *FileStreamingLogWriter) Close() error { close(w.chunkChan) } - // Wait for async writer to finish + // Wait for async writer to finish buffering chunks if w.closeChan != nil { <-w.closeChan w.chunkChan = nil } - if w.file != nil { - return w.file.Close() + if w.file == nil { + return nil } - return nil + // Write all content in the correct order + var content strings.Builder + + // 1. Write API REQUEST section + if len(w.apiRequest) > 0 { + if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) { + content.Write(w.apiRequest) + if !bytes.HasSuffix(w.apiRequest, []byte("\n")) { + content.WriteString("\n") + } + } else { + content.WriteString("=== API REQUEST ===\n") + content.Write(w.apiRequest) + content.WriteString("\n") + } + content.WriteString("\n") + } + + // 2. Write API RESPONSE section + if len(w.apiResponse) > 0 { + if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) { + content.Write(w.apiResponse) + if !bytes.HasSuffix(w.apiResponse, []byte("\n")) { + content.WriteString("\n") + } + } else { + content.WriteString("=== API RESPONSE ===\n") + content.Write(w.apiResponse) + content.WriteString("\n") + } + content.WriteString("\n") + } + + // 3. Write RESPONSE section (status, headers, buffered chunks) + content.WriteString("=== RESPONSE ===\n") + if w.statusWritten { + content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus)) + } + + for key, values := range w.responseHeaders { + for _, value := range values { + content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + } + } + content.WriteString("\n") + + // Write buffered response body chunks + if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 { + content.Write(w.bufferedChunks.Bytes()) + } + + // Write the complete content to file + if _, err := w.file.WriteString(content.String()); err != nil { + _ = w.file.Close() + return err + } + + return w.file.Close() } -// asyncWriter runs in a goroutine to handle async chunk writing. -// It continuously reads chunks from the channel and writes them to the file. +// asyncWriter runs in a goroutine to buffer chunks from the channel. +// It continuously reads chunks from the channel and buffers them for later writing. func (w *FileStreamingLogWriter) asyncWriter() { defer close(w.closeChan) for chunk := range w.chunkChan { - if w.file != nil { - _, _ = w.file.Write(chunk) + if w.bufferedChunks != nil { + w.bufferedChunks.Write(chunk) } } } @@ -754,6 +874,28 @@ func (w *NoOpStreamingLogWriter) WriteStatus(_ int, _ map[string][]string) error return nil } +// WriteAPIRequest is a no-op implementation that does nothing and always returns nil. +// +// Parameters: +// - apiRequest: The API request data (ignored) +// +// Returns: +// - error: Always returns nil +func (w *NoOpStreamingLogWriter) WriteAPIRequest(_ []byte) error { + return nil +} + +// WriteAPIResponse is a no-op implementation that does nothing and always returns nil. +// +// Parameters: +// - apiResponse: The API response data (ignored) +// +// Returns: +// - error: Always returns nil +func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error { + return nil +} + // Close is a no-op implementation that does nothing and always returns nil. // // Returns: diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index aa09f688..31f08f98 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -693,8 +693,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Low", - Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 Nothink", + Description: "Stable version of GPT 5.1, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -719,8 +719,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Medium", - Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 Medium", + Description: "Stable version of GPT 5.1, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -732,8 +732,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 High", - Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 High", + Description: "Stable version of GPT 5.1, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -745,8 +745,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex", - Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 Codex", + Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -758,8 +758,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex Low", - Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 Codex Low", + Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -771,8 +771,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex Medium", - Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 Codex Medium", + Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -784,8 +784,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex High", - Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", + DisplayName: "GPT 5.1 Codex High", + Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -797,8 +797,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex Mini", - Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", + DisplayName: "GPT 5.1 Codex Mini", + Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -810,8 +810,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex Mini Medium", - Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", + DisplayName: "GPT 5.1 Codex Mini Medium", + Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -823,8 +823,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5 Codex Mini High", - Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", + DisplayName: "GPT 5.1 Codex Mini High", + Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -837,8 +837,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-max", - DisplayName: "GPT 5 Codex Max", - Description: "Stable version of GPT 5 Codex Max", + DisplayName: "GPT 5.1 Codex Max", + Description: "Stable version of GPT 5.1 Codex Max", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -850,8 +850,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-max", - DisplayName: "GPT 5 Codex Max Low", - Description: "Stable version of GPT 5 Codex Max Low", + DisplayName: "GPT 5.1 Codex Max Low", + Description: "Stable version of GPT 5.1 Codex Max Low", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -863,8 +863,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-max", - DisplayName: "GPT 5 Codex Max Medium", - Description: "Stable version of GPT 5 Codex Max Medium", + DisplayName: "GPT 5.1 Codex Max Medium", + Description: "Stable version of GPT 5.1 Codex Max Medium", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -876,8 +876,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-max", - DisplayName: "GPT 5 Codex Max High", - Description: "Stable version of GPT 5 Codex Max High", + DisplayName: "GPT 5.1 Codex Max High", + Description: "Stable version of GPT 5.1 Codex Max High", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -889,8 +889,8 @@ func GetOpenAIModels() []*ModelInfo { OwnedBy: "openai", Type: "openai", Version: "gpt-5.1-max", - DisplayName: "GPT 5 Codex Max XHigh", - Description: "Stable version of GPT 5 Codex Max XHigh", + DisplayName: "GPT 5.1 Codex Max XHigh", + Description: "Stable version of GPT 5.1 Codex Max XHigh", ContextLength: 400000, MaxCompletionTokens: 128000, SupportedParameters: []string{"tools"}, @@ -944,7 +944,6 @@ func GetQwenModels() []*ModelInfo { } // GetIFlowModels returns supported models for iFlow OAuth accounts. - func GetIFlowModels() []*ModelInfo { entries := []struct { ID string @@ -987,6 +986,28 @@ func GetIFlowModels() []*ModelInfo { return models } +// AntigravityModelConfig captures static antigravity model overrides, including +// Thinking budget limits and provider max completion tokens. +type AntigravityModelConfig struct { + Thinking *ThinkingSupport + MaxCompletionTokens int + Name string +} + +// GetAntigravityModelConfig returns static configuration for antigravity models. +// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup. +func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { + return map[string]*AntigravityModelConfig{ + "gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"}, + "gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"}, + "gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"}, + "gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-preview"}, + "gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-image-preview"}, + "gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, + "gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, + } +} + // GetGitHubCopilotModels returns the available models for GitHub Copilot. // These models are available through the GitHub Copilot API at api.githubcopilot.com. func GetGitHubCopilotModels() []*ModelInfo { @@ -1174,17 +1195,6 @@ func GetGitHubCopilotModels() []*ModelInfo { // GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions func GetKiroModels() []*ModelInfo { return []*ModelInfo{ - { - ID: "kiro-auto", - Object: "model", - Created: 1732752000, // 2024-11-28 - OwnedBy: "aws", - Type: "kiro", - DisplayName: "Kiro Auto", - Description: "Automatic model selection by AWS CodeWhisperer", - ContextLength: 200000, - MaxCompletionTokens: 64000, - }, { ID: "kiro-claude-opus-4.5", Object: "model", diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 61a06721..d37cd2c2 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -309,7 +309,9 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c to := sdktranslator.FromString("gemini") payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) payload = applyThinkingMetadata(payload, req.Metadata, req.Model) + payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload) payload = util.ConvertThinkingLevelToBudget(payload) + payload = util.NormalizeGeminiThinkingBudget(req.Model, payload) payload = util.StripThinkingConfigIfUnsupported(req.Model, payload) payload = fixGeminiImageAspectRatio(req.Model, payload) payload = applyPayloadConfig(e.cfg, req.Model, payload) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index f2f0fdc5..a7289c64 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -28,18 +28,18 @@ import ( ) const ( - antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com" - antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com" - antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com" - antigravityStreamPath = "/v1internal:streamGenerateContent" - antigravityGeneratePath = "/v1internal:generateContent" - antigravityModelsPath = "/v1internal:fetchAvailableModels" - antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" - antigravityClientSecret = "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" - defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64" - antigravityAuthType = "antigravity" - refreshSkew = 3000 * time.Second - streamScannerBuffer int = 20_971_520 + antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com" + // antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com" + antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com" + antigravityStreamPath = "/v1internal:streamGenerateContent" + antigravityGeneratePath = "/v1internal:generateContent" + antigravityModelsPath = "/v1internal:fetchAvailableModels" + antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" + antigravityClientSecret = "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" + defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64" + antigravityAuthType = "antigravity" + refreshSkew = 3000 * time.Second + streamScannerBuffer int = 20_971_520 ) var ( @@ -81,6 +81,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) + translated = normalizeAntigravityThinking(req.Model, translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -174,6 +176,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) + translated = normalizeAntigravityThinking(req.Model, translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -370,28 +374,34 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c } now := time.Now().Unix() + modelConfig := registry.GetAntigravityModelConfig() models := make([]*registry.ModelInfo, 0, len(result.Map())) - for id := range result.Map() { - id = modelName2Alias(id) - if id != "" { + for originalName := range result.Map() { + aliasName := modelName2Alias(originalName) + if aliasName != "" { + cfg := modelConfig[aliasName] + modelName := aliasName + if cfg != nil && cfg.Name != "" { + modelName = cfg.Name + } modelInfo := ®istry.ModelInfo{ - ID: id, - Name: id, - Description: id, - DisplayName: id, - Version: id, + ID: aliasName, + Name: modelName, + Description: aliasName, + DisplayName: aliasName, + Version: aliasName, Object: "model", Created: now, OwnedBy: antigravityAuthType, Type: antigravityAuthType, } - // Add Thinking support for thinking models - if strings.HasSuffix(id, "-thinking") || strings.Contains(id, "-thinking-") { - modelInfo.Thinking = ®istry.ThinkingSupport{ - Min: 1024, - Max: 100000, - ZeroAllowed: false, - DynamicAllowed: true, + // Look up Thinking support from static config using alias name + if cfg != nil { + if cfg.Thinking != nil { + modelInfo.Thinking = cfg.Thinking + } + if cfg.MaxCompletionTokens > 0 { + modelInfo.MaxCompletionTokens = cfg.MaxCompletionTokens } } models = append(models, modelInfo) @@ -537,6 +547,9 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau strJSON = util.DeleteKey(strJSON, "minLength") strJSON = util.DeleteKey(strJSON, "maxLength") strJSON = util.DeleteKey(strJSON, "exclusiveMinimum") + strJSON = util.DeleteKey(strJSON, "exclusiveMaximum") + strJSON = util.DeleteKey(strJSON, "$ref") + strJSON = util.DeleteKey(strJSON, "$defs") paths = make([]string, 0) util.Walk(gjson.Parse(strJSON), "", "anyOf", &paths) @@ -652,7 +665,7 @@ func buildBaseURL(auth *cliproxyauth.Auth) string { if baseURLs := antigravityBaseURLFallbackOrder(auth); len(baseURLs) > 0 { return baseURLs[0] } - return antigravityBaseURLAutopush + return antigravityBaseURLDaily } func resolveHost(base string) string { @@ -688,7 +701,7 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string { } return []string{ antigravityBaseURLDaily, - antigravityBaseURLAutopush, + // antigravityBaseURLAutopush, antigravityBaseURLProd, } } @@ -816,3 +829,65 @@ func alias2ModelName(modelName string) string { return modelName } } + +// normalizeAntigravityThinking clamps or removes thinking config based on model support. +// For Claude models, it additionally ensures thinking budget < max_tokens. +func normalizeAntigravityThinking(model string, payload []byte) []byte { + payload = util.StripThinkingConfigIfUnsupported(model, payload) + if !util.ModelSupportsThinking(model) { + return payload + } + budget := gjson.GetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget") + if !budget.Exists() { + return payload + } + raw := int(budget.Int()) + normalized := util.NormalizeThinkingBudget(model, raw) + + isClaude := strings.Contains(strings.ToLower(model), "claude") + if isClaude { + effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload) + if effectiveMax > 0 && normalized >= effectiveMax { + normalized = effectiveMax - 1 + } + minBudget := antigravityMinThinkingBudget(model) + if minBudget > 0 && normalized >= 0 && normalized < minBudget { + // Budget is below minimum, remove thinking config entirely + payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.thinkingConfig") + return payload + } + if setDefaultMax { + if res, errSet := sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax); errSet == nil { + payload = res + } + } + } + + updated, err := sjson.SetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget", normalized) + if err != nil { + return payload + } + return updated +} + +// antigravityEffectiveMaxTokens returns the max tokens to cap thinking: +// prefer request-provided maxOutputTokens; otherwise fall back to model default. +// The boolean indicates whether the value came from the model default (and thus should be written back). +func antigravityEffectiveMaxTokens(model string, payload []byte) (max int, fromModel bool) { + if maxTok := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxTok.Exists() && maxTok.Int() > 0 { + return int(maxTok.Int()), false + } + if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.MaxCompletionTokens > 0 { + return modelInfo.MaxCompletionTokens, true + } + return 0, false +} + +// antigravityMinThinkingBudget returns the minimum thinking budget for a model. +// Falls back to -1 if no model info is found. +func antigravityMinThinkingBudget(model string) int { + if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.Thinking != nil { + return modelInfo.Thinking.Min + } + return -1 +} diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 147a1ea1..a2e0ecec 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -64,6 +64,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth to := sdktranslator.FromString("gemini-cli") basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model) + basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload) + basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload) basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload) basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload) basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload) @@ -199,6 +201,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut to := sdktranslator.FromString("gemini-cli") basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model) + basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload) + basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload) basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload) basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload) basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload) diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index fc7b8e19..8879a4f1 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -80,6 +80,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r to := sdktranslator.FromString("gemini") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body = applyThinkingMetadata(body, req.Metadata, req.Model) + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) + body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) @@ -169,6 +171,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A to := sdktranslator.FromString("gemini") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body = applyThinkingMetadata(body, req.Metadata, req.Model) + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) + body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index de4ba072..c7d10a67 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -296,6 +296,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) + body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) @@ -391,6 +393,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) + body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) @@ -487,6 +491,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) + body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) @@ -599,6 +605,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) + body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 0157d68c..b965c9ca 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -547,8 +547,6 @@ func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { // Agentic variants (-agentic suffix) map to the same backend model IDs. func (e *KiroExecutor) mapModelToKiro(model string) string { modelMap := map[string]string{ - // Proxy format (kiro- prefix) - "kiro-auto": "auto", "kiro-claude-opus-4.5": "claude-opus-4.5", "kiro-claude-sonnet-4.5": "claude-sonnet-4.5", "kiro-claude-sonnet-4": "claude-sonnet-4", diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index e1b73da0..a810ba7a 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -180,7 +180,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ if t.Get("type").String() == "enabled" { if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number { budget := int(b.Int()) - budget = util.NormalizeThinkingBudget(modelName, budget) out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.include_thoughts", true) } diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index 4073f20b..42265e80 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -36,6 +37,9 @@ type Params struct { HasToolUse bool // Indicates if tool use was observed in the stream } +// toolUseIDCounter provides a process-wide unique counter for tool use identifiers. +var toolUseIDCounter uint64 + // ConvertAntigravityResponseToClaude performs sophisticated streaming response format conversion. // This function implements a complex state machine that translates backend client responses // into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types @@ -216,7 +220,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex) - data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))) data, _ = sjson.Set(data, "content_block.name", fcName) output = output + fmt.Sprintf("data: %s\n\n\n", data) diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index d1914ec8..717f88f7 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -48,13 +48,13 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) case "low": - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024)) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) case "medium": - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192)) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) case "high": - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768)) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) default: out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1) @@ -66,15 +66,15 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ if !hasOfficialThinking && util.ModelSupportsThinking(modelName) { if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() { var setBudget bool - var normalized int + var budget int if v := tc.Get("thinkingBudget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } else if v := tc.Get("thinking_budget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } @@ -82,22 +82,27 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool()) } else if v := tc.Get("include_thoughts"); v.Exists() { out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool()) - } else if setBudget && normalized != 0 { + } else if setBudget && budget != 0 { out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) } } } - // For gemini-3-pro-preview, always send default thinkingConfig when none specified. - // This matches the official Gemini CLI behavior which always sends: - // { thinkingBudget: -1, includeThoughts: true } - // See: ai-gemini-cli/packages/core/src/config/defaultModelConfigs.ts - if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && modelName == "gemini-3-pro-preview" { - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) + // Claude/Anthropic API format: thinking.type == "enabled" with budget_tokens + // This allows Claude Code and other Claude API clients to pass thinking configuration + if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && util.ModelSupportsThinking(modelName) { + if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() { + if t.Get("type").String() == "enabled" { + if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number { + budget := int(b.Int()) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) + } + } + } } - // Temperature/top_p/top_k + // Temperature/top_p/top_k/max_tokens if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number { out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num) } @@ -107,6 +112,9 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number { out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num) } + if maxTok := gjson.GetBytes(rawJSON, "max_tokens"); maxTok.Exists() && maxTok.Type == gjson.Number { + out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num) + } // Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities // e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"] diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go index e069f7ec..24694e1d 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" . "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions" @@ -24,6 +25,9 @@ type convertCliResponseToOpenAIChatParams struct { FunctionIndex int } +// functionCallIDCounter provides a process-wide unique counter for function call identifiers. +var functionCallIDCounter uint64 + // ConvertAntigravityResponseToOpenAI translates a single chunk of a streaming response from the // Gemini CLI API format to the OpenAI Chat Completions streaming format. // It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses. @@ -146,7 +150,7 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}` fcName := functionCallResult.Get("name").String() - functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))) functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex) functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go index 50fd5a25..913727ce 100644 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go +++ b/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go @@ -165,7 +165,6 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) [] if t.Get("type").String() == "enabled" { if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number { budget := int(b.Int()) - budget = util.NormalizeThinkingBudget(modelName, budget) out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.include_thoughts", true) } diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go index e7f6275f..ba9f6801 100644 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go +++ b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -27,6 +28,9 @@ type Params struct { ResponseIndex int // Index counter for content blocks in the streaming response } +// toolUseIDCounter provides a process-wide unique counter for tool use identifiers. +var toolUseIDCounter uint64 + // ConvertGeminiCLIResponseToClaude performs sophisticated streaming response format conversion. // This function implements a complex state machine that translates backend client responses // into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types @@ -180,7 +184,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) - data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))) data, _ = sjson.Set(data, "content_block.name", fcName) sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data)) diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go index d14f1119..b52bf224 100644 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go +++ b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go @@ -48,13 +48,13 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) case "low": - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024)) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) case "medium": - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192)) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) case "high": - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768)) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) default: out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1) @@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo if !hasOfficialThinking && util.ModelSupportsThinking(modelName) { if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() { var setBudget bool - var normalized int + var budget int if v := tc.Get("thinkingBudget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } else if v := tc.Get("thinking_budget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } @@ -82,21 +82,12 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool()) } else if v := tc.Get("include_thoughts"); v.Exists() { out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool()) - } else if setBudget && normalized != 0 { + } else if setBudget && budget != 0 { out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) } } } - // For gemini-3-pro-preview, always send default thinkingConfig when none specified. - // This matches the official Gemini CLI behavior which always sends: - // { thinkingBudget: -1, includeThoughts: true } - // See: ai-gemini-cli/packages/core/src/config/defaultModelConfigs.ts - if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && modelName == "gemini-3-pro-preview" { - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true) - } - // Temperature/top_p/top_k if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number { out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num) diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go index 9c422a07..753870f3 100644 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go +++ b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" . "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions" @@ -24,6 +25,9 @@ type convertCliResponseToOpenAIChatParams struct { FunctionIndex int } +// functionCallIDCounter provides a process-wide unique counter for function call identifiers. +var functionCallIDCounter uint64 + // ConvertCliResponseToOpenAI translates a single chunk of a streaming response from the // Gemini CLI API format to the OpenAI Chat Completions streaming format. // It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses. @@ -146,7 +150,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}` fcName := functionCallResult.Get("name").String() - functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))) functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex) functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { diff --git a/internal/translator/gemini/claude/gemini_claude_request.go b/internal/translator/gemini/claude/gemini_claude_request.go index 05f9be5d..45a5a88f 100644 --- a/internal/translator/gemini/claude/gemini_claude_request.go +++ b/internal/translator/gemini/claude/gemini_claude_request.go @@ -158,7 +158,6 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) if t.Get("type").String() == "enabled" { if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number { budget := int(b.Int()) - budget = util.NormalizeThinkingBudget(modelName, budget) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) } diff --git a/internal/translator/gemini/claude/gemini_claude_response.go b/internal/translator/gemini/claude/gemini_claude_response.go index a80171a9..8fd566df 100644 --- a/internal/translator/gemini/claude/gemini_claude_response.go +++ b/internal/translator/gemini/claude/gemini_claude_response.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -26,6 +27,9 @@ type Params struct { ResponseIndex int } +// toolUseIDCounter provides a process-wide unique counter for tool use identifiers. +var toolUseIDCounter uint64 + // ConvertGeminiResponseToClaude performs sophisticated streaming response format conversion. // This function implements a complex state machine that translates backend client responses // into Claude-compatible Server-Sent Events (SSE) format. It manages different response types @@ -197,7 +201,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) - data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))) data, _ = sjson.Set(data, "content_block.name", fcName) output = output + fmt.Sprintf("data: %s\n\n\n", data) diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go index 0df8987f..8c48a5b3 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go @@ -48,13 +48,13 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1) out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true) case "low": - out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024)) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 1024) out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true) case "medium": - out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192)) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 8192) out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true) case "high": - out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768)) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 32768) out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true) default: out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1) @@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) if !hasOfficialThinking && util.ModelSupportsThinking(modelName) { if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() { var setBudget bool - var normalized int + var budget int if v := tc.Get("thinkingBudget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } else if v := tc.Get("thinking_budget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } @@ -82,7 +82,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool()) } else if v := tc.Get("include_thoughts"); v.Exists() { out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool()) - } else if setBudget && normalized != 0 { + } else if setBudget && budget != 0 { out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true) } } diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go index 12e28cca..a1ebc855 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -23,6 +24,9 @@ type convertGeminiResponseToOpenAIChatParams struct { FunctionIndex int } +// functionCallIDCounter provides a process-wide unique counter for function call identifiers. +var functionCallIDCounter uint64 + // ConvertGeminiResponseToOpenAI translates a single chunk of a streaming response from the // Gemini API format to the OpenAI Chat Completions streaming format. // It processes various Gemini event types and transforms them into OpenAI-compatible JSON responses. @@ -148,7 +152,7 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}` fcName := functionCallResult.Get("name").String() - functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))) functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex) functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { @@ -281,7 +285,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina } functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}` fcName := functionCallResult.Get("name").String() - functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))) functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw) diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go index 4ea75c18..bdf59785 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go @@ -400,16 +400,16 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) case "minimal": - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024)) + out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 1024) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) case "low": - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 4096)) + out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 4096) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) case "medium": - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192)) + out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 8192) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) case "high": - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768)) + out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 32768) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) default: out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1) @@ -421,32 +421,22 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte if !hasOfficialThinking && util.ModelSupportsThinking(modelName) { if tc := root.Get("extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() { var setBudget bool - var normalized int + var budget int if v := tc.Get("thinking_budget"); v.Exists() { - normalized = util.NormalizeThinkingBudget(modelName, int(v.Int())) - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", normalized) + budget = int(v.Int()) + out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget) setBudget = true } if v := tc.Get("include_thoughts"); v.Exists() { out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool()) } else if setBudget { - if normalized != 0 { + if budget != 0 { out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) } } } } - // For gemini-3-pro-preview, always send default thinkingConfig when none specified. - // This matches the official Gemini CLI behavior which always sends: - // { thinkingBudget: -1, includeThoughts: true } - // See: ai-gemini-cli/packages/core/src/config/defaultModelConfigs.ts - if !gjson.Get(out, "generationConfig.thinkingConfig").Exists() && modelName == "gemini-3-pro-preview" { - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1) - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true) - // log.Debugf("Applied default thinkingConfig for gemini-3-pro-preview (matches Gemini CLI): thinkingBudget=-1, include_thoughts=true") - } - result := []byte(out) result = common.AttachDefaultSafetySettings(result, "safetySettings") return result diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go index ce221863..e08b265d 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -37,6 +38,12 @@ type geminiToResponsesState struct { FuncCallIDs map[int]string } +// responseIDCounter provides a process-wide unique counter for synthesized response identifiers. +var responseIDCounter uint64 + +// funcCallIDCounter provides a process-wide unique counter for function call identifiers. +var funcCallIDCounter uint64 + func emitEvent(event string, payload string) string { return fmt.Sprintf("event: %s\ndata: %s", event, payload) } @@ -205,7 +212,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, st.FuncArgsBuf[idx] = &strings.Builder{} } if st.FuncCallIDs[idx] == "" { - st.FuncCallIDs[idx] = fmt.Sprintf("call_%d", time.Now().UnixNano()) + st.FuncCallIDs[idx] = fmt.Sprintf("call_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1)) } st.FuncNames[idx] = name @@ -464,7 +471,7 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string // id: prefer provider responseId, otherwise synthesize id := root.Get("responseId").String() if id == "" { - id = fmt.Sprintf("resp_%x", time.Now().UnixNano()) + id = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1)) } // Normalize to response-style id (prefix resp_ if missing) if !strings.HasPrefix(id, "resp_") { @@ -575,7 +582,7 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string if fc := p.Get("functionCall"); fc.Exists() { name := fc.Get("name").String() args := fc.Get("args") - callID := fmt.Sprintf("call_%x", time.Now().UnixNano()) + callID := fmt.Sprintf("call_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1)) outputs = append(outputs, map[string]interface{}{ "id": fmt.Sprintf("fc_%s", callID), "type": "function_call", diff --git a/internal/translator/openai/claude/openai_claude_request.go b/internal/translator/openai/claude/openai_claude_request.go index bff306cc..3521b2e5 100644 --- a/internal/translator/openai/claude/openai_claude_request.go +++ b/internal/translator/openai/claude/openai_claude_request.go @@ -8,6 +8,7 @@ package claude import ( "bytes" "encoding/json" + "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -242,11 +243,12 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) { switch partType { case "text": - if !part.Get("text").Exists() { + text := part.Get("text").String() + if strings.TrimSpace(text) == "" { return "", false } textContent := `{"type":"text","text":""}` - textContent, _ = sjson.Set(textContent, "text", part.Get("text").String()) + textContent, _ = sjson.Set(textContent, "text", text) return textContent, true case "image": diff --git a/internal/translator/openai/openai/responses/openai_openai-responses_response.go b/internal/translator/openai/openai/responses/openai_openai-responses_response.go index 00ec5c7f..c698b93f 100644 --- a/internal/translator/openai/openai/responses/openai_openai-responses_response.go +++ b/internal/translator/openai/openai/responses/openai_openai-responses_response.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -41,6 +42,9 @@ type oaiToResponsesState struct { UsageSeen bool } +// responseIDCounter provides a process-wide unique counter for synthesized response identifiers. +var responseIDCounter uint64 + func emitRespEvent(event string, payload string) string { return fmt.Sprintf("event: %s\ndata: %s", event, payload) } @@ -590,7 +594,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co // id: use provider id if present, otherwise synthesize id := root.Get("id").String() if id == "" { - id = fmt.Sprintf("resp_%x", time.Now().UnixNano()) + id = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1)) } resp, _ = sjson.Set(resp, "id", id) diff --git a/internal/util/gemini_thinking.go b/internal/util/gemini_thinking.go index 14077fa0..fc389511 100644 --- a/internal/util/gemini_thinking.go +++ b/internal/util/gemini_thinking.go @@ -207,6 +207,47 @@ func GeminiThinkingFromMetadata(metadata map[string]any) (*int, *bool, bool) { return budgetPtr, includePtr, matched } +// modelsWithDefaultThinking lists models that should have thinking enabled by default +// when no explicit thinkingConfig is provided. +var modelsWithDefaultThinking = map[string]bool{ + "gemini-3-pro-preview": true, +} + +// ModelHasDefaultThinking returns true if the model should have thinking enabled by default. +func ModelHasDefaultThinking(model string) bool { + return modelsWithDefaultThinking[model] +} + +// ApplyDefaultThinkingIfNeeded injects default thinkingConfig for models that require it. +// For standard Gemini API format (generationConfig.thinkingConfig path). +// Returns the modified body if thinkingConfig was added, otherwise returns the original. +func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte { + if !ModelHasDefaultThinking(model) { + return body + } + if gjson.GetBytes(body, "generationConfig.thinkingConfig").Exists() { + return body + } + updated, _ := sjson.SetBytes(body, "generationConfig.thinkingConfig.thinkingBudget", -1) + updated, _ = sjson.SetBytes(updated, "generationConfig.thinkingConfig.include_thoughts", true) + return updated +} + +// ApplyDefaultThinkingIfNeededCLI injects default thinkingConfig for models that require it. +// For Gemini CLI API format (request.generationConfig.thinkingConfig path). +// Returns the modified body if thinkingConfig was added, otherwise returns the original. +func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte { + if !ModelHasDefaultThinking(model) { + return body + } + if gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists() { + return body + } + updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget", -1) + updated, _ = sjson.SetBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts", true) + return updated +} + // StripThinkingConfigIfUnsupported removes thinkingConfig from the request body // when the target model does not advertise Thinking capability. It cleans both // standard Gemini and Gemini CLI JSON envelopes. This acts as a final safety net @@ -223,6 +264,32 @@ func StripThinkingConfigIfUnsupported(model string, body []byte) []byte { return updated } +// NormalizeGeminiThinkingBudget normalizes the thinkingBudget value in a standard Gemini +// request body (generationConfig.thinkingConfig.thinkingBudget path). +func NormalizeGeminiThinkingBudget(model string, body []byte) []byte { + const budgetPath = "generationConfig.thinkingConfig.thinkingBudget" + budget := gjson.GetBytes(body, budgetPath) + if !budget.Exists() { + return body + } + normalized := NormalizeThinkingBudget(model, int(budget.Int())) + updated, _ := sjson.SetBytes(body, budgetPath, normalized) + return updated +} + +// NormalizeGeminiCLIThinkingBudget normalizes the thinkingBudget value in a Gemini CLI +// request body (request.generationConfig.thinkingConfig.thinkingBudget path). +func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte { + const budgetPath = "request.generationConfig.thinkingConfig.thinkingBudget" + budget := gjson.GetBytes(body, budgetPath) + if !budget.Exists() { + return body + } + normalized := NormalizeThinkingBudget(model, int(budget.Int())) + updated, _ := sjson.SetBytes(body, budgetPath, normalized) + return updated +} + // ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel" // and converts it to "thinkingBudget". // "high" -> 32768 diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index da152141..36276de9 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -1272,7 +1272,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } for i := range cfg.KiroKey { kk := cfg.KiroKey[i] - var accessToken, profileArn string + var accessToken, profileArn, refreshToken string // Try to load from token file first if kk.TokenFile != "" && kAuth != nil { @@ -1282,6 +1282,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } else { accessToken = tokenData.AccessToken profileArn = tokenData.ProfileArn + refreshToken = tokenData.RefreshToken } } @@ -1292,6 +1293,9 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if kk.ProfileArn != "" { profileArn = kk.ProfileArn } + if kk.RefreshToken != "" { + refreshToken = kk.RefreshToken + } if accessToken == "" { log.Warnf("kiro config[%d] missing access_token, skipping", i) @@ -1313,6 +1317,9 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if kk.AgentTaskType != "" { attrs["agent_task_type"] = kk.AgentTaskType } + if refreshToken != "" { + attrs["refresh_token"] = refreshToken + } proxyURL := strings.TrimSpace(kk.ProxyURL) a := &coreauth.Auth{ ID: id, @@ -1324,6 +1331,14 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { CreatedAt: now, UpdatedAt: now, } + + if refreshToken != "" { + if a.Metadata == nil { + a.Metadata = make(map[string]any) + } + a.Metadata["refresh_token"] = refreshToken + } + out = append(out, a) } for i := range cfg.OpenAICompatibility { diff --git a/sdk/api/handlers/gemini/gemini_handlers.go b/sdk/api/handlers/gemini/gemini_handlers.go index 7ba72a93..6cd9ee62 100644 --- a/sdk/api/handlers/gemini/gemini_handlers.go +++ b/sdk/api/handlers/gemini/gemini_handlers.go @@ -48,8 +48,24 @@ func (h *GeminiAPIHandler) Models() []map[string]any { // GeminiModels handles the Gemini models listing endpoint. // It returns a JSON response containing available Gemini models and their specifications. func (h *GeminiAPIHandler) GeminiModels(c *gin.Context) { + rawModels := h.Models() + normalizedModels := make([]map[string]any, 0, len(rawModels)) + defaultMethods := []string{"generateContent"} + for _, model := range rawModels { + normalizedModel := make(map[string]any, len(model)) + for k, v := range model { + normalizedModel[k] = v + } + if name, ok := normalizedModel["name"].(string); ok && name != "" && !strings.HasPrefix(name, "models/") { + normalizedModel["name"] = "models/" + name + } + if _, ok := normalizedModel["supportedGenerationMethods"]; !ok { + normalizedModel["supportedGenerationMethods"] = defaultMethods + } + normalizedModels = append(normalizedModels, normalizedModel) + } c.JSON(http.StatusOK, gin.H{ - "models": h.Models(), + "models": normalizedModels, }) } diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 4e897d52..251457a6 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -502,7 +502,7 @@ func (s *Service) Run(ctx context.Context) error { }() time.Sleep(100 * time.Millisecond) - fmt.Printf("API server started successfully on: %d\n", s.cfg.Port) + fmt.Printf("API server started successfully on: %s:%d\n", s.cfg.Host, s.cfg.Port) if s.hooks.OnAfterStart != nil { s.hooks.OnAfterStart(s) @@ -788,7 +788,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { Created: time.Now().Unix(), OwnedBy: compat.Name, Type: "openai-compatibility", - DisplayName: m.Name, + DisplayName: modelID, }) } // Register and return diff --git a/test/amp_management_test.go b/test/amp_management_test.go new file mode 100644 index 00000000..19450dbf --- /dev/null +++ b/test/amp_management_test.go @@ -0,0 +1,827 @@ +package test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// newAmpTestHandler creates a test handler with default ampcode configuration. +func newAmpTestHandler(t *testing.T) (*management.Handler, string) { + t.Helper() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + UpstreamURL: "https://example.com", + UpstreamAPIKey: "test-api-key-12345", + RestrictManagementToLocalhost: true, + ForceModelMappings: false, + ModelMappings: []config.AmpModelMapping{ + {From: "gpt-4", To: "gemini-pro"}, + }, + }, + } + + if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + h := management.NewHandler(cfg, configPath, nil) + return h, configPath +} + +// setupAmpRouter creates a test router with all ampcode management endpoints. +func setupAmpRouter(h *management.Handler) *gin.Engine { + r := gin.New() + mgmt := r.Group("/v0/management") + { + mgmt.GET("/ampcode", h.GetAmpCode) + mgmt.GET("/ampcode/upstream-url", h.GetAmpUpstreamURL) + mgmt.PUT("/ampcode/upstream-url", h.PutAmpUpstreamURL) + mgmt.DELETE("/ampcode/upstream-url", h.DeleteAmpUpstreamURL) + mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey) + mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey) + mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey) + mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost) + mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost) + mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings) + mgmt.PUT("/ampcode/model-mappings", h.PutAmpModelMappings) + mgmt.PATCH("/ampcode/model-mappings", h.PatchAmpModelMappings) + mgmt.DELETE("/ampcode/model-mappings", h.DeleteAmpModelMappings) + mgmt.GET("/ampcode/force-model-mappings", h.GetAmpForceModelMappings) + mgmt.PUT("/ampcode/force-model-mappings", h.PutAmpForceModelMappings) + } + return r +} + +// TestGetAmpCode verifies GET /v0/management/ampcode returns full ampcode config. +func TestGetAmpCode(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]config.AmpCode + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + ampcode := resp["ampcode"] + if ampcode.UpstreamURL != "https://example.com" { + t.Errorf("expected upstream-url %q, got %q", "https://example.com", ampcode.UpstreamURL) + } + if len(ampcode.ModelMappings) != 1 { + t.Errorf("expected 1 model mapping, got %d", len(ampcode.ModelMappings)) + } +} + +// TestGetAmpUpstreamURL verifies GET /v0/management/ampcode/upstream-url returns the upstream URL. +func TestGetAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["upstream-url"] != "https://example.com" { + t.Errorf("expected %q, got %q", "https://example.com", resp["upstream-url"]) + } +} + +// TestPutAmpUpstreamURL verifies PUT /v0/management/ampcode/upstream-url updates the upstream URL. +func TestPutAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "https://new-upstream.com"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +// TestDeleteAmpUpstreamURL verifies DELETE /v0/management/ampcode/upstream-url clears the upstream URL. +func TestDeleteAmpUpstreamURL(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpUpstreamAPIKey verifies GET /v0/management/ampcode/upstream-api-key returns the API key. +func TestGetAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + key := resp["upstream-api-key"].(string) + if key != "test-api-key-12345" { + t.Errorf("expected key %q, got %q", "test-api-key-12345", key) + } +} + +// TestPutAmpUpstreamAPIKey verifies PUT /v0/management/ampcode/upstream-api-key updates the API key. +func TestPutAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "new-secret-key"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key. +func TestDeleteAmpUpstreamAPIKey(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpRestrictManagementToLocalhost verifies GET returns the localhost restriction setting. +func TestGetAmpRestrictManagementToLocalhost(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["restrict-management-to-localhost"] != true { + t.Error("expected restrict-management-to-localhost to be true") + } +} + +// TestPutAmpRestrictManagementToLocalhost verifies PUT updates the localhost restriction setting. +func TestPutAmpRestrictManagementToLocalhost(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": false}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpModelMappings verifies GET /v0/management/ampcode/model-mappings returns all mappings. +func TestGetAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 1 { + t.Fatalf("expected 1 mapping, got %d", len(mappings)) + } + if mappings[0].From != "gpt-4" || mappings[0].To != "gemini-pro" { + t.Errorf("unexpected mapping: %+v", mappings[0]) + } +} + +// TestPutAmpModelMappings verifies PUT /v0/management/ampcode/model-mappings replaces all mappings. +func TestPutAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "claude-3", "to": "gpt-4o"}, {"from": "gemini", "to": "claude"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +// TestPatchAmpModelMappings verifies PATCH updates existing mappings and adds new ones. +func TestPatchAmpModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "gpt-4", "to": "updated-model"}, {"from": "new-model", "to": "target"}]}` + req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } +} + +// TestDeleteAmpModelMappings_Specific verifies DELETE removes specified mappings by "from" field. +func TestDeleteAmpModelMappings_Specific(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": ["gpt-4"]}` + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestDeleteAmpModelMappings_All verifies DELETE with empty body removes all mappings. +func TestDeleteAmpModelMappings_All(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestGetAmpForceModelMappings verifies GET returns the force-model-mappings setting. +func TestGetAmpForceModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["force-model-mappings"] != false { + t.Error("expected force-model-mappings to be false") + } +} + +// TestPutAmpForceModelMappings verifies PUT updates the force-model-mappings setting. +func TestPutAmpForceModelMappings(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": true}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestPutAmpModelMappings_VerifyState verifies PUT replaces mappings and state is persisted. +func TestPutAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "model-a", "to": "model-b"}, {"from": "model-c", "to": "model-d"}, {"from": "model-e", "to": "model-f"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d, body: %s", w.Code, w.Body.String()) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 3 { + t.Fatalf("expected 3 mappings, got %d", len(mappings)) + } + + expected := map[string]string{"model-a": "model-b", "model-c": "model-d", "model-e": "model-f"} + for _, m := range mappings { + if expected[m.From] != m.To { + t.Errorf("mapping %q -> expected %q, got %q", m.From, expected[m.From], m.To) + } + } +} + +// TestPatchAmpModelMappings_VerifyState verifies PATCH merges mappings correctly. +func TestPatchAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": [{"from": "gpt-4", "to": "updated-target"}, {"from": "new-model", "to": "new-target"}]}` + req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PATCH failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 2 { + t.Fatalf("expected 2 mappings (1 updated + 1 new), got %d", len(mappings)) + } + + found := make(map[string]string) + for _, m := range mappings { + found[m.From] = m.To + } + + if found["gpt-4"] != "updated-target" { + t.Errorf("gpt-4 should map to updated-target, got %q", found["gpt-4"]) + } + if found["new-model"] != "new-target" { + t.Errorf("new-model should map to new-target, got %q", found["new-model"]) + } +} + +// TestDeleteAmpModelMappings_VerifyState verifies DELETE removes specific mappings and keeps others. +func TestDeleteAmpModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + putBody := `{"value": [{"from": "a", "to": "1"}, {"from": "b", "to": "2"}, {"from": "c", "to": "3"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + delBody := `{"value": ["a", "c"]}` + req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 1 { + t.Fatalf("expected 1 mapping remaining, got %d", len(mappings)) + } + if mappings[0].From != "b" || mappings[0].To != "2" { + t.Errorf("expected b->2, got %s->%s", mappings[0].From, mappings[0].To) + } +} + +// TestDeleteAmpModelMappings_NonExistent verifies DELETE with non-existent mapping doesn't affect existing ones. +func TestDeleteAmpModelMappings_NonExistent(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + delBody := `{"value": ["non-existent-model"]}` + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 1 { + t.Errorf("original mapping should remain, got %d mappings", len(resp["model-mappings"])) + } +} + +// TestPutAmpModelMappings_Empty verifies PUT with empty array clears all mappings. +func TestPutAmpModelMappings_Empty(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": []}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 0 { + t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"])) + } +} + +// TestPutAmpUpstreamURL_VerifyState verifies PUT updates upstream URL and persists state. +func TestPutAmpUpstreamURL_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "https://new-api.example.com"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-url"] != "https://new-api.example.com" { + t.Errorf("expected %q, got %q", "https://new-api.example.com", resp["upstream-url"]) + } +} + +// TestDeleteAmpUpstreamURL_VerifyState verifies DELETE clears upstream URL. +func TestDeleteAmpUpstreamURL_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-url"] != "" { + t.Errorf("expected empty string, got %q", resp["upstream-url"]) + } +} + +// TestPutAmpUpstreamAPIKey_VerifyState verifies PUT updates API key and persists state. +func TestPutAmpUpstreamAPIKey_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": "new-secret-api-key-xyz"}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-api-key"] != "new-secret-api-key-xyz" { + t.Errorf("expected %q, got %q", "new-secret-api-key-xyz", resp["upstream-api-key"]) + } +} + +// TestDeleteAmpUpstreamAPIKey_VerifyState verifies DELETE clears API key. +func TestDeleteAmpUpstreamAPIKey_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("DELETE failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["upstream-api-key"] != "" { + t.Errorf("expected empty string, got %q", resp["upstream-api-key"]) + } +} + +// TestPutAmpRestrictManagementToLocalhost_VerifyState verifies PUT updates localhost restriction. +func TestPutAmpRestrictManagementToLocalhost_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": false}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["restrict-management-to-localhost"] != false { + t.Error("expected false after update") + } +} + +// TestPutAmpForceModelMappings_VerifyState verifies PUT updates force-model-mappings setting. +func TestPutAmpForceModelMappings_VerifyState(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value": true}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("PUT failed: status %d", w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string]bool + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if resp["force-model-mappings"] != true { + t.Error("expected true after update") + } +} + +// TestPutBoolField_EmptyObject verifies PUT with empty object returns 400. +func TestPutBoolField_EmptyObject(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status %d for empty object, got %d", http.StatusBadRequest, w.Code) + } +} + +// TestComplexMappingsWorkflow tests a full workflow: PUT, PATCH, DELETE, and GET. +func TestComplexMappingsWorkflow(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + putBody := `{"value": [{"from": "m1", "to": "t1"}, {"from": "m2", "to": "t2"}, {"from": "m3", "to": "t3"}, {"from": "m4", "to": "t4"}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + patchBody := `{"value": [{"from": "m2", "to": "t2-updated"}, {"from": "m5", "to": "t5"}]}` + req = httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(patchBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + delBody := `{"value": ["m1", "m3"]}` + req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + mappings := resp["model-mappings"] + if len(mappings) != 3 { + t.Fatalf("expected 3 mappings (m2, m4, m5), got %d", len(mappings)) + } + + expected := map[string]string{"m2": "t2-updated", "m4": "t4", "m5": "t5"} + found := make(map[string]string) + for _, m := range mappings { + found[m.From] = m.To + } + + for from, to := range expected { + if found[from] != to { + t.Errorf("mapping %s: expected %q, got %q", from, to, found[from]) + } + } +} + +// TestNilHandlerGetAmpCode verifies handler works with empty config. +func TestNilHandlerGetAmpCode(t *testing.T) { + cfg := &config.Config{} + h := management.NewHandler(cfg, "", nil) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestEmptyConfigGetAmpModelMappings verifies GET returns empty array for fresh config. +func TestEmptyConfigGetAmpModelMappings(t *testing.T) { + cfg := &config.Config{} + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + h := management.NewHandler(cfg, configPath, nil) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if len(resp["model-mappings"]) != 0 { + t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"])) + } +}