Compare commits

...

25 Commits

Author SHA1 Message Date
Luis Pater
f241124599 Merge branch 'router-for-me:main' into main 2025-12-06 00:43:02 +08:00
Luis Pater
c44c46dd80 Fixed: #421
feat(antigravity): implement project ID retrieval and integration in payload processing
2025-12-06 00:40:55 +08:00
Luis Pater
aa810ee719 Merge branch 'router-for-me:main' into main 2025-12-05 23:06:46 +08:00
Luis Pater
412148af0e feat(antigravity): add function ID to FunctionCall and FunctionResponse models 2025-12-05 23:05:35 +08:00
Luis Pater
5d2baf6058 Merge branch 'router-for-me:main' into main 2025-12-05 21:37:55 +08:00
Luis Pater
d28258501a Merge pull request #423 from router-for-me/amp
fix(amp): suppress ErrAbortHandler panics in reverse proxy handler
2025-12-05 21:36:01 +08:00
hkfires
55cd31fb96 fix(amp): suppress ErrAbortHandler panics in reverse proxy handler 2025-12-05 21:28:58 +08:00
Luis Pater
d138df07bf Merge branch 'router-for-me:main' into main 2025-12-05 21:28:32 +08:00
Luis Pater
c5df8e7897 Merge pull request #422 from router-for-me/amp
Amp
2025-12-05 21:25:43 +08:00
Luis Pater
d4d529833d **refactor(antigravity): handle anyOf property, remove exclusiveMinimum, and comment unused prod URL** 2025-12-05 21:24:12 +08:00
hkfires
caa48e7c6f fix(amp): improve proxy state management and request logging behavior 2025-12-05 21:09:53 +08:00
hkfires
acdfb3bceb feat(amp): add root-level /threads routes for CLI compatibility 2025-12-05 18:14:10 +08:00
hkfires
89d68962b1 fix(amp): filter amp request logging to only provider endpoint 2025-12-05 18:14:09 +08:00
Luis Pater
691cdb6bdf **fix(api): update GitHub release URL and user agent for CLIProxyAPIPlus** 2025-12-05 10:32:28 +08:00
Luis Pater
8064cba288 Merge branch 'router-for-me:main' into main 2025-12-05 10:31:30 +08:00
Luis Pater
361443db10 **feat(api): add GetLatestVersion endpoint to fetch latest release version from GitHub** 2025-12-05 10:29:12 +08:00
Luis Pater
d6352dd4d4 **feat(util): add DeleteKey function and update antigravity executor for Claude model compatibility** 2025-12-05 01:55:45 +08:00
Luis Pater
f8aba62860 Merge branch 'router-for-me:main' into main 2025-12-05 00:45:51 +08:00
Luis Pater
a7eeb06f3d Merge pull request #418 from router-for-me/amp
Amp
2025-12-05 00:43:15 +08:00
hkfires
9426be7a5c fix(amp): update log message wording for disabled proxy state 2025-12-04 21:36:16 +08:00
hkfires
4a135f1986 feat(amp): add hot-reload support for upstream URL and localhost restriction 2025-12-04 21:30:59 +08:00
hkfires
c4c02f4ad0 feat(amp): add partial reload support with config change detection 2025-12-04 21:30:59 +08:00
Luis Pater
b87b9b455f Merge pull request #416 from router-for-me/amp
Amp
2025-12-04 20:52:33 +08:00
hkfires
db03ae9663 feat(watcher): add AmpCode config change detection 2025-12-04 19:50:54 +08:00
hkfires
969ff6bb68 fix(amp): update explicit API key on config change 2025-12-04 19:32:44 +08:00
20 changed files with 691 additions and 123 deletions

View File

@@ -1,16 +1,28 @@
package management package management
import ( import (
"encoding/json"
"fmt"
"io" "io"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
const (
latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPIPlus/releases/latest"
latestReleaseUserAgent = "CLIProxyAPIPlus"
)
func (h *Handler) GetConfig(c *gin.Context) { func (h *Handler) GetConfig(c *gin.Context) {
if h == nil || h.cfg == nil { if h == nil || h.cfg == nil {
c.JSON(200, gin.H{}) c.JSON(200, gin.H{})
@@ -20,6 +32,66 @@ func (h *Handler) GetConfig(c *gin.Context) {
c.JSON(200, &cfgCopy) c.JSON(200, &cfgCopy)
} }
type releaseInfo struct {
TagName string `json:"tag_name"`
Name string `json:"name"`
}
// GetLatestVersion returns the latest release version from GitHub without downloading assets.
func (h *Handler) GetLatestVersion(c *gin.Context) {
client := &http.Client{Timeout: 10 * time.Second}
proxyURL := ""
if h != nil && h.cfg != nil {
proxyURL = strings.TrimSpace(h.cfg.ProxyURL)
}
if proxyURL != "" {
sdkCfg := &sdkconfig.SDKConfig{ProxyURL: proxyURL}
util.SetProxy(sdkCfg, client)
}
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, latestReleaseURL, nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "request_create_failed", "message": err.Error()})
return
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("User-Agent", latestReleaseUserAgent)
resp, err := client.Do(req)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "request_failed", "message": err.Error()})
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.WithError(errClose).Debug("failed to close latest version response body")
}
}()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
c.JSON(http.StatusBadGateway, gin.H{"error": "unexpected_status", "message": fmt.Sprintf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))})
return
}
var info releaseInfo
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "decode_failed", "message": errDecode.Error()})
return
}
version := strings.TrimSpace(info.TagName)
if version == "" {
version = strings.TrimSpace(info.Name)
}
if version == "" {
c.JSON(http.StatusBadGateway, gin.H{"error": "invalid_response", "message": "missing release version"})
return
}
c.JSON(http.StatusOK, gin.H{"latest-version": version})
}
func WriteConfig(path string, data []byte) error { func WriteConfig(path string, data []byte) error {
data = config.NormalizeCommentIndentation(data) data = config.NormalizeCommentIndentation(data)
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)

View File

@@ -112,5 +112,10 @@ func shouldLogRequest(path string) bool {
if strings.HasPrefix(path, "/v0/management") || strings.HasPrefix(path, "/management") { if strings.HasPrefix(path, "/v0/management") || strings.HasPrefix(path, "/management") {
return false return false
} }
if strings.HasPrefix(path, "/api") {
return strings.HasPrefix(path, "/api/provider")
}
return true return true
} }

View File

@@ -27,11 +27,20 @@ type Option func(*AmpModule)
type AmpModule struct { type AmpModule struct {
secretSource SecretSource secretSource SecretSource
proxy *httputil.ReverseProxy proxy *httputil.ReverseProxy
proxyMu sync.RWMutex // protects proxy for hot-reload
accessManager *sdkaccess.Manager accessManager *sdkaccess.Manager
authMiddleware_ gin.HandlerFunc authMiddleware_ gin.HandlerFunc
modelMapper *DefaultModelMapper modelMapper *DefaultModelMapper
enabled bool enabled bool
registerOnce sync.Once registerOnce sync.Once
// restrictToLocalhost controls localhost-only access for management routes (hot-reloadable)
restrictToLocalhost bool
restrictMu sync.RWMutex
// configMu protects lastConfig for partial reload comparison
configMu sync.RWMutex
lastConfig *config.AmpCode
} }
// New creates a new Amp routing module with the given options. // New creates a new Amp routing module with the given options.
@@ -107,9 +116,19 @@ func (m *AmpModule) Register(ctx modules.Context) error {
// Initialize model mapper from config (for routing unavailable models to alternatives) // Initialize model mapper from config (for routing unavailable models to alternatives)
m.modelMapper = NewModelMapper(settings.ModelMappings) m.modelMapper = NewModelMapper(settings.ModelMappings)
// Store initial config for partial reload comparison
settingsCopy := settings
m.lastConfig = &settingsCopy
// Initialize localhost restriction setting (hot-reloadable)
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)
// Always register provider aliases - these work without an upstream // Always register provider aliases - these work without an upstream
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth) m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
// Register management proxy routes once; middleware will gate access when upstream is unavailable.
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler)
// If no upstream URL, skip proxy routes but provider aliases are still available // If no upstream URL, skip proxy routes but provider aliases are still available
if upstreamURL == "" { if upstreamURL == "" {
log.Debug("amp upstream proxy disabled (no upstream URL configured)") log.Debug("amp upstream proxy disabled (no upstream URL configured)")
@@ -118,28 +137,11 @@ func (m *AmpModule) Register(ctx modules.Context) error {
return return
} }
// Create secret source with precedence: config > env > file if err := m.enableUpstreamProxy(upstreamURL, &settings); err != nil {
// Cache secrets for 5 minutes to reduce file I/O
if m.secretSource == nil {
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
}
// Create reverse proxy with gzip handling via ModifyResponse
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
if err != nil {
regErr = fmt.Errorf("failed to create amp proxy: %w", err) regErr = fmt.Errorf("failed to create amp proxy: %w", err)
return return
} }
m.proxy = proxy
m.enabled = true
// Register management proxy routes (requires upstream)
// Restrict to localhost by default for security (prevents drive-by browser attacks)
handler := proxyHandler(proxy)
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, handler, settings.RestrictManagementToLocalhost)
log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
log.Debug("amp provider alias routes registered") log.Debug("amp provider alias routes registered")
}) })
@@ -162,44 +164,169 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
} }
} }
// OnConfigUpdated handles configuration updates. // OnConfigUpdated handles configuration updates with partial reload support.
// Currently requires restart for URL changes (could be enhanced for dynamic updates). // Only updates components that have actually changed to avoid unnecessary work.
// Supports hot-reload for: model-mappings, upstream-api-key, upstream-url, restrict-management-to-localhost.
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
settings := cfg.AmpCode newSettings := cfg.AmpCode
// Update model mappings (hot-reload supported) // Get previous config for comparison
if m.modelMapper != nil { m.configMu.RLock()
m.modelMapper.UpdateMappings(settings.ModelMappings) oldSettings := m.lastConfig
if m.enabled { m.configMu.RUnlock()
log.Infof("amp config updated: reloading %d model mapping(s)", len(settings.ModelMappings))
}
} else if m.enabled {
log.Warnf("amp model mapper not initialized, skipping model mapping update")
}
if !m.enabled { if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
return nil m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
} if !newSettings.RestrictManagementToLocalhost {
log.Warnf("amp management routes now accessible from any IP - this is insecure!")
upstreamURL := strings.TrimSpace(settings.UpstreamURL)
if upstreamURL == "" {
log.Warn("amp upstream URL removed from config, restart required to disable")
return nil
}
// If API key changed, invalidate the cache
if m.secretSource != nil {
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
ms.InvalidateCache()
log.Debug("amp secret cache invalidated due to config update")
} }
} }
log.Debug("amp config updated (restart required for URL changes)") newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
oldUpstreamURL := ""
if oldSettings != nil {
oldUpstreamURL = strings.TrimSpace(oldSettings.UpstreamURL)
}
if !m.enabled && newUpstreamURL != "" {
if err := m.enableUpstreamProxy(newUpstreamURL, &newSettings); err != nil {
log.Errorf("amp config: failed to enable upstream proxy for %s: %v", newUpstreamURL, err)
}
}
// Check model mappings change
modelMappingsChanged := m.hasModelMappingsChanged(oldSettings, &newSettings)
if modelMappingsChanged {
if m.modelMapper != nil {
m.modelMapper.UpdateMappings(newSettings.ModelMappings)
} else if m.enabled {
log.Warnf("amp model mapper not initialized, skipping model mapping update")
}
}
if m.enabled {
// Check upstream URL change - now supports hot-reload
if newUpstreamURL == "" && oldUpstreamURL != "" {
m.setProxy(nil)
m.enabled = false
} else if oldUpstreamURL != "" && newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" {
// Recreate proxy with new URL
proxy, err := createReverseProxy(newUpstreamURL, m.secretSource)
if err != nil {
log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err)
} else {
m.setProxy(proxy)
}
}
// Check API key change
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
if apiKeyChanged {
if m.secretSource != nil {
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
}
}
}
}
// Store current config for next comparison
m.configMu.Lock()
settingsCopy := newSettings // copy struct
m.lastConfig = &settingsCopy
m.configMu.Unlock()
return nil return nil
} }
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
if m.secretSource == nil {
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
}
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
if err != nil {
return err
}
m.setProxy(proxy)
m.enabled = true
log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
return nil
}
// hasModelMappingsChanged compares old and new model mappings.
func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.AmpCode) bool {
if old == nil {
return len(new.ModelMappings) > 0
}
if len(old.ModelMappings) != len(new.ModelMappings) {
return true
}
// Build map for efficient comparison
oldMap := make(map[string]string, len(old.ModelMappings))
for _, mapping := range old.ModelMappings {
oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To)
}
for _, mapping := range new.ModelMappings {
from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To)
if oldTo, exists := oldMap[from]; !exists || oldTo != to {
return true
}
}
return false
}
// hasAPIKeyChanged compares old and new API keys.
func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) bool {
oldKey := ""
if old != nil {
oldKey = strings.TrimSpace(old.UpstreamAPIKey)
}
newKey := strings.TrimSpace(new.UpstreamAPIKey)
return oldKey != newKey
}
// GetModelMapper returns the model mapper instance (for testing/debugging). // GetModelMapper returns the model mapper instance (for testing/debugging).
func (m *AmpModule) GetModelMapper() *DefaultModelMapper { func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
return m.modelMapper return m.modelMapper
} }
// getProxy returns the current proxy instance (thread-safe for hot-reload).
func (m *AmpModule) getProxy() *httputil.ReverseProxy {
m.proxyMu.RLock()
defer m.proxyMu.RUnlock()
return m.proxy
}
// setProxy updates the proxy instance (thread-safe for hot-reload).
func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) {
m.proxyMu.Lock()
defer m.proxyMu.Unlock()
m.proxy = proxy
}
// IsRestrictedToLocalhost returns whether management routes are restricted to localhost.
func (m *AmpModule) IsRestrictedToLocalhost() bool {
m.restrictMu.RLock()
defer m.restrictMu.RUnlock()
return m.restrictToLocalhost
}
// setRestrictToLocalhost updates the localhost restriction setting.
func (m *AmpModule) setRestrictToLocalhost(restrict bool) {
m.restrictMu.Lock()
defer m.restrictMu.Unlock()
m.restrictToLocalhost = restrict
}

View File

@@ -48,25 +48,25 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
case RouteTypeLocalProvider: case RouteTypeLocalProvider:
fields["cost"] = "free" fields["cost"] = "free"
fields["source"] = "local_oauth" fields["source"] = "local_oauth"
log.WithFields(fields).Infof("[amp] using local provider for model: %s", requestedModel) log.WithFields(fields).Debugf("amp using local provider for model: %s", requestedModel)
case RouteTypeModelMapping: case RouteTypeModelMapping:
fields["cost"] = "free" fields["cost"] = "free"
fields["source"] = "local_oauth" fields["source"] = "local_oauth"
fields["mapping"] = requestedModel + " -> " + resolvedModel fields["mapping"] = requestedModel + " -> " + resolvedModel
log.WithFields(fields).Infof("[amp] model mapped: %s -> %s", requestedModel, resolvedModel) // model mapping already logged in mapper; avoid duplicate here
case RouteTypeAmpCredits: case RouteTypeAmpCredits:
fields["cost"] = "amp_credits" fields["cost"] = "amp_credits"
fields["source"] = "ampcode.com" fields["source"] = "ampcode.com"
fields["model_id"] = requestedModel // Explicit model_id for easy config reference fields["model_id"] = requestedModel // Explicit model_id for easy config reference
log.WithFields(fields).Warnf("[amp] forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel) log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
case RouteTypeNoProvider: case RouteTypeNoProvider:
fields["cost"] = "none" fields["cost"] = "none"
fields["source"] = "error" fields["source"] = "error"
fields["model_id"] = requestedModel // Explicit model_id for easy config reference fields["model_id"] = requestedModel // Explicit model_id for easy config reference
log.WithFields(fields).Warnf("[amp] no provider available for model_id: %s", requestedModel) log.WithFields(fields).Warnf("no provider available for model_id: %s", requestedModel)
} }
} }

View File

@@ -152,9 +152,9 @@ func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) {
mapper := NewModelMapper(nil) mapper := NewModelMapper(nil)
mapper.UpdateMappings([]config.AmpModelMapping{ mapper.UpdateMappings([]config.AmpModelMapping{
{From: "", To: "model-b"}, // Invalid: empty from {From: "", To: "model-b"}, // Invalid: empty from
{From: "model-a", To: ""}, // Invalid: empty to {From: "model-a", To: ""}, // Invalid: empty to
{From: " ", To: "model-b"}, // Invalid: whitespace from {From: " ", To: "model-b"}, // Invalid: whitespace from
{From: "model-c", To: "model-d"}, // Valid {From: "model-c", To: "model-d"}, // Valid
}) })

View File

@@ -1,11 +1,14 @@
package amp package amp
import ( import (
"errors"
"net" "net"
"net/http"
"net/http/httputil" "net/http/httputil"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
@@ -14,15 +17,16 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// localhostOnlyMiddleware restricts access to localhost (127.0.0.1, ::1) only. // localhostOnlyMiddleware returns a middleware that dynamically checks the module's
// Returns 403 Forbidden for non-localhost clients. // localhost restriction setting. This allows hot-reload of the restriction without restarting.
// func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
// Security: Uses RemoteAddr (actual TCP connection) instead of ClientIP() to prevent
// header spoofing attacks via X-Forwarded-For or similar headers. This means the
// middleware will not work correctly behind reverse proxies - users deploying behind
// nginx/Cloudflare should disable this feature and use firewall rules instead.
func localhostOnlyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// Check current setting (hot-reloadable)
if !m.IsRestrictedToLocalhost() {
c.Next()
return
}
// Use actual TCP connection address (RemoteAddr) to prevent header spoofing // Use actual TCP connection address (RemoteAddr) to prevent header spoofing
// This cannot be forged by X-Forwarded-For or other client-controlled headers // This cannot be forged by X-Forwarded-For or other client-controlled headers
remoteAddr := c.Request.RemoteAddr remoteAddr := c.Request.RemoteAddr
@@ -77,23 +81,58 @@ func noCORSMiddleware() gin.HandlerFunc {
} }
} }
// managementAvailabilityMiddleware short-circuits management routes when the upstream
// proxy is disabled, preventing noisy localhost warnings and accidental exposure.
func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if m.getProxy() == nil {
logging.SkipGinRequestLogging(c)
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{
"error": "amp upstream proxy not available",
})
return
}
c.Next()
}
}
// registerManagementRoutes registers Amp management proxy routes // registerManagementRoutes registers Amp management proxy routes
// These routes proxy through to the Amp control plane for OAuth, user management, etc. // These routes proxy through to the Amp control plane for OAuth, user management, etc.
// If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1. // Uses dynamic middleware and proxy getter for hot-reload support.
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) { func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler) {
ampAPI := engine.Group("/api") ampAPI := engine.Group("/api")
// Always disable CORS for management routes to prevent browser-based attacks // Always disable CORS for management routes to prevent browser-based attacks
ampAPI.Use(noCORSMiddleware()) ampAPI.Use(m.managementAvailabilityMiddleware(), noCORSMiddleware())
// Apply localhost-only restriction if configured // Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
if restrictToLocalhost { ampAPI.Use(m.localhostOnlyMiddleware())
ampAPI.Use(localhostOnlyMiddleware())
log.Info("amp management routes restricted to localhost only (CORS disabled)") if !m.IsRestrictedToLocalhost() {
} else {
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!") log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
} }
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
proxyHandler := func(c *gin.Context) {
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
defer func() {
if rec := recover(); rec != nil {
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
// Upstream already wrote the status (often 404) before the client/stream ended.
return
}
panic(rec)
}
}()
proxy := m.getProxy()
if proxy == nil {
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
return
}
proxy.ServeHTTP(c.Writer, c.Request)
}
// Management routes - these are proxied directly to Amp upstream // Management routes - these are proxied directly to Amp upstream
ampAPI.Any("/internal", proxyHandler) ampAPI.Any("/internal", proxyHandler)
ampAPI.Any("/internal/*path", proxyHandler) ampAPI.Any("/internal/*path", proxyHandler)
@@ -114,11 +153,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
ampAPI.Any("/tab/*path", proxyHandler) ampAPI.Any("/tab/*path", proxyHandler)
// Root-level routes that AMP CLI expects without /api prefix // Root-level routes that AMP CLI expects without /api prefix
// These need the same security middleware as the /api/* routes // These need the same security middleware as the /api/* routes (dynamic for hot-reload)
rootMiddleware := []gin.HandlerFunc{noCORSMiddleware()} rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
if restrictToLocalhost { engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
rootMiddleware = append(rootMiddleware, localhostOnlyMiddleware())
}
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...) engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
// Root-level auth routes for CLI login flow // Root-level auth routes for CLI login flow
@@ -134,7 +171,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler) geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
geminiBridge := createGeminiBridgeHandler(geminiHandlers) geminiBridge := createGeminiBridgeHandler(geminiHandlers)
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy { geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
return m.proxy return m.getProxy()
}) })
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
@@ -177,10 +214,10 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler) openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
// Create fallback handler wrapper that forwards to ampcode.com when provider not found // Create fallback handler wrapper that forwards to ampcode.com when provider not found
// Uses lazy evaluation to access proxy (which is created after routes are registered) // Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
// Also includes model mapping support for routing unavailable models to alternatives // Also includes model mapping support for routing unavailable models to alternatives
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.proxy return m.getProxy()
}, m.modelMapper) }, m.modelMapper)
// Provider-specific routes under /api/provider/:provider // Provider-specific routes under /api/provider/:provider

View File

@@ -13,16 +13,26 @@ func TestRegisterManagementRoutes(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
// Spy to track if proxy handler was called // Create module with proxy for testing
proxyCalled := false m := &AmpModule{
proxyHandler := func(c *gin.Context) { restrictToLocalhost: false, // disable localhost restriction for tests
proxyCalled = true
c.String(200, "proxied")
} }
m := &AmpModule{} // Create a mock proxy that tracks calls
proxyCalled := false
mockProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxyCalled = true
w.WriteHeader(200)
w.Write([]byte("proxied"))
}))
defer mockProxy.Close()
// Create real proxy to mock server
proxy, _ := createReverseProxy(mockProxy.URL, NewStaticSecretSource(""))
m.setProxy(proxy)
base := &handlers.BaseAPIHandler{} base := &handlers.BaseAPIHandler{}
m.registerManagementRoutes(r, base, proxyHandler, false) // false = don't restrict to localhost in tests m.registerManagementRoutes(r, base)
managementPaths := []struct { managementPaths := []struct {
path string path string
@@ -37,13 +47,14 @@ func TestRegisterManagementRoutes(t *testing.T) {
{"/api/meta", http.MethodGet}, {"/api/meta", http.MethodGet},
{"/api/telemetry", http.MethodGet}, {"/api/telemetry", http.MethodGet},
{"/api/threads", http.MethodGet}, {"/api/threads", http.MethodGet},
{"/threads/", http.MethodGet},
{"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix) {"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix)
{"/api/otel", http.MethodGet}, {"/api/otel", http.MethodGet},
{"/api/tab", http.MethodGet}, {"/api/tab", http.MethodGet},
{"/api/tab/some/path", http.MethodGet}, {"/api/tab/some/path", http.MethodGet},
{"/auth", http.MethodGet}, // Root-level auth route {"/auth", http.MethodGet}, // Root-level auth route
{"/auth/cli-login", http.MethodGet}, // CLI login flow {"/auth/cli-login", http.MethodGet}, // CLI login flow
{"/auth/callback", http.MethodGet}, // OAuth callback {"/auth/callback", http.MethodGet}, // OAuth callback
// Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST // Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST
{"/api/provider/google/v1beta1/models", http.MethodGet}, {"/api/provider/google/v1beta1/models", http.MethodGet},
{"/api/provider/google/v1beta1/models", http.MethodPost}, {"/api/provider/google/v1beta1/models", http.MethodPost},
@@ -231,8 +242,13 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
// Apply localhost-only middleware // Create module with localhost restriction enabled
r.Use(localhostOnlyMiddleware()) m := &AmpModule{
restrictToLocalhost: true,
}
// Apply dynamic localhost-only middleware
r.Use(m.localhostOnlyMiddleware())
r.GET("/test", func(c *gin.Context) { r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok") c.String(http.StatusOK, "ok")
}) })
@@ -305,3 +321,53 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
}) })
} }
} }
func TestLocalhostOnlyMiddleware_HotReload(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// Create module with localhost restriction initially enabled
m := &AmpModule{
restrictToLocalhost: true,
}
// Apply dynamic localhost-only middleware
r.Use(m.localhostOnlyMiddleware())
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
// Test 1: Remote IP should be blocked when restriction is enabled
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.100:12345"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Expected 403 when restriction enabled, got %d", w.Code)
}
// Test 2: Hot-reload - disable restriction
m.setRestrictToLocalhost(false)
req = httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.100:12345"
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected 200 after disabling restriction, got %d", w.Code)
}
// Test 3: Hot-reload - re-enable restriction
m.setRestrictToLocalhost(true)
req = httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.100:12345"
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Expected 403 after re-enabling restriction, got %d", w.Code)
}
}

View File

@@ -139,6 +139,17 @@ func (s *MultiSourceSecret) InvalidateCache() {
s.cache = nil s.cache = nil
} }
// UpdateExplicitKey refreshes the config-provided key and clears cache.
func (s *MultiSourceSecret) UpdateExplicitKey(key string) {
if s == nil {
return
}
s.mu.Lock()
s.explicitKey = strings.TrimSpace(key)
s.cache = nil
s.mu.Unlock()
}
// StaticSecretSource returns a fixed API key (for testing) // StaticSecretSource returns a fixed API key (for testing)
type StaticSecretSource struct { type StaticSecretSource struct {
key string key string

View File

@@ -472,6 +472,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/config", s.mgmt.GetConfig) mgmt.GET("/config", s.mgmt.GetConfig)
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML) mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML) mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
mgmt.GET("/latest-version", s.mgmt.GetLatestVersion)
mgmt.GET("/debug", s.mgmt.GetDebug) mgmt.GET("/debug", s.mgmt.GetDebug)
mgmt.PUT("/debug", s.mgmt.PutDebug) mgmt.PUT("/debug", s.mgmt.PutDebug)

View File

@@ -85,6 +85,9 @@ type InlineData struct {
// FunctionCall represents a tool call requested by the model. // FunctionCall represents a tool call requested by the model.
// It includes the function name and its arguments that the model wants to execute. // It includes the function name and its arguments that the model wants to execute.
type FunctionCall struct { type FunctionCall struct {
// ID is the identifier of the function to be called.
ID string `json:"id,omitempty"`
// Name is the identifier of the function to be called. // Name is the identifier of the function to be called.
Name string `json:"name"` Name string `json:"name"`
@@ -95,6 +98,9 @@ type FunctionCall struct {
// FunctionResponse represents the result of a tool execution. // FunctionResponse represents the result of a tool execution.
// This is sent back to the model after a tool call has been processed. // This is sent back to the model after a tool call has been processed.
type FunctionResponse struct { type FunctionResponse struct {
// ID is the identifier of the function to be called.
ID string `json:"id,omitempty"`
// Name is the identifier of the function that was called. // Name is the identifier of the function that was called.
Name string `json:"name"` Name string `json:"name"`

View File

@@ -14,6 +14,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const skipGinLogKey = "__gin_skip_request_logging__"
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses // GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
// using logrus. It captures request details including method, path, status code, latency, // using logrus. It captures request details including method, path, status code, latency,
// client IP, and any error messages, formatting them in a Gin-style log format. // client IP, and any error messages, formatting them in a Gin-style log format.
@@ -28,6 +30,10 @@ func GinLogrusLogger() gin.HandlerFunc {
c.Next() c.Next()
if shouldSkipGinRequestLogging(c) {
return
}
if raw != "" { if raw != "" {
path = path + "?" + raw path = path + "?" + raw
} }
@@ -77,3 +83,24 @@ func GinLogrusRecovery() gin.HandlerFunc {
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
}) })
} }
// SkipGinRequestLogging marks the provided Gin context so that GinLogrusLogger
// will skip emitting a log line for the associated request.
func SkipGinRequestLogging(c *gin.Context) {
if c == nil {
return
}
c.Set(skipGinLogKey, true)
}
func shouldSkipGinRequestLogging(c *gin.Context) bool {
if c == nil {
return false
}
val, exists := c.Get(skipGinLogKey)
if !exists {
return false
}
flag, ok := val.(bool)
return ok && flag
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -508,8 +509,46 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
requestURL.WriteString(url.QueryEscape(alt)) requestURL.WriteString(url.QueryEscape(alt))
} }
payload = geminiToAntigravity(modelName, payload) // Extract project_id from auth metadata if available
projectID := ""
if auth != nil && auth.Metadata != nil {
if pid, ok := auth.Metadata["project_id"].(string); ok {
projectID = strings.TrimSpace(pid)
}
}
payload = geminiToAntigravity(modelName, payload, projectID)
payload, _ = sjson.SetBytes(payload, "model", alias2ModelName(modelName)) payload, _ = sjson.SetBytes(payload, "model", alias2ModelName(modelName))
if strings.Contains(modelName, "claude") {
strJSON := string(payload)
paths := make([]string, 0)
util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
for _, p := range paths {
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
}
strJSON = util.DeleteKey(strJSON, "$schema")
strJSON = util.DeleteKey(strJSON, "maxItems")
strJSON = util.DeleteKey(strJSON, "minItems")
strJSON = util.DeleteKey(strJSON, "minLength")
strJSON = util.DeleteKey(strJSON, "maxLength")
strJSON = util.DeleteKey(strJSON, "exclusiveMinimum")
paths = make([]string, 0)
util.Walk(gjson.Parse(strJSON), "", "anyOf", &paths)
for _, p := range paths {
anyOf := gjson.Get(strJSON, p)
if anyOf.IsArray() {
anyOfItems := anyOf.Array()
if len(anyOfItems) > 0 {
strJSON, _ = sjson.SetRaw(strJSON, p[:len(p)-len(".anyOf")], anyOfItems[0].Raw)
}
}
}
payload = []byte(strJSON)
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload)) httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
if errReq != nil { if errReq != nil {
return nil, errReq return nil, errReq
@@ -644,9 +683,9 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
return []string{base} return []string{base}
} }
return []string{ return []string{
antigravityBaseURLProd,
antigravityBaseURLDaily, antigravityBaseURLDaily,
antigravityBaseURLAutopush, antigravityBaseURLAutopush,
antigravityBaseURLProd,
} }
} }
@@ -670,10 +709,16 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
return "" return ""
} }
func geminiToAntigravity(modelName string, payload []byte) []byte { func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
template, _ := sjson.Set(string(payload), "model", modelName) template, _ := sjson.Set(string(payload), "model", modelName)
template, _ = sjson.Set(template, "userAgent", "antigravity") template, _ = sjson.Set(template, "userAgent", "antigravity")
template, _ = sjson.Set(template, "project", generateProjectID())
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
if projectID != "" {
template, _ = sjson.Set(template, "project", projectID)
} else {
template, _ = sjson.Set(template, "project", generateProjectID())
}
template, _ = sjson.Set(template, "requestId", generateRequestID()) template, _ = sjson.Set(template, "requestId", generateRequestID())
template, _ = sjson.Set(template, "request.sessionId", generateSessionID()) template, _ = sjson.Set(template, "request.sessionId", generateSessionID())

View File

@@ -89,10 +89,11 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" { } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
functionName := contentResult.Get("name").String() functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String() functionArgs := contentResult.Get("input").String()
functionID := contentResult.Get("id").String()
var args map[string]any var args map[string]any
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil { if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
clientContent.Parts = append(clientContent.Parts, client.Part{ clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{Name: functionName, Args: args}, FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
ThoughtSignature: geminiCLIClaudeThoughtSignature, ThoughtSignature: geminiCLIClaudeThoughtSignature,
}) })
} }
@@ -105,7 +106,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-") funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
} }
responseData := contentResult.Get("content").Raw responseData := contentResult.Get("content").Raw
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}} functionResponse := client.FunctionResponse{ID: toolCallID, Name: funcName, Response: map[string]interface{}{"result": responseData}}
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse}) clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
} }
} }

View File

@@ -141,35 +141,38 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
params.ResponseType = 2 // Set state to thinking params.ResponseType = 2 // Set state to thinking
} }
} else { } else {
// Process regular text content (user-visible output) finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason")
// Continue existing text block if already in content state if partTextResult.String() != "" || !finishReasonResult.Exists() {
if params.ResponseType == 1 { // Process regular text content (user-visible output)
output = output + "event: content_block_delta\n" // Continue existing text block if already in content state
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String()) if params.ResponseType == 1 {
output = output + fmt.Sprintf("data: %s\n\n\n", data) output = output + "event: content_block_delta\n"
} else { data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
// Transition from another state to text content output = output + fmt.Sprintf("data: %s\n\n\n", data)
// First, close any existing content block } else {
if params.ResponseType != 0 { // Transition from another state to text content
if params.ResponseType == 2 { // First, close any existing content block
// output = output + "event: content_block_delta\n" if params.ResponseType != 0 {
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex) if params.ResponseType == 2 {
// output = output + "\n\n\n" // output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++
} }
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++
}
// Start a new text content block // Start a new text content block
output = output + "event: content_block_start\n" output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex) output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
output = output + "\n\n\n" output = output + "\n\n\n"
output = output + "event: content_block_delta\n" output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String()) data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data) output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.ResponseType = 1 // Set state to content params.ResponseType = 1 // Set state to content
}
} }
} }
} else if functionCallResult.Exists() { } else if functionCallResult.Exists() {

View File

@@ -251,6 +251,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
fid := tc.Get("id").String() fid := tc.Get("id").String()
fname := tc.Get("function.name").String() fname := tc.Get("function.name").String()
fargs := tc.Get("function.arguments").String() fargs := tc.Get("function.arguments").String()
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs)) node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
@@ -266,6 +267,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
pp := 0 pp := 0
for _, fid := range fIDs { for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok { if name, ok := tcID2Name[fid]; ok {
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name) toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
resp := toolResponses[fid] resp := toolResponses[fid]
if resp == "" { if resp == "" {

View File

@@ -327,7 +327,7 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string {
func mustMarshalJSON(v interface{}) string { func mustMarshalJSON(v interface{}) string {
data, err := json.Marshal(v) data, err := json.Marshal(v)
if err != nil { if err != nil {
panic(err) return ""
} }
return string(data) return string(data)
} }

View File

@@ -249,6 +249,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
functionCall := `{"functionCall":{"name":"","args":{}}}` functionCall := `{"functionCall":{"name":"","args":{}}}`
functionCall, _ = sjson.Set(functionCall, "functionCall.name", name) functionCall, _ = sjson.Set(functionCall, "functionCall.name", name)
functionCall, _ = sjson.Set(functionCall, "thoughtSignature", geminiResponsesThoughtSignature) functionCall, _ = sjson.Set(functionCall, "thoughtSignature", geminiResponsesThoughtSignature)
functionCall, _ = sjson.Set(functionCall, "functionCall.id", item.Get("call_id").String())
// Parse arguments JSON string and set as args object // Parse arguments JSON string and set as args object
if arguments != "" { if arguments != "" {
@@ -285,6 +286,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
} }
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.name", functionName) functionResponse, _ = sjson.Set(functionResponse, "functionResponse.name", functionName)
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.id", callID)
// Set the raw JSON output directly (preserves string encoding) // Set the raw JSON output directly (preserves string encoding)
if outputRaw != "" && outputRaw != "null" { if outputRaw != "" && outputRaw != "null" {

View File

@@ -79,6 +79,15 @@ func RenameKey(jsonStr, oldKeyPath, newKeyPath string) (string, error) {
return finalJson, nil return finalJson, nil
} }
func DeleteKey(jsonStr, keyName string) string {
paths := make([]string, 0)
Walk(gjson.Parse(jsonStr), "", keyName, &paths)
for _, p := range paths {
jsonStr, _ = sjson.Delete(jsonStr, p)
}
return jsonStr
}
// FixJSON converts non-standard JSON that uses single quotes for strings into // FixJSON converts non-standard JSON that uses single quotes for strings into
// RFC 8259-compliant JSON by converting those single-quoted strings to // RFC 8259-compliant JSON by converting those single-quoted strings to
// double-quoted strings with proper escaping. // double-quoted strings with proper escaping.

View File

@@ -570,6 +570,35 @@ func summarizeExcludedModels(list []string) excludedModelsSummary {
} }
} }
type ampModelMappingsSummary struct {
hash string
count int
}
func summarizeAmpModelMappings(mappings []config.AmpModelMapping) ampModelMappingsSummary {
if len(mappings) == 0 {
return ampModelMappingsSummary{}
}
entries := make([]string, 0, len(mappings))
for _, mapping := range mappings {
from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To)
if from == "" && to == "" {
continue
}
entries = append(entries, from+"->"+to)
}
if len(entries) == 0 {
return ampModelMappingsSummary{}
}
sort.Strings(entries)
sum := sha256.Sum256([]byte(strings.Join(entries, "|")))
return ampModelMappingsSummary{
hash: hex.EncodeToString(sum[:]),
count: len(entries),
}
}
func summarizeOAuthExcludedModels(entries map[string][]string) map[string]excludedModelsSummary { func summarizeOAuthExcludedModels(entries map[string][]string) map[string]excludedModelsSummary {
if len(entries) == 0 { if len(entries) == 0 {
return nil return nil
@@ -1762,6 +1791,31 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
} }
} }
// AmpCode settings (redacted where needed)
oldAmpURL := strings.TrimSpace(oldCfg.AmpCode.UpstreamURL)
newAmpURL := strings.TrimSpace(newCfg.AmpCode.UpstreamURL)
if oldAmpURL != newAmpURL {
changes = append(changes, fmt.Sprintf("ampcode.upstream-url: %s -> %s", oldAmpURL, newAmpURL))
}
oldAmpKey := strings.TrimSpace(oldCfg.AmpCode.UpstreamAPIKey)
newAmpKey := strings.TrimSpace(newCfg.AmpCode.UpstreamAPIKey)
switch {
case oldAmpKey == "" && newAmpKey != "":
changes = append(changes, "ampcode.upstream-api-key: added")
case oldAmpKey != "" && newAmpKey == "":
changes = append(changes, "ampcode.upstream-api-key: removed")
case oldAmpKey != newAmpKey:
changes = append(changes, "ampcode.upstream-api-key: updated")
}
if oldCfg.AmpCode.RestrictManagementToLocalhost != newCfg.AmpCode.RestrictManagementToLocalhost {
changes = append(changes, fmt.Sprintf("ampcode.restrict-management-to-localhost: %t -> %t", oldCfg.AmpCode.RestrictManagementToLocalhost, newCfg.AmpCode.RestrictManagementToLocalhost))
}
oldMappings := summarizeAmpModelMappings(oldCfg.AmpCode.ModelMappings)
newMappings := summarizeAmpModelMappings(newCfg.AmpCode.ModelMappings)
if oldMappings.hash != newMappings.hash {
changes = append(changes, fmt.Sprintf("ampcode.model-mappings: updated (%d -> %d entries)", oldMappings.count, newMappings.count))
}
if entries, _ := diffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 { if entries, _ := diffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
changes = append(changes, entries...) changes = append(changes, entries...)
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@@ -127,6 +128,18 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
} }
} }
// Fetch project ID via loadCodeAssist (same approach as Gemini CLI)
projectID := ""
if tokenResp.AccessToken != "" {
fetchedProjectID, errProject := fetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
if errProject != nil {
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
} else {
projectID = fetchedProjectID
log.Infof("antigravity: obtained project ID %s", projectID)
}
}
now := time.Now() now := time.Now()
metadata := map[string]any{ metadata := map[string]any{
"type": "antigravity", "type": "antigravity",
@@ -139,6 +152,9 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
if email != "" { if email != "" {
metadata["email"] = email metadata["email"] = email
} }
if projectID != "" {
metadata["project_id"] = projectID
}
fileName := sanitizeAntigravityFileName(email) fileName := sanitizeAntigravityFileName(email)
label := email label := email
@@ -147,6 +163,9 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
} }
fmt.Println("Antigravity authentication successful") fmt.Println("Antigravity authentication successful")
if projectID != "" {
fmt.Printf("Using GCP project: %s\n", projectID)
}
return &coreauth.Auth{ return &coreauth.Auth{
ID: fileName, ID: fileName,
Provider: "antigravity", Provider: "antigravity",
@@ -291,3 +310,84 @@ func sanitizeAntigravityFileName(email string) string {
replacer := strings.NewReplacer("@", "_", ".", "_") replacer := strings.NewReplacer("@", "_", ".", "_")
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email)) return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
} }
// Antigravity API constants for project discovery
const (
antigravityAPIEndpoint = "https://cloudcode-pa.googleapis.com"
antigravityAPIVersion = "v1internal"
antigravityAPIUserAgent = "google-api-nodejs-client/9.15.1"
antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
)
// fetchAntigravityProjectID retrieves the project ID for the authenticated user via loadCodeAssist.
// This uses the same approach as Gemini CLI to get the cloudaicompanionProject.
func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
// Call loadCodeAssist to get the project
loadReqBody := map[string]any{
"metadata": map[string]string{
"ideType": "IDE_UNSPECIFIED",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(loadReqBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", antigravityAPIEndpoint, antigravityAPIVersion)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", antigravityAPIUserAgent)
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
req.Header.Set("Client-Metadata", antigravityClientMetadata)
resp, errDo := httpClient.Do(req)
if errDo != nil {
return "", fmt.Errorf("execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(resp.Body)
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var loadResp map[string]any
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
// Extract projectID from response
projectID := ""
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
projectID = strings.TrimSpace(id)
}
if projectID == "" {
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
if id, okID := projectMap["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
if projectID == "" {
return "", fmt.Errorf("no cloudaicompanionProject in response")
}
return projectID, nil
}