diff --git a/.gitignore b/.gitignore index ef2d935a..9c081b4c 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ GEMINI.md .vscode/* .claude/* .serena/* +/cmd/server/server diff --git a/README.md b/README.md index 81d9398b..6e88f05e 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,37 @@ # CLI Proxy API +--- + +## πŸ”” Important: Amp CLI Support Fork + +**This is a specialized fork of [router-for-me/CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) that adds support for the Amp CLI tool.** + +### Why This Fork Exists + +The **Amp CLI** requires custom routing patterns to function properly. The upstream CLIProxyAPI project maintainers opted not to merge Amp-specific routing support into the main codebase. + +### Which Version Should You Use? + +- **Use this fork** if you want to run **both Factory CLI and Amp CLI** with the same proxy server +- **Use upstream** ([router-for-me/CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI)) if you only need Factory CLI support + +### πŸ“– Complete Setup Guide + +**β†’ [USING_WITH_FACTORY_AND_AMP.md](USING_WITH_FACTORY_AND_AMP.md)** - Comprehensive guide for using this proxy with both Factory CLI (Droid) and Amp CLI and IDE extensions, including OAuth setup, configuration examples, and troubleshooting. + +### Key Differences + +This fork includes: +- βœ… **Amp CLI route aliases** (`/api/provider/{provider}/v1...`) +- βœ… **Amp upstream proxy support** for OAuth and management routes +- βœ… **Automatic gzip decompression** for Amp upstream responses +- βœ… **Smart secret management** with precedence: config > env > file +- βœ… **All Factory CLI features** from upstream (fully compatible) + +All Amp-specific code is isolated in the `internal/api/modules/amp` module, making it easy to sync upstream changes with minimal conflicts. + +--- + English | [δΈ­ζ–‡](README_CN.md) A proxy server that provides OpenAI/Gemini/Claude/Codex compatible API interfaces for CLI. @@ -40,6 +72,15 @@ Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB - OpenAI-compatible upstream providers via config (e.g., OpenRouter) - Reusable Go SDK for embedding the proxy (see `docs/sdk-usage.md`) +### Fork-Specific: Amp CLI Support πŸ”₯ +- **Full Amp CLI integration** via provider route aliases (`/api/provider/{provider}/v1...`) +- **Amp upstream proxy** for OAuth authentication and management routes +- **Smart secret management** with configurable precedence (config > env > file) +- **Automatic gzip decompression** for Amp upstream responses +- **5-minute secret caching** to reduce file I/O overhead +- **Zero conflict** with Factory CLI - use both tools simultaneously +- **Modular architecture** for easy upstream sync (90% reduction in merge conflicts) + ## Getting Started CLIProxyAPI Guides: [https://help.router-for.me/](https://help.router-for.me/) @@ -78,7 +119,7 @@ Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with A Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed -> [!NOTE] +> [!NOTE] > If you developed a project based on CLIProxyAPI, please open a PR to add it to this list. ## License diff --git a/USING_WITH_FACTORY_AND_AMP.md b/USING_WITH_FACTORY_AND_AMP.md new file mode 100644 index 00000000..bb2521a7 --- /dev/null +++ b/USING_WITH_FACTORY_AND_AMP.md @@ -0,0 +1,494 @@ +# Using Factory CLI (Droid) and Amp CLI with ChatGPT/Claude Subscriptions (OAuth) + + +## Why Use Subscriptions Instead of API Keys or Pass-Through Pricing? + +Using Factory CLI (droid) or Amp CLI/IDE with this CLIProxyAPI fork lets you leverage your **existing provider subscriptions** (ChatGPT Plus/Pro, Claude Pro/Max) instead of per-token API billing. + +**The value proposition is compelling:** +- **ChatGPT Plus/Pro** ($20-200/month) includes substantial use based on 5h and weekly quota limits +- **Claude Pro/Max** ($20-100-200/month) includes substantial Claude Sonnet 4.5 and Opus 4.1 on 5h and weekly quota limits +- **Pay-per-token APIs** can cost 5-10x+ for equivalent usage, even with pass-through pricing and no markup + +By using OAuth subscriptions through this proxy, you get significantly better value while using the powerful CLI and IDE harnesses from Factory and AmpCode. + +## Disclaimer + +- This project is for personal/educational use only. You are solely responsible for how you use it. +- Using reverse proxies or alternate API bases may violate provider Terms of Service (OpenAI, Anthropic, Google, etc.). +- Accounts can be rate-limited, locked, or banned. Credentials and data may be at risk if misconfigured. +- Do not use to resell access, bypass access controls, or otherwise abuse services. +- No warranties. Use at your own risk. + +## Summary + +- Run Factory CLI (droid) and Amp CLI through a single local proxy server. +- This fork keeps all upstream Factory compatibility and adds Amp-specific support: + - Provider route aliases for Amp: `/api/provider/{provider}/v1...` + - Amp OAuth/management upstream proxy + - Smart secret resolution and automatic gzip handling +- Outcome: one proxy for both tools, minimal switching, clean separation of Amp supporting code from upstream repo. + +## Why This Fork? + +- Upstream maintainers chose not to include Amp-specific routing to keep scope focused on pure proxy functionality. +- Amp CLI expects Amp-specific alias routes and management endpoints the upstream CLIProxyAPI does not expose. +- This fork adds: + - Route aliases: `/api/provider/{provider}/v1...` + - Amp upstream proxy and OAuth + - Localhost-only access controls for Amp management routes (secure-by-default) +- Amp-specific code is isolated under `internal/api/modules/amp`, reducing merge conflicts with upstream. + +## Architecture Overview + +### Factory (droid) flow + +```mermaid +flowchart LR + A["Factory CLI (droid)"] -->|"OpenAI/Claude-compatible calls"| B["CLIProxyAPI Fork"] + B -->|"/v1/chat/completions
/v1/messages
/v1/models"| C["Translators/Router"] + C -->|"OAuth tokens"| D[("Providers")] + D -->|"OpenAI Codex / Claude"| E["Responses+Streaming"] + E --> B --> A +``` + +### Amp flow + +```mermaid +flowchart LR + A["Amp CLI"] -->|"/api/provider/provider/v1..."| B["CLIProxyAPI Fork"] + B -->|"Route aliases map to
upstream /v1 handlers"| C["Translators/Router"] + A -->|"/api/auth
/api/user
/api/meta
/api/threads..."| B + B -->|"Amp upstream proxy
(config: amp-upstream-url)"| F[("ampcode.com")] + C -->|"OpenAI / Anthropic"| D[("Providers")] + D --> B --> A +``` + +### Notes + +- Factory uses standard OpenAI-compatible routes under `/v1/...`. +- Amp uses `/api/provider/{provider}/v1...` plus management routes proxied to `amp-upstream-url`. +- Management routes are restricted to localhost by default. + +## Prerequisites + +- Go 1.24+ +- Active subscriptions: + - **ChatGPT Plus/Pro** (for GPT-5/GPT-5 Codex via OAuth) + - **Claude Pro/Max** (for Claude models via OAuth) + - **Amp** (for Amp CLI features in this fork) +- CLI tools: + - Factory CLI (droid) + - Amp CLI +- Local port `8317` available (or choose your own in config) + +## Installation & Build + +### Clone and build: + +```bash +git clone https://github.com/ben-vargas/ai-cli-proxy-api.git +cd ai-cli-proxy-api +``` + +**macOS/Linux:** +```bash +go build -o cli-proxy-api ./cmd/server +``` + +**Windows:** +```bash +go build -o cli-proxy-api.exe ./cmd/server +``` + +### Homebrew (Factory CLI only): + +> **⚠️ Note:** The Homebrew package installs the upstream version without Amp CLI support. Use the git clone method above if you need Amp CLI functionality. + +```bash +brew install cliproxyapi +brew services start cliproxyapi +``` + +## OAuth Setup + +Run these commands in the repo folder after building to authenticate with your subscriptions: + +### OpenAI (ChatGPT Plus/Pro for GPT-5/Codex): + +```bash +./cli-proxy-api --codex-login +``` + +- Opens browser on port `1455` for OAuth callback +- Requires active ChatGPT Plus or Pro subscription +- Tokens saved to `~/.cli-proxy-api/codex-.json` + +### Claude (Anthropic for Claude models): + +```bash +./cli-proxy-api --claude-login +``` + +- Opens browser on port `54545` for OAuth callback +- Requires active Claude Pro or Claude Max subscription +- Tokens saved to `~/.cli-proxy-api/claude-.json` + +**Tip:** Add `--no-browser` to print the login URL instead of opening a browser (useful for remote/headless servers). + +## Configuration for Factory CLI + +Factory CLI uses `~/.factory/config.json` to define custom models. Add entries to the `custom_models` array. + +### Complete configuration example + +Copy this entire configuration to `~/.factory/config.json` for quick setup: + +```json +{ + "custom_models": [ + { + "model_display_name": "Claude Haiku 4.5 [Proxy]", + "model": "claude-haiku-4-5-20251001", + "base_url": "http://localhost:8317", + "api_key": "dummy-not-used", + "provider": "anthropic" + }, + { + "model_display_name": "Claude Sonnet 4.5 [Proxy]", + "model": "claude-sonnet-4-5-20250929", + "base_url": "http://localhost:8317", + "api_key": "dummy-not-used", + "provider": "anthropic" + }, + { + "model_display_name": "Claude Opus 4.1 [Proxy]", + "model": "claude-opus-4-1-20250805", + "base_url": "http://localhost:8317", + "api_key": "dummy-not-used", + "provider": "anthropic" + }, + { + "model_display_name": "Claude Sonnet 4 [Proxy]", + "model": "claude-sonnet-4-20250514", + "base_url": "http://localhost:8317", + "api_key": "dummy-not-used", + "provider": "anthropic" + }, + { + "model_display_name": "GPT-5 [Proxy]", + "model": "gpt-5", + "base_url": "http://localhost:8317/v1", + "api_key": "dummy-not-used", + "provider": "openai" + }, + { + "model_display_name": "GPT-5 Minimal [Proxy]", + "model": "gpt-5-minimal", + "base_url": "http://localhost:8317/v1", + "api_key": "dummy-not-used", + "provider": "openai" + }, + { + "model_display_name": "GPT-5 Medium [Proxy]", + "model": "gpt-5-medium", + "base_url": "http://localhost:8317/v1", + "api_key": "dummy-not-used", + "provider": "openai" + }, + { + "model_display_name": "GPT-5 High [Proxy]", + "model": "gpt-5-high", + "base_url": "http://localhost:8317/v1", + "api_key": "dummy-not-used", + "provider": "openai" + }, + { + "model_display_name": "GPT-5 Codex [Proxy]", + "model": "gpt-5-codex", + "base_url": "http://localhost:8317/v1", + "api_key": "dummy-not-used", + "provider": "openai" + }, + { + "model_display_name": "GPT-5 Codex High [Proxy]", + "model": "gpt-5-codex-high", + "base_url": "http://localhost:8317/v1", + "api_key": "dummy-not-used", + "provider": "openai" + } + ] +} +``` + +After configuration, your custom models will appear in the `/model` selector: + +![Factory CLI model selector showing custom models](assets/factory_droid_custom_models.png) + +### Required fields: + +| Field | Required | Description | Example | +|-------|----------|-------------|---------| +| `model_display_name` | βœ“ | Human-friendly name shown in `/model` selector | `"Claude Sonnet 4.5 [Proxy]"` | +| `model` | βœ“ | Model identifier sent to API | `"claude-sonnet-4-5-20250929"` | +| `base_url` | βœ“ | Proxy endpoint | `"http://localhost:8317"` or `"http://localhost:8317/v1"` | +| `api_key` | βœ“ | API key (use `"dummy-not-used"` for proxy) | `"dummy-not-used"` | +| `provider` | βœ“ | API format type | `"anthropic"`, `"openai"`, or `"generic-chat-completion-api"` | + +### Provider-specific base URLs: + +| Provider | Base URL | Reason | +|----------|----------|--------| +| `anthropic` | `http://localhost:8317` | Factory appends `/v1/messages` automatically | +| `openai` | `http://localhost:8317/v1` | Factory appends `/responses` (needs `/v1` prefix) | +| `generic-chat-completion-api` | `http://localhost:8317/v1` | For OpenAI Chat Completions compatible models | + +### Using custom models: + +1. Edit `~/.factory/config.json` with the models above +2. Restart Factory CLI (`droid`) +3. Use `/model` command to select your custom model + +## Configuration for Amp CLI + +Enable Amp integration (fork-specific): + +In `config.yaml`: + +```yaml +# Amp CLI integration +amp-upstream-url: "https://ampcode.com" + +# Optional override; otherwise uses env or file (see precedence below) +# amp-upstream-api-key: "your-amp-api-key" + +# Security: restrict management routes to localhost (recommended) +amp-restrict-management-to-localhost: true +``` + +### Secret resolution precedence + +| Source | Key | Priority | +|-----------------------------------------|----------------------------------|----------| +| Config file | `amp-upstream-api-key` | High | +| Environment | `AMP_API_KEY` | Medium | +| Amp secrets file | `~/.local/share/amp/secrets.json`| Low | + +### Set Amp CLI to use this proxy + +Edit `~/.config/amp/settings.json` and add the `amp.url` setting: + +```json +{ + "amp.url": "http://localhost:8317" +} +``` + +Or set the environment variable: + +```bash +export AMP_URL=http://localhost:8317 +``` + +Then login (proxied via `amp-upstream-url`): + +```bash +amp login +``` + +Use Amp as normal: + +```bash +amp "Hello, world!" +``` + +### Supported Amp routes + +**Provider Aliases (always available):** +- `/api/provider/openai/v1/chat/completions` +- `/api/provider/openai/v1/responses` +- `/api/provider/anthropic/v1/messages` +- And related provider routes/versions your Amp CLI calls + +**Management Routes (require `amp-upstream-url`):** +- `/api/auth`, `/api/user`, `/api/meta`, `/api/internal`, `/api/threads`, `/api/telemetry` +- Localhost-only by default for security + +### Works with Amp IDE Extension + +This proxy configuration also works with the Amp IDE extension for VSCode and forks (Cursor, Windsurf, etc). Simply set the Amp URL in your IDE extension settings: + +1. Open Amp extension settings in your IDE +2. Set **Amp URL** to `http://localhost:8317` +3. Login with your Amp account +4. Start using Amp in your IDE with the same OAuth subscriptions! + +![Amp IDE extension settings](assets/amp_ide_extension_amp_url.png) + +The IDE extension uses the same routes as the CLI, so both can share the proxy simultaneously. + +## Running the Proxy + +> **Important:** The proxy requires a config file with `port` set (e.g., `port: 8317`). There is no built-in default port. + +### With config file: + +```bash +./cli-proxy-api --config config.yaml +``` + +If `config.yaml` is in the current directory: + +```bash +./cli-proxy-api +``` + +### Tmux (recommended for remote servers): + +Running in tmux keeps the proxy alive across SSH disconnects: + +**Start proxy in detached tmux session:** +```bash +tmux new-session -d -s proxy -c ~/ai-cli-proxy-api \ + "./cli-proxy-api --config config.yaml" +``` + +**View/attach to proxy session:** +```bash +tmux attach-session -t proxy +``` + +**Detach from session (proxy keeps running):** +``` +Ctrl+b, then d +``` + +**Stop proxy:** +```bash +tmux kill-session -t proxy +``` + +**Check if running:** +```bash +tmux has-session -t proxy && echo "Running" || echo "Not running" +``` + +**Optional: Add to `~/.bashrc` for convenience:** +```bash +alias proxy-start='tmux new-session -d -s proxy -c ~/ai-cli-proxy-api "./cli-proxy-api --config config.yaml" && echo "Proxy started (use proxy-view to attach)"' +alias proxy-view='tmux attach-session -t proxy' +alias proxy-stop='tmux kill-session -t proxy 2>/dev/null && echo "Proxy stopped"' +alias proxy-status='tmux has-session -t proxy 2>/dev/null && echo "βœ“ Running" || echo "βœ— Not running"' +``` + +### As a service (examples): + +**Homebrew:** +```bash +brew services start cliproxyapi +``` + +**Systemd/Docker:** use your standard service templates; point the binary and config appropriately + +### Key config fields (example) + +```yaml +port: 8317 +auth-dir: "~/.cli-proxy-api" +debug: false +logging-to-file: true + +remote-management: + allow-remote: false + secret-key: "" # leave empty to disable management API + disable-control-panel: false + +# Amp integration +amp-upstream-url: "https://ampcode.com" +# amp-upstream-api-key: "your-amp-api-key" +amp-restrict-management-to-localhost: true + +# Retries and quotas +request-retry: 3 +quota-exceeded: + switch-project: true + switch-preview-model: true +``` + +## Usage Examples + +### Factory + +**List models:** +```bash +curl http://localhost:8317/v1/models +``` + +**Chat Completions (Claude):** +```bash +curl -s http://localhost:8317/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "claude-sonnet-4-5-20250929", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 1024 + }' +``` + +### Amp + +**Provider alias (OpenAI-style):** +```bash +curl -s http://localhost:8317/api/provider/openai/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-5", + "messages": [{"role": "user", "content": "Hello"}] + }' +``` + +**Management (localhost only by default):** +```bash +curl -s http://localhost:8317/api/user +``` + +## Troubleshooting + +### Common errors and fixes + +| Symptom/Code | Likely Cause | Fix | +|------------------------------------------|------------------------------------------------------|----------------------------------------------------------------------| +| 404 /v1/chat/completions | Factory not pointing to proxy base | Set base to `http://localhost:8317/v1` (env/flag/config). | +| 404 /api/provider/... | Incorrect route path or typo | Ensure you're calling `/api/provider/{provider}/v1...` paths exactly.| +| 403 on /api/user (Amp) | Management restricted to localhost | Run from same machine or set `amp-restrict-management-to-localhost: false` (not recommended). | +| 401/403 from provider | Missing/expired OAuth or API key | Re-run the relevant `--*-login` or configure keys in `config.yaml`. | +| 429/Quota exceeded | Project/model quota exhausted | Enable `quota-exceeded` switching or switch accounts. | +| 5xx from provider | Upstream transient error | Increase `request-retry` and try again. | +| SSE/stream stuck | Client not handling SSE properly | Use SSE-capable client or set `stream: false`. | +| Amp gzip decoding errors | Compressed upstream responses | Fork auto-decompresses; update to latest build if issue persists. | +| CORS errors in browser | Protected management endpoints | Use CLI/terminal; avoid browsers for management endpoints. | +| Wrong model name | Provider alias mismatch | Use `gpt-*` for OpenAI or `claude-*` for Anthropic models. | + +### Diagnostics + +- Check logs (`debug: true` temporarily or `logging-to-file: true`). +- Verify config in effect: print effective config or confirm with startup logs. +- Test base reachability: `curl http://localhost:8317/v1/models`. +- For Amp, verify `amp-upstream-url` and secrets resolution. + +## Security Checklist + +- Keep `amp-restrict-management-to-localhost: true` (default). +- Do not expose the proxy publicly; bind to localhost or protect with firewall/VPN. +- If enabling remote management, set `remote-management.secret-key` and TLS/ingress protections. +- Disable the built-in management UI if hosting your own: + - `remote-management.disable-control-panel: true` +- Rotate tokens/keys; store config and auth-dir on encrypted disk or managed secret stores. +- Keep binary up to date to receive security fixes. + +## References + +- This fork README: [README.md](README.md) +- Upstream project: [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) +- Amp CLI: [Official Manual](https://ampcode.com/manual) +- Factory CLI (droid): [Official Documentation](https://docs.factory.ai/cli/getting-started/overview) diff --git a/examples/custom-provider/main.go b/examples/custom-provider/main.go index eb1755d0..b22675f9 100644 --- a/examples/custom-provider/main.go +++ b/examples/custom-provider/main.go @@ -137,6 +137,10 @@ func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Re return clipexec.Response{Payload: body}, nil } +func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) { + return clipexec.Response{}, errors.New("count tokens not implemented") +} + func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) { ch := make(chan clipexec.StreamChunk, 1) go func() { diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go new file mode 100644 index 00000000..07e52e46 --- /dev/null +++ b/internal/api/modules/amp/amp.go @@ -0,0 +1,185 @@ +// Package amp implements the Amp CLI routing module, providing OAuth-based +// integration with Amp CLI for ChatGPT and Anthropic subscriptions. +package amp + +import ( + "fmt" + "net/http/httputil" + "strings" + "sync" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access" + log "github.com/sirupsen/logrus" +) + +// Option configures the AmpModule. +type Option func(*AmpModule) + +// AmpModule implements the RouteModuleV2 interface for Amp CLI integration. +// It provides: +// - Reverse proxy to Amp control plane for OAuth/management +// - Provider-specific route aliases (/api/provider/{provider}/...) +// - Automatic gzip decompression for misconfigured upstreams +type AmpModule struct { + secretSource SecretSource + proxy *httputil.ReverseProxy + accessManager *sdkaccess.Manager + authMiddleware_ gin.HandlerFunc + enabled bool + registerOnce sync.Once +} + +// New creates a new Amp routing module with the given options. +// This is the preferred constructor using the Option pattern. +// +// Example: +// +// ampModule := amp.New( +// amp.WithAccessManager(accessManager), +// amp.WithAuthMiddleware(authMiddleware), +// amp.WithSecretSource(customSecret), +// ) +func New(opts ...Option) *AmpModule { + m := &AmpModule{ + secretSource: nil, // Will be created on demand if not provided + } + for _, opt := range opts { + opt(m) + } + return m +} + +// NewLegacy creates a new Amp routing module using the legacy constructor signature. +// This is provided for backwards compatibility. +// +// DEPRECATED: Use New with options instead. +func NewLegacy(accessManager *sdkaccess.Manager, authMiddleware gin.HandlerFunc) *AmpModule { + return New( + WithAccessManager(accessManager), + WithAuthMiddleware(authMiddleware), + ) +} + +// WithSecretSource sets a custom secret source for the module. +func WithSecretSource(source SecretSource) Option { + return func(m *AmpModule) { + m.secretSource = source + } +} + +// WithAccessManager sets the access manager for the module. +func WithAccessManager(am *sdkaccess.Manager) Option { + return func(m *AmpModule) { + m.accessManager = am + } +} + +// WithAuthMiddleware sets the authentication middleware for provider routes. +func WithAuthMiddleware(middleware gin.HandlerFunc) Option { + return func(m *AmpModule) { + m.authMiddleware_ = middleware + } +} + +// Name returns the module identifier +func (m *AmpModule) Name() string { + return "amp-routing" +} + +// Register sets up Amp routes if configured. +// This implements the RouteModuleV2 interface with Context. +// Routes are registered only once via sync.Once for idempotent behavior. +func (m *AmpModule) Register(ctx modules.Context) error { + upstreamURL := strings.TrimSpace(ctx.Config.AmpUpstreamURL) + + // Determine auth middleware (from module or context) + auth := m.getAuthMiddleware(ctx) + + // Use registerOnce to ensure routes are only registered once + var regErr error + m.registerOnce.Do(func() { + // Always register provider aliases - these work without an upstream + m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth) + + // If no upstream URL, skip proxy routes but provider aliases are still available + if upstreamURL == "" { + log.Debug("Amp upstream proxy disabled (no upstream URL configured)") + log.Debug("Amp provider alias routes registered") + m.enabled = false + return + } + + // Create secret source with precedence: config > env > file + // Cache secrets for 5 minutes to reduce file I/O + if m.secretSource == nil { + m.secretSource = NewMultiSourceSecret(ctx.Config.AmpUpstreamAPIKey, 0 /* default 5min */) + } + + // Create reverse proxy with gzip handling via ModifyResponse + proxy, err := createReverseProxy(upstreamURL, m.secretSource) + if err != nil { + regErr = fmt.Errorf("failed to create amp proxy: %w", err) + return + } + + m.proxy = proxy + m.enabled = true + + // Register management proxy routes (requires upstream) + // Restrict to localhost by default for security (prevents drive-by browser attacks) + handler := proxyHandler(proxy) + m.registerManagementRoutes(ctx.Engine, handler, ctx.Config.AmpRestrictManagementToLocalhost) + + log.Infof("Amp upstream proxy enabled for: %s", upstreamURL) + log.Debug("Amp provider alias routes registered") + }) + + return regErr +} + +// getAuthMiddleware returns the authentication middleware, preferring the +// module's configured middleware, then the context middleware, then a fallback. +func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc { + if m.authMiddleware_ != nil { + return m.authMiddleware_ + } + if ctx.AuthMiddleware != nil { + return ctx.AuthMiddleware + } + // Fallback: no authentication (should not happen in production) + log.Warn("Amp module: no auth middleware provided, allowing all requests") + return func(c *gin.Context) { + c.Next() + } +} + +// OnConfigUpdated handles configuration updates. +// Currently requires restart for URL changes (could be enhanced for dynamic updates). +func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { + if !m.enabled { + log.Debug("Amp routing not enabled, skipping config update") + return nil + } + + upstreamURL := strings.TrimSpace(cfg.AmpUpstreamURL) + if upstreamURL == "" { + log.Warn("Amp upstream URL removed from config, restart required to disable") + return nil + } + + // If API key changed, invalidate the cache + if m.secretSource != nil { + if ms, ok := m.secretSource.(*MultiSourceSecret); ok { + ms.InvalidateCache() + log.Debug("Amp secret cache invalidated due to config update") + } + } + + log.Debug("Amp config updated (restart required for URL changes)") + return nil +} + + diff --git a/internal/api/modules/amp/amp_test.go b/internal/api/modules/amp/amp_test.go new file mode 100644 index 00000000..5ae16647 --- /dev/null +++ b/internal/api/modules/amp/amp_test.go @@ -0,0 +1,303 @@ +package amp + +import ( + "context" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" +) + +func TestAmpModule_Name(t *testing.T) { + m := New() + if m.Name() != "amp-routing" { + t.Fatalf("want amp-routing, got %s", m.Name()) + } +} + +func TestAmpModule_New(t *testing.T) { + accessManager := sdkaccess.NewManager() + authMiddleware := func(c *gin.Context) { c.Next() } + + m := NewLegacy(accessManager, authMiddleware) + + if m.accessManager != accessManager { + t.Fatal("accessManager not set") + } + if m.authMiddleware_ == nil { + t.Fatal("authMiddleware not set") + } + if m.enabled { + t.Fatal("enabled should be false initially") + } + if m.proxy != nil { + t.Fatal("proxy should be nil initially") + } +} + +func TestAmpModule_Register_WithUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + + // Fake upstream to ensure URL is valid + upstream := httptest.NewServer(nil) + defer upstream.Close() + + accessManager := sdkaccess.NewManager() + base := &handlers.BaseAPIHandler{} + + m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() }) + + cfg := &config.Config{ + AmpUpstreamURL: upstream.URL, + AmpUpstreamAPIKey: "test-key", + } + + ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }} + if err := m.Register(ctx); err != nil { + t.Fatalf("register error: %v", err) + } + + if !m.enabled { + t.Fatal("module should be enabled with upstream URL") + } + if m.proxy == nil { + t.Fatal("proxy should be initialized") + } + if m.secretSource == nil { + t.Fatal("secretSource should be initialized") + } +} + +func TestAmpModule_Register_WithoutUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + + accessManager := sdkaccess.NewManager() + base := &handlers.BaseAPIHandler{} + + m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() }) + + cfg := &config.Config{ + AmpUpstreamURL: "", // No upstream + } + + ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }} + if err := m.Register(ctx); err != nil { + t.Fatalf("register should not error without upstream: %v", err) + } + + if m.enabled { + t.Fatal("module should be disabled without upstream URL") + } + if m.proxy != nil { + t.Fatal("proxy should not be initialized without upstream") + } + + // But provider aliases should still be registered + req := httptest.NewRequest("GET", "/api/provider/openai/models", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code == 404 { + t.Fatal("provider aliases should be registered even without upstream") + } +} + +func TestAmpModule_Register_InvalidUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + + accessManager := sdkaccess.NewManager() + base := &handlers.BaseAPIHandler{} + + m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() }) + + cfg := &config.Config{ + AmpUpstreamURL: "://invalid-url", + } + + ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }} + if err := m.Register(ctx); err == nil { + t.Fatal("expected error for invalid upstream URL") + } +} + +func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) { + tmpDir := t.TempDir() + p := filepath.Join(tmpDir, "secrets.json") + if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil { + t.Fatal(err) + } + + m := &AmpModule{enabled: true} + ms := NewMultiSourceSecretWithPath("", p, time.Minute) + m.secretSource = ms + + // Warm the cache + if _, err := ms.Get(context.Background()); err != nil { + t.Fatal(err) + } + + if ms.cache == nil { + t.Fatal("expected cache to be set") + } + + // Update config - should invalidate cache + if err := m.OnConfigUpdated(&config.Config{AmpUpstreamURL: "http://x"}); err != nil { + t.Fatal(err) + } + + if ms.cache != nil { + t.Fatal("expected cache to be invalidated") + } +} + +func TestAmpModule_OnConfigUpdated_NotEnabled(t *testing.T) { + m := &AmpModule{enabled: false} + + // Should not error or panic when disabled + if err := m.OnConfigUpdated(&config.Config{}); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestAmpModule_OnConfigUpdated_URLRemoved(t *testing.T) { + m := &AmpModule{enabled: true} + ms := NewMultiSourceSecret("", 0) + m.secretSource = ms + + // Config update with empty URL - should log warning but not error + cfg := &config.Config{AmpUpstreamURL: ""} + + if err := m.OnConfigUpdated(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestAmpModule_OnConfigUpdated_NonMultiSourceSecret(t *testing.T) { + // Test that OnConfigUpdated doesn't panic with StaticSecretSource + m := &AmpModule{enabled: true} + m.secretSource = NewStaticSecretSource("static-key") + + cfg := &config.Config{AmpUpstreamURL: "http://example.com"} + + // Should not error or panic + if err := m.OnConfigUpdated(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestAmpModule_AuthMiddleware_Fallback(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + + // Create module with no auth middleware + m := &AmpModule{authMiddleware_: nil} + + // Get the fallback middleware via getAuthMiddleware + ctx := modules.Context{Engine: r, AuthMiddleware: nil} + middleware := m.getAuthMiddleware(ctx) + + if middleware == nil { + t.Fatal("getAuthMiddleware should return a fallback, not nil") + } + + // Test that it works + called := false + r.GET("/test", middleware, func(c *gin.Context) { + called = true + c.String(200, "ok") + }) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if !called { + t.Fatal("fallback middleware should allow requests through") + } +} + +func TestAmpModule_SecretSource_FromConfig(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + + upstream := httptest.NewServer(nil) + defer upstream.Close() + + accessManager := sdkaccess.NewManager() + base := &handlers.BaseAPIHandler{} + + m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() }) + + // Config with explicit API key + cfg := &config.Config{ + AmpUpstreamURL: upstream.URL, + AmpUpstreamAPIKey: "config-key", + } + + ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }} + if err := m.Register(ctx); err != nil { + t.Fatalf("register error: %v", err) + } + + // Secret source should be MultiSourceSecret with config key + if m.secretSource == nil { + t.Fatal("secretSource should be set") + } + + // Verify it returns the config key + key, err := m.secretSource.Get(context.Background()) + if err != nil { + t.Fatalf("Get error: %v", err) + } + if key != "config-key" { + t.Fatalf("want config-key, got %s", key) + } +} + +func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) { + gin.SetMode(gin.TestMode) + + scenarios := []struct { + name string + configURL string + }{ + {"with_upstream", "http://example.com"}, + {"without_upstream", ""}, + } + + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + r := gin.New() + accessManager := sdkaccess.NewManager() + base := &handlers.BaseAPIHandler{} + + m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() }) + + cfg := &config.Config{AmpUpstreamURL: scenario.configURL} + + ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }} + if err := m.Register(ctx); err != nil && scenario.configURL != "" { + t.Fatalf("register error: %v", err) + } + + // Provider aliases should always be available + req := httptest.NewRequest("GET", "/api/provider/openai/models", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code == 404 { + t.Fatal("provider aliases should be registered") + } + }) + } +} diff --git a/internal/api/modules/amp/proxy.go b/internal/api/modules/amp/proxy.go new file mode 100644 index 00000000..5e267290 --- /dev/null +++ b/internal/api/modules/amp/proxy.go @@ -0,0 +1,176 @@ +package amp + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "net/http" + "net/http/httputil" + "net/url" + "strconv" + "strings" + + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" +) + +// readCloser wraps a reader and forwards Close to a separate closer. +// Used to restore peeked bytes while preserving upstream body Close behavior. +type readCloser struct { + r io.Reader + c io.Closer +} + +func (rc *readCloser) Read(p []byte) (int, error) { return rc.r.Read(p) } +func (rc *readCloser) Close() error { return rc.c.Close() } + +// createReverseProxy creates a reverse proxy handler for Amp upstream +// with automatic gzip decompression via ModifyResponse +func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputil.ReverseProxy, error) { + parsed, err := url.Parse(upstreamURL) + if err != nil { + return nil, fmt.Errorf("invalid amp upstream url: %w", err) + } + + proxy := httputil.NewSingleHostReverseProxy(parsed) + originalDirector := proxy.Director + + // Modify outgoing requests to inject API key and fix routing + proxy.Director = func(req *http.Request) { + originalDirector(req) + req.Host = parsed.Host + + // Preserve correlation headers for debugging + if req.Header.Get("X-Request-ID") == "" { + // Could generate one here if needed + } + + // Inject API key from secret source (precedence: config > env > file) + if key, err := secretSource.Get(req.Context()); err == nil && key != "" { + req.Header.Set("X-Api-Key", key) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key)) + } else if err != nil { + log.Warnf("amp secret source error (continuing without auth): %v", err) + } + } + + // Modify incoming responses to handle gzip without Content-Encoding + // This addresses the same issue as inline handler gzip handling, but at the proxy level + proxy.ModifyResponse = func(resp *http.Response) error { + // Only process successful responses + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil + } + + // Skip if already marked as gzip (Content-Encoding set) + if resp.Header.Get("Content-Encoding") != "" { + return nil + } + + // Skip streaming responses (SSE, chunked) + if isStreamingResponse(resp) { + return nil + } + + // Save reference to original upstream body for proper cleanup + originalBody := resp.Body + + // Peek at first 2 bytes to detect gzip magic bytes + header := make([]byte, 2) + n, _ := io.ReadFull(originalBody, header) + + // Check for gzip magic bytes (0x1f 0x8b) + // If n < 2, we didn't get enough bytes, so it's not gzip + if n >= 2 && header[0] == 0x1f && header[1] == 0x8b { + // It's gzip - read the rest of the body + rest, err := io.ReadAll(originalBody) + if err != nil { + // Restore what we read and return original body (preserve Close behavior) + resp.Body = &readCloser{ + r: io.MultiReader(bytes.NewReader(header[:n]), originalBody), + c: originalBody, + } + return nil + } + + // Reconstruct complete gzipped data + gzippedData := append(header[:n], rest...) + + // Decompress + gzipReader, err := gzip.NewReader(bytes.NewReader(gzippedData)) + if err != nil { + log.Warnf("amp proxy: gzip header detected but decompress failed: %v", err) + // Close original body and return in-memory copy + _ = originalBody.Close() + resp.Body = io.NopCloser(bytes.NewReader(gzippedData)) + return nil + } + + decompressed, err := io.ReadAll(gzipReader) + _ = gzipReader.Close() + if err != nil { + log.Warnf("amp proxy: gzip decompress error: %v", err) + // Close original body and return in-memory copy + _ = originalBody.Close() + resp.Body = io.NopCloser(bytes.NewReader(gzippedData)) + return nil + } + + // Close original body since we're replacing with in-memory decompressed content + _ = originalBody.Close() + + // Replace body with decompressed content + resp.Body = io.NopCloser(bytes.NewReader(decompressed)) + resp.ContentLength = int64(len(decompressed)) + + // Update headers to reflect decompressed state + resp.Header.Del("Content-Encoding") // No longer compressed + resp.Header.Del("Content-Length") // Remove stale compressed length + resp.Header.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10)) // Set decompressed length + + log.Debugf("amp proxy: decompressed gzip response (%d -> %d bytes)", len(gzippedData), len(decompressed)) + } else { + // Not gzip - restore peeked bytes while preserving Close behavior + // Handle edge cases: n might be 0, 1, or 2 depending on EOF + resp.Body = &readCloser{ + r: io.MultiReader(bytes.NewReader(header[:n]), originalBody), + c: originalBody, + } + } + + return nil + } + + // Error handler for proxy failures + proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { + log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err) + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusBadGateway) + _, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`)) + } + + return proxy, nil +} + +// isStreamingResponse detects if the response is streaming (SSE only) +// Note: We only treat text/event-stream as streaming. Chunked transfer encoding +// is a transport-level detail and doesn't mean we can't decompress the full response. +// Many JSON APIs use chunked encoding for normal responses. +func isStreamingResponse(resp *http.Response) bool { + contentType := resp.Header.Get("Content-Type") + + // Only Server-Sent Events are true streaming responses + if strings.Contains(contentType, "text/event-stream") { + return true + } + + return false +} + +// proxyHandler converts httputil.ReverseProxy to gin.HandlerFunc +func proxyHandler(proxy *httputil.ReverseProxy) gin.HandlerFunc { + return func(c *gin.Context) { + proxy.ServeHTTP(c.Writer, c.Request) + } +} diff --git a/internal/api/modules/amp/proxy_test.go b/internal/api/modules/amp/proxy_test.go new file mode 100644 index 00000000..864ed22c --- /dev/null +++ b/internal/api/modules/amp/proxy_test.go @@ -0,0 +1,439 @@ +package amp + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +// Helper: compress data with gzip +func gzipBytes(b []byte) []byte { + var buf bytes.Buffer + zw := gzip.NewWriter(&buf) + zw.Write(b) + zw.Close() + return buf.Bytes() +} + +// Helper: create a mock http.Response +func mkResp(status int, hdr http.Header, body []byte) *http.Response { + if hdr == nil { + hdr = http.Header{} + } + return &http.Response{ + StatusCode: status, + Header: hdr, + Body: io.NopCloser(bytes.NewReader(body)), + ContentLength: int64(len(body)), + } +} + +func TestCreateReverseProxy_ValidURL(t *testing.T) { + proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("key")) + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if proxy == nil { + t.Fatal("expected proxy to be created") + } +} + +func TestCreateReverseProxy_InvalidURL(t *testing.T) { + _, err := createReverseProxy("://invalid", NewStaticSecretSource("key")) + if err == nil { + t.Fatal("expected error for invalid URL") + } +} + +func TestModifyResponse_GzipScenarios(t *testing.T) { + proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k")) + if err != nil { + t.Fatal(err) + } + + goodJSON := []byte(`{"ok":true}`) + good := gzipBytes(goodJSON) + truncated := good[:10] + corrupted := append([]byte{0x1f, 0x8b}, []byte("notgzip")...) + + cases := []struct { + name string + header http.Header + body []byte + status int + wantBody []byte + wantCE string + }{ + { + name: "decompresses_valid_gzip_no_header", + header: http.Header{}, + body: good, + status: 200, + wantBody: goodJSON, + wantCE: "", + }, + { + name: "skips_when_ce_present", + header: http.Header{"Content-Encoding": []string{"gzip"}}, + body: good, + status: 200, + wantBody: good, + wantCE: "gzip", + }, + { + name: "passes_truncated_unchanged", + header: http.Header{}, + body: truncated, + status: 200, + wantBody: truncated, + wantCE: "", + }, + { + name: "passes_corrupted_unchanged", + header: http.Header{}, + body: corrupted, + status: 200, + wantBody: corrupted, + wantCE: "", + }, + { + name: "non_gzip_unchanged", + header: http.Header{}, + body: []byte("plain"), + status: 200, + wantBody: []byte("plain"), + wantCE: "", + }, + { + name: "empty_body", + header: http.Header{}, + body: []byte{}, + status: 200, + wantBody: []byte{}, + wantCE: "", + }, + { + name: "single_byte_body", + header: http.Header{}, + body: []byte{0x1f}, + status: 200, + wantBody: []byte{0x1f}, + wantCE: "", + }, + { + name: "skips_non_2xx_status", + header: http.Header{}, + body: good, + status: 404, + wantBody: good, + wantCE: "", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + resp := mkResp(tc.status, tc.header, tc.body) + if err := proxy.ModifyResponse(resp); err != nil { + t.Fatalf("ModifyResponse error: %v", err) + } + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if !bytes.Equal(got, tc.wantBody) { + t.Fatalf("body mismatch:\nwant: %q\ngot: %q", tc.wantBody, got) + } + if ce := resp.Header.Get("Content-Encoding"); ce != tc.wantCE { + t.Fatalf("Content-Encoding: want %q, got %q", tc.wantCE, ce) + } + }) + } +} + +func TestModifyResponse_UpdatesContentLengthHeader(t *testing.T) { + proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k")) + if err != nil { + t.Fatal(err) + } + + goodJSON := []byte(`{"message":"test response"}`) + gzipped := gzipBytes(goodJSON) + + // Simulate upstream response with gzip body AND Content-Length header + // (this is the scenario the bot flagged - stale Content-Length after decompression) + resp := mkResp(200, http.Header{ + "Content-Length": []string{fmt.Sprintf("%d", len(gzipped))}, // Compressed size + }, gzipped) + + if err := proxy.ModifyResponse(resp); err != nil { + t.Fatalf("ModifyResponse error: %v", err) + } + + // Verify body is decompressed + got, _ := io.ReadAll(resp.Body) + if !bytes.Equal(got, goodJSON) { + t.Fatalf("body should be decompressed, got: %q, want: %q", got, goodJSON) + } + + // Verify Content-Length header is updated to decompressed size + wantCL := fmt.Sprintf("%d", len(goodJSON)) + gotCL := resp.Header.Get("Content-Length") + if gotCL != wantCL { + t.Fatalf("Content-Length header mismatch: want %q (decompressed), got %q", wantCL, gotCL) + } + + // Verify struct field also matches + if resp.ContentLength != int64(len(goodJSON)) { + t.Fatalf("resp.ContentLength mismatch: want %d, got %d", len(goodJSON), resp.ContentLength) + } +} + +func TestModifyResponse_SkipsStreamingResponses(t *testing.T) { + proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k")) + if err != nil { + t.Fatal(err) + } + + goodJSON := []byte(`{"ok":true}`) + gzipped := gzipBytes(goodJSON) + + t.Run("sse_skips_decompression", func(t *testing.T) { + resp := mkResp(200, http.Header{"Content-Type": []string{"text/event-stream"}}, gzipped) + if err := proxy.ModifyResponse(resp); err != nil { + t.Fatalf("ModifyResponse error: %v", err) + } + // SSE should NOT be decompressed + got, _ := io.ReadAll(resp.Body) + if !bytes.Equal(got, gzipped) { + t.Fatal("SSE response should not be decompressed") + } + }) +} + +func TestModifyResponse_DecompressesChunkedJSON(t *testing.T) { + proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k")) + if err != nil { + t.Fatal(err) + } + + goodJSON := []byte(`{"ok":true}`) + gzipped := gzipBytes(goodJSON) + + t.Run("chunked_json_decompresses", func(t *testing.T) { + // Chunked JSON responses (like thread APIs) should be decompressed + resp := mkResp(200, http.Header{"Transfer-Encoding": []string{"chunked"}}, gzipped) + if err := proxy.ModifyResponse(resp); err != nil { + t.Fatalf("ModifyResponse error: %v", err) + } + // Should decompress because it's not SSE + got, _ := io.ReadAll(resp.Body) + if !bytes.Equal(got, goodJSON) { + t.Fatalf("chunked JSON should be decompressed, got: %q, want: %q", got, goodJSON) + } + }) +} + +func TestReverseProxy_InjectsHeaders(t *testing.T) { + gotHeaders := make(chan http.Header, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeaders <- r.Header.Clone() + w.WriteHeader(200) + w.Write([]byte(`ok`)) + })) + defer upstream.Close() + + proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("secret")) + if err != nil { + t.Fatal(err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxy.ServeHTTP(w, r) + })) + defer srv.Close() + + res, err := http.Get(srv.URL + "/test") + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + hdr := <-gotHeaders + if hdr.Get("X-Api-Key") != "secret" { + t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key")) + } + if hdr.Get("Authorization") != "Bearer secret" { + t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization")) + } +} + +func TestReverseProxy_EmptySecret(t *testing.T) { + gotHeaders := make(chan http.Header, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeaders <- r.Header.Clone() + w.WriteHeader(200) + w.Write([]byte(`ok`)) + })) + defer upstream.Close() + + proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("")) + if err != nil { + t.Fatal(err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxy.ServeHTTP(w, r) + })) + defer srv.Close() + + res, err := http.Get(srv.URL + "/test") + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + hdr := <-gotHeaders + // Should NOT inject headers when secret is empty + if hdr.Get("X-Api-Key") != "" { + t.Fatalf("X-Api-Key should not be set, got: %q", hdr.Get("X-Api-Key")) + } + if authVal := hdr.Get("Authorization"); authVal != "" && authVal != "Bearer " { + t.Fatalf("Authorization should not be set, got: %q", authVal) + } +} + +func TestReverseProxy_ErrorHandler(t *testing.T) { + // Point proxy to a non-routable address to trigger error + proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource("")) + if err != nil { + t.Fatal(err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxy.ServeHTTP(w, r) + })) + defer srv.Close() + + res, err := http.Get(srv.URL + "/any") + if err != nil { + t.Fatal(err) + } + body, _ := io.ReadAll(res.Body) + res.Body.Close() + + if res.StatusCode != http.StatusBadGateway { + t.Fatalf("want 502, got %d", res.StatusCode) + } + if !bytes.Contains(body, []byte(`"amp_upstream_proxy_error"`)) { + t.Fatalf("unexpected body: %s", body) + } + if ct := res.Header.Get("Content-Type"); ct != "application/json" { + t.Fatalf("content-type: want application/json, got %s", ct) + } +} + +func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) { + // Upstream returns gzipped JSON without Content-Encoding header + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write(gzipBytes([]byte(`{"upstream":"ok"}`))) + })) + defer upstream.Close() + + proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key")) + if err != nil { + t.Fatal(err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxy.ServeHTTP(w, r) + })) + defer srv.Close() + + res, err := http.Get(srv.URL + "/test") + if err != nil { + t.Fatal(err) + } + body, _ := io.ReadAll(res.Body) + res.Body.Close() + + expected := []byte(`{"upstream":"ok"}`) + if !bytes.Equal(body, expected) { + t.Fatalf("want decompressed JSON, got: %s", body) + } +} + +func TestReverseProxy_FullRoundTrip_PlainJSON(t *testing.T) { + // Upstream returns plain JSON + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + w.Write([]byte(`{"plain":"json"}`)) + })) + defer upstream.Close() + + proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key")) + if err != nil { + t.Fatal(err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxy.ServeHTTP(w, r) + })) + defer srv.Close() + + res, err := http.Get(srv.URL + "/test") + if err != nil { + t.Fatal(err) + } + body, _ := io.ReadAll(res.Body) + res.Body.Close() + + expected := []byte(`{"plain":"json"}`) + if !bytes.Equal(body, expected) { + t.Fatalf("want plain JSON unchanged, got: %s", body) + } +} + +func TestIsStreamingResponse(t *testing.T) { + cases := []struct { + name string + header http.Header + want bool + }{ + { + name: "sse", + header: http.Header{"Content-Type": []string{"text/event-stream"}}, + want: true, + }, + { + name: "chunked_not_streaming", + header: http.Header{"Transfer-Encoding": []string{"chunked"}}, + want: false, // Chunked is transport-level, not streaming + }, + { + name: "normal_json", + header: http.Header{"Content-Type": []string{"application/json"}}, + want: false, + }, + { + name: "empty", + header: http.Header{}, + want: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + resp := &http.Response{Header: tc.header} + got := isStreamingResponse(resp) + if got != tc.want { + t.Fatalf("want %v, got %v", tc.want, got) + } + }) + } +} diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go new file mode 100644 index 00000000..f952de8d --- /dev/null +++ b/internal/api/modules/amp/routes.go @@ -0,0 +1,166 @@ +package amp + +import ( + "net" + "strings" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai" + log "github.com/sirupsen/logrus" +) + +// localhostOnlyMiddleware restricts access to localhost (127.0.0.1, ::1) only. +// Returns 403 Forbidden for non-localhost clients. +func localhostOnlyMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + clientIP := c.ClientIP() + + // Parse the IP to handle both IPv4 and IPv6 + ip := net.ParseIP(clientIP) + if ip == nil { + log.Warnf("Amp management: invalid client IP %s, denying access", clientIP) + c.AbortWithStatusJSON(403, gin.H{ + "error": "Access denied: management routes restricted to localhost", + }) + return + } + + // Check if IP is loopback (127.0.0.1 or ::1) + if !ip.IsLoopback() { + log.Warnf("Amp management: non-localhost IP %s attempted access, denying", clientIP) + c.AbortWithStatusJSON(403, gin.H{ + "error": "Access denied: management routes restricted to localhost", + }) + return + } + + c.Next() + } +} + +// noCORSMiddleware disables CORS for management routes to prevent browser-based attacks. +// This overwrites any global CORS headers set by the server. +func noCORSMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + // Remove CORS headers to prevent cross-origin access from browsers + c.Header("Access-Control-Allow-Origin", "") + c.Header("Access-Control-Allow-Methods", "") + c.Header("Access-Control-Allow-Headers", "") + c.Header("Access-Control-Allow-Credentials", "") + + // For OPTIONS preflight, deny with 403 + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(403) + return + } + + c.Next() + } +} + +// registerManagementRoutes registers Amp management proxy routes +// These routes proxy through to the Amp control plane for OAuth, user management, etc. +// If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1. +func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) { + ampAPI := engine.Group("/api") + + // Always disable CORS for management routes to prevent browser-based attacks + ampAPI.Use(noCORSMiddleware()) + + // Apply localhost-only restriction if configured + if restrictToLocalhost { + ampAPI.Use(localhostOnlyMiddleware()) + log.Info("Amp management routes restricted to localhost only (CORS disabled)") + } else { + log.Warn("⚠️ Amp management routes are NOT restricted to localhost - this is insecure!") + } + + // Management routes - these are proxied directly to Amp upstream + ampAPI.Any("/internal", proxyHandler) + ampAPI.Any("/internal/*path", proxyHandler) + ampAPI.Any("/user", proxyHandler) + ampAPI.Any("/user/*path", proxyHandler) + ampAPI.Any("/auth", proxyHandler) + ampAPI.Any("/auth/*path", proxyHandler) + ampAPI.Any("/meta", proxyHandler) + ampAPI.Any("/meta/*path", proxyHandler) + ampAPI.Any("/ads", proxyHandler) + ampAPI.Any("/telemetry", proxyHandler) + ampAPI.Any("/telemetry/*path", proxyHandler) + ampAPI.Any("/threads", proxyHandler) + ampAPI.Any("/threads/*path", proxyHandler) + ampAPI.Any("/otel", proxyHandler) + ampAPI.Any("/otel/*path", proxyHandler) + + // Google v1beta1 passthrough (Gemini native API) + ampAPI.Any("/provider/google/v1beta1/*path", proxyHandler) +} + +// registerProviderAliases registers /api/provider/{provider}/... routes +// These allow Amp CLI to route requests like: +// +// /api/provider/openai/v1/chat/completions +// /api/provider/anthropic/v1/messages +// /api/provider/google/v1beta/models +func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) { + // Create handler instances for different providers + openaiHandlers := openai.NewOpenAIAPIHandler(baseHandler) + geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler) + claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler) + openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler) + + // Provider-specific routes under /api/provider/:provider + ampProviders := engine.Group("/api/provider") + if auth != nil { + ampProviders.Use(auth) + } + + provider := ampProviders.Group("/:provider") + + // Dynamic models handler - routes to appropriate provider based on path parameter + ampModelsHandler := func(c *gin.Context) { + providerName := strings.ToLower(c.Param("provider")) + + switch providerName { + case "anthropic": + claudeCodeHandlers.ClaudeModels(c) + case "google": + geminiHandlers.GeminiModels(c) + default: + // Default to OpenAI-compatible (works for openai, groq, cerebras, etc.) + openaiHandlers.OpenAIModels(c) + } + } + + // Root-level routes (for providers that omit /v1, like groq/cerebras) + provider.GET("/models", ampModelsHandler) + provider.POST("/chat/completions", openaiHandlers.ChatCompletions) + provider.POST("/completions", openaiHandlers.Completions) + provider.POST("/responses", openaiResponsesHandlers.Responses) + + // /v1 routes (OpenAI/Claude-compatible endpoints) + v1Amp := provider.Group("/v1") + { + v1Amp.GET("/models", ampModelsHandler) + + // OpenAI-compatible endpoints + v1Amp.POST("/chat/completions", openaiHandlers.ChatCompletions) + v1Amp.POST("/completions", openaiHandlers.Completions) + v1Amp.POST("/responses", openaiResponsesHandlers.Responses) + + // Claude/Anthropic-compatible endpoints + v1Amp.POST("/messages", claudeCodeHandlers.ClaudeMessages) + v1Amp.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens) + } + + // /v1beta routes (Gemini native API) + v1betaAmp := provider.Group("/v1beta") + { + v1betaAmp.GET("/models", geminiHandlers.GeminiModels) + v1betaAmp.POST("/models/:action", geminiHandlers.GeminiHandler) + v1betaAmp.GET("/models/:action", geminiHandlers.GeminiGetHandler) + } +} diff --git a/internal/api/modules/amp/routes_test.go b/internal/api/modules/amp/routes_test.go new file mode 100644 index 00000000..953b93bd --- /dev/null +++ b/internal/api/modules/amp/routes_test.go @@ -0,0 +1,216 @@ +package amp + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" +) + +func TestRegisterManagementRoutes(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + + // Spy to track if proxy handler was called + proxyCalled := false + proxyHandler := func(c *gin.Context) { + proxyCalled = true + c.String(200, "proxied") + } + + m := &AmpModule{} + m.registerManagementRoutes(r, proxyHandler, false) // false = don't restrict to localhost in tests + + managementPaths := []string{ + "/api/internal", + "/api/internal/some/path", + "/api/user", + "/api/user/profile", + "/api/auth", + "/api/auth/login", + "/api/meta", + "/api/telemetry", + "/api/threads", + "/api/otel", + "/api/provider/google/v1beta1/models", + } + + for _, path := range managementPaths { + t.Run(path, func(t *testing.T) { + proxyCalled = false + req := httptest.NewRequest(http.MethodGet, path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code == http.StatusNotFound { + t.Fatalf("route %s not registered", path) + } + if !proxyCalled { + t.Fatalf("proxy handler not called for %s", path) + } + }) + } +} + +func TestRegisterProviderAliases_AllProvidersRegistered(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + + // Minimal base handler setup (no need to initialize, just check routing) + base := &handlers.BaseAPIHandler{} + + // Track if auth middleware was called + authCalled := false + authMiddleware := func(c *gin.Context) { + authCalled = true + c.Header("X-Auth", "ok") + // Abort with success to avoid calling the actual handler (which needs full setup) + c.AbortWithStatus(http.StatusOK) + } + + m := &AmpModule{authMiddleware_: authMiddleware} + m.registerProviderAliases(r, base, authMiddleware) + + paths := []struct { + path string + method string + }{ + {"/api/provider/openai/models", http.MethodGet}, + {"/api/provider/anthropic/models", http.MethodGet}, + {"/api/provider/google/models", http.MethodGet}, + {"/api/provider/groq/models", http.MethodGet}, + {"/api/provider/openai/chat/completions", http.MethodPost}, + {"/api/provider/anthropic/v1/messages", http.MethodPost}, + {"/api/provider/google/v1beta/models", http.MethodGet}, + } + + for _, tc := range paths { + t.Run(tc.path, func(t *testing.T) { + authCalled = false + req := httptest.NewRequest(tc.method, tc.path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code == http.StatusNotFound { + t.Fatalf("route %s %s not registered", tc.method, tc.path) + } + if !authCalled { + t.Fatalf("auth middleware not executed for %s", tc.path) + } + if w.Header().Get("X-Auth") != "ok" { + t.Fatalf("auth middleware header not set for %s", tc.path) + } + }) + } +} + +func TestRegisterProviderAliases_DynamicModelsHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + + base := &handlers.BaseAPIHandler{} + + m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }} + m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }) + + providers := []string{"openai", "anthropic", "google", "groq", "cerebras"} + + for _, provider := range providers { + t.Run(provider, func(t *testing.T) { + path := "/api/provider/" + provider + "/models" + req := httptest.NewRequest(http.MethodGet, path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + // Should not 404 + if w.Code == http.StatusNotFound { + t.Fatalf("models route not found for provider: %s", provider) + } + }) + } +} + +func TestRegisterProviderAliases_V1Routes(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + + base := &handlers.BaseAPIHandler{} + + m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }} + m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }) + + v1Paths := []struct { + path string + method string + }{ + {"/api/provider/openai/v1/models", http.MethodGet}, + {"/api/provider/openai/v1/chat/completions", http.MethodPost}, + {"/api/provider/openai/v1/completions", http.MethodPost}, + {"/api/provider/anthropic/v1/messages", http.MethodPost}, + {"/api/provider/anthropic/v1/messages/count_tokens", http.MethodPost}, + } + + for _, tc := range v1Paths { + t.Run(tc.path, func(t *testing.T) { + req := httptest.NewRequest(tc.method, tc.path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code == http.StatusNotFound { + t.Fatalf("v1 route %s %s not registered", tc.method, tc.path) + } + }) + } +} + +func TestRegisterProviderAliases_V1BetaRoutes(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + + base := &handlers.BaseAPIHandler{} + + m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }} + m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }) + + v1betaPaths := []struct { + path string + method string + }{ + {"/api/provider/google/v1beta/models", http.MethodGet}, + {"/api/provider/google/v1beta/models/generateContent", http.MethodPost}, + } + + for _, tc := range v1betaPaths { + t.Run(tc.path, func(t *testing.T) { + req := httptest.NewRequest(tc.method, tc.path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code == http.StatusNotFound { + t.Fatalf("v1beta route %s %s not registered", tc.method, tc.path) + } + }) + } +} + +func TestRegisterProviderAliases_NoAuthMiddleware(t *testing.T) { + // Test that routes still register even if auth middleware is nil (fallback behavior) + gin.SetMode(gin.TestMode) + r := gin.New() + + base := &handlers.BaseAPIHandler{} + + m := &AmpModule{authMiddleware_: nil} // No auth middleware + m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }) + + req := httptest.NewRequest(http.MethodGet, "/api/provider/openai/models", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + // Should still work (with fallback no-op auth) + if w.Code == http.StatusNotFound { + t.Fatal("routes should register even without auth middleware") + } +} diff --git a/internal/api/modules/amp/secret.go b/internal/api/modules/amp/secret.go new file mode 100644 index 00000000..a4af1414 --- /dev/null +++ b/internal/api/modules/amp/secret.go @@ -0,0 +1,155 @@ +package amp + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +// SecretSource provides Amp API keys with configurable precedence and caching +type SecretSource interface { + Get(ctx context.Context) (string, error) +} + +// cachedSecret holds a secret value with expiration +type cachedSecret struct { + value string + expiresAt time.Time +} + +// MultiSourceSecret implements precedence-based secret lookup: +// 1. Explicit config value (highest priority) +// 2. Environment variable AMP_API_KEY +// 3. File-based secret (lowest priority) +type MultiSourceSecret struct { + explicitKey string + envKey string + filePath string + cacheTTL time.Duration + + mu sync.RWMutex + cache *cachedSecret +} + +// NewMultiSourceSecret creates a secret source with precedence and caching +func NewMultiSourceSecret(explicitKey string, cacheTTL time.Duration) *MultiSourceSecret { + if cacheTTL == 0 { + cacheTTL = 5 * time.Minute // Default 5 minute cache + } + + home, _ := os.UserHomeDir() + filePath := filepath.Join(home, ".local", "share", "amp", "secrets.json") + + return &MultiSourceSecret{ + explicitKey: strings.TrimSpace(explicitKey), + envKey: "AMP_API_KEY", + filePath: filePath, + cacheTTL: cacheTTL, + } +} + +// NewMultiSourceSecretWithPath creates a secret source with a custom file path (for testing) +func NewMultiSourceSecretWithPath(explicitKey string, filePath string, cacheTTL time.Duration) *MultiSourceSecret { + if cacheTTL == 0 { + cacheTTL = 5 * time.Minute + } + + return &MultiSourceSecret{ + explicitKey: strings.TrimSpace(explicitKey), + envKey: "AMP_API_KEY", + filePath: filePath, + cacheTTL: cacheTTL, + } +} + +// Get retrieves the Amp API key using precedence: config > env > file +// Results are cached for cacheTTL duration to avoid excessive file reads +func (s *MultiSourceSecret) Get(ctx context.Context) (string, error) { + // Precedence 1: Explicit config key (highest priority, no caching needed) + if s.explicitKey != "" { + return s.explicitKey, nil + } + + // Precedence 2: Environment variable + if envValue := strings.TrimSpace(os.Getenv(s.envKey)); envValue != "" { + return envValue, nil + } + + // Precedence 3: File-based secret (lowest priority, cached) + // Check cache first + s.mu.RLock() + if s.cache != nil && time.Now().Before(s.cache.expiresAt) { + value := s.cache.value + s.mu.RUnlock() + return value, nil + } + s.mu.RUnlock() + + // Cache miss or expired - read from file + key, err := s.readFromFile() + if err != nil { + // Cache empty result to avoid repeated file reads on missing files + s.updateCache("") + return "", err + } + + // Cache the result + s.updateCache(key) + return key, nil +} + +// readFromFile reads the Amp API key from the secrets file +func (s *MultiSourceSecret) readFromFile() (string, error) { + content, err := os.ReadFile(s.filePath) + if err != nil { + if os.IsNotExist(err) { + return "", nil // Missing file is not an error, just no key available + } + return "", fmt.Errorf("failed to read amp secrets from %s: %w", s.filePath, err) + } + + var secrets map[string]string + if err := json.Unmarshal(content, &secrets); err != nil { + return "", fmt.Errorf("failed to parse amp secrets from %s: %w", s.filePath, err) + } + + key := strings.TrimSpace(secrets["apiKey@https://ampcode.com/"]) + return key, nil +} + +// updateCache updates the cached secret value +func (s *MultiSourceSecret) updateCache(value string) { + s.mu.Lock() + defer s.mu.Unlock() + s.cache = &cachedSecret{ + value: value, + expiresAt: time.Now().Add(s.cacheTTL), + } +} + +// InvalidateCache clears the cached secret, forcing a fresh read on next Get +func (s *MultiSourceSecret) InvalidateCache() { + s.mu.Lock() + defer s.mu.Unlock() + s.cache = nil +} + +// StaticSecretSource returns a fixed API key (for testing) +type StaticSecretSource struct { + key string +} + +// NewStaticSecretSource creates a secret source with a fixed key +func NewStaticSecretSource(key string) *StaticSecretSource { + return &StaticSecretSource{key: strings.TrimSpace(key)} +} + +// Get returns the static API key +func (s *StaticSecretSource) Get(ctx context.Context) (string, error) { + return s.key, nil +} diff --git a/internal/api/modules/amp/secret_test.go b/internal/api/modules/amp/secret_test.go new file mode 100644 index 00000000..9c3e820a --- /dev/null +++ b/internal/api/modules/amp/secret_test.go @@ -0,0 +1,280 @@ +package amp + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "sync" + "testing" + "time" +) + +func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) { + ctx := context.Background() + + cases := []struct { + name string + configKey string + envKey string + fileJSON string + want string + }{ + {"config_wins", "cfg", "env", `{"apiKey@https://ampcode.com/":"file"}`, "cfg"}, + {"env_wins_when_no_cfg", "", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"}, + {"file_when_no_cfg_env", "", "", `{"apiKey@https://ampcode.com/":"file"}`, "file"}, + {"empty_cfg_trims_then_env", " ", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"}, + {"empty_env_then_file", "", " ", `{"apiKey@https://ampcode.com/":"file"}`, "file"}, + {"missing_file_returns_empty", "", "", "", ""}, + {"all_empty_returns_empty", " ", " ", `{"apiKey@https://ampcode.com/":" "}`, ""}, + } + + for _, tc := range cases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + tmpDir := t.TempDir() + secretsPath := filepath.Join(tmpDir, "secrets.json") + + if tc.fileJSON != "" { + if err := os.WriteFile(secretsPath, []byte(tc.fileJSON), 0600); err != nil { + t.Fatal(err) + } + } + + t.Setenv("AMP_API_KEY", tc.envKey) + + s := NewMultiSourceSecretWithPath(tc.configKey, secretsPath, 100*time.Millisecond) + got, err := s.Get(ctx) + if err != nil && tc.fileJSON != "" && json.Valid([]byte(tc.fileJSON)) { + t.Fatalf("unexpected error: %v", err) + } + if got != tc.want { + t.Fatalf("want %q, got %q", tc.want, got) + } + }) + } +} + +func TestMultiSourceSecret_CacheBehavior(t *testing.T) { + ctx := context.Background() + tmpDir := t.TempDir() + p := filepath.Join(tmpDir, "secrets.json") + + // Initial value + if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil { + t.Fatal(err) + } + + s := NewMultiSourceSecretWithPath("", p, 50*time.Millisecond) + + // First read - should return v1 + got1, err := s.Get(ctx) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if got1 != "v1" { + t.Fatalf("expected v1, got %s", got1) + } + + // Change file; within TTL we should still see v1 (cached) + if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v2"}`), 0600); err != nil { + t.Fatal(err) + } + got2, _ := s.Get(ctx) + if got2 != "v1" { + t.Fatalf("cache hit expected v1, got %s", got2) + } + + // After TTL expires, should see v2 + time.Sleep(60 * time.Millisecond) + got3, _ := s.Get(ctx) + if got3 != "v2" { + t.Fatalf("cache miss expected v2, got %s", got3) + } + + // Invalidate forces re-read immediately + if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v3"}`), 0600); err != nil { + t.Fatal(err) + } + s.InvalidateCache() + got4, _ := s.Get(ctx) + if got4 != "v3" { + t.Fatalf("invalidate expected v3, got %s", got4) + } +} + +func TestMultiSourceSecret_FileHandling(t *testing.T) { + ctx := context.Background() + + t.Run("missing_file_no_error", func(t *testing.T) { + s := NewMultiSourceSecretWithPath("", "/nonexistent/path/secrets.json", 100*time.Millisecond) + got, err := s.Get(ctx) + if err != nil { + t.Fatalf("expected no error for missing file, got: %v", err) + } + if got != "" { + t.Fatalf("expected empty string, got %q", got) + } + }) + + t.Run("invalid_json", func(t *testing.T) { + tmpDir := t.TempDir() + p := filepath.Join(tmpDir, "secrets.json") + if err := os.WriteFile(p, []byte(`{invalid json`), 0600); err != nil { + t.Fatal(err) + } + + s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) + _, err := s.Get(ctx) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + }) + + t.Run("missing_key_in_json", func(t *testing.T) { + tmpDir := t.TempDir() + p := filepath.Join(tmpDir, "secrets.json") + if err := os.WriteFile(p, []byte(`{"other":"value"}`), 0600); err != nil { + t.Fatal(err) + } + + s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) + got, err := s.Get(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "" { + t.Fatalf("expected empty string for missing key, got %q", got) + } + }) + + t.Run("empty_key_value", func(t *testing.T) { + tmpDir := t.TempDir() + p := filepath.Join(tmpDir, "secrets.json") + if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":" "}`), 0600); err != nil { + t.Fatal(err) + } + + s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) + got, _ := s.Get(ctx) + if got != "" { + t.Fatalf("expected empty after trim, got %q", got) + } + }) +} + +func TestMultiSourceSecret_Concurrency(t *testing.T) { + tmpDir := t.TempDir() + p := filepath.Join(tmpDir, "secrets.json") + if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"concurrent"}`), 0600); err != nil { + t.Fatal(err) + } + + s := NewMultiSourceSecretWithPath("", p, 5*time.Second) + ctx := context.Background() + + // Spawn many goroutines calling Get concurrently + const goroutines = 50 + const iterations = 100 + + var wg sync.WaitGroup + errors := make(chan error, goroutines) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < iterations; j++ { + val, err := s.Get(ctx) + if err != nil { + errors <- err + return + } + if val != "concurrent" { + errors <- err + return + } + } + }() + } + + wg.Wait() + close(errors) + + for err := range errors { + t.Errorf("concurrency error: %v", err) + } +} + +func TestStaticSecretSource(t *testing.T) { + ctx := context.Background() + + t.Run("returns_provided_key", func(t *testing.T) { + s := NewStaticSecretSource("test-key-123") + got, err := s.Get(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "test-key-123" { + t.Fatalf("want test-key-123, got %q", got) + } + }) + + t.Run("trims_whitespace", func(t *testing.T) { + s := NewStaticSecretSource(" test-key ") + got, err := s.Get(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "test-key" { + t.Fatalf("want test-key, got %q", got) + } + }) + + t.Run("empty_string", func(t *testing.T) { + s := NewStaticSecretSource("") + got, err := s.Get(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "" { + t.Fatalf("want empty string, got %q", got) + } + }) +} + +func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) { + // Test that missing file results are cached to avoid repeated file reads + tmpDir := t.TempDir() + p := filepath.Join(tmpDir, "nonexistent.json") + + s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) + ctx := context.Background() + + // First call - file doesn't exist, should cache empty result + got1, err := s.Get(ctx) + if err != nil { + t.Fatalf("expected no error for missing file, got: %v", err) + } + if got1 != "" { + t.Fatalf("expected empty string, got %q", got1) + } + + // Create the file now + if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"new-value"}`), 0600); err != nil { + t.Fatal(err) + } + + // Second call - should still return empty (cached), not read the new file + got2, _ := s.Get(ctx) + if got2 != "" { + t.Fatalf("cache should return empty, got %q", got2) + } + + // After TTL expires, should see the new value + time.Sleep(110 * time.Millisecond) + got3, _ := s.Get(ctx) + if got3 != "new-value" { + t.Fatalf("after cache expiry, expected new-value, got %q", got3) + } +} diff --git a/internal/api/modules/modules.go b/internal/api/modules/modules.go new file mode 100644 index 00000000..8c5447d9 --- /dev/null +++ b/internal/api/modules/modules.go @@ -0,0 +1,92 @@ +// Package modules provides a pluggable routing module system for extending +// the API server with optional features without modifying core routing logic. +package modules + +import ( + "fmt" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" +) + +// Context encapsulates the dependencies exposed to routing modules during +// registration. Modules can use the Gin engine to attach routes, the shared +// BaseAPIHandler for constructing SDK-specific handlers, and the resolved +// authentication middleware for protecting routes that require API keys. +type Context struct { + Engine *gin.Engine + BaseHandler *handlers.BaseAPIHandler + Config *config.Config + AuthMiddleware gin.HandlerFunc +} + +// RouteModule represents a pluggable routing module that can register routes +// and handle configuration updates independently of the core server. +// +// DEPRECATED: Use RouteModuleV2 for new modules. This interface is kept for +// backwards compatibility and will be removed in a future version. +type RouteModule interface { + // Name returns a human-readable identifier for the module + Name() string + + // Register sets up routes and handlers for this module. + // It receives the Gin engine, base handlers, and current configuration. + // Returns an error if registration fails (errors are logged but don't stop the server). + Register(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, cfg *config.Config) error + + // OnConfigUpdated is called when the configuration is reloaded. + // Modules can respond to configuration changes here. + // Returns an error if the update cannot be applied. + OnConfigUpdated(cfg *config.Config) error +} + +// RouteModuleV2 represents a pluggable bundle of routes that can integrate with +// the API server without modifying its core routing logic. Implementations can +// attach routes during Register and react to configuration updates via +// OnConfigUpdated. +// +// This is the preferred interface for new modules. It uses Context for cleaner +// dependency injection and supports idempotent registration. +type RouteModuleV2 interface { + // Name returns a unique identifier for logging and diagnostics. + Name() string + + // Register wires the module's routes into the provided Gin engine. Modules + // should treat multiple calls as idempotent and avoid duplicate route + // registration when invoked more than once. + Register(ctx Context) error + + // OnConfigUpdated notifies the module when the server configuration changes + // via hot reload. Implementations can refresh cached state or emit warnings. + OnConfigUpdated(cfg *config.Config) error +} + +// RegisterModule is a helper that registers a module using either the V1 or V2 +// interface. This allows gradual migration from V1 to V2 without breaking +// existing modules. +// +// Example usage: +// +// ctx := modules.Context{ +// Engine: engine, +// BaseHandler: baseHandler, +// Config: cfg, +// AuthMiddleware: authMiddleware, +// } +// if err := modules.RegisterModule(ctx, ampModule); err != nil { +// log.Errorf("Failed to register module: %v", err) +// } +func RegisterModule(ctx Context, mod interface{}) error { + // Try V2 interface first (preferred) + if v2, ok := mod.(RouteModuleV2); ok { + return v2.Register(ctx) + } + + // Fall back to V1 interface for backwards compatibility + if v1, ok := mod.(RouteModule); ok { + return v1.Register(ctx.Engine, ctx.BaseHandler, ctx.Config) + } + + return fmt.Errorf("unsupported module type %T (must implement RouteModule or RouteModuleV2)", mod) +} diff --git a/internal/api/server.go b/internal/api/server.go index 78672f02..2c545be7 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -21,6 +21,8 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/access" managementHandlers "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management" "github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware" + "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules" + ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset" @@ -261,6 +263,20 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk // Setup routes s.setupRoutes() + + // Register Amp module using V2 interface with Context + ampModule := ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager)) + ctx := modules.Context{ + Engine: engine, + BaseHandler: s.handlers, + Config: cfg, + AuthMiddleware: AuthMiddleware(accessManager), + } + if err := modules.RegisterModule(ctx, ampModule); err != nil { + log.Errorf("Failed to register Amp module: %v", err) + } + + // Apply additional router configurators from options if optionState.routerConfigurator != nil { optionState.routerConfigurator(engine, s.handlers, cfg) } diff --git a/internal/api/server_test.go b/internal/api/server_test.go new file mode 100644 index 00000000..06653210 --- /dev/null +++ b/internal/api/server_test.go @@ -0,0 +1,111 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + gin "github.com/gin-gonic/gin" + proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +func newTestServer(t *testing.T) *Server { + t.Helper() + + gin.SetMode(gin.TestMode) + + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o700); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + + cfg := &proxyconfig.Config{ + SDKConfig: sdkconfig.SDKConfig{ + APIKeys: []string{"test-key"}, + }, + Port: 0, + AuthDir: authDir, + Debug: true, + LoggingToFile: false, + UsageStatisticsEnabled: false, + } + + authManager := auth.NewManager(nil, nil, nil) + accessManager := sdkaccess.NewManager() + + configPath := filepath.Join(tmpDir, "config.yaml") + return NewServer(cfg, authManager, accessManager, configPath) +} + +func TestAmpProviderModelRoutes(t *testing.T) { + testCases := []struct { + name string + path string + wantStatus int + wantContains string + }{ + { + name: "openai root models", + path: "/api/provider/openai/models", + wantStatus: http.StatusOK, + wantContains: `"object":"list"`, + }, + { + name: "groq root models", + path: "/api/provider/groq/models", + wantStatus: http.StatusOK, + wantContains: `"object":"list"`, + }, + { + name: "openai models", + path: "/api/provider/openai/v1/models", + wantStatus: http.StatusOK, + wantContains: `"object":"list"`, + }, + { + name: "anthropic models", + path: "/api/provider/anthropic/v1/models", + wantStatus: http.StatusOK, + wantContains: `"data"`, + }, + { + name: "google models v1", + path: "/api/provider/google/v1/models", + wantStatus: http.StatusOK, + wantContains: `"models"`, + }, + { + name: "google models v1beta", + path: "/api/provider/google/v1beta/models", + wantStatus: http.StatusOK, + wantContains: `"models"`, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + server := newTestServer(t) + + req := httptest.NewRequest(http.MethodGet, tc.path, nil) + req.Header.Set("Authorization", "Bearer test-key") + + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code != tc.wantStatus { + t.Fatalf("unexpected status code for %s: got %d want %d; body=%s", tc.path, rr.Code, tc.wantStatus, rr.Body.String()) + } + if body := rr.Body.String(); !strings.Contains(body, tc.wantContains) { + t.Fatalf("response body for %s missing %q: %s", tc.path, tc.wantContains, body) + } + }) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 58d4b20c..ec97064e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -23,6 +23,17 @@ type Config struct { // Port is the network port on which the API server will listen. Port int `yaml:"port" json:"-"` + // AmpUpstreamURL defines the upstream Amp control plane used for non-provider calls. + AmpUpstreamURL string `yaml:"amp-upstream-url" json:"amp-upstream-url"` + + // AmpUpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls. + AmpUpstreamAPIKey string `yaml:"amp-upstream-api-key" json:"amp-upstream-api-key"` + + // AmpRestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.) + // to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by + // browser attacks and remote access to management endpoints. Default: true (recommended). + AmpRestrictManagementToLocalhost bool `yaml:"amp-restrict-management-to-localhost" json:"amp-restrict-management-to-localhost"` + // AuthDir is the directory where authentication token files are stored. AuthDir string `yaml:"auth-dir" json:"-"` @@ -258,6 +269,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { cfg.LoggingToFile = false cfg.UsageStatisticsEnabled = false cfg.DisableCooling = false + cfg.AmpRestrictManagementToLocalhost = true // Default to secure: only localhost access if err = yaml.Unmarshal(data, &cfg); err != nil { if optional { // In cloud deploy mode, if YAML parsing fails, return empty config instead of error. diff --git a/sdk/api/handlers/claude/code_handlers.go b/sdk/api/handlers/claude/code_handlers.go index 7fac9d74..63ea6065 100644 --- a/sdk/api/handlers/claude/code_handlers.go +++ b/sdk/api/handlers/claude/code_handlers.go @@ -8,9 +8,12 @@ package claude import ( "bufio" + "bytes" + "compress/gzip" "context" "encoding/json" "fmt" + "io" "net/http" "time" @@ -19,6 +22,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) @@ -153,6 +157,23 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO cliCancel(errMsg.Error) return } + + // Decompress gzipped responses - Claude API sometimes returns gzip without Content-Encoding header + // This fixes title generation and other non-streaming responses that arrive compressed + if len(resp) >= 2 && resp[0] == 0x1f && resp[1] == 0x8b { + gzReader, err := gzip.NewReader(bytes.NewReader(resp)) + if err != nil { + log.Warnf("failed to decompress gzipped Claude response: %v", err) + } else { + defer gzReader.Close() + if decompressed, err := io.ReadAll(gzReader); err != nil { + log.Warnf("failed to read decompressed Claude response: %v", err) + } else { + resp = decompressed + } + } + } + _, _ = c.Writer.Write(resp) cliCancel() } diff --git a/sdk/api/httpx/gzip.go b/sdk/api/httpx/gzip.go new file mode 100644 index 00000000..09ecc01d --- /dev/null +++ b/sdk/api/httpx/gzip.go @@ -0,0 +1,33 @@ +// Package httpx provides HTTP transport utilities for SDK clients, +// including automatic gzip decompression for misconfigured upstreams. +package httpx + +import ( + "bytes" + "compress/gzip" + "io" +) + +// DecodePossibleGzip inspects the raw response body and transparently +// decompresses it when the payload is gzip compressed. Some upstream +// providers return gzip data without a Content-Encoding header, which +// confuses clients expecting JSON. This helper restores the original +// JSON bytes while leaving plain responses untouched. +// +// This function is preserved for backward compatibility but new code +// should use GzipFixupTransport instead. +func DecodePossibleGzip(raw []byte) ([]byte, error) { + if len(raw) >= 2 && raw[0] == 0x1f && raw[1] == 0x8b { + reader, err := gzip.NewReader(bytes.NewReader(raw)) + if err != nil { + return nil, err + } + decompressed, err := io.ReadAll(reader) + _ = reader.Close() + if err != nil { + return nil, err + } + return decompressed, nil + } + return raw, nil +} diff --git a/sdk/api/httpx/transport.go b/sdk/api/httpx/transport.go new file mode 100644 index 00000000..25be69df --- /dev/null +++ b/sdk/api/httpx/transport.go @@ -0,0 +1,177 @@ +package httpx + +import ( + "bytes" + "compress/gzip" + "io" + "net/http" + "strings" + + log "github.com/sirupsen/logrus" +) + +// GzipFixupTransport wraps an http.RoundTripper to auto-decode gzip responses +// that don't properly set Content-Encoding header. +// +// Some upstream providers (especially when proxied) return gzip-compressed +// responses without setting the Content-Encoding: gzip header, which causes +// Go's http client to pass the compressed bytes directly to the application. +// +// This transport detects gzip magic bytes and transparently decompresses +// the response while preserving streaming behavior for SSE and chunked responses. +type GzipFixupTransport struct { + // Base is the underlying transport. If nil, http.DefaultTransport is used. + Base http.RoundTripper +} + +// RoundTrip implements http.RoundTripper +func (t *GzipFixupTransport) RoundTrip(req *http.Request) (*http.Response, error) { + base := t.Base + if base == nil { + base = http.DefaultTransport + } + + resp, err := base.RoundTrip(req) + if err != nil || resp == nil { + return resp, err + } + + // Skip if Go already decompressed it + if resp.Uncompressed { + return resp, nil + } + + // Skip if Content-Encoding is already set (properly configured upstream) + if resp.Header.Get("Content-Encoding") != "" { + return resp, nil + } + + // Skip streaming responses - they need different handling + if isStreamingResponse(resp) { + // For streaming responses, wrap with a streaming gzip detector + // that can handle chunked gzip data + resp.Body = &streamingGzipDetector{ + inner: resp.Body, + } + return resp, nil + } + + // For non-streaming responses, peek and decompress if needed + resp.Body = &gzipDetectingReader{ + inner: resp.Body, + } + + return resp, nil +} + +// isStreamingResponse checks if response is SSE or chunked +func isStreamingResponse(resp *http.Response) bool { + contentType := resp.Header.Get("Content-Type") + + // Check for Server-Sent Events + if strings.Contains(contentType, "text/event-stream") { + return true + } + + // Check for chunked transfer encoding + if strings.Contains(strings.ToLower(resp.Header.Get("Transfer-Encoding")), "chunked") { + return true + } + + return false +} + +// gzipDetectingReader is an io.ReadCloser that detects gzip magic bytes +// on first read and switches to gzip decompression if detected. +// This is used for non-streaming responses. +type gzipDetectingReader struct { + inner io.ReadCloser + reader io.Reader + once bool +} + +func (g *gzipDetectingReader) Read(p []byte) (int, error) { + if !g.once { + g.once = true + + // Peek at first 2 bytes to detect gzip magic bytes + buf := make([]byte, 2) + n, err := io.ReadFull(g.inner, buf) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + // Can't peek, use original reader + g.reader = io.MultiReader(bytes.NewReader(buf[:n]), g.inner) + return g.reader.Read(p) + } + + if n >= 2 && buf[0] == 0x1f && buf[1] == 0x8b { + // It's gzipped, create gzip reader + multiReader := io.MultiReader(bytes.NewReader(buf[:n]), g.inner) + gzipReader, err := gzip.NewReader(multiReader) + if err != nil { + log.Warnf("gzip header detected but reader creation failed: %v", err) + g.reader = multiReader + } else { + g.reader = gzipReader + } + } else { + // Not gzipped, combine peeked bytes with rest + g.reader = io.MultiReader(bytes.NewReader(buf[:n]), g.inner) + } + } + + return g.reader.Read(p) +} + +func (g *gzipDetectingReader) Close() error { + if closer, ok := g.reader.(io.Closer); ok { + _ = closer.Close() + } + return g.inner.Close() +} + +// streamingGzipDetector is similar to gzipDetectingReader but designed for +// streaming responses. It doesn't buffer; it wraps with a streaming gzip reader. +type streamingGzipDetector struct { + inner io.ReadCloser + reader io.Reader + once bool +} + +func (s *streamingGzipDetector) Read(p []byte) (int, error) { + if !s.once { + s.once = true + + // Peek at first 2 bytes + buf := make([]byte, 2) + n, err := io.ReadFull(s.inner, buf) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + s.reader = io.MultiReader(bytes.NewReader(buf[:n]), s.inner) + return s.reader.Read(p) + } + + if n >= 2 && buf[0] == 0x1f && buf[1] == 0x8b { + // It's gzipped - wrap with streaming gzip reader + multiReader := io.MultiReader(bytes.NewReader(buf[:n]), s.inner) + gzipReader, err := gzip.NewReader(multiReader) + if err != nil { + log.Warnf("streaming gzip header detected but reader creation failed: %v", err) + s.reader = multiReader + } else { + s.reader = gzipReader + log.Debug("streaming gzip decompression enabled") + } + } else { + // Not gzipped + s.reader = io.MultiReader(bytes.NewReader(buf[:n]), s.inner) + } + } + + return s.reader.Read(p) +} + +func (s *streamingGzipDetector) Close() error { + if closer, ok := s.reader.(io.Closer); ok { + _ = closer.Close() + } + return s.inner.Close() +}