mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-09 15:25:17 +00:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
608cd8ee3d | ||
|
|
9f41894573 | ||
|
|
bb45fee1cf | ||
|
|
af00304b0c | ||
|
|
5c3a013cd1 | ||
|
|
ab9e9442ec | ||
|
|
6ad188921c | ||
|
|
15ed98d6a9 | ||
|
|
a283545b6b | ||
|
|
3efbd865a8 | ||
|
|
aee659fb66 | ||
|
|
5aa386d8b9 | ||
|
|
0adc0ee6aa | ||
|
|
92f13fc316 | ||
|
|
05cfa16e5f | ||
|
|
93a6e2d920 | ||
|
|
6b49580716 | ||
|
|
de77903915 | ||
|
|
56ed0d8d90 |
@@ -1,3 +1,7 @@
|
||||
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
|
||||
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
|
||||
host: ""
|
||||
|
||||
# Server port
|
||||
port: 8317
|
||||
|
||||
@@ -149,6 +153,8 @@ ws-auth: false
|
||||
# upstream-api-key: ""
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
|
||||
# restrict-management-to-localhost: true
|
||||
# # Force model mappings to run before checking local API keys (default: false)
|
||||
# force-model-mappings: false
|
||||
# # Amp Model Mappings
|
||||
# # Route unavailable Amp models to alternative models available in your local proxy.
|
||||
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
|
||||
|
||||
@@ -241,11 +241,3 @@ func (h *Handler) DeleteProxyURL(c *gin.Context) {
|
||||
h.cfg.ProxyURL = ""
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// Prioritize Model Mappings (for Amp CLI)
|
||||
func (h *Handler) GetPrioritizeModelMappings(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"prioritize-model-mappings": h.cfg.AmpCode.PrioritizeModelMappings})
|
||||
}
|
||||
func (h *Handler) PutPrioritizeModelMappings(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.PrioritizeModelMappings = v })
|
||||
}
|
||||
|
||||
@@ -706,3 +706,155 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
|
||||
}
|
||||
entry.Models = normalized
|
||||
}
|
||||
|
||||
// GetAmpCode returns the complete ampcode configuration.
|
||||
func (h *Handler) GetAmpCode(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(200, gin.H{"ampcode": config.AmpCode{}})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"ampcode": h.cfg.AmpCode})
|
||||
}
|
||||
|
||||
// GetAmpUpstreamURL returns the ampcode upstream URL.
|
||||
func (h *Handler) GetAmpUpstreamURL(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(200, gin.H{"upstream-url": ""})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"upstream-url": h.cfg.AmpCode.UpstreamURL})
|
||||
}
|
||||
|
||||
// PutAmpUpstreamURL updates the ampcode upstream URL.
|
||||
func (h *Handler) PutAmpUpstreamURL(c *gin.Context) {
|
||||
h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) })
|
||||
}
|
||||
|
||||
// DeleteAmpUpstreamURL clears the ampcode upstream URL.
|
||||
func (h *Handler) DeleteAmpUpstreamURL(c *gin.Context) {
|
||||
h.cfg.AmpCode.UpstreamURL = ""
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// GetAmpUpstreamAPIKey returns the ampcode upstream API key.
|
||||
func (h *Handler) GetAmpUpstreamAPIKey(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(200, gin.H{"upstream-api-key": ""})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"upstream-api-key": h.cfg.AmpCode.UpstreamAPIKey})
|
||||
}
|
||||
|
||||
// PutAmpUpstreamAPIKey updates the ampcode upstream API key.
|
||||
func (h *Handler) PutAmpUpstreamAPIKey(c *gin.Context) {
|
||||
h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) })
|
||||
}
|
||||
|
||||
// DeleteAmpUpstreamAPIKey clears the ampcode upstream API key.
|
||||
func (h *Handler) DeleteAmpUpstreamAPIKey(c *gin.Context) {
|
||||
h.cfg.AmpCode.UpstreamAPIKey = ""
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// GetAmpRestrictManagementToLocalhost returns the localhost restriction setting.
|
||||
func (h *Handler) GetAmpRestrictManagementToLocalhost(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(200, gin.H{"restrict-management-to-localhost": true})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"restrict-management-to-localhost": h.cfg.AmpCode.RestrictManagementToLocalhost})
|
||||
}
|
||||
|
||||
// PutAmpRestrictManagementToLocalhost updates the localhost restriction setting.
|
||||
func (h *Handler) PutAmpRestrictManagementToLocalhost(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.RestrictManagementToLocalhost = v })
|
||||
}
|
||||
|
||||
// GetAmpModelMappings returns the ampcode model mappings.
|
||||
func (h *Handler) GetAmpModelMappings(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(200, gin.H{"model-mappings": []config.AmpModelMapping{}})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"model-mappings": h.cfg.AmpCode.ModelMappings})
|
||||
}
|
||||
|
||||
// PutAmpModelMappings replaces all ampcode model mappings.
|
||||
func (h *Handler) PutAmpModelMappings(c *gin.Context) {
|
||||
var body struct {
|
||||
Value []config.AmpModelMapping `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
h.cfg.AmpCode.ModelMappings = body.Value
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// PatchAmpModelMappings adds or updates model mappings.
|
||||
func (h *Handler) PatchAmpModelMappings(c *gin.Context) {
|
||||
var body struct {
|
||||
Value []config.AmpModelMapping `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
existing := make(map[string]int)
|
||||
for i, m := range h.cfg.AmpCode.ModelMappings {
|
||||
existing[strings.TrimSpace(m.From)] = i
|
||||
}
|
||||
|
||||
for _, newMapping := range body.Value {
|
||||
from := strings.TrimSpace(newMapping.From)
|
||||
if idx, ok := existing[from]; ok {
|
||||
h.cfg.AmpCode.ModelMappings[idx] = newMapping
|
||||
} else {
|
||||
h.cfg.AmpCode.ModelMappings = append(h.cfg.AmpCode.ModelMappings, newMapping)
|
||||
existing[from] = len(h.cfg.AmpCode.ModelMappings) - 1
|
||||
}
|
||||
}
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// DeleteAmpModelMappings removes specified model mappings by "from" field.
|
||||
func (h *Handler) DeleteAmpModelMappings(c *gin.Context) {
|
||||
var body struct {
|
||||
Value []string `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || len(body.Value) == 0 {
|
||||
h.cfg.AmpCode.ModelMappings = nil
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
|
||||
toRemove := make(map[string]bool)
|
||||
for _, from := range body.Value {
|
||||
toRemove[strings.TrimSpace(from)] = true
|
||||
}
|
||||
|
||||
newMappings := make([]config.AmpModelMapping, 0, len(h.cfg.AmpCode.ModelMappings))
|
||||
for _, m := range h.cfg.AmpCode.ModelMappings {
|
||||
if !toRemove[strings.TrimSpace(m.From)] {
|
||||
newMappings = append(newMappings, m)
|
||||
}
|
||||
}
|
||||
h.cfg.AmpCode.ModelMappings = newMappings
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// GetAmpForceModelMappings returns whether model mappings are forced.
|
||||
func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(200, gin.H{"force-model-mappings": false})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings})
|
||||
}
|
||||
|
||||
// PutAmpForceModelMappings updates the force model mappings setting.
|
||||
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
|
||||
}
|
||||
|
||||
@@ -240,16 +240,6 @@ func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) {
|
||||
Value *bool `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
var m map[string]any
|
||||
if err2 := c.ShouldBindJSON(&m); err2 == nil {
|
||||
for _, v := range m {
|
||||
if b, ok := v.(bool); ok {
|
||||
set(b)
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -232,7 +232,16 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
w.streamDone = nil
|
||||
}
|
||||
|
||||
// Write API Request and Response to the streaming log before closing
|
||||
if w.streamWriter != nil {
|
||||
apiRequest := w.extractAPIRequest(c)
|
||||
if len(apiRequest) > 0 {
|
||||
_ = w.streamWriter.WriteAPIRequest(apiRequest)
|
||||
}
|
||||
apiResponse := w.extractAPIResponse(c)
|
||||
if len(apiResponse) > 0 {
|
||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||
}
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
|
||||
@@ -100,14 +100,14 @@ func (m *AmpModule) Name() string {
|
||||
return "amp-routing"
|
||||
}
|
||||
|
||||
// getPrioritizeModelMappings returns whether model mappings should take precedence over local API keys
|
||||
func (m *AmpModule) getPrioritizeModelMappings() bool {
|
||||
// forceModelMappings returns whether model mappings should take precedence over local API keys
|
||||
func (m *AmpModule) forceModelMappings() bool {
|
||||
m.configMu.RLock()
|
||||
defer m.configMu.RUnlock()
|
||||
if m.lastConfig == nil {
|
||||
return false
|
||||
}
|
||||
return m.lastConfig.PrioritizeModelMappings
|
||||
return m.lastConfig.ForceModelMappings
|
||||
}
|
||||
|
||||
// Register sets up Amp routes if configured.
|
||||
|
||||
@@ -77,29 +77,29 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
|
||||
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
||||
// when the model's provider is not available in CLIProxyAPI
|
||||
type FallbackHandler struct {
|
||||
getProxy func() *httputil.ReverseProxy
|
||||
modelMapper ModelMapper
|
||||
getPrioritizeModelMappings func() bool
|
||||
getProxy func() *httputil.ReverseProxy
|
||||
modelMapper ModelMapper
|
||||
forceModelMappings func() bool
|
||||
}
|
||||
|
||||
// NewFallbackHandler creates a new fallback handler wrapper
|
||||
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
||||
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
||||
return &FallbackHandler{
|
||||
getProxy: getProxy,
|
||||
getPrioritizeModelMappings: func() bool { return false },
|
||||
getProxy: getProxy,
|
||||
forceModelMappings: func() bool { return false },
|
||||
}
|
||||
}
|
||||
|
||||
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
||||
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, getPrioritize func() bool) *FallbackHandler {
|
||||
if getPrioritize == nil {
|
||||
getPrioritize = func() bool { return false }
|
||||
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
|
||||
if forceModelMappings == nil {
|
||||
forceModelMappings = func() bool { return false }
|
||||
}
|
||||
return &FallbackHandler{
|
||||
getProxy: getProxy,
|
||||
modelMapper: mapper,
|
||||
getPrioritizeModelMappings: getPrioritize,
|
||||
getProxy: getProxy,
|
||||
modelMapper: mapper,
|
||||
forceModelMappings: forceModelMappings,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,11 +141,11 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
usedMapping := false
|
||||
var providers []string
|
||||
|
||||
// Check if model mappings should take priority over local API keys
|
||||
prioritizeMappings := fh.getPrioritizeModelMappings != nil && fh.getPrioritizeModelMappings()
|
||||
// Check if model mappings should be forced ahead of local API keys
|
||||
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
|
||||
|
||||
if prioritizeMappings {
|
||||
// PRIORITY MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||
if forceMappings {
|
||||
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||
// This allows users to route Amp requests to their preferred OAuth providers
|
||||
if fh.modelMapper != nil {
|
||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||
|
||||
@@ -171,7 +171,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
|
||||
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||
return m.getProxy()
|
||||
}, m.modelMapper, m.getPrioritizeModelMappings)
|
||||
}, m.modelMapper, m.forceModelMappings)
|
||||
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
||||
|
||||
// Route POST model calls through Gemini bridge with FallbackHandler.
|
||||
@@ -209,7 +209,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
// Also includes model mapping support for routing unavailable models to alternatives
|
||||
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||
return m.getProxy()
|
||||
}, m.modelMapper, m.getPrioritizeModelMappings)
|
||||
}, m.modelMapper, m.forceModelMappings)
|
||||
|
||||
// Provider-specific routes under /api/provider/:provider
|
||||
ampProviders := engine.Group("/api/provider")
|
||||
|
||||
@@ -300,7 +300,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
|
||||
// Create HTTP server
|
||||
s.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", cfg.Port),
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
Handler: engine,
|
||||
}
|
||||
|
||||
@@ -520,9 +520,25 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
|
||||
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
|
||||
|
||||
mgmt.GET("/prioritize-model-mappings", s.mgmt.GetPrioritizeModelMappings)
|
||||
mgmt.PUT("/prioritize-model-mappings", s.mgmt.PutPrioritizeModelMappings)
|
||||
mgmt.PATCH("/prioritize-model-mappings", s.mgmt.PutPrioritizeModelMappings)
|
||||
mgmt.GET("/ampcode", s.mgmt.GetAmpCode)
|
||||
mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL)
|
||||
mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
|
||||
mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
|
||||
mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL)
|
||||
mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey)
|
||||
mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
|
||||
mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
|
||||
mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey)
|
||||
mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost)
|
||||
mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
|
||||
mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
|
||||
mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings)
|
||||
mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings)
|
||||
mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings)
|
||||
mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings)
|
||||
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
|
||||
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||
|
||||
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
||||
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
||||
|
||||
@@ -20,6 +20,9 @@ import (
|
||||
// Config represents the application's configuration, loaded from a YAML file.
|
||||
type Config struct {
|
||||
config.SDKConfig `yaml:",inline"`
|
||||
// Host is the network host/interface on which the API server will bind.
|
||||
// Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access.
|
||||
Host string `yaml:"host" json:"-"`
|
||||
// Port is the network port on which the API server will listen.
|
||||
Port int `yaml:"port" json:"-"`
|
||||
|
||||
@@ -152,9 +155,9 @@ type AmpCode struct {
|
||||
// allow routing to an alternative model that IS available.
|
||||
ModelMappings []AmpModelMapping `yaml:"model-mappings" json:"model-mappings"`
|
||||
|
||||
// PrioritizeModelMappings when true, model mappings take precedence over local API keys.
|
||||
// ForceModelMappings when true, model mappings take precedence over local API keys.
|
||||
// When false (default), local API keys are used first if available.
|
||||
PrioritizeModelMappings bool `yaml:"prioritize-model-mappings" json:"prioritize-model-mappings"`
|
||||
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
|
||||
}
|
||||
|
||||
// PayloadConfig defines default and override parameter rules applied to provider payloads.
|
||||
@@ -353,6 +356,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Unmarshal the YAML data into the Config struct.
|
||||
var cfg Config
|
||||
// Set defaults before unmarshal so that absent keys keep defaults.
|
||||
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
|
||||
cfg.LoggingToFile = false
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
|
||||
@@ -84,6 +84,26 @@ type StreamingLogWriter interface {
|
||||
// - error: An error if writing fails, nil otherwise
|
||||
WriteStatus(status int, headers map[string][]string) error
|
||||
|
||||
// WriteAPIRequest writes the upstream API request details to the log.
|
||||
// This should be called before WriteStatus to maintain proper log ordering.
|
||||
//
|
||||
// Parameters:
|
||||
// - apiRequest: The API request data (typically includes URL, headers, body sent upstream)
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if writing fails, nil otherwise
|
||||
WriteAPIRequest(apiRequest []byte) error
|
||||
|
||||
// WriteAPIResponse writes the upstream API response details to the log.
|
||||
// This should be called after the streaming response is complete.
|
||||
//
|
||||
// Parameters:
|
||||
// - apiResponse: The API response data
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if writing fails, nil otherwise
|
||||
WriteAPIResponse(apiResponse []byte) error
|
||||
|
||||
// Close finalizes the log file and cleans up resources.
|
||||
//
|
||||
// Returns:
|
||||
@@ -248,10 +268,11 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
|
||||
|
||||
// Create streaming writer
|
||||
writer := &FileStreamingLogWriter{
|
||||
file: file,
|
||||
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
|
||||
closeChan: make(chan struct{}),
|
||||
errorChan: make(chan error, 1),
|
||||
file: file,
|
||||
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
|
||||
closeChan: make(chan struct{}),
|
||||
errorChan: make(chan error, 1),
|
||||
bufferedChunks: &bytes.Buffer{},
|
||||
}
|
||||
|
||||
// Start async writer goroutine
|
||||
@@ -628,11 +649,12 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
|
||||
|
||||
// FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs.
|
||||
// It handles asynchronous writing of streaming response chunks to a file.
|
||||
// All data is buffered and written in the correct order when Close is called.
|
||||
type FileStreamingLogWriter struct {
|
||||
// file is the file where log data is written.
|
||||
file *os.File
|
||||
|
||||
// chunkChan is a channel for receiving response chunks to write.
|
||||
// chunkChan is a channel for receiving response chunks to buffer.
|
||||
chunkChan chan []byte
|
||||
|
||||
// closeChan is a channel for signaling when the writer is closed.
|
||||
@@ -641,8 +663,23 @@ type FileStreamingLogWriter struct {
|
||||
// errorChan is a channel for reporting errors during writing.
|
||||
errorChan chan error
|
||||
|
||||
// statusWritten indicates whether the response status has been written.
|
||||
// bufferedChunks stores the response chunks in order.
|
||||
bufferedChunks *bytes.Buffer
|
||||
|
||||
// responseStatus stores the HTTP status code.
|
||||
responseStatus int
|
||||
|
||||
// statusWritten indicates whether a non-zero status was recorded.
|
||||
statusWritten bool
|
||||
|
||||
// responseHeaders stores the response headers.
|
||||
responseHeaders map[string][]string
|
||||
|
||||
// apiRequest stores the upstream API request data.
|
||||
apiRequest []byte
|
||||
|
||||
// apiResponse stores the upstream API response data.
|
||||
apiResponse []byte
|
||||
}
|
||||
|
||||
// WriteChunkAsync writes a response chunk asynchronously (non-blocking).
|
||||
@@ -666,39 +703,65 @@ func (w *FileStreamingLogWriter) WriteChunkAsync(chunk []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
// WriteStatus writes the response status and headers to the log.
|
||||
// WriteStatus buffers the response status and headers for later writing.
|
||||
//
|
||||
// Parameters:
|
||||
// - status: The response status code
|
||||
// - headers: The response headers
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if writing fails, nil otherwise
|
||||
// - error: Always returns nil (buffering cannot fail)
|
||||
func (w *FileStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error {
|
||||
if w.file == nil || w.statusWritten {
|
||||
if status == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var content strings.Builder
|
||||
content.WriteString("========================================\n")
|
||||
content.WriteString("=== RESPONSE ===\n")
|
||||
content.WriteString(fmt.Sprintf("Status: %d\n", status))
|
||||
|
||||
for key, values := range headers {
|
||||
for _, value := range values {
|
||||
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
|
||||
w.responseStatus = status
|
||||
if headers != nil {
|
||||
w.responseHeaders = make(map[string][]string, len(headers))
|
||||
for key, values := range headers {
|
||||
headerValues := make([]string, len(values))
|
||||
copy(headerValues, values)
|
||||
w.responseHeaders[key] = headerValues
|
||||
}
|
||||
}
|
||||
content.WriteString("\n")
|
||||
w.statusWritten = true
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := w.file.WriteString(content.String())
|
||||
if err == nil {
|
||||
w.statusWritten = true
|
||||
// WriteAPIRequest buffers the upstream API request details for later writing.
|
||||
//
|
||||
// Parameters:
|
||||
// - apiRequest: The API request data (typically includes URL, headers, body sent upstream)
|
||||
//
|
||||
// Returns:
|
||||
// - error: Always returns nil (buffering cannot fail)
|
||||
func (w *FileStreamingLogWriter) WriteAPIRequest(apiRequest []byte) error {
|
||||
if len(apiRequest) == 0 {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
w.apiRequest = bytes.Clone(apiRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteAPIResponse buffers the upstream API response details for later writing.
|
||||
//
|
||||
// Parameters:
|
||||
// - apiResponse: The API response data
|
||||
//
|
||||
// Returns:
|
||||
// - error: Always returns nil (buffering cannot fail)
|
||||
func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
|
||||
if len(apiResponse) == 0 {
|
||||
return nil
|
||||
}
|
||||
w.apiResponse = bytes.Clone(apiResponse)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close finalizes the log file and cleans up resources.
|
||||
// It writes all buffered data to the file in the correct order:
|
||||
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if closing fails, nil otherwise
|
||||
@@ -707,27 +770,84 @@ func (w *FileStreamingLogWriter) Close() error {
|
||||
close(w.chunkChan)
|
||||
}
|
||||
|
||||
// Wait for async writer to finish
|
||||
// Wait for async writer to finish buffering chunks
|
||||
if w.closeChan != nil {
|
||||
<-w.closeChan
|
||||
w.chunkChan = nil
|
||||
}
|
||||
|
||||
if w.file != nil {
|
||||
return w.file.Close()
|
||||
if w.file == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
// Write all content in the correct order
|
||||
var content strings.Builder
|
||||
|
||||
// 1. Write API REQUEST section
|
||||
if len(w.apiRequest) > 0 {
|
||||
if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) {
|
||||
content.Write(w.apiRequest)
|
||||
if !bytes.HasSuffix(w.apiRequest, []byte("\n")) {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
content.WriteString("=== API REQUEST ===\n")
|
||||
content.Write(w.apiRequest)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
// 2. Write API RESPONSE section
|
||||
if len(w.apiResponse) > 0 {
|
||||
if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) {
|
||||
content.Write(w.apiResponse)
|
||||
if !bytes.HasSuffix(w.apiResponse, []byte("\n")) {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
content.WriteString("=== API RESPONSE ===\n")
|
||||
content.Write(w.apiResponse)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
// 3. Write RESPONSE section (status, headers, buffered chunks)
|
||||
content.WriteString("=== RESPONSE ===\n")
|
||||
if w.statusWritten {
|
||||
content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus))
|
||||
}
|
||||
|
||||
for key, values := range w.responseHeaders {
|
||||
for _, value := range values {
|
||||
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
|
||||
}
|
||||
}
|
||||
content.WriteString("\n")
|
||||
|
||||
// Write buffered response body chunks
|
||||
if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 {
|
||||
content.Write(w.bufferedChunks.Bytes())
|
||||
}
|
||||
|
||||
// Write the complete content to file
|
||||
if _, err := w.file.WriteString(content.String()); err != nil {
|
||||
_ = w.file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return w.file.Close()
|
||||
}
|
||||
|
||||
// asyncWriter runs in a goroutine to handle async chunk writing.
|
||||
// It continuously reads chunks from the channel and writes them to the file.
|
||||
// asyncWriter runs in a goroutine to buffer chunks from the channel.
|
||||
// It continuously reads chunks from the channel and buffers them for later writing.
|
||||
func (w *FileStreamingLogWriter) asyncWriter() {
|
||||
defer close(w.closeChan)
|
||||
|
||||
for chunk := range w.chunkChan {
|
||||
if w.file != nil {
|
||||
_, _ = w.file.Write(chunk)
|
||||
if w.bufferedChunks != nil {
|
||||
w.bufferedChunks.Write(chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -754,6 +874,28 @@ func (w *NoOpStreamingLogWriter) WriteStatus(_ int, _ map[string][]string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteAPIRequest is a no-op implementation that does nothing and always returns nil.
|
||||
//
|
||||
// Parameters:
|
||||
// - apiRequest: The API request data (ignored)
|
||||
//
|
||||
// Returns:
|
||||
// - error: Always returns nil
|
||||
func (w *NoOpStreamingLogWriter) WriteAPIRequest(_ []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteAPIResponse is a no-op implementation that does nothing and always returns nil.
|
||||
//
|
||||
// Parameters:
|
||||
// - apiResponse: The API response data (ignored)
|
||||
//
|
||||
// Returns:
|
||||
// - error: Always returns nil
|
||||
func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close is a no-op implementation that does nothing and always returns nil.
|
||||
//
|
||||
// Returns:
|
||||
|
||||
@@ -943,18 +943,6 @@ func GetQwenModels() []*ModelInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// GetAntigravityThinkingConfig returns the Thinking configuration for antigravity models.
|
||||
// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup.
|
||||
func GetAntigravityThinkingConfig() map[string]*ThinkingSupport {
|
||||
return map[string]*ThinkingSupport{
|
||||
"gemini-2.5-flash": {Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
||||
"gemini-2.5-flash-lite": {Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
||||
"gemini-3-pro-preview": {Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
"gemini-claude-sonnet-4-5-thinking": {Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
"gemini-claude-opus-4-5-thinking": {Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
}
|
||||
}
|
||||
|
||||
// GetIFlowModels returns supported models for iFlow OAuth accounts.
|
||||
func GetIFlowModels() []*ModelInfo {
|
||||
entries := []struct {
|
||||
@@ -998,6 +986,25 @@ func GetIFlowModels() []*ModelInfo {
|
||||
return models
|
||||
}
|
||||
|
||||
// AntigravityModelConfig captures static antigravity model overrides, including
|
||||
// Thinking budget limits and provider max completion tokens.
|
||||
type AntigravityModelConfig struct {
|
||||
Thinking *ThinkingSupport
|
||||
MaxCompletionTokens int
|
||||
}
|
||||
|
||||
// GetAntigravityModelConfig returns static configuration for antigravity models.
|
||||
// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup.
|
||||
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
return map[string]*AntigravityModelConfig{
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
|
||||
"gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
}
|
||||
}
|
||||
|
||||
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
|
||||
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
||||
func GetGitHubCopilotModels() []*ModelInfo {
|
||||
|
||||
@@ -81,6 +81,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -174,6 +175,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -370,7 +372,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
thinkingConfig := registry.GetAntigravityThinkingConfig()
|
||||
modelConfig := registry.GetAntigravityModelConfig()
|
||||
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
||||
for originalName := range result.Map() {
|
||||
aliasName := modelName2Alias(originalName)
|
||||
@@ -387,8 +389,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
Type: antigravityAuthType,
|
||||
}
|
||||
// Look up Thinking support from static config using alias name
|
||||
if thinking, ok := thinkingConfig[aliasName]; ok {
|
||||
modelInfo.Thinking = thinking
|
||||
if cfg, ok := modelConfig[aliasName]; ok {
|
||||
if cfg.Thinking != nil {
|
||||
modelInfo.Thinking = cfg.Thinking
|
||||
}
|
||||
if cfg.MaxCompletionTokens > 0 {
|
||||
modelInfo.MaxCompletionTokens = cfg.MaxCompletionTokens
|
||||
}
|
||||
}
|
||||
models = append(models, modelInfo)
|
||||
}
|
||||
@@ -533,6 +540,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
strJSON = util.DeleteKey(strJSON, "minLength")
|
||||
strJSON = util.DeleteKey(strJSON, "maxLength")
|
||||
strJSON = util.DeleteKey(strJSON, "exclusiveMinimum")
|
||||
strJSON = util.DeleteKey(strJSON, "exclusiveMaximum")
|
||||
|
||||
paths = make([]string, 0)
|
||||
util.Walk(gjson.Parse(strJSON), "", "anyOf", &paths)
|
||||
@@ -812,3 +820,53 @@ func alias2ModelName(modelName string) string {
|
||||
return modelName
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeAntigravityThinking clamps or removes thinking config based on model support.
|
||||
// For Claude models, it additionally ensures thinking budget < max_tokens.
|
||||
func normalizeAntigravityThinking(model string, payload []byte) []byte {
|
||||
payload = util.StripThinkingConfigIfUnsupported(model, payload)
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
return payload
|
||||
}
|
||||
budget := gjson.GetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||
if !budget.Exists() {
|
||||
return payload
|
||||
}
|
||||
raw := int(budget.Int())
|
||||
normalized := util.NormalizeThinkingBudget(model, raw)
|
||||
|
||||
isClaude := strings.Contains(strings.ToLower(model), "claude")
|
||||
if isClaude {
|
||||
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
|
||||
if effectiveMax > 0 && normalized >= effectiveMax {
|
||||
normalized = effectiveMax - 1
|
||||
if normalized < 1 {
|
||||
normalized = 1
|
||||
}
|
||||
}
|
||||
if setDefaultMax {
|
||||
if res, errSet := sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax); errSet == nil {
|
||||
payload = res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
updated, err := sjson.SetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
|
||||
if err != nil {
|
||||
return payload
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// antigravityEffectiveMaxTokens returns the max tokens to cap thinking:
|
||||
// prefer request-provided maxOutputTokens; otherwise fall back to model default.
|
||||
// The boolean indicates whether the value came from the model default (and thus should be written back).
|
||||
func antigravityEffectiveMaxTokens(model string, payload []byte) (max int, fromModel bool) {
|
||||
if maxTok := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxTok.Exists() && maxTok.Int() > 0 {
|
||||
return int(maxTok.Int()), false
|
||||
}
|
||||
if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
|
||||
return modelInfo.MaxCompletionTokens, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
@@ -157,7 +157,7 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
attempts, attempt := ensureAttempt(ginCtx)
|
||||
_, attempt := ensureAttempt(ginCtx)
|
||||
ensureResponseIntro(attempt)
|
||||
|
||||
if !attempt.headersWritten {
|
||||
@@ -175,8 +175,6 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
|
||||
}
|
||||
attempt.response.WriteString(string(data))
|
||||
attempt.bodyHasContent = true
|
||||
|
||||
updateAggregatedResponse(ginCtx, attempts)
|
||||
}
|
||||
|
||||
func ginContextFrom(ctx context.Context) *gin.Context {
|
||||
|
||||
@@ -111,7 +111,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
}
|
||||
|
||||
// Temperature/top_p/top_k
|
||||
// Temperature/top_p/top_k/max_tokens
|
||||
if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number {
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num)
|
||||
}
|
||||
@@ -121,6 +121,9 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number {
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num)
|
||||
}
|
||||
if maxTok := gjson.GetBytes(rawJSON, "max_tokens"); maxTok.Exists() && maxTok.Type == gjson.Number {
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num)
|
||||
}
|
||||
|
||||
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
|
||||
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
||||
|
||||
@@ -502,7 +502,7 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
fmt.Printf("API server started successfully on: %d\n", s.cfg.Port)
|
||||
fmt.Printf("API server started successfully on: %s:%d\n", s.cfg.Host, s.cfg.Port)
|
||||
|
||||
if s.hooks.OnAfterStart != nil {
|
||||
s.hooks.OnAfterStart(s)
|
||||
|
||||
827
test/amp_management_test.go
Normal file
827
test/amp_management_test.go
Normal file
@@ -0,0 +1,827 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// newAmpTestHandler creates a test handler with default ampcode configuration.
|
||||
func newAmpTestHandler(t *testing.T) (*management.Handler, string) {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "https://example.com",
|
||||
UpstreamAPIKey: "test-api-key-12345",
|
||||
RestrictManagementToLocalhost: true,
|
||||
ForceModelMappings: false,
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "gpt-4", To: "gemini-pro"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil {
|
||||
t.Fatalf("failed to write config file: %v", err)
|
||||
}
|
||||
|
||||
h := management.NewHandler(cfg, configPath, nil)
|
||||
return h, configPath
|
||||
}
|
||||
|
||||
// setupAmpRouter creates a test router with all ampcode management endpoints.
|
||||
func setupAmpRouter(h *management.Handler) *gin.Engine {
|
||||
r := gin.New()
|
||||
mgmt := r.Group("/v0/management")
|
||||
{
|
||||
mgmt.GET("/ampcode", h.GetAmpCode)
|
||||
mgmt.GET("/ampcode/upstream-url", h.GetAmpUpstreamURL)
|
||||
mgmt.PUT("/ampcode/upstream-url", h.PutAmpUpstreamURL)
|
||||
mgmt.DELETE("/ampcode/upstream-url", h.DeleteAmpUpstreamURL)
|
||||
mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey)
|
||||
mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey)
|
||||
mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey)
|
||||
mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost)
|
||||
mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost)
|
||||
mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings)
|
||||
mgmt.PUT("/ampcode/model-mappings", h.PutAmpModelMappings)
|
||||
mgmt.PATCH("/ampcode/model-mappings", h.PatchAmpModelMappings)
|
||||
mgmt.DELETE("/ampcode/model-mappings", h.DeleteAmpModelMappings)
|
||||
mgmt.GET("/ampcode/force-model-mappings", h.GetAmpForceModelMappings)
|
||||
mgmt.PUT("/ampcode/force-model-mappings", h.PutAmpForceModelMappings)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// TestGetAmpCode verifies GET /v0/management/ampcode returns full ampcode config.
|
||||
func TestGetAmpCode(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var resp map[string]config.AmpCode
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
ampcode := resp["ampcode"]
|
||||
if ampcode.UpstreamURL != "https://example.com" {
|
||||
t.Errorf("expected upstream-url %q, got %q", "https://example.com", ampcode.UpstreamURL)
|
||||
}
|
||||
if len(ampcode.ModelMappings) != 1 {
|
||||
t.Errorf("expected 1 model mapping, got %d", len(ampcode.ModelMappings))
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAmpUpstreamURL verifies GET /v0/management/ampcode/upstream-url returns the upstream URL.
|
||||
func TestGetAmpUpstreamURL(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp["upstream-url"] != "https://example.com" {
|
||||
t.Errorf("expected %q, got %q", "https://example.com", resp["upstream-url"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestPutAmpUpstreamURL verifies PUT /v0/management/ampcode/upstream-url updates the upstream URL.
|
||||
func TestPutAmpUpstreamURL(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": "https://new-upstream.com"}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteAmpUpstreamURL verifies DELETE /v0/management/ampcode/upstream-url clears the upstream URL.
|
||||
func TestDeleteAmpUpstreamURL(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAmpUpstreamAPIKey verifies GET /v0/management/ampcode/upstream-api-key returns the API key.
|
||||
func TestGetAmpUpstreamAPIKey(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
key := resp["upstream-api-key"].(string)
|
||||
if key != "test-api-key-12345" {
|
||||
t.Errorf("expected key %q, got %q", "test-api-key-12345", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPutAmpUpstreamAPIKey verifies PUT /v0/management/ampcode/upstream-api-key updates the API key.
|
||||
func TestPutAmpUpstreamAPIKey(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": "new-secret-key"}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key.
|
||||
func TestDeleteAmpUpstreamAPIKey(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAmpRestrictManagementToLocalhost verifies GET returns the localhost restriction setting.
|
||||
func TestGetAmpRestrictManagementToLocalhost(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var resp map[string]bool
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp["restrict-management-to-localhost"] != true {
|
||||
t.Error("expected restrict-management-to-localhost to be true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPutAmpRestrictManagementToLocalhost verifies PUT updates the localhost restriction setting.
|
||||
func TestPutAmpRestrictManagementToLocalhost(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": false}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAmpModelMappings verifies GET /v0/management/ampcode/model-mappings returns all mappings.
|
||||
func TestGetAmpModelMappings(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var resp map[string][]config.AmpModelMapping
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
mappings := resp["model-mappings"]
|
||||
if len(mappings) != 1 {
|
||||
t.Fatalf("expected 1 mapping, got %d", len(mappings))
|
||||
}
|
||||
if mappings[0].From != "gpt-4" || mappings[0].To != "gemini-pro" {
|
||||
t.Errorf("unexpected mapping: %+v", mappings[0])
|
||||
}
|
||||
}
|
||||
|
||||
// TestPutAmpModelMappings verifies PUT /v0/management/ampcode/model-mappings replaces all mappings.
|
||||
func TestPutAmpModelMappings(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": [{"from": "claude-3", "to": "gpt-4o"}, {"from": "gemini", "to": "claude"}]}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestPatchAmpModelMappings verifies PATCH updates existing mappings and adds new ones.
|
||||
func TestPatchAmpModelMappings(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": [{"from": "gpt-4", "to": "updated-model"}, {"from": "new-model", "to": "target"}]}`
|
||||
req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteAmpModelMappings_Specific verifies DELETE removes specified mappings by "from" field.
|
||||
func TestDeleteAmpModelMappings_Specific(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": ["gpt-4"]}`
|
||||
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteAmpModelMappings_All verifies DELETE with empty body removes all mappings.
|
||||
func TestDeleteAmpModelMappings_All(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAmpForceModelMappings verifies GET returns the force-model-mappings setting.
|
||||
func TestGetAmpForceModelMappings(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var resp map[string]bool
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp["force-model-mappings"] != false {
|
||||
t.Error("expected force-model-mappings to be false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPutAmpForceModelMappings verifies PUT updates the force-model-mappings setting.
|
||||
func TestPutAmpForceModelMappings(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": true}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPutAmpModelMappings_VerifyState verifies PUT replaces mappings and state is persisted.
|
||||
func TestPutAmpModelMappings_VerifyState(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": [{"from": "model-a", "to": "model-b"}, {"from": "model-c", "to": "model-d"}, {"from": "model-e", "to": "model-f"}]}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("PUT failed: status %d, body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string][]config.AmpModelMapping
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
mappings := resp["model-mappings"]
|
||||
if len(mappings) != 3 {
|
||||
t.Fatalf("expected 3 mappings, got %d", len(mappings))
|
||||
}
|
||||
|
||||
expected := map[string]string{"model-a": "model-b", "model-c": "model-d", "model-e": "model-f"}
|
||||
for _, m := range mappings {
|
||||
if expected[m.From] != m.To {
|
||||
t.Errorf("mapping %q -> expected %q, got %q", m.From, expected[m.From], m.To)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestPatchAmpModelMappings_VerifyState verifies PATCH merges mappings correctly.
|
||||
func TestPatchAmpModelMappings_VerifyState(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": [{"from": "gpt-4", "to": "updated-target"}, {"from": "new-model", "to": "new-target"}]}`
|
||||
req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("PATCH failed: status %d", w.Code)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string][]config.AmpModelMapping
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
mappings := resp["model-mappings"]
|
||||
if len(mappings) != 2 {
|
||||
t.Fatalf("expected 2 mappings (1 updated + 1 new), got %d", len(mappings))
|
||||
}
|
||||
|
||||
found := make(map[string]string)
|
||||
for _, m := range mappings {
|
||||
found[m.From] = m.To
|
||||
}
|
||||
|
||||
if found["gpt-4"] != "updated-target" {
|
||||
t.Errorf("gpt-4 should map to updated-target, got %q", found["gpt-4"])
|
||||
}
|
||||
if found["new-model"] != "new-target" {
|
||||
t.Errorf("new-model should map to new-target, got %q", found["new-model"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteAmpModelMappings_VerifyState verifies DELETE removes specific mappings and keeps others.
|
||||
func TestDeleteAmpModelMappings_VerifyState(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
putBody := `{"value": [{"from": "a", "to": "1"}, {"from": "b", "to": "2"}, {"from": "c", "to": "3"}]}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
delBody := `{"value": ["a", "c"]}`
|
||||
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("DELETE failed: status %d", w.Code)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string][]config.AmpModelMapping
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
mappings := resp["model-mappings"]
|
||||
if len(mappings) != 1 {
|
||||
t.Fatalf("expected 1 mapping remaining, got %d", len(mappings))
|
||||
}
|
||||
if mappings[0].From != "b" || mappings[0].To != "2" {
|
||||
t.Errorf("expected b->2, got %s->%s", mappings[0].From, mappings[0].To)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteAmpModelMappings_NonExistent verifies DELETE with non-existent mapping doesn't affect existing ones.
|
||||
func TestDeleteAmpModelMappings_NonExistent(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
delBody := `{"value": ["non-existent-model"]}`
|
||||
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string][]config.AmpModelMapping
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if len(resp["model-mappings"]) != 1 {
|
||||
t.Errorf("original mapping should remain, got %d mappings", len(resp["model-mappings"]))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPutAmpModelMappings_Empty verifies PUT with empty array clears all mappings.
|
||||
func TestPutAmpModelMappings_Empty(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": []}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string][]config.AmpModelMapping
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if len(resp["model-mappings"]) != 0 {
|
||||
t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"]))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPutAmpUpstreamURL_VerifyState verifies PUT updates upstream URL and persists state.
|
||||
func TestPutAmpUpstreamURL_VerifyState(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": "https://new-api.example.com"}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("PUT failed: status %d", w.Code)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp["upstream-url"] != "https://new-api.example.com" {
|
||||
t.Errorf("expected %q, got %q", "https://new-api.example.com", resp["upstream-url"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteAmpUpstreamURL_VerifyState verifies DELETE clears upstream URL.
|
||||
func TestDeleteAmpUpstreamURL_VerifyState(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("DELETE failed: status %d", w.Code)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp["upstream-url"] != "" {
|
||||
t.Errorf("expected empty string, got %q", resp["upstream-url"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestPutAmpUpstreamAPIKey_VerifyState verifies PUT updates API key and persists state.
|
||||
func TestPutAmpUpstreamAPIKey_VerifyState(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": "new-secret-api-key-xyz"}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("PUT failed: status %d", w.Code)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp["upstream-api-key"] != "new-secret-api-key-xyz" {
|
||||
t.Errorf("expected %q, got %q", "new-secret-api-key-xyz", resp["upstream-api-key"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteAmpUpstreamAPIKey_VerifyState verifies DELETE clears API key.
|
||||
func TestDeleteAmpUpstreamAPIKey_VerifyState(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("DELETE failed: status %d", w.Code)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp["upstream-api-key"] != "" {
|
||||
t.Errorf("expected empty string, got %q", resp["upstream-api-key"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestPutAmpRestrictManagementToLocalhost_VerifyState verifies PUT updates localhost restriction.
|
||||
func TestPutAmpRestrictManagementToLocalhost_VerifyState(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": false}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("PUT failed: status %d", w.Code)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string]bool
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp["restrict-management-to-localhost"] != false {
|
||||
t.Error("expected false after update")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPutAmpForceModelMappings_VerifyState verifies PUT updates force-model-mappings setting.
|
||||
func TestPutAmpForceModelMappings_VerifyState(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value": true}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("PUT failed: status %d", w.Code)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string]bool
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if resp["force-model-mappings"] != true {
|
||||
t.Error("expected true after update")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPutBoolField_EmptyObject verifies PUT with empty object returns 400.
|
||||
func TestPutBoolField_EmptyObject(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status %d for empty object, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestComplexMappingsWorkflow tests a full workflow: PUT, PATCH, DELETE, and GET.
|
||||
func TestComplexMappingsWorkflow(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
putBody := `{"value": [{"from": "m1", "to": "t1"}, {"from": "m2", "to": "t2"}, {"from": "m3", "to": "t3"}, {"from": "m4", "to": "t4"}]}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
patchBody := `{"value": [{"from": "m2", "to": "t2-updated"}, {"from": "m5", "to": "t5"}]}`
|
||||
req = httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(patchBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
delBody := `{"value": ["m1", "m3"]}`
|
||||
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var resp map[string][]config.AmpModelMapping
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
mappings := resp["model-mappings"]
|
||||
if len(mappings) != 3 {
|
||||
t.Fatalf("expected 3 mappings (m2, m4, m5), got %d", len(mappings))
|
||||
}
|
||||
|
||||
expected := map[string]string{"m2": "t2-updated", "m4": "t4", "m5": "t5"}
|
||||
found := make(map[string]string)
|
||||
for _, m := range mappings {
|
||||
found[m.From] = m.To
|
||||
}
|
||||
|
||||
for from, to := range expected {
|
||||
if found[from] != to {
|
||||
t.Errorf("mapping %s: expected %q, got %q", from, to, found[from])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestNilHandlerGetAmpCode verifies handler works with empty config.
|
||||
func TestNilHandlerGetAmpCode(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
h := management.NewHandler(cfg, "", nil)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptyConfigGetAmpModelMappings verifies GET returns empty array for fresh config.
|
||||
func TestEmptyConfigGetAmpModelMappings(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
|
||||
h := management.NewHandler(cfg, configPath, nil)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var resp map[string][]config.AmpModelMapping
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if len(resp["model-mappings"]) != 0 {
|
||||
t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"]))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user