Compare commits

...

72 Commits

Author SHA1 Message Date
Luis Pater
5c95129884 Merge branch 'router-for-me:main' into main 2025-12-30 11:48:29 +08:00
Luis Pater
9fa2a7e9df Merge pull request #782 from router-for-me/modelmappings
refactor(config): rename model-name-mappings to oauth-model-mappings
2025-12-30 11:40:12 +08:00
hkfires
d443c86620 refactor(config): rename model mapping fields from from/to to name/alias 2025-12-30 11:07:59 +08:00
hkfires
7be3f1c36c refactor(config): rename model-name-mappings to oauth-model-mappings 2025-12-30 11:07:58 +08:00
Luis Pater
f6ab6d97b9 fix(logging): add isDirWritable utility to enhance log dir validation in ConfigureLogOutput 2025-12-30 10:48:25 +08:00
Luis Pater
bc866bac49 fix(logging): refactor ConfigureLogOutput to accept config object and adjust log directory handling 2025-12-30 10:28:25 +08:00
Luis Pater
50e6d845f4 feat(cliproxy): introduce global model name mappings for improved aliasing and routing 2025-12-30 08:13:06 +08:00
Luis Pater
a8cb01819d Merge pull request #772 from soffchen/main
fix: Implement fallback log directory for file logging on read-only system
2025-12-30 02:24:49 +08:00
Luis Pater
1a99cfded4 Merge branch 'router-for-me:main' into main 2025-12-30 00:38:41 +08:00
Luis Pater
530273906b Merge pull request #776 from router-for-me/fix-ag-claude
fix(antigravity): inject required placeholder when properties exist w…
2025-12-30 00:37:01 +08:00
Supra4E8C
06ddf575d9 fix(antigravity): inject required placeholder when properties exist without required 2025-12-29 23:55:59 +08:00
Luis Pater
cf369d4684 Merge branch 'router-for-me:main' into main 2025-12-29 22:41:44 +08:00
hkfires
3099114cbb refactor(api): simplify codex id token claims extraction 2025-12-29 19:48:02 +08:00
Soff
44b63f0767 fix: Return an error if the user home directory cannot be determined for the fallback log path. 2025-12-29 18:46:15 +08:00
Soff Chen
6705d20194 fix: Implement fallback log directory for file logging on read-only systems. 2025-12-29 18:35:48 +08:00
Chén Mù
a38a9c0b0f Merge pull request #770 from router-for-me/api
feat(api): add id token claims extraction for codex auth entries
2025-12-29 00:44:41 -08:00
hkfires
8286caa366 feat(api): add id token claims extraction for codex auth entries 2025-12-29 16:34:16 +08:00
Chén Mù
bd1ec8424d Merge pull request #767 from router-for-me/amp
feat(amp): add per-client upstream API key mapping support
2025-12-28 22:10:11 -08:00
hkfires
225e2c6797 feat(amp): add per-client upstream API key mapping support 2025-12-29 12:26:25 +08:00
Luis Pater
d8fc485513 fix(translators): correct key path for system_instruction.parts in Claude request logic 2025-12-29 11:54:26 +08:00
hkfires
f137eb0ac4 chore: add codex, agents, and opencode dirs to ignore files 2025-12-29 08:42:29 +08:00
Chén Mù
f39a460487 Merge pull request #761 from router-for-me/log
fix(logging): improve request/response capture
2025-12-28 16:13:10 -08:00
Luis Pater
ee171bc563 feat(api): add ManagementTokenRequester interface for management token request endpoints 2025-12-29 02:42:29 +08:00
hkfires
a95428f204 fix(handlers): preserve upstream response logs before duplicate detection 2025-12-28 22:35:36 +08:00
hkfires
3ca5fb1046 fix(handlers): match raw error text before JSON body for duplicate detection 2025-12-28 19:35:36 +08:00
hkfires
a091d12f4e fix(logging): improve request/response capture 2025-12-28 19:04:31 +08:00
Luis Pater
e3d8d726e6 Merge branch 'router-for-me:main' into main 2025-12-28 15:09:33 +08:00
Luis Pater
457924828a Merge pull request #757 from ben-vargas/fix-thinking-toolchoice-conflict
Fix: disable thinking when tool_choice forces tool use
2025-12-28 14:04:30 +08:00
Ben Vargas
aca2ef6359 Fix: disable thinking when tool_choice forces tool use
Anthropic API does not allow extended thinking when tool_choice is set
to "any" or a specific tool. This was causing 400 errors when using
features like Amp's /handoff command which forces tool_choice.

Added disableThinkingIfToolChoiceForced() that removes thinking config
when incompatible tool_choice is detected, applied to both streaming
and non-streaming paths.

Fixes router-for-me/CLIProxyAPI#630
2025-12-27 16:31:37 -07:00
Luis Pater
ade7194792 feat(management): add generic API call handler to management endpoints 2025-12-28 04:40:32 +08:00
Luis Pater
0f51e73baa Merge branch 'router-for-me:main' into main 2025-12-28 03:07:58 +08:00
Luis Pater
3a436e116a feat(cliproxy): implement model aliasing and hashing for Codex configurations, enhance request routing logic, and normalize Codex model entries 2025-12-28 03:06:51 +08:00
Luis Pater
d06e2dc83c Merge branch 'router-for-me:main' into main 2025-12-28 02:10:16 +08:00
Luis Pater
336867853b Merge pull request #756 from leaph/check-ai-thinking-settings
feat(iflow): add model-specific thinking configs for GLM-4.7 and Mini…
2025-12-28 02:08:27 +08:00
leaph
6403ff4ec4 feat(iflow): add model-specific thinking configs for GLM-4.7 and MiniMax-M2.1
- GLM-4.7: Uses extra_body={"thinking": {"type": "enabled"}, "clear_thinking": false}
- MiniMax-M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
- Added preserveReasoningContentInMessages() to support re-injection of reasoning
  content in assistant message history for multi-turn conversations
- Added ThinkingSupport to MiniMax-M2.1 model definition
2025-12-27 18:39:15 +01:00
Luis Pater
d222469b44 Update issue templates 2025-12-28 01:22:42 +08:00
Luis Pater
790a17ce98 Merge pull request #70 from router-for-me/plus
v6.6.60
2025-12-28 00:57:14 +08:00
Luis Pater
d473c952fb Merge branch 'main' into plus 2025-12-28 00:56:04 +08:00
Luis Pater
7646a2b877 Fixed: #749
fix(translators): ensure `gjson.String` content is non-empty before setting `parts` in OpenAI request logic
2025-12-28 00:54:26 +08:00
Luis Pater
62090f2568 Merge pull request #750 from router-for-me/config
fix(config): preserve original config structure and avoid default value pollution
2025-12-27 22:10:01 +08:00
Luis Pater
d35152bbef Merge branch 'router-for-me:main' into main 2025-12-27 22:03:50 +08:00
Luis Pater
c281f4cbaf Fixed: #747
fix(translators): rename and integrate `usageMetadata` as `cpaUsageMetadata` in Claude processing logic
2025-12-27 22:02:11 +08:00
hkfires
09455f9e85 fix(config): make streaming keepalive and retries ints 2025-12-27 20:56:47 +08:00
hkfires
c8e72ba0dc fix(config): smart merge writes non-default new keys only 2025-12-27 20:28:54 +08:00
hkfires
375ef252ab docs(config): clarify merge mapping behavior 2025-12-27 19:30:21 +08:00
hkfires
ee552f8720 chore(config): update ignore patterns 2025-12-27 19:13:14 +08:00
hkfires
2e88c4858e fix(config): avoid adding new keys when merging 2025-12-27 19:00:47 +08:00
Luis Pater
3f50da85c1 Merge pull request #745 from router-for-me/auth
fix(auth): make provider rotation atomic
2025-12-27 13:01:22 +08:00
hkfires
8be06255f7 fix(auth): make provider rotation atomic 2025-12-27 12:56:48 +08:00
Luis Pater
60936b5185 Merge branch 'router-for-me:main' into main 2025-12-27 03:57:03 +08:00
Luis Pater
72274099aa Fixed: #738
fix(translators): refine prompt token calculation by incorporating cached tokens in Claude response handling
2025-12-27 03:56:11 +08:00
Luis Pater
b7f7b3a1d8 Merge branch 'router-for-me:main' into main 2025-12-27 01:26:33 +08:00
Luis Pater
dcae098e23 Fixed: #736
fix(translators): handle gjson string types in Claude request processing to ensure consistent content parsing
2025-12-27 01:25:47 +08:00
Luis Pater
618606966f Merge pull request #68 from router-for-me/plus
v6.6.56
2025-12-26 12:14:42 +08:00
Luis Pater
05f249d77f Merge branch 'main' into plus 2025-12-26 12:14:35 +08:00
Luis Pater
2eb05ec640 Merge pull request #727 from nguyenphutrong/main
docs: add Quotio to community projects
2025-12-26 11:53:09 +08:00
Luis Pater
3ce0d76aa4 feat(usage): add import/export functionality for usage statistics and enhance deduplication logic 2025-12-26 11:49:51 +08:00
Trong Nguyen
a00b79d9be docs(readme): add Quotio to community projects section 2025-12-26 10:46:05 +07:00
Luis Pater
9fe6a215e6 Merge branch 'router-for-me:main' into main 2025-12-26 05:10:19 +08:00
Luis Pater
33e53a2a56 fix(translators): ensure correct handling and output of multimodal assistant content across request handlers 2025-12-26 05:08:04 +08:00
Luis Pater
cd5b80785f Merge pull request #722 from hungthai1401/bugfix/remove-extra-args
Fixed incorrect function signature call to `NewBaseAPIHandlers`
2025-12-26 02:56:56 +08:00
Thai Nguyen Hung
54f71aa273 fix(test): remove extra argument from ExecuteStreamWithAuthManager call 2025-12-25 21:55:35 +07:00
Luis Pater
3f949b7f84 Merge pull request #704 from tinyc0der/add-index
fix(openai): add index field to image response for LiteLLM compatibility
2025-12-25 21:35:12 +08:00
Luis Pater
cf8b2dcc85 Merge pull request #67 from router-for-me/plus
v6.6.54
2025-12-25 21:07:45 +08:00
Luis Pater
8e24d9dc34 Merge branch 'main' into plus 2025-12-25 21:07:37 +08:00
Luis Pater
443c4538bb feat(config): add commercial-mode to optimize HTTP middleware for lower memory usage 2025-12-25 21:05:01 +08:00
TinyCoder
a7fc2ee4cf refactor(image): avoid using json.Marshal 2025-12-25 14:21:01 +07:00
Luis Pater
8e749ac22d docs(readme): update GLM model version from 4.6 to 4.7 in README and README_CN 2025-12-24 23:59:48 +08:00
Luis Pater
69e09d9bc7 docs(readme): update GLM model version from 4.6 to 4.7 in README and README_CN 2025-12-24 23:46:27 +08:00
Luis Pater
ed57d82bc1 Merge branch 'router-for-me:main' into main 2025-12-24 23:31:09 +08:00
Luis Pater
06ad527e8c Fixed: #696
fix(translators): adjust prompt token calculation by subtracting cached tokens across Gemini, OpenAI, and Claude handlers
2025-12-24 23:29:18 +08:00
TinyCoder
671558a822 fix(openai): add index field to image response for LiteLLM compatibility
LiteLLM's Pydantic model requires an index field in each image object.
Without it, responses fail validation with "images.0.index Field required".
2025-12-24 17:43:31 +07:00
57 changed files with 2760 additions and 256 deletions

View File

@@ -13,8 +13,6 @@ Dockerfile
docs/* docs/*
README.md README.md
README_CN.md README_CN.md
MANAGEMENT_API.md
MANAGEMENT_API_CN.md
LICENSE LICENSE
# Runtime data folders (should be mounted as volumes) # Runtime data folders (should be mounted as volumes)
@@ -25,10 +23,14 @@ config.yaml
# Development/editor # Development/editor
bin/* bin/*
.claude/*
.vscode/* .vscode/*
.claude/*
.codex/*
.gemini/* .gemini/*
.serena/* .serena/*
.agent/* .agent/*
.agents/*
.opencode/*
.bmad/* .bmad/*
_bmad/* _bmad/*
_bmad-output/*

View File

@@ -7,6 +7,13 @@ assignees: ''
--- ---
**Is it a request payload issue?**
[ ] Yes, this is a request payload issue. I am using a client/cURL to send a request payload, but I received an unexpected error.
[ ] No, it's another issue.
**If it's a request payload issue, you MUST know**
Our team doesn't have any GODs or ORACLEs or MIND READERs. Please make sure to attach the request log or curl payload.
**Describe the bug** **Describe the bug**
A clear and concise description of what the bug is. A clear and concise description of what the bug is.

11
.gitignore vendored
View File

@@ -12,11 +12,15 @@ bin/*
logs/* logs/*
conv/* conv/*
temp/* temp/*
refs/*
# Storage backends
pgstore/* pgstore/*
gitstore/* gitstore/*
objectstore/* objectstore/*
# Static assets
static/* static/*
refs/*
# Authentication data # Authentication data
auths/* auths/*
@@ -30,12 +34,17 @@ GEMINI.md
# Tooling metadata # Tooling metadata
.vscode/* .vscode/*
.codex/*
.claude/* .claude/*
.gemini/* .gemini/*
.serena/* .serena/*
.agent/* .agent/*
.agents/*
.agents/*
.opencode/*
.bmad/* .bmad/*
_bmad/* _bmad/*
_bmad-output/*
.mcp/cache/ .mcp/cache/
# macOS # macOS

View File

@@ -434,7 +434,7 @@ func main() {
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled) usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling) coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil { if err = logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to configure log output: %v", err) log.Errorf("failed to configure log output: %v", err)
return return
} }

View File

@@ -35,10 +35,14 @@ auth-dir: "~/.cli-proxy-api"
api-keys: api-keys:
- "your-api-key-1" - "your-api-key-1"
- "your-api-key-2" - "your-api-key-2"
- "your-api-key-3"
# Enable debug logging # Enable debug logging
debug: false debug: false
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
commercial-mode: false
# Open OAuth URLs in incognito/private browser mode. # Open OAuth URLs in incognito/private browser mode.
# Useful when you want to login with a different account without logging out from your current session. # Useful when you want to login with a different account without logging out from your current session.
# Default: false (but Kiro auth defaults to true for multi-account support) # Default: false (but Kiro auth defaults to true for multi-account support)
@@ -106,6 +110,9 @@ ws-auth: false
# headers: # headers:
# X-Custom-Header: "custom-value" # X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# models:
# - name: "gpt-5-codex" # upstream model name
# alias: "codex-latest" # client alias mapped to the upstream model
# excluded-models: # excluded-models:
# - "gpt-5.1" # exclude specific models (exact match) # - "gpt-5.1" # exclude specific models (exact match)
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex) # - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
@@ -164,9 +171,9 @@ ws-auth: false
# headers: # headers:
# X-Custom-Header: "custom-value" # X-Custom-Header: "custom-value"
# models: # optional: map aliases to upstream model names # models: # optional: map aliases to upstream model names
# - name: "gemini-2.0-flash" # upstream model name # - name: "gemini-2.5-flash" # upstream model name
# alias: "vertex-flash" # client-visible alias # alias: "vertex-flash" # client-visible alias
# - name: "gemini-1.5-pro" # - name: "gemini-2.5-pro"
# alias: "vertex-pro" # alias: "vertex-pro"
# Amp Integration # Amp Integration
@@ -175,6 +182,18 @@ ws-auth: false
# upstream-url: "https://ampcode.com" # upstream-url: "https://ampcode.com"
# # Optional: Override API key for Amp upstream (otherwise uses env or file) # # Optional: Override API key for Amp upstream (otherwise uses env or file)
# upstream-api-key: "" # upstream-api-key: ""
# # Per-client upstream API key mapping
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
# # Useful when different clients need to use different Amp accounts/quotas.
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
# upstream-api-keys:
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
# api-keys: # Client keys that use this upstream key
# - "your-api-key-1"
# - "your-api-key-2"
# - upstream-api-key: "amp_key_for_team_b"
# api-keys:
# - "your-api-key-3"
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false) # # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
# restrict-management-to-localhost: false # restrict-management-to-localhost: false
# # Force model mappings to run before checking local API keys (default: false) # # Force model mappings to run before checking local API keys (default: false)
@@ -184,12 +203,42 @@ ws-auth: false
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5) # # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
# # but you have a similar model available (e.g., Claude Sonnet 4). # # but you have a similar model available (e.g., Claude Sonnet 4).
# model-mappings: # model-mappings:
# - from: "claude-opus-4.5" # Model requested by Amp CLI # - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
# to: "claude-sonnet-4" # Route to this available model instead # to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
# - from: "gpt-5" # - from: "claude-sonnet-4-5-20250929"
# to: "gemini-2.5-pro" # to: "gemini-claude-sonnet-4-5-thinking"
# - from: "claude-3-opus-20240229" # - from: "claude-haiku-4-5-20251001"
# to: "claude-3-5-sonnet-20241022" # to: "gemini-2.5-flash"
# Global OAuth model name mappings (per channel)
# These mappings rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# oauth-model-mappings:
# gemini-cli:
# - name: "gemini-2.5-pro" # original model name under this channel
# alias: "g2.5p" # client-visible alias
# vertex:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# aistudio:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# antigravity:
# - name: "gemini-3-pro-preview"
# alias: "g3p"
# claude:
# - name: "claude-sonnet-4-5-20250929"
# alias: "cs4.5"
# codex:
# - name: "gpt-5"
# alias: "g5"
# qwen:
# - name: "qwen3-coder-plus"
# alias: "qwen-plus"
# iflow:
# - name: "glm-4.7"
# alias: "glm-god"
# OAuth provider excluded models # OAuth provider excluded models
# oauth-excluded-models: # oauth-excluded-models:

View File

@@ -0,0 +1,538 @@
package management
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const defaultAPICallTimeout = 60 * time.Second
const (
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOAuthClientSecret = "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"
)
var geminiOAuthScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
type apiCallRequest struct {
AuthIndexSnake *string `json:"auth_index"`
AuthIndexCamel *string `json:"authIndex"`
AuthIndexPascal *string `json:"AuthIndex"`
Method string `json:"method"`
URL string `json:"url"`
Header map[string]string `json:"header"`
Data string `json:"data"`
}
type apiCallResponse struct {
StatusCode int `json:"status_code"`
Header map[string][]string `json:"header"`
Body string `json:"body"`
}
// APICall makes a generic HTTP request on behalf of the management API caller.
// It is protected by the management middleware.
//
// Endpoint:
//
// POST /v0/management/api-call
//
// Authentication:
//
// Same as other management APIs (requires a management key and remote-management rules).
// You can provide the key via:
// - Authorization: Bearer <key>
// - X-Management-Key: <key>
//
// Request JSON:
// - auth_index / authIndex / AuthIndex (optional):
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
// If omitted or not found, credential-specific proxy/token substitution is skipped.
// - method (required): HTTP method, e.g. GET, POST, PUT, PATCH, DELETE.
// - url (required): Absolute URL including scheme and host, e.g. "https://api.example.com/v1/ping".
// - header (optional): Request headers map.
// Supports magic variable "$TOKEN$" which is replaced using the selected credential:
// 1) metadata.access_token
// 2) attributes.api_key
// 3) metadata.token / metadata.id_token / metadata.cookie
// Example: {"Authorization":"Bearer $TOKEN$"}.
// Note: if you need to override the HTTP Host header, set header["Host"].
// - data (optional): Raw request body as string (useful for POST/PUT/PATCH).
//
// Proxy selection (highest priority first):
// 1. Selected credential proxy_url
// 2. Global config proxy-url
// 3. Direct connect (environment proxies are not used)
//
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
// - status_code: Upstream HTTP status code.
// - header: Upstream response headers.
// - body: Upstream response body as string.
//
// Example:
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer <MANAGEMENT_KEY>" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"GET","url":"https://api.example.com/v1/ping","header":{"Authorization":"Bearer $TOKEN$"}}'
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer 831227" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
func (h *Handler) APICall(c *gin.Context) {
var body apiCallRequest
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
method := strings.ToUpper(strings.TrimSpace(body.Method))
if method == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing method"})
return
}
urlStr := strings.TrimSpace(body.URL)
if urlStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
return
}
parsedURL, errParseURL := url.Parse(urlStr)
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
return
}
authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal)
auth := h.authByIndex(authIndex)
reqHeaders := body.Header
if reqHeaders == nil {
reqHeaders = map[string]string{}
}
var hostOverride string
var token string
var tokenResolved bool
var tokenErr error
for key, value := range reqHeaders {
if !strings.Contains(value, "$TOKEN$") {
continue
}
if !tokenResolved {
token, tokenErr = h.resolveTokenForAuth(c.Request.Context(), auth)
tokenResolved = true
}
if auth != nil && token == "" {
if tokenErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token refresh failed"})
return
}
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token not found"})
return
}
if token == "" {
continue
}
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
}
var requestBody io.Reader
if body.Data != "" {
requestBody = strings.NewReader(body.Data)
}
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
if errNewRequest != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
return
}
for key, value := range reqHeaders {
if strings.EqualFold(key, "host") {
hostOverride = strings.TrimSpace(value)
continue
}
req.Header.Set(key, value)
}
if hostOverride != "" {
req.Host = hostOverride
}
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
}
httpClient.Transport = h.apiCallTransport(auth)
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.WithError(errDo).Debug("management APICall request failed")
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
respBody, errReadAll := io.ReadAll(resp.Body)
if errReadAll != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
return
}
c.JSON(http.StatusOK, apiCallResponse{
StatusCode: resp.StatusCode,
Header: resp.Header,
Body: string(respBody),
})
}
func firstNonEmptyString(values ...*string) string {
for _, v := range values {
if v == nil {
continue
}
if out := strings.TrimSpace(*v); out != "" {
return out
}
}
return ""
}
func tokenValueForAuth(auth *coreauth.Auth) string {
if auth == nil {
return ""
}
if v := tokenValueFromMetadata(auth.Metadata); v != "" {
return v
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
return v
}
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" {
return v
}
}
return ""
}
func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if provider == "gemini-cli" {
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
return token, errToken
}
return tokenValueForAuth(auth), nil
}
func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if auth == nil {
return "", nil
}
metadata, updater := geminiOAuthMetadata(auth)
if len(metadata) == 0 {
return "", fmt.Errorf("gemini oauth metadata missing")
}
base := make(map[string]any)
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
base = cloneMap(tokenRaw)
}
var token oauth2.Token
if len(base) > 0 {
if raw, errMarshal := json.Marshal(base); errMarshal == nil {
_ = json.Unmarshal(raw, &token)
}
}
if token.AccessToken == "" {
token.AccessToken = stringValue(metadata, "access_token")
}
if token.RefreshToken == "" {
token.RefreshToken = stringValue(metadata, "refresh_token")
}
if token.TokenType == "" {
token.TokenType = stringValue(metadata, "token_type")
}
if token.Expiry.IsZero() {
if expiry := stringValue(metadata, "expiry"); expiry != "" {
if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil {
token.Expiry = ts
}
}
}
conf := &oauth2.Config{
ClientID: geminiOAuthClientID,
ClientSecret: geminiOAuthClientSecret,
Scopes: geminiOAuthScopes,
Endpoint: google.Endpoint,
}
ctxToken := ctx
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
src := conf.TokenSource(ctxToken, &token)
currentToken, errToken := src.Token()
if errToken != nil {
return "", errToken
}
merged := buildOAuthTokenMap(base, currentToken)
fields := buildOAuthTokenFields(currentToken, merged)
if updater != nil {
updater(fields)
}
return strings.TrimSpace(currentToken.AccessToken), nil
}
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
if auth == nil {
return nil, nil
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
snapshot := shared.MetadataSnapshot()
return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) }
}
return auth.Metadata, func(fields map[string]any) {
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
for k, v := range fields {
auth.Metadata[k] = v
}
}
}
func stringValue(metadata map[string]any, key string) string {
if len(metadata) == 0 || key == "" {
return ""
}
if v, ok := metadata[key].(string); ok {
return strings.TrimSpace(v)
}
return ""
}
func cloneMap(in map[string]any) map[string]any {
if len(in) == 0 {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
merged := cloneMap(base)
if merged == nil {
merged = make(map[string]any)
}
if tok == nil {
return merged
}
if raw, errMarshal := json.Marshal(tok); errMarshal == nil {
var tokenMap map[string]any
if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil {
for k, v := range tokenMap {
merged[k] = v
}
}
}
return merged
}
func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
fields := make(map[string]any, 5)
if tok != nil && tok.AccessToken != "" {
fields["access_token"] = tok.AccessToken
}
if tok != nil && tok.TokenType != "" {
fields["token_type"] = tok.TokenType
}
if tok != nil && tok.RefreshToken != "" {
fields["refresh_token"] = tok.RefreshToken
}
if tok != nil && !tok.Expiry.IsZero() {
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
}
if len(merged) > 0 {
fields["token"] = cloneMap(merged)
}
return fields
}
func tokenValueFromMetadata(metadata map[string]any) string {
if len(metadata) == 0 {
return ""
}
if v, ok := metadata["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if tokenRaw, ok := metadata["token"]; ok && tokenRaw != nil {
switch typed := tokenRaw.(type) {
case string:
if v := strings.TrimSpace(typed); v != "" {
return v
}
case map[string]any:
if v, ok := typed["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := typed["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
case map[string]string:
if v := strings.TrimSpace(typed["access_token"]); v != "" {
return v
}
if v := strings.TrimSpace(typed["accessToken"]); v != "" {
return v
}
}
}
if v, ok := metadata["token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["id_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["cookie"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
return ""
}
func (h *Handler) authByIndex(authIndex string) *coreauth.Auth {
authIndex = strings.TrimSpace(authIndex)
if authIndex == "" || h == nil || h.authManager == nil {
return nil
}
auths := h.authManager.List()
for _, auth := range auths {
if auth == nil {
continue
}
auth.EnsureIndex()
if auth.Index == authIndex {
return auth
}
}
return nil
}
func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
var proxyCandidates []string
if auth != nil {
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
for _, proxyStr := range proxyCandidates {
if transport := buildProxyTransport(proxyStr); transport != nil {
return transport
}
}
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok || transport == nil {
return &http.Transport{Proxy: nil}
}
clone := transport.Clone()
clone.Proxy = nil
return clone
}
func buildProxyTransport(proxyStr string) *http.Transport {
proxyStr = strings.TrimSpace(proxyStr)
if proxyStr == "" {
return nil
}
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.WithError(errParse).Debug("parse proxy URL failed")
return nil
}
if proxyURL.Scheme == "" || proxyURL.Host == "" {
log.Debug("proxy URL missing scheme/host")
return nil
}
if proxyURL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
return nil
}
return &http.Transport{
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
}
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
}

View File

@@ -431,9 +431,46 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
log.WithError(err).Warnf("failed to stat auth file %s", path) log.WithError(err).Warnf("failed to stat auth file %s", path)
} }
} }
if claims := extractCodexIDTokenClaims(auth); claims != nil {
entry["id_token"] = claims
}
return entry return entry
} }
func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
if auth == nil || auth.Metadata == nil {
return nil
}
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
return nil
}
idTokenRaw, ok := auth.Metadata["id_token"].(string)
if !ok {
return nil
}
idToken := strings.TrimSpace(idTokenRaw)
if idToken == "" {
return nil
}
claims, err := codex.ParseJWTToken(idToken)
if err != nil || claims == nil {
return nil
}
result := gin.H{}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" {
result["chatgpt_account_id"] = v
}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
result["plan_type"] = v
}
if len(result) == 0 {
return nil
}
return result
}
func authEmail(auth *coreauth.Auth) string { func authEmail(auth *coreauth.Auth) string {
if auth == nil { if auth == nil {
return "" return ""

View File

@@ -597,11 +597,7 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
filtered := make([]config.CodexKey, 0, len(arr)) filtered := make([]config.CodexKey, 0, len(arr))
for i := range arr { for i := range arr {
entry := arr[i] entry := arr[i]
entry.APIKey = strings.TrimSpace(entry.APIKey) normalizeCodexKey(&entry)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
if entry.BaseURL == "" { if entry.BaseURL == "" {
continue continue
} }
@@ -613,12 +609,13 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
} }
func (h *Handler) PatchCodexKey(c *gin.Context) { func (h *Handler) PatchCodexKey(c *gin.Context) {
type codexKeyPatch struct { type codexKeyPatch struct {
APIKey *string `json:"api-key"` APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"` Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"` BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"` ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"` Models *[]config.CodexModel `json:"models"`
ExcludedModels *[]string `json:"excluded-models"` Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
} }
var body struct { var body struct {
Index *int `json:"index"` Index *int `json:"index"`
@@ -667,12 +664,16 @@ func (h *Handler) PatchCodexKey(c *gin.Context) {
if body.Value.ProxyURL != nil { if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
} }
if body.Value.Models != nil {
entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...)
}
if body.Value.Headers != nil { if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers) entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
} }
if body.Value.ExcludedModels != nil { if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
} }
normalizeCodexKey(&entry)
h.cfg.CodexKey[targetIndex] = entry h.cfg.CodexKey[targetIndex] = entry
h.cfg.SanitizeCodexKeys() h.cfg.SanitizeCodexKeys()
h.persist(c) h.persist(c)
@@ -762,6 +763,32 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
entry.Models = normalized entry.Models = normalized
} }
func normalizeCodexKey(entry *config.CodexKey) {
if entry == nil {
return
}
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.Prefix = strings.TrimSpace(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
if len(entry.Models) == 0 {
return
}
normalized := make([]config.CodexModel, 0, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
model.Name = strings.TrimSpace(model.Name)
model.Alias = strings.TrimSpace(model.Alias)
if model.Name == "" && model.Alias == "" {
continue
}
normalized = append(normalized, model)
}
entry.Models = normalized
}
// GetAmpCode returns the complete ampcode configuration. // GetAmpCode returns the complete ampcode configuration.
func (h *Handler) GetAmpCode(c *gin.Context) { func (h *Handler) GetAmpCode(c *gin.Context) {
if h == nil || h.cfg == nil { if h == nil || h.cfg == nil {
@@ -913,3 +940,151 @@ func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) { func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
} }
// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
return
}
c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
}
// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
// Normalize entries: trim whitespace, filter empty
normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
h.cfg.AmpCode.UpstreamAPIKeys = normalized
h.persist(c)
}
// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
// Matching is done by upstream-api-key value.
func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
existing := make(map[string]int)
for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
}
for _, newEntry := range body.Value {
upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
normalizedEntry := config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
}
if idx, ok := existing[upstreamKey]; ok {
h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
} else {
h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
}
}
h.persist(c)
}
// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
// Body must be JSON: {"value": ["<upstream-api-key>", ...]}.
// If "value" is an empty array, clears all entries.
// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
if body.Value == nil {
c.JSON(400, gin.H{"error": "missing value"})
return
}
// Empty array means clear all
if len(body.Value) == 0 {
h.cfg.AmpCode.UpstreamAPIKeys = nil
h.persist(c)
return
}
toRemove := make(map[string]bool)
for _, key := range body.Value {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
toRemove[trimmed] = true
}
if len(toRemove) == 0 {
c.JSON(400, gin.H{"error": "empty value"})
return
}
newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
newEntries = append(newEntries, entry)
}
}
h.cfg.AmpCode.UpstreamAPIKeys = newEntries
h.persist(c)
}
// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
if len(entries) == 0 {
return nil
}
out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
apiKeys := normalizeAPIKeysList(entry.APIKeys)
out = append(out, config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: apiKeys,
})
}
if len(out) == 0 {
return nil
}
return out
}
// normalizeAPIKeysList trims and filters empty strings from a list of API keys.
func normalizeAPIKeysList(keys []string) []string {
if len(keys) == 0 {
return nil
}
out := make([]string, 0, len(keys))
for _, k := range keys {
trimmed := strings.TrimSpace(k)
if trimmed != "" {
out = append(out, trimmed)
}
}
if len(out) == 0 {
return nil
}
return out
}

View File

@@ -59,6 +59,11 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
} }
} }
// NewHandler creates a new management handler instance.
func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
return NewHandler(cfg, "", manager)
}
// SetConfig updates the in-memory config reference when the server hot-reloads. // SetConfig updates the in-memory config reference when the server hot-reloads.
func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg } func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg }

View File

@@ -1,12 +1,25 @@
package management package management
import ( import (
"encoding/json"
"net/http" "net/http"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage" "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
) )
type usageExportPayload struct {
Version int `json:"version"`
ExportedAt time.Time `json:"exported_at"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
type usageImportPayload struct {
Version int `json:"version"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
// GetUsageStatistics returns the in-memory request statistics snapshot. // GetUsageStatistics returns the in-memory request statistics snapshot.
func (h *Handler) GetUsageStatistics(c *gin.Context) { func (h *Handler) GetUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot var snapshot usage.StatisticsSnapshot
@@ -18,3 +31,49 @@ func (h *Handler) GetUsageStatistics(c *gin.Context) {
"failed_requests": snapshot.FailureCount, "failed_requests": snapshot.FailureCount,
}) })
} }
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
if h != nil && h.usageStats != nil {
snapshot = h.usageStats.Snapshot()
}
c.JSON(http.StatusOK, usageExportPayload{
Version: 1,
ExportedAt: time.Now().UTC(),
Usage: snapshot,
})
}
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
if h == nil || h.usageStats == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
return
}
data, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
return
}
var payload usageImportPayload
if err := json.Unmarshal(data, &payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
return
}
if payload.Version != 0 && payload.Version != 1 {
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
return
}
result := h.usageStats.MergeSnapshot(payload.Usage)
snapshot := h.usageStats.Snapshot()
c.JSON(http.StatusOK, gin.H{
"added": result.Added,
"skipped": result.Skipped,
"total_requests": snapshot.TotalRequests,
"failed_requests": snapshot.FailureCount,
})
}

View File

@@ -227,11 +227,20 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
} }
} }
// Check API key change // Check API key change (both default and per-client mappings)
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings) apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
if apiKeyChanged { upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
if apiKeyChanged || upstreamAPIKeysChanged {
if m.secretSource != nil { if m.secretSource != nil {
if ms, ok := m.secretSource.(*MultiSourceSecret); ok { if ms, ok := m.secretSource.(*MappedSecretSource); ok {
if apiKeyChanged {
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
}
if upstreamAPIKeysChanged {
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
}
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey) ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache() ms.InvalidateCache()
} }
@@ -251,10 +260,22 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error { func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
if m.secretSource == nil { if m.secretSource == nil {
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */) // Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
mappedSource := NewMappedSecretSource(defaultSource)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
ms.UpdateMappings(settings.UpstreamAPIKeys)
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok { } else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
ms.UpdateExplicitKey(settings.UpstreamAPIKey) ms.UpdateExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache() ms.InvalidateCache()
mappedSource := NewMappedSecretSource(ms)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
} }
proxy, err := createReverseProxy(upstreamURL, m.secretSource) proxy, err := createReverseProxy(upstreamURL, m.secretSource)
@@ -313,6 +334,66 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b
return oldKey != newKey return oldKey != newKey
} }
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
if old == nil {
return len(new.UpstreamAPIKeys) > 0
}
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
return true
}
// Build map for comparison: upstreamKey -> set of clientKeys
type entryInfo struct {
upstreamKey string
clientKeys map[string]struct{}
}
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
for i, entry := range old.UpstreamAPIKeys {
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
for _, k := range entry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
clientKeys[trimmed] = struct{}{}
}
oldEntries[i] = entryInfo{
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
clientKeys: clientKeys,
}
}
for i, newEntry := range new.UpstreamAPIKeys {
if i >= len(oldEntries) {
return true
}
oldE := oldEntries[i]
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
return true
}
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
for _, k := range newEntry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
newKeys[trimmed] = struct{}{}
}
if len(newKeys) != len(oldE.clientKeys) {
return true
}
for k := range newKeys {
if _, ok := oldE.clientKeys[k]; !ok {
return true
}
}
}
return false
}
// GetModelMapper returns the model mapper instance (for testing/debugging). // GetModelMapper returns the model mapper instance (for testing/debugging).
func (m *AmpModule) GetModelMapper() *DefaultModelMapper { func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
return m.modelMapper return m.modelMapper

View File

@@ -312,3 +312,41 @@ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
}) })
} }
} }
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
},
}
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
}
}
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
},
}
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected no change when only whitespace/empty entries differ")
}
}

View File

@@ -18,6 +18,33 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func removeQueryValuesMatching(req *http.Request, key string, match string) {
if req == nil || req.URL == nil || match == "" {
return
}
q := req.URL.Query()
values, ok := q[key]
if !ok || len(values) == 0 {
return
}
kept := make([]string, 0, len(values))
for _, v := range values {
if v == match {
continue
}
kept = append(kept, v)
}
if len(kept) == 0 {
q.Del(key)
} else {
q[key] = kept
}
req.URL.RawQuery = q.Encode()
}
// readCloser wraps a reader and forwards Close to a separate closer. // readCloser wraps a reader and forwards Close to a separate closer.
// Used to restore peeked bytes while preserving upstream body Close behavior. // Used to restore peeked bytes while preserving upstream body Close behavior.
type readCloser struct { type readCloser struct {
@@ -48,6 +75,14 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// We will set our own Authorization using the configured upstream-api-key // We will set our own Authorization using the configured upstream-api-key
req.Header.Del("Authorization") req.Header.Del("Authorization")
req.Header.Del("X-Api-Key") req.Header.Del("X-Api-Key")
req.Header.Del("X-Goog-Api-Key")
// Remove query-based credentials if they match the authenticated client API key.
// This prevents leaking client auth material to the Amp upstream while avoiding
// breaking unrelated upstream query parameters.
clientKey := getClientAPIKeyFromContext(req.Context())
removeQueryValuesMatching(req, "key", clientKey)
removeQueryValuesMatching(req, "auth_token", clientKey)
// Preserve correlation headers for debugging // Preserve correlation headers for debugging
if req.Header.Get("X-Request-ID") == "" { if req.Header.Get("X-Request-ID") == "" {

View File

@@ -3,11 +3,15 @@ package amp
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
) )
// Helper: compress data with gzip // Helper: compress data with gzip
@@ -306,6 +310,159 @@ func TestReverseProxy_EmptySecret(t *testing.T) {
} }
} }
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
type captured struct {
headers http.Header
query string
}
got := make(chan captured, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "Bearer client-key")
req.Header.Set("X-Api-Key", "client-key")
req.Header.Set("X-Goog-Api-Key", "client-key")
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
c := <-got
// These are client-provided credentials and must not reach the upstream.
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
}
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
}
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
}
// Query-based credentials should be stripped only when they match the authenticated client key.
// Should keep unrelated values and parameters.
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
}
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
}
}
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "u1" {
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer u1" {
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "default" {
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer default" {
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_ErrorHandler(t *testing.T) { func TestReverseProxy_ErrorHandler(t *testing.T) {
// Point proxy to a non-routable address to trigger error // Point proxy to a non-routable address to trigger error
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource("")) proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))

View File

@@ -1,6 +1,7 @@
package amp package amp
import ( import (
"context"
"errors" "errors"
"net" "net"
"net/http" "net/http"
@@ -16,6 +17,37 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// clientAPIKeyContextKey is the context key used to pass the client API key
// from gin.Context to the request context for SecretSource lookup.
type clientAPIKeyContextKey struct{}
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
// into the request context so that SecretSource can look it up for per-client upstream routing.
func clientAPIKeyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Extract the client API key from gin context (set by AuthMiddleware)
if apiKey, exists := c.Get("apiKey"); exists {
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
// Inject into request context for SecretSource.Get(ctx) to read
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
c.Request = c.Request.WithContext(ctx)
}
}
c.Next()
}
}
// getClientAPIKeyFromContext retrieves the client API key from request context.
// Returns empty string if not present.
func getClientAPIKeyFromContext(ctx context.Context) string {
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
if keyStr, ok := val.(string); ok {
return keyStr
}
}
return ""
}
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's // localhostOnlyMiddleware returns a middleware that dynamically checks the module's
// localhost restriction setting. This allows hot-reload of the restriction without restarting. // localhost restriction setting. This allows hot-reload of the restriction without restarting.
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc { func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
@@ -129,6 +161,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings") authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
} }
// Inject client API key into request context for per-client upstream routing
ampAPI.Use(clientAPIKeyMiddleware())
// Dynamic proxy handler that uses m.getProxy() for hot-reload support // Dynamic proxy handler that uses m.getProxy() for hot-reload support
proxyHandler := func(c *gin.Context) { proxyHandler := func(c *gin.Context) {
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces // Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
@@ -175,6 +210,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
if authWithBypass != nil { if authWithBypass != nil {
rootMiddleware = append(rootMiddleware, authWithBypass) rootMiddleware = append(rootMiddleware, authWithBypass)
} }
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
engine.GET("/threads", append(rootMiddleware, proxyHandler)...) engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...) engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
engine.GET("/docs", append(rootMiddleware, proxyHandler)...) engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
@@ -244,6 +281,8 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
if auth != nil { if auth != nil {
ampProviders.Use(auth) ampProviders.Use(auth)
} }
// Inject client API key into request context for per-client upstream routing
ampProviders.Use(clientAPIKeyMiddleware())
provider := ampProviders.Group("/:provider") provider := ampProviders.Group("/:provider")

View File

@@ -9,6 +9,9 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
) )
// SecretSource provides Amp API keys with configurable precedence and caching // SecretSource provides Amp API keys with configurable precedence and caching
@@ -164,3 +167,82 @@ func NewStaticSecretSource(key string) *StaticSecretSource {
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) { func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
return s.key, nil return s.key, nil
} }
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
// When a request context contains a client API key that matches a configured mapping,
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
type MappedSecretSource struct {
defaultSource SecretSource
mu sync.RWMutex
lookup map[string]string // clientKey -> upstreamKey
}
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
return &MappedSecretSource{
defaultSource: defaultSource,
lookup: make(map[string]string),
}
}
// Get retrieves the Amp API key, checking per-client mappings first.
// If the request context contains a client API key that matches a configured mapping,
// returns the corresponding upstream key. Otherwise, falls back to the default source.
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
// Try to get client API key from request context
clientKey := getClientAPIKeyFromContext(ctx)
if clientKey != "" {
s.mu.RLock()
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
s.mu.RUnlock()
return upstreamKey, nil
}
s.mu.RUnlock()
}
// Fall back to default source
return s.defaultSource.Get(ctx)
}
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
// If the same client key appears in multiple entries, logs a warning and uses the first one.
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
newLookup := make(map[string]string)
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
for _, clientKey := range entry.APIKeys {
trimmedKey := strings.TrimSpace(clientKey)
if trimmedKey == "" {
continue
}
if _, exists := newLookup[trimmedKey]; exists {
// Log warning for duplicate client key, first one wins
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
continue
}
newLookup[trimmedKey] = upstreamKey
}
}
s.mu.Lock()
s.lookup = newLookup
s.mu.Unlock()
}
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(key)
}
}
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) InvalidateCache() {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.InvalidateCache()
}
}

View File

@@ -8,6 +8,10 @@ import (
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
) )
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) { func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
@@ -278,3 +282,85 @@ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
t.Fatalf("after cache expiry, expected new-value, got %q", got3) t.Fatalf("after cache expiry, expected new-value, got %q", got3)
} }
} }
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1, got %q", got)
}
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
got, err = s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "default" {
t.Fatalf("want default fallback, got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1 (first wins), got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
hook := test.NewLocal(log.StandardLogger())
defer hook.Reset()
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
foundWarning := false
for _, entry := range hook.AllEntries() {
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
foundWarning = true
break
}
}
if !foundWarning {
t.Fatal("expected warning log for duplicate client key, but none was found")
}
}

View File

@@ -209,13 +209,15 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
// Resolve logs directory relative to the configuration file directory. // Resolve logs directory relative to the configuration file directory.
var requestLogger logging.RequestLogger var requestLogger logging.RequestLogger
var toggle func(bool) var toggle func(bool)
if optionState.requestLoggerFactory != nil { if !cfg.CommercialMode {
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath) if optionState.requestLoggerFactory != nil {
} requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
if requestLogger != nil { }
engine.Use(middleware.RequestLoggingMiddleware(requestLogger)) if requestLogger != nil {
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok { engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
toggle = setter.SetEnabled if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
toggle = setter.SetEnabled
}
} }
} }
@@ -494,6 +496,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware()) mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
{ {
mgmt.GET("/usage", s.mgmt.GetUsageStatistics) mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
mgmt.GET("/config", s.mgmt.GetConfig) mgmt.GET("/config", s.mgmt.GetConfig)
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML) mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML) mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
@@ -516,6 +520,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL) mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL) mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
mgmt.POST("/api-call", s.mgmt.APICall)
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject) mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject) mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject) mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
@@ -565,6 +571,10 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings) mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
@@ -867,7 +877,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
} }
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB { if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil { if err := logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to reconfigure log output: %v", err) log.Errorf("failed to reconfigure log output: %v", err)
} else { } else {
if oldCfg == nil { if oldCfg == nil {

View File

@@ -39,6 +39,9 @@ type Config struct {
// Debug enables or disables debug-level logging and other debug features. // Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug" json:"debug"` Debug bool `yaml:"debug" json:"debug"`
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
// LoggingToFile controls whether application logs are written to rotating files or stdout. // LoggingToFile controls whether application logs are written to rotating files or stdout.
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"` LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
@@ -95,6 +98,14 @@ type Config struct {
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries. // OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"` OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
// OAuthModelMappings defines global model name mappings for OAuth/file-backed auth channels.
// These mappings affect both model listing and model routing for supported channels:
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
//
// NOTE: This does not apply to existing per-credential model alias features under:
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
OAuthModelMappings map[string][]ModelNameMapping `yaml:"oauth-model-mappings,omitempty" json:"oauth-model-mappings,omitempty"`
// Payload defines default and override rules for provider payload parameters. // Payload defines default and override rules for provider payload parameters.
Payload PayloadConfig `yaml:"payload" json:"payload"` Payload PayloadConfig `yaml:"payload" json:"payload"`
@@ -146,6 +157,13 @@ type RoutingConfig struct {
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"` Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
} }
// ModelNameMapping defines a model ID rename mapping for a specific channel.
// It maps the original model name (Name) to the client-visible alias (Alias).
type ModelNameMapping struct {
Name string `yaml:"name" json:"name"`
Alias string `yaml:"alias" json:"alias"`
}
// AmpModelMapping defines a model name mapping for Amp CLI requests. // AmpModelMapping defines a model name mapping for Amp CLI requests.
// When Amp requests a model that isn't available locally, this mapping // When Amp requests a model that isn't available locally, this mapping
// allows routing to an alternative model that IS available. // allows routing to an alternative model that IS available.
@@ -172,6 +190,11 @@ type AmpCode struct {
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls. // UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"` UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
// When a client authenticates with a key that matches an entry, that upstream key is used.
// If no match is found, falls back to UpstreamAPIKey (default behavior).
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.) // RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by // to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient). // browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
@@ -187,6 +210,17 @@ type AmpCode struct {
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"` ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
} }
// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key.
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
// is used for the upstream Amp request.
type AmpUpstreamAPIKeyEntry struct {
// UpstreamAPIKey is the API key to use when proxying to the Amp upstream.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// APIKeys are the client API keys (from top-level api-keys) that map to this upstream key.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
}
// PayloadConfig defines default and override parameter rules applied to provider payloads. // PayloadConfig defines default and override parameter rules applied to provider payloads.
type PayloadConfig struct { type PayloadConfig struct {
// Default defines rules that only set parameters when they are missing in the payload. // Default defines rules that only set parameters when they are missing in the payload.
@@ -262,6 +296,9 @@ type CodexKey struct {
// ProxyURL overrides the global proxy setting for this API key if provided. // ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"` ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
// Models defines upstream model names and aliases for request routing.
Models []CodexModel `yaml:"models" json:"models"`
// Headers optionally adds extra HTTP headers for requests sent with this key. // Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
@@ -269,6 +306,15 @@ type CodexKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
} }
// CodexModel describes a mapping between an alias and the actual upstream model name.
type CodexModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
// GeminiKey represents the configuration for a Gemini API key, // GeminiKey represents the configuration for a Gemini API key,
// including optional overrides for upstream base URL, proxy routing, and headers. // including optional overrides for upstream base URL, proxy routing, and headers.
type GeminiKey struct { type GeminiKey struct {
@@ -475,6 +521,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Normalize OAuth provider model exclusion map. // Normalize OAuth provider model exclusion map.
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels) cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
// Normalize global OAuth model name mappings.
cfg.SanitizeOAuthModelMappings()
if cfg.legacyMigrationPending { if cfg.legacyMigrationPending {
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...") fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
if !optional && configFile != "" { if !optional && configFile != "" {
@@ -491,6 +540,50 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
return &cfg, nil return &cfg, nil
} }
// SanitizeOAuthModelMappings normalizes and deduplicates global OAuth model name mappings.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// and ensures (From, To) pairs are unique within each channel.
func (cfg *Config) SanitizeOAuthModelMappings() {
if cfg == nil || len(cfg.OAuthModelMappings) == 0 {
return
}
out := make(map[string][]ModelNameMapping, len(cfg.OAuthModelMappings))
for rawChannel, mappings := range cfg.OAuthModelMappings {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(mappings) == 0 {
continue
}
seenName := make(map[string]struct{}, len(mappings))
seenAlias := make(map[string]struct{}, len(mappings))
clean := make([]ModelNameMapping, 0, len(mappings))
for _, mapping := range mappings {
name := strings.TrimSpace(mapping.Name)
alias := strings.TrimSpace(mapping.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
nameKey := strings.ToLower(name)
aliasKey := strings.ToLower(alias)
if _, ok := seenName[nameKey]; ok {
continue
}
if _, ok := seenAlias[aliasKey]; ok {
continue
}
seenName[nameKey] = struct{}{}
seenAlias[aliasKey] = struct{}{}
clean = append(clean, ModelNameMapping{Name: name, Alias: alias})
}
if len(clean) > 0 {
out[channel] = clean
}
}
cfg.OAuthModelMappings = out
}
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are // SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
// not actionable, specifically those missing a BaseURL. It trims whitespace before // not actionable, specifically those missing a BaseURL. It trims whitespace before
// evaluation and preserves the relative order of remaining entries. // evaluation and preserves the relative order of remaining entries.
@@ -876,8 +969,8 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
} }
// mergeMappingPreserve merges keys from src into dst mapping node while preserving // mergeMappingPreserve merges keys from src into dst mapping node while preserving
// key order and comments of existing keys in dst. Unknown keys from src are appended // key order and comments of existing keys in dst. New keys are only added if their
// to dst at the end, copying their node structure from src. // value is non-zero to avoid polluting the config with defaults.
func mergeMappingPreserve(dst, src *yaml.Node) { func mergeMappingPreserve(dst, src *yaml.Node) {
if dst == nil || src == nil { if dst == nil || src == nil {
return return
@@ -888,20 +981,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
copyNodeShallow(dst, src) copyNodeShallow(dst, src)
return return
} }
// Build a lookup of existing keys in dst
for i := 0; i+1 < len(src.Content); i += 2 { for i := 0; i+1 < len(src.Content); i += 2 {
sk := src.Content[i] sk := src.Content[i]
sv := src.Content[i+1] sv := src.Content[i+1]
idx := findMapKeyIndex(dst, sk.Value) idx := findMapKeyIndex(dst, sk.Value)
if idx >= 0 { if idx >= 0 {
// Merge into existing value node // Merge into existing value node (always update, even to zero values)
dv := dst.Content[idx+1] dv := dst.Content[idx+1]
mergeNodePreserve(dv, sv) mergeNodePreserve(dv, sv)
} else { } else {
if shouldSkipEmptyCollectionOnPersist(sk.Value, sv) { // New key: only add if value is non-zero to avoid polluting config with defaults
if isZeroValueNode(sv) {
continue continue
} }
// Append new key/value pair by deep-copying from src
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv)) dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
} }
} }
@@ -984,32 +1076,49 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
return -1 return -1
} }
func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool { // isZeroValueNode returns true if the YAML node represents a zero/default value
switch key { // that should not be written as a new key to preserve config cleanliness.
case "generative-language-api-key", // For mappings and sequences, recursively checks if all children are zero values.
"gemini-api-key", func isZeroValueNode(node *yaml.Node) bool {
"vertex-api-key",
"claude-api-key",
"codex-api-key",
"openai-compatibility":
return isEmptyCollectionNode(node)
default:
return false
}
}
func isEmptyCollectionNode(node *yaml.Node) bool {
if node == nil { if node == nil {
return true return true
} }
switch node.Kind { switch node.Kind {
case yaml.SequenceNode:
return len(node.Content) == 0
case yaml.ScalarNode: case yaml.ScalarNode:
return node.Tag == "!!null" switch node.Tag {
default: case "!!bool":
return false return node.Value == "false"
case "!!int", "!!float":
return node.Value == "0" || node.Value == "0.0"
case "!!str":
return node.Value == ""
case "!!null":
return true
}
case yaml.SequenceNode:
if len(node.Content) == 0 {
return true
}
// Check if all elements are zero values
for _, child := range node.Content {
if !isZeroValueNode(child) {
return false
}
}
return true
case yaml.MappingNode:
if len(node.Content) == 0 {
return true
}
// Check if all values are zero values (values are at odd indices)
for i := 1; i < len(node.Content); i += 2 {
if !isZeroValueNode(node.Content[i]) {
return false
}
}
return true
} }
return false
} }
// deepCopyNode creates a deep copy of a yaml.Node graph. // deepCopyNode creates a deep copy of a yaml.Node graph.

View File

@@ -30,13 +30,13 @@ type SDKConfig struct {
// StreamingConfig holds server streaming behavior configuration. // StreamingConfig holds server streaming behavior configuration.
type StreamingConfig struct { type StreamingConfig struct {
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n"). // KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
// nil means default (15 seconds). <= 0 disables keep-alives. // <= 0 disables keep-alives. Default is 0.
KeepAliveSeconds *int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"` KeepAliveSeconds int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent, // BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
// to allow auth rotation / transient recovery. // to allow auth rotation / transient recovery.
// nil means default (2). 0 disables bootstrap retries. // <= 0 disables bootstrap retries. Default is 0.
BootstrapRetries *int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"` BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
} }
// AccessConfig groups request authentication providers. // AccessConfig groups request authentication providers.

View File

@@ -10,6 +10,7 @@ import (
"sync" "sync"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
@@ -84,10 +85,30 @@ func SetupBaseLogger() {
}) })
} }
// isDirWritable checks if the specified directory exists and is writable by attempting to create and remove a test file.
func isDirWritable(dir string) bool {
info, err := os.Stat(dir)
if err != nil || !info.IsDir() {
return false
}
testFile := filepath.Join(dir, ".perm_test")
f, err := os.Create(testFile)
if err != nil {
return false
}
defer func() {
_ = f.Close()
_ = os.Remove(testFile)
}()
return true
}
// ConfigureLogOutput switches the global log destination between rotating files and stdout. // ConfigureLogOutput switches the global log destination between rotating files and stdout.
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory // When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
// until the total size is within the limit. // until the total size is within the limit.
func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error { func ConfigureLogOutput(cfg *config.Config) error {
SetupBaseLogger() SetupBaseLogger()
writerMu.Lock() writerMu.Lock()
@@ -96,10 +117,12 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
logDir := "logs" logDir := "logs"
if base := util.WritablePath(); base != "" { if base := util.WritablePath(); base != "" {
logDir = filepath.Join(base, "logs") logDir = filepath.Join(base, "logs")
} else if !isDirWritable(logDir) {
logDir = filepath.Join(cfg.AuthDir, "logs")
} }
protectedPath := "" protectedPath := ""
if loggingToFile { if cfg.LoggingToFile {
if err := os.MkdirAll(logDir, 0o755); err != nil { if err := os.MkdirAll(logDir, 0o755); err != nil {
return fmt.Errorf("logging: failed to create log directory: %w", err) return fmt.Errorf("logging: failed to create log directory: %w", err)
} }
@@ -123,7 +146,7 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
log.SetOutput(os.Stdout) log.SetOutput(os.Stdout)
} }
configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath) configureLogDirCleanerLocked(logDir, cfg.LogsMaxTotalSizeMB, protectedPath)
return nil return nil
} }

View File

@@ -741,7 +741,7 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600}, {ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600}, {ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000}, {ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000}, {ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
} }
models := make([]*ModelInfo, 0, len(entries)) models := make([]*ModelInfo, 0, len(entries))
for _, entry := range entries { for _, entry := range entries {

View File

@@ -76,7 +76,12 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
// Execute performs a non-streaming request to the Antigravity API. // Execute performs a non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if strings.Contains(req.Model, "claude") { upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
}
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
if isClaude {
return e.executeClaudeNonStream(ctx, auth, req, opts) return e.executeClaudeNonStream(ctx, auth, req, opts)
} }
@@ -98,7 +103,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
@@ -109,7 +114,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
var lastErr error var lastErr error
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL) httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, false, opts.Alt, baseURL)
if errReq != nil { if errReq != nil {
err = errReq err = errReq
return resp, err return resp, err
@@ -190,10 +195,15 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
}
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated, true)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
@@ -204,7 +214,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
var lastErr error var lastErr error
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL) httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL)
if errReq != nil { if errReq != nil {
err = errReq err = errReq
return resp, err return resp, err
@@ -524,10 +534,16 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
}
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
@@ -538,7 +554,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
var lastErr error var lastErr error
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL) httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL)
if errReq != nil { if errReq != nil {
err = errReq err = errReq
return nil, err return nil, err
@@ -676,6 +692,12 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
respCtx := context.WithValue(ctx, "alt", opts.Alt) respCtx := context.WithValue(ctx, "alt", opts.Alt)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
}
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -694,7 +716,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model) payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload) payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
payload = normalizeAntigravityThinking(req.Model, payload) payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
payload = deleteJSONField(payload, "project") payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model") payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings") payload = deleteJSONField(payload, "request.safetySettings")
@@ -1308,7 +1330,7 @@ func alias2ModelName(modelName string) string {
// normalizeAntigravityThinking clamps or removes thinking config based on model support. // normalizeAntigravityThinking clamps or removes thinking config based on model support.
// For Claude models, it additionally ensures thinking budget < max_tokens. // For Claude models, it additionally ensures thinking budget < max_tokens.
func normalizeAntigravityThinking(model string, payload []byte) []byte { func normalizeAntigravityThinking(model string, payload []byte, isClaude bool) []byte {
payload = util.StripThinkingConfigIfUnsupported(model, payload) payload = util.StripThinkingConfigIfUnsupported(model, payload)
if !util.ModelSupportsThinking(model) { if !util.ModelSupportsThinking(model) {
return payload return payload
@@ -1320,7 +1342,6 @@ func normalizeAntigravityThinking(model string, payload []byte) []byte {
raw := int(budget.Int()) raw := int(budget.Int())
normalized := util.NormalizeThinkingBudget(model, raw) normalized := util.NormalizeThinkingBudget(model, raw)
isClaude := strings.Contains(strings.ToLower(model), "claude")
if isClaude { if isClaude {
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload) effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
if effectiveMax > 0 && normalized >= effectiveMax { if effectiveMax > 0 && normalized >= effectiveMax {

View File

@@ -74,6 +74,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} }
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled // Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body) body = ensureMaxTokensForThinking(req.Model, body)
@@ -185,6 +188,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
body = checkSystemInstructions(body) body = checkSystemInstructions(body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled // Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body) body = ensureMaxTokensForThinking(req.Model, body)
@@ -461,6 +467,19 @@ func (e *ClaudeExecutor) injectThinkingConfig(modelName string, metadata map[str
return util.ApplyClaudeThinkingConfig(body, budget) return util.ApplyClaudeThinkingConfig(body, budget)
} }
// disableThinkingIfToolChoiceForced checks if tool_choice forces tool use and disables thinking.
// Anthropic API does not allow thinking when tool_choice is set to "any" or a specific tool.
// See: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations
func disableThinkingIfToolChoiceForced(body []byte) []byte {
toolChoiceType := gjson.GetBytes(body, "tool_choice.type").String()
// "auto" is allowed with thinking, but "any" or "tool" (specific tool) are not
if toolChoiceType == "any" || toolChoiceType == "tool" {
// Remove thinking configuration entirely to avoid API error
body, _ = sjson.DeleteBytes(body, "thinking")
}
return body
}
// ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled. // ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled.
// Anthropic API requires this constraint; violating it returns a 400 error. // Anthropic API requires this constraint; violating it returns a 400 error.
// This function should be called after all thinking configuration is finalized. // This function should be called after all thinking configuration is finalized.

View File

@@ -50,6 +50,16 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
@@ -147,6 +157,16 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
@@ -247,12 +267,22 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
modelForCounting := req.Model modelForCounting := upstreamModel
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", upstreamModel)
@@ -520,3 +550,87 @@ func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
} }
return return
} }
func (e *CodexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveCodexConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
func (e *CodexExecutor) resolveCodexConfig(auth *cliproxyauth.Auth) *config.CodexKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.CodexKey {
entry := &e.cfg.CodexKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.CodexKey {
entry := &e.cfg.CodexKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}

View File

@@ -67,6 +67,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
return resp, errValidate return resp, errValidate
} }
body = applyIFlowThinkingConfig(body) body = applyIFlowThinkingConfig(body)
body = preserveReasoningContentInMessages(body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -159,6 +160,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
return nil, errValidate return nil, errValidate
} }
body = applyIFlowThinkingConfig(body) body = applyIFlowThinkingConfig(body)
body = preserveReasoningContentInMessages(body)
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour. // Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
toolsResult := gjson.GetBytes(body, "tools") toolsResult := gjson.GetBytes(body, "tools")
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 { if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
@@ -445,20 +447,98 @@ func ensureToolsArray(body []byte) []byte {
return updated return updated
} }
// applyIFlowThinkingConfig converts normalized reasoning_effort to iFlow chat_template_kwargs.enable_thinking. // preserveReasoningContentInMessages ensures reasoning_content from assistant messages in the
// This should be called after NormalizeThinkingConfig has processed the payload. // conversation history is preserved when sending to iFlow models that support thinking.
// iFlow only supports boolean enable_thinking, so any non-"none" effort enables thinking. // This is critical for multi-turn conversations where the model needs to see its previous
func applyIFlowThinkingConfig(body []byte) []byte { // reasoning to maintain coherent thought chains across tool calls and conversation turns.
effort := gjson.GetBytes(body, "reasoning_effort") //
if !effort.Exists() { // For GLM-4.7 and MiniMax-M2.1, the full assistant response (including reasoning) must be
// appended back into message history before the next call.
func preserveReasoningContentInMessages(body []byte) []byte {
model := strings.ToLower(gjson.GetBytes(body, "model").String())
// Only apply to models that support thinking with history preservation
needsPreservation := strings.HasPrefix(model, "glm-4.7") ||
strings.HasPrefix(model, "glm-4-7") ||
strings.HasPrefix(model, "minimax-m2.1") ||
strings.HasPrefix(model, "minimax-m2-1")
if !needsPreservation {
return body return body
} }
val := strings.ToLower(strings.TrimSpace(effort.String())) messages := gjson.GetBytes(body, "messages")
enableThinking := val != "none" && val != "" if !messages.Exists() || !messages.IsArray() {
return body
}
body, _ = sjson.DeleteBytes(body, "reasoning_effort") // Check if any assistant message already has reasoning_content preserved
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking) hasReasoningContent := false
messages.ForEach(func(_, msg gjson.Result) bool {
role := msg.Get("role").String()
if role == "assistant" {
rc := msg.Get("reasoning_content")
if rc.Exists() && rc.String() != "" {
hasReasoningContent = true
return false // stop iteration
}
}
return true
})
// If reasoning content is already present, the messages are properly formatted
// No need to modify - the client has correctly preserved reasoning in history
if hasReasoningContent {
log.Debugf("iflow executor: reasoning_content found in message history for %s", model)
}
return body
}
// applyIFlowThinkingConfig converts normalized reasoning_effort to model-specific thinking configurations.
// This should be called after NormalizeThinkingConfig has processed the payload.
//
// Model-specific handling:
// - GLM-4.7: Uses extra_body={"thinking": {"type": "enabled"}, "clear_thinking": false}
// - MiniMax-M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
// - Other iFlow models: Uses chat_template_kwargs.enable_thinking (boolean)
func applyIFlowThinkingConfig(body []byte) []byte {
effort := gjson.GetBytes(body, "reasoning_effort")
model := strings.ToLower(gjson.GetBytes(body, "model").String())
// Check if thinking should be enabled
val := ""
if effort.Exists() {
val = strings.ToLower(strings.TrimSpace(effort.String()))
}
enableThinking := effort.Exists() && val != "none" && val != ""
// Remove reasoning_effort as we'll convert to model-specific format
if effort.Exists() {
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
}
// GLM-4.7: Use extra_body with thinking config and clear_thinking: false
if strings.HasPrefix(model, "glm-4.7") || strings.HasPrefix(model, "glm-4-7") {
if enableThinking {
body, _ = sjson.SetBytes(body, "extra_body.thinking.type", "enabled")
body, _ = sjson.SetBytes(body, "extra_body.clear_thinking", false)
}
return body
}
// MiniMax-M2.1: Use reasoning_split=true for interleaved thinking
if strings.HasPrefix(model, "minimax-m2.1") || strings.HasPrefix(model, "minimax-m2-1") {
if enableThinking {
body, _ = sjson.SetBytes(body, "reasoning_split", true)
}
return body
}
// Other iFlow models (including GLM-4.6): Use chat_template_kwargs.enable_thinking
if effort.Exists() {
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
}
return body return body
} }

View File

@@ -19,7 +19,7 @@ type usageReporter struct {
provider string provider string
model string model string
authID string authID string
authIndex uint64 authIndex string
apiKey string apiKey string
source string source string
requestedAt time.Time requestedAt time.Time
@@ -482,12 +482,16 @@ func StripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
cleaned := jsonBytes cleaned := jsonBytes
var changed bool var changed bool
if gjson.GetBytes(cleaned, "usageMetadata").Exists() { if usageMetadata = gjson.GetBytes(cleaned, "usageMetadata"); usageMetadata.Exists() {
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
cleaned, _ = sjson.SetRawBytes(cleaned, "cpaUsageMetadata", []byte(usageMetadata.Raw))
cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata") cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata")
changed = true changed = true
} }
if gjson.GetBytes(cleaned, "response.usageMetadata").Exists() { if usageMetadata = gjson.GetBytes(cleaned, "response.usageMetadata"); usageMetadata.Exists() {
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
cleaned, _ = sjson.SetRawBytes(cleaned, "response.cpaUsageMetadata", []byte(usageMetadata.Raw))
cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata") cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata")
changed = true changed = true
} }

View File

@@ -99,6 +99,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// This follows the Claude Code API specification for streaming message initialization // This follows the Claude Code API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}` messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
// Use cpaUsageMetadata within the message_start event for Claude.
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
}
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
}
// Override default values with actual response metadata if available from the Gemini CLI response // Override default values with actual response metadata if available from the Gemini CLI response
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String()) messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
@@ -271,11 +279,11 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() { if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
params.HasUsageMetadata = true params.HasUsageMetadata = true
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int() params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int()
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int() - params.CachedTokenCount
params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int() params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int()
params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int() params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int()
params.TotalTokenCount = usageResult.Get("totalTokenCount").Int() params.TotalTokenCount = usageResult.Get("totalTokenCount").Int()
params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int()
if params.CandidatesTokenCount == 0 && params.TotalTokenCount > 0 { if params.CandidatesTokenCount == 0 && params.TotalTokenCount > 0 {
params.CandidatesTokenCount = params.TotalTokenCount - params.PromptTokenCount - params.ThoughtsTokenCount params.CandidatesTokenCount = params.TotalTokenCount - params.PromptTokenCount - params.ThoughtsTokenCount
if params.CandidatesTokenCount < 0 { if params.CandidatesTokenCount < 0 {

View File

@@ -247,10 +247,30 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else if role == "assistant" { } else if role == "assistant" {
node := []byte(`{"role":"model","parts":[]}`) node := []byte(`{"role":"model","parts":[]}`)
p := 0 p := 0
if content.Type == gjson.String { if content.Type == gjson.String && content.String() != "" {
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String()) node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
p++ p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 { // expect data:...
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
p++
}
}
}
}
} }
// Tool calls -> single model content with functionCall parts // Tool calls -> single model content with functionCall parts
@@ -305,6 +325,8 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if pp > 0 { if pp > 0 {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode) out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
} }
} else {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
} }
} }
} }

View File

@@ -87,15 +87,15 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
// Extract and set usage metadata (token counts). // Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() { if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
} }
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
} }
promptTokenCount := usageResult.Get("promptTokenCount").Int() promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 { if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
@@ -181,12 +181,14 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
mimeType = "image/png" mimeType = "image/png"
} }
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
imagesResult := gjson.Get(template, "choices.0.delta.images") imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() { if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`) template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
} }
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload) template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
} }

View File

@@ -209,9 +209,12 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
inputTokens := usage.Get("input_tokens").Int() inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int() outputTokens := usage.Get("output_tokens").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens) cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens) template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens) template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
} }
return []string{template} return []string{template}
@@ -285,8 +288,6 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
var messageID string var messageID string
var model string var model string
var createdAt int64 var createdAt int64
var inputTokens, outputTokens int64
var reasoningTokens int64
var stopReason string var stopReason string
var contentParts []string var contentParts []string
var reasoningParts []string var reasoningParts []string
@@ -303,9 +304,6 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
messageID = message.Get("id").String() messageID = message.Get("id").String()
model = message.Get("model").String() model = message.Get("model").String()
createdAt = time.Now().Unix() createdAt = time.Now().Unix()
if usage := message.Get("usage"); usage.Exists() {
inputTokens = usage.Get("input_tokens").Int()
}
} }
case "content_block_start": case "content_block_start":
@@ -368,11 +366,14 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
} }
} }
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
outputTokens = usage.Get("output_tokens").Int() inputTokens := usage.Get("input_tokens").Int()
// Estimate reasoning tokens from accumulated thinking content outputTokens := usage.Get("output_tokens").Int()
if len(reasoningParts) > 0 { cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
reasoningTokens = int64(len(strings.Join(reasoningParts, "")) / 4) // Rough estimation cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
} out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
} }
} }
} }
@@ -431,16 +432,5 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason)) out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
} }
// Set usage information including prompt tokens, completion tokens, and total tokens
totalTokens := inputTokens + outputTokens
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", totalTokens)
// Add reasoning tokens to usage details if any reasoning content was processed
if reasoningTokens > 0 {
out, _ = sjson.Set(out, "usage.completion_tokens_details.reasoning_tokens", reasoningTokens)
}
return out return out
} }

View File

@@ -114,13 +114,16 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
var builder strings.Builder var builder strings.Builder
if parts := item.Get("content"); parts.Exists() && parts.IsArray() { if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool { parts.ForEach(func(_, part gjson.Result) bool {
text := part.Get("text").String() textResult := part.Get("text")
text := textResult.String()
if builder.Len() > 0 && text != "" { if builder.Len() > 0 && text != "" {
builder.WriteByte('\n') builder.WriteByte('\n')
} }
builder.WriteString(text) builder.WriteString(text)
return true return true
}) })
} else if parts.Type == gjson.String {
builder.WriteString(parts.String())
} }
instructionsText = builder.String() instructionsText = builder.String()
if instructionsText != "" { if instructionsText != "" {
@@ -207,6 +210,8 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
} }
return true return true
}) })
} else if parts.Type == gjson.String {
textAggregate.WriteString(parts.String())
} }
// Fallback to given role if content types not decisive // Fallback to given role if content types not decisive

View File

@@ -218,8 +218,29 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if content.Type == gjson.String { if content.Type == gjson.String {
// Assistant text -> single model content // Assistant text -> single model content
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String()) node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
p++ p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 { // expect data:...
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
p++
}
}
}
}
} }
// Tool calls -> single model content with functionCall parts // Tool calls -> single model content with functionCall parts
@@ -260,6 +281,8 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if pp > 0 { if pp > 0 {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode) out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
} }
} else {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
} }
} }
} }

View File

@@ -170,12 +170,14 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
mimeType = "image/png" mimeType = "image/png"
} }
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
imagesResult := gjson.Get(template, "choices.0.delta.images") imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() { if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`) template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
} }
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload) template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
} }

View File

@@ -56,7 +56,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction) out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction)
} }
} else if systemResult.Type == gjson.String { } else if systemResult.Type == gjson.String {
out, _ = sjson.Set(out, "request.system_instruction.parts.-1.text", systemResult.String()) out, _ = sjson.Set(out, "system_instruction.parts.-1.text", systemResult.String())
} }
// contents // contents

View File

@@ -233,18 +233,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
} else if role == "assistant" { } else if role == "assistant" {
node := []byte(`{"role":"model","parts":[]}`) node := []byte(`{"role":"model","parts":[]}`)
p := 0 p := 0
if content.Type == gjson.String { if content.Type == gjson.String {
// Assistant text -> single model content // Assistant text -> single model content
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String()) node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
p++ p++
} else if content.IsArray() { } else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts // Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() { for _, item := range content.Array() {
switch item.Get("type").String() { switch item.Get("type").String() {
case "text": case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
p++ p++
case "image_url": case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity. // If the assistant returned an inline data URL, preserve it for history fidelity.
@@ -261,7 +258,6 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
} }
} }
} }
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
} }
// Tool calls -> single model content with functionCall parts // Tool calls -> single model content with functionCall parts
@@ -302,6 +298,8 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
if pp > 0 { if pp > 0 {
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode) out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
} }
} else {
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
} }
} }
} }

View File

@@ -89,15 +89,15 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
// Extract and set usage metadata (token counts). // Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() { if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
} }
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
} }
promptTokenCount := usageResult.Get("promptTokenCount").Int() promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 { if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
@@ -182,12 +182,14 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
mimeType = "image/png" mimeType = "image/png"
} }
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
imagesResult := gjson.Get(template, "choices.0.delta.images") imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() { if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`) template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
} }
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload) template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
} }
@@ -316,12 +318,14 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
mimeType = "image/png" mimeType = "image/png"
} }
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
imagesResult := gjson.Get(template, "choices.0.message.images") imagesResult := gjson.Get(template, "choices.0.message.images")
if !imagesResult.Exists() || !imagesResult.IsArray() { if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`) template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`)
} }
imageIndex := len(gjson.Get(template, "choices.0.message.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.message.role", "assistant") template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", imagePayload) template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", imagePayload)
} }

View File

@@ -6,6 +6,7 @@ package usage
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -90,7 +91,7 @@ type modelStats struct {
type RequestDetail struct { type RequestDetail struct {
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
Source string `json:"source"` Source string `json:"source"`
AuthIndex uint64 `json:"auth_index"` AuthIndex string `json:"auth_index"`
Tokens TokenStats `json:"tokens"` Tokens TokenStats `json:"tokens"`
Failed bool `json:"failed"` Failed bool `json:"failed"`
} }
@@ -281,6 +282,118 @@ func (s *RequestStatistics) Snapshot() StatisticsSnapshot {
return result return result
} }
type MergeResult struct {
Added int64 `json:"added"`
Skipped int64 `json:"skipped"`
}
// MergeSnapshot merges an exported statistics snapshot into the current store.
// Existing data is preserved and duplicate request details are skipped.
func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResult {
result := MergeResult{}
if s == nil {
return result
}
s.mu.Lock()
defer s.mu.Unlock()
seen := make(map[string]struct{})
for apiName, stats := range s.apis {
if stats == nil {
continue
}
for modelName, modelStatsValue := range stats.Models {
if modelStatsValue == nil {
continue
}
for _, detail := range modelStatsValue.Details {
seen[dedupKey(apiName, modelName, detail)] = struct{}{}
}
}
}
for apiName, apiSnapshot := range snapshot.APIs {
apiName = strings.TrimSpace(apiName)
if apiName == "" {
continue
}
stats, ok := s.apis[apiName]
if !ok || stats == nil {
stats = &apiStats{Models: make(map[string]*modelStats)}
s.apis[apiName] = stats
} else if stats.Models == nil {
stats.Models = make(map[string]*modelStats)
}
for modelName, modelSnapshot := range apiSnapshot.Models {
modelName = strings.TrimSpace(modelName)
if modelName == "" {
modelName = "unknown"
}
for _, detail := range modelSnapshot.Details {
detail.Tokens = normaliseTokenStats(detail.Tokens)
if detail.Timestamp.IsZero() {
detail.Timestamp = time.Now()
}
key := dedupKey(apiName, modelName, detail)
if _, exists := seen[key]; exists {
result.Skipped++
continue
}
seen[key] = struct{}{}
s.recordImported(apiName, modelName, stats, detail)
result.Added++
}
}
}
return result
}
func (s *RequestStatistics) recordImported(apiName, modelName string, stats *apiStats, detail RequestDetail) {
totalTokens := detail.Tokens.TotalTokens
if totalTokens < 0 {
totalTokens = 0
}
s.totalRequests++
if detail.Failed {
s.failureCount++
} else {
s.successCount++
}
s.totalTokens += totalTokens
s.updateAPIStats(stats, modelName, detail)
dayKey := detail.Timestamp.Format("2006-01-02")
hourKey := detail.Timestamp.Hour()
s.requestsByDay[dayKey]++
s.requestsByHour[hourKey]++
s.tokensByDay[dayKey] += totalTokens
s.tokensByHour[hourKey] += totalTokens
}
func dedupKey(apiName, modelName string, detail RequestDetail) string {
timestamp := detail.Timestamp.UTC().Format(time.RFC3339Nano)
tokens := normaliseTokenStats(detail.Tokens)
return fmt.Sprintf(
"%s|%s|%s|%s|%s|%t|%d|%d|%d|%d|%d",
apiName,
modelName,
timestamp,
detail.Source,
detail.AuthIndex,
detail.Failed,
tokens.InputTokens,
tokens.OutputTokens,
tokens.ReasoningTokens,
tokens.CachedTokens,
tokens.TotalTokens,
)
}
func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string { func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string {
if ctx != nil { if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil { if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
@@ -340,6 +453,16 @@ func normaliseDetail(detail coreusage.Detail) TokenStats {
return tokens return tokens
} }
func normaliseTokenStats(tokens TokenStats) TokenStats {
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
}
return tokens
}
func formatHour(hour int) string { func formatHour(hour int) string {
if hour < 0 { if hour < 0 {
hour = 0 hour = 0

View File

@@ -344,7 +344,7 @@ func cleanupRequiredFields(jsonStr string) string {
} }
// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas. // addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas.
// Claude VALIDATED mode requires at least one property in tool schemas. // Claude VALIDATED mode requires at least one required property in tool schemas.
func addEmptySchemaPlaceholder(jsonStr string) string { func addEmptySchemaPlaceholder(jsonStr string) string {
// Find all "type" fields // Find all "type" fields
paths := findPaths(jsonStr, "type") paths := findPaths(jsonStr, "type")
@@ -364,6 +364,9 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
// Check if properties exists and is empty or missing // Check if properties exists and is empty or missing
propsPath := joinPath(parentPath, "properties") propsPath := joinPath(parentPath, "properties")
propsVal := gjson.Get(jsonStr, propsPath) propsVal := gjson.Get(jsonStr, propsPath)
reqPath := joinPath(parentPath, "required")
reqVal := gjson.Get(jsonStr, reqPath)
hasRequiredProperties := reqVal.IsArray() && len(reqVal.Array()) > 0
needsPlaceholder := false needsPlaceholder := false
if !propsVal.Exists() { if !propsVal.Exists() {
@@ -381,8 +384,17 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool") jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
// Add to required array // Add to required array
reqPath := joinPath(parentPath, "required")
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"}) jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
continue
}
// If schema has properties but none are required, add a minimal placeholder.
if propsVal.IsObject() && !hasRequiredProperties {
placeholderPath := joinPath(propsPath, "_")
if !gjson.Get(jsonStr, placeholderPath).Exists() {
jsonStr, _ = sjson.Set(jsonStr, placeholderPath+".type", "boolean")
}
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"_"})
} }
} }

View File

@@ -614,71 +614,6 @@ func TestCleanJSONSchemaForAntigravity_MultipleNonNullTypes(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) {
// propertyNames is used to validate object property names (e.g., must match a pattern)
// Gemini doesn't support this keyword and will reject requests containing it
input := `{
"type": "object",
"properties": {
"metadata": {
"type": "object",
"propertyNames": {
"pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$"
},
"additionalProperties": {
"type": "string"
}
}
}
}`
expected := `{
"type": "object",
"properties": {
"metadata": {
"type": "object"
}
}
}`
result := CleanJSONSchemaForGemini(input)
compareJSON(t, expected, result)
// Verify propertyNames is completely removed
if strings.Contains(result, "propertyNames") {
t.Errorf("propertyNames keyword should be removed, got: %s", result)
}
}
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) {
// Test deeply nested propertyNames (as seen in real Claude tool schemas)
input := `{
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"config": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
}
}
}
}
}`
result := CleanJSONSchemaForGemini(input)
if strings.Contains(result, "propertyNames") {
t.Errorf("Nested propertyNames should be removed, got: %s", result)
}
}
func compareJSON(t *testing.T, expectedJSON, actualJSON string) { func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
var expMap, actMap map[string]interface{} var expMap, actMap map[string]interface{}
errExp := json.Unmarshal([]byte(expectedJSON), &expMap) errExp := json.Unmarshal([]byte(expectedJSON), &expMap)

View File

@@ -7,10 +7,11 @@ import (
) )
const ( const (
ThinkingBudgetMetadataKey = "thinking_budget" ThinkingBudgetMetadataKey = "thinking_budget"
ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts" ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts"
ReasoningEffortMetadataKey = "reasoning_effort" ReasoningEffortMetadataKey = "reasoning_effort"
ThinkingOriginalModelMetadataKey = "thinking_original_model" ThinkingOriginalModelMetadataKey = "thinking_original_model"
ModelMappingOriginalModelMetadataKey = "model_mapping_original_model"
) )
// NormalizeThinkingModel parses dynamic thinking suffixes on model names and returns // NormalizeThinkingModel parses dynamic thinking suffixes on model names and returns
@@ -215,6 +216,13 @@ func ResolveOriginalModel(model string, metadata map[string]any) string {
} }
if metadata != nil { if metadata != nil {
if v, ok := metadata[ModelMappingOriginalModelMetadataKey]; ok {
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
if base := normalize(s); base != "" {
return base
}
}
}
if v, ok := metadata[ThinkingOriginalModelMetadataKey]; ok { if v, ok := metadata[ThinkingOriginalModelMetadataKey]; ok {
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" { if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
if base := normalize(s); base != "" { if base := normalize(s); base != "" {

View File

@@ -6,6 +6,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"os" "os"
"reflect"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -126,7 +127,7 @@ func (w *Watcher) reloadConfig() bool {
} }
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelMappings, newConfig.OAuthModelMappings))
log.Infof("config successfully reloaded, triggering client reload") log.Infof("config successfully reloaded, triggering client reload")
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh) w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)

View File

@@ -185,6 +185,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings { if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings)) changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
} }
oldUpstreamAPIKeysCount := len(oldCfg.AmpCode.UpstreamAPIKeys)
newUpstreamAPIKeysCount := len(newCfg.AmpCode.UpstreamAPIKeys)
if !equalUpstreamAPIKeys(oldCfg.AmpCode.UpstreamAPIKeys, newCfg.AmpCode.UpstreamAPIKeys) {
changes = append(changes, fmt.Sprintf("ampcode.upstream-api-keys: updated (%d -> %d entries)", oldUpstreamAPIKeysCount, newUpstreamAPIKeysCount))
}
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 { if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
changes = append(changes, entries...) changes = append(changes, entries...)
@@ -301,3 +306,43 @@ func formatProxyURL(raw string) string {
} }
return scheme + "://" + host return scheme + "://" + host
} }
func equalStringSet(a, b []string) bool {
if len(a) == 0 && len(b) == 0 {
return true
}
aSet := make(map[string]struct{}, len(a))
for _, k := range a {
aSet[strings.TrimSpace(k)] = struct{}{}
}
bSet := make(map[string]struct{}, len(b))
for _, k := range b {
bSet[strings.TrimSpace(k)] = struct{}{}
}
if len(aSet) != len(bSet) {
return false
}
for k := range aSet {
if _, ok := bSet[k]; !ok {
return false
}
}
return true
}
// equalUpstreamAPIKeys compares two slices of AmpUpstreamAPIKeyEntry for equality.
// Comparison is done by count and content (upstream key and client keys).
func equalUpstreamAPIKeys(a, b []config.AmpUpstreamAPIKeyEntry) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if strings.TrimSpace(a[i].UpstreamAPIKey) != strings.TrimSpace(b[i].UpstreamAPIKey) {
return false
}
if !equalStringSet(a[i].APIKeys, b[i].APIKeys) {
return false
}
}
return true
}

View File

@@ -56,6 +56,21 @@ func ComputeClaudeModelsHash(models []config.ClaudeModel) string {
return hashJoined(keys) return hashJoined(keys)
} }
// ComputeCodexModelsHash returns a stable hash for Codex model aliases.
func ComputeCodexModelsHash(models []config.CodexModel) string {
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return hashJoined(keys)
}
// ComputeExcludedModelsHash returns a normalized hash for excluded model lists. // ComputeExcludedModelsHash returns a normalized hash for excluded model lists.
func ComputeExcludedModelsHash(excluded []string) string { func ComputeExcludedModelsHash(excluded []string) string {
if len(excluded) == 0 { if len(excluded) == 0 {

View File

@@ -81,6 +81,15 @@ func TestComputeClaudeModelsHash_Empty(t *testing.T) {
} }
} }
func TestComputeCodexModelsHash_Empty(t *testing.T) {
if got := ComputeCodexModelsHash(nil); got != "" {
t.Fatalf("expected empty hash for nil models, got %q", got)
}
if got := ComputeCodexModelsHash([]config.CodexModel{}); got != "" {
t.Fatalf("expected empty hash for empty slice, got %q", got)
}
}
func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) { func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
a := []config.ClaudeModel{ a := []config.ClaudeModel{
{Name: "m1", Alias: "a1"}, {Name: "m1", Alias: "a1"},
@@ -95,6 +104,20 @@ func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
} }
} }
func TestComputeCodexModelsHash_IgnoresBlankAndDedup(t *testing.T) {
a := []config.CodexModel{
{Name: "m1", Alias: "a1"},
{Name: " "},
{Name: "M1", Alias: "A1"},
}
b := []config.CodexModel{
{Name: "m1", Alias: "a1"},
}
if h1, h2 := ComputeCodexModelsHash(a), ComputeCodexModelsHash(b); h1 == "" || h1 != h2 {
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
}
}
func TestComputeExcludedModelsHash_Normalizes(t *testing.T) { func TestComputeExcludedModelsHash_Normalizes(t *testing.T) {
hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"}) hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"})
hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"}) hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"})
@@ -157,3 +180,15 @@ func TestComputeClaudeModelsHash_Deterministic(t *testing.T) {
t.Fatalf("expected different hash when models change, got %s", h3) t.Fatalf("expected different hash when models change, got %s", h3)
} }
} }
func TestComputeCodexModelsHash_Deterministic(t *testing.T) {
models := []config.CodexModel{{Name: "a", Alias: "A"}, {Name: "b"}}
h1 := ComputeCodexModelsHash(models)
h2 := ComputeCodexModelsHash(models)
if h1 == "" || h1 != h2 {
t.Fatalf("expected deterministic hash, got %s / %s", h1, h2)
}
if h3 := ComputeCodexModelsHash([]config.CodexModel{{Name: "a"}}); h3 == h1 {
t.Fatalf("expected different hash when models change, got %s", h3)
}
}

View File

@@ -151,6 +151,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
if ck.BaseURL != "" { if ck.BaseURL != "" {
attrs["base_url"] = ck.BaseURL attrs["base_url"] = ck.BaseURL
} }
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(ck.Headers, attrs) addConfigHeadersToAttrs(ck.Headers, attrs)
proxyURL := strings.TrimSpace(ck.ProxyURL) proxyURL := strings.TrimSpace(ck.ProxyURL)
a := &coreauth.Auth{ a := &coreauth.Auth{

View File

@@ -104,8 +104,8 @@ func BuildErrorResponseBody(status int, errText string) []byte {
// Returning 0 disables keep-alives (default when unset). // Returning 0 disables keep-alives (default when unset).
func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration { func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
seconds := defaultStreamingKeepAliveSeconds seconds := defaultStreamingKeepAliveSeconds
if cfg != nil && cfg.Streaming.KeepAliveSeconds != nil { if cfg != nil {
seconds = *cfg.Streaming.KeepAliveSeconds seconds = cfg.Streaming.KeepAliveSeconds
} }
if seconds <= 0 { if seconds <= 0 {
return 0 return 0
@@ -116,8 +116,8 @@ func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent. // StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
func StreamingBootstrapRetries(cfg *config.SDKConfig) int { func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
retries := defaultStreamingBootstrapRetries retries := defaultStreamingBootstrapRetries
if cfg != nil && cfg.Streaming.BootstrapRetries != nil { if cfg != nil {
retries = *cfg.Streaming.BootstrapRetries retries = cfg.Streaming.BootstrapRetries
} }
if retries < 0 { if retries < 0 {
retries = 0 retries = 0
@@ -619,7 +619,22 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
} }
body := BuildErrorResponseBody(status, errText) body := BuildErrorResponseBody(status, errText)
c.Set("API_RESPONSE", bytes.Clone(body)) // Append first to preserve upstream response logs, then drop duplicate payloads if already recorded.
var previous []byte
if existing, exists := c.Get("API_RESPONSE"); exists {
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
previous = bytes.Clone(existingBytes)
}
}
appendAPIResponse(c, body)
trimmedErrText := strings.TrimSpace(errText)
trimmedBody := bytes.TrimSpace(body)
if len(previous) > 0 {
if (trimmedErrText != "" && bytes.Contains(previous, []byte(trimmedErrText))) ||
(len(trimmedBody) > 0 && bytes.Contains(previous, trimmedBody)) {
c.Set("API_RESPONSE", previous)
}
}
if !c.Writer.Written() { if !c.Writer.Written() {
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")

View File

@@ -94,12 +94,11 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
registry.GetGlobalRegistry().UnregisterClient(auth2.ID) registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
}) })
bootstrapRetries := 1
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{ handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{ Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: &bootstrapRetries, BootstrapRetries: 1,
}, },
}, manager, nil) }, manager)
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil { if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels") t.Fatalf("expected non-nil channels")

62
sdk/api/management.go Normal file
View File

@@ -0,0 +1,62 @@
// Package api exposes helpers for embedding CLIProxyAPI.
//
// It wraps internal management handler types so external projects can integrate
// management endpoints without importing internal packages.
package api
import (
"github.com/gin-gonic/gin"
internalmanagement "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
// ManagementTokenRequester exposes a limited subset of management endpoints for requesting tokens.
type ManagementTokenRequester interface {
RequestAnthropicToken(*gin.Context)
RequestGeminiCLIToken(*gin.Context)
RequestCodexToken(*gin.Context)
RequestAntigravityToken(*gin.Context)
RequestQwenToken(*gin.Context)
RequestIFlowToken(*gin.Context)
RequestIFlowCookieToken(*gin.Context)
}
type managementTokenRequester struct {
handler *internalmanagement.Handler
}
// NewManagementTokenRequester creates a limited management handler exposing only token request endpoints.
func NewManagementTokenRequester(cfg *config.Config, manager *coreauth.Manager) ManagementTokenRequester {
return &managementTokenRequester{
handler: internalmanagement.NewHandlerWithoutConfigFilePath(cfg, manager),
}
}
func (m *managementTokenRequester) RequestAnthropicToken(c *gin.Context) {
m.handler.RequestAnthropicToken(c)
}
func (m *managementTokenRequester) RequestGeminiCLIToken(c *gin.Context) {
m.handler.RequestGeminiCLIToken(c)
}
func (m *managementTokenRequester) RequestCodexToken(c *gin.Context) {
m.handler.RequestCodexToken(c)
}
func (m *managementTokenRequester) RequestAntigravityToken(c *gin.Context) {
m.handler.RequestAntigravityToken(c)
}
func (m *managementTokenRequester) RequestQwenToken(c *gin.Context) {
m.handler.RequestQwenToken(c)
}
func (m *managementTokenRequester) RequestIFlowToken(c *gin.Context) {
m.handler.RequestIFlowToken(c)
}
func (m *managementTokenRequester) RequestIFlowCookieToken(c *gin.Context) {
m.handler.RequestIFlowCookieToken(c)
}

View File

@@ -111,6 +111,9 @@ type Manager struct {
requestRetry atomic.Int32 requestRetry atomic.Int32
maxRetryInterval atomic.Int64 maxRetryInterval atomic.Int64
// modelNameMappings stores global model name alias mappings (alias -> upstream name) keyed by channel.
modelNameMappings atomic.Value
// Optional HTTP RoundTripper provider injected by host. // Optional HTTP RoundTripper provider injected by host.
rtProvider RoundTripperProvider rtProvider RoundTripperProvider
@@ -203,10 +206,10 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
if auth == nil { if auth == nil {
return nil, nil return nil, nil
} }
auth.EnsureIndex()
if auth.ID == "" { if auth.ID == "" {
auth.ID = uuid.NewString() auth.ID = uuid.NewString()
} }
auth.EnsureIndex()
m.mu.Lock() m.mu.Lock()
m.auths[auth.ID] = auth.Clone() m.auths[auth.ID] = auth.Clone()
m.mu.Unlock() m.mu.Unlock()
@@ -221,7 +224,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
return nil, nil return nil, nil
} }
m.mu.Lock() m.mu.Lock()
if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == 0 { if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == "" {
auth.Index = existing.Index auth.Index = existing.Index
auth.indexAssigned = existing.indexAssigned auth.indexAssigned = existing.indexAssigned
} }
@@ -263,7 +266,6 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
} }
rotated := m.rotateProviders(req.Model, normalized) rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
retryTimes, maxWait := m.retrySettings() retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1 attempts := retryTimes + 1
@@ -302,7 +304,6 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
} }
rotated := m.rotateProviders(req.Model, normalized) rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
retryTimes, maxWait := m.retrySettings() retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1 attempts := retryTimes + 1
@@ -341,7 +342,6 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
} }
rotated := m.rotateProviders(req.Model, normalized) rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
retryTimes, maxWait := m.retrySettings() retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1 attempts := retryTimes + 1
@@ -413,6 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
} }
execReq := req execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts) resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil { if errExec != nil {
@@ -474,6 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
} }
execReq := req execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil { if errExec != nil {
@@ -535,6 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
} }
execReq := req execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil { if errStream != nil {
rerr := &Error{Message: errStream.Error()} rerr := &Error{Message: errStream.Error()}
@@ -595,6 +598,7 @@ func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]
keys := []string{ keys := []string{
util.ThinkingOriginalModelMetadataKey, util.ThinkingOriginalModelMetadataKey,
util.GeminiOriginalModelMetadataKey, util.GeminiOriginalModelMetadataKey,
util.ModelMappingOriginalModelMetadataKey,
} }
var out map[string]any var out map[string]any
for _, key := range keys { for _, key := range keys {
@@ -640,13 +644,20 @@ func (m *Manager) normalizeProviders(providers []string) []string {
return result return result
} }
// rotateProviders returns a rotated view of the providers list starting from the
// current offset for the model, and atomically increments the offset for the next call.
// This ensures concurrent requests get different starting providers.
func (m *Manager) rotateProviders(model string, providers []string) []string { func (m *Manager) rotateProviders(model string, providers []string) []string {
if len(providers) == 0 { if len(providers) == 0 {
return nil return nil
} }
m.mu.RLock()
// Atomic read-and-increment: get current offset and advance cursor in one lock
m.mu.Lock()
offset := m.providerOffsets[model] offset := m.providerOffsets[model]
m.mu.RUnlock() m.providerOffsets[model] = (offset + 1) % len(providers)
m.mu.Unlock()
if len(providers) > 0 { if len(providers) > 0 {
offset %= len(providers) offset %= len(providers)
} }
@@ -662,19 +673,6 @@ func (m *Manager) rotateProviders(model string, providers []string) []string {
return rotated return rotated
} }
func (m *Manager) advanceProviderCursor(model string, providers []string) {
if len(providers) == 0 {
m.mu.Lock()
delete(m.providerOffsets, model)
m.mu.Unlock()
return
}
m.mu.Lock()
current := m.providerOffsets[model]
m.providerOffsets[model] = (current + 1) % len(providers)
m.mu.Unlock()
}
func (m *Manager) retrySettings() (int, time.Duration) { func (m *Manager) retrySettings() (int, time.Duration) {
if m == nil { if m == nil {
return 0, 0 return 0, 0

View File

@@ -0,0 +1,172 @@
package auth
import (
"strings"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
type modelNameMappingTable struct {
// reverse maps channel -> alias (lower) -> original upstream model name.
reverse map[string]map[string]string
}
func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelNameMapping) *modelNameMappingTable {
if len(mappings) == 0 {
return &modelNameMappingTable{}
}
out := &modelNameMappingTable{
reverse: make(map[string]map[string]string, len(mappings)),
}
for rawChannel, entries := range mappings {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(entries) == 0 {
continue
}
rev := make(map[string]string, len(entries))
for _, entry := range entries {
name := strings.TrimSpace(entry.Name)
alias := strings.TrimSpace(entry.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
aliasKey := strings.ToLower(alias)
if _, exists := rev[aliasKey]; exists {
continue
}
rev[aliasKey] = name
}
if len(rev) > 0 {
out.reverse[channel] = rev
}
}
if len(out.reverse) == 0 {
out.reverse = nil
}
return out
}
// SetOAuthModelMappings updates the OAuth model name mapping table used during execution.
// The mapping is applied per-auth channel to resolve the upstream model name while keeping the
// client-visible model name unchanged for translation/response formatting.
func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.ModelNameMapping) {
if m == nil {
return
}
table := compileModelNameMappingTable(mappings)
// atomic.Value requires non-nil store values.
if table == nil {
table = &modelNameMappingTable{}
}
m.modelNameMappings.Store(table)
}
func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any {
original := m.resolveOAuthUpstreamModel(auth, requestedModel)
if original == "" {
return metadata
}
if metadata != nil {
if v, ok := metadata[util.ModelMappingOriginalModelMetadataKey]; ok {
if s, okStr := v.(string); okStr && strings.EqualFold(s, original) {
return metadata
}
}
}
out := make(map[string]any, 1)
if len(metadata) > 0 {
out = make(map[string]any, len(metadata)+1)
for k, v := range metadata {
out[k] = v
}
}
out[util.ModelMappingOriginalModelMetadataKey] = original
return out
}
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
if m == nil || auth == nil {
return ""
}
channel := modelMappingChannel(auth)
if channel == "" {
return ""
}
key := strings.ToLower(strings.TrimSpace(requestedModel))
if key == "" {
return ""
}
raw := m.modelNameMappings.Load()
table, _ := raw.(*modelNameMappingTable)
if table == nil || table.reverse == nil {
return ""
}
rev := table.reverse[channel]
if rev == nil {
return ""
}
original := strings.TrimSpace(rev[key])
if original == "" || strings.EqualFold(original, requestedModel) {
return ""
}
return original
}
// modelMappingChannel extracts the OAuth model mapping channel from an Auth object.
// It determines the provider and auth kind from the Auth's attributes and delegates
// to OAuthModelMappingChannel for the actual channel resolution.
func modelMappingChannel(auth *Auth) string {
if auth == nil {
return ""
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
authKind := ""
if auth.Attributes != nil {
authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"]))
}
if authKind == "" {
if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") {
authKind = "apikey"
}
}
return OAuthModelMappingChannel(provider, authKind)
}
// OAuthModelMappingChannel returns the OAuth model mapping channel name for a given provider
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
// OAuth model mappings (e.g., API key authentication).
//
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
func OAuthModelMappingChannel(provider, authKind string) string {
provider = strings.ToLower(strings.TrimSpace(provider))
authKind = strings.ToLower(strings.TrimSpace(authKind))
switch provider {
case "gemini":
// gemini provider uses gemini-api-key config, not oauth-model-mappings.
// OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer.
return ""
case "vertex":
if authKind == "apikey" {
return ""
}
return "vertex"
case "claude":
if authKind == "apikey" {
return ""
}
return "claude"
case "codex":
if authKind == "apikey" {
return ""
}
return "codex"
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow":
return provider
default:
return ""
}
}

View File

@@ -1,11 +1,12 @@
package auth package auth
import ( import (
"crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth" baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
@@ -15,8 +16,8 @@ import (
type Auth struct { type Auth struct {
// ID uniquely identifies the auth record across restarts. // ID uniquely identifies the auth record across restarts.
ID string `json:"id"` ID string `json:"id"`
// Index is a monotonically increasing runtime identifier used for diagnostics. // Index is a stable runtime identifier derived from auth metadata (not persisted).
Index uint64 `json:"-"` Index string `json:"-"`
// Provider is the upstream provider key (e.g. "gemini", "claude"). // Provider is the upstream provider key (e.g. "gemini", "claude").
Provider string `json:"provider"` Provider string `json:"provider"`
// Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview"). // Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview").
@@ -94,12 +95,6 @@ type ModelState struct {
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }
var authIndexCounter atomic.Uint64
func nextAuthIndex() uint64 {
return authIndexCounter.Add(1) - 1
}
// Clone shallow copies the Auth structure, duplicating maps to avoid accidental mutation. // Clone shallow copies the Auth structure, duplicating maps to avoid accidental mutation.
func (a *Auth) Clone() *Auth { func (a *Auth) Clone() *Auth {
if a == nil { if a == nil {
@@ -128,15 +123,41 @@ func (a *Auth) Clone() *Auth {
return &copyAuth return &copyAuth
} }
// EnsureIndex returns the global index, assigning one if it was not set yet. func stableAuthIndex(seed string) string {
func (a *Auth) EnsureIndex() uint64 { seed = strings.TrimSpace(seed)
if a == nil { if seed == "" {
return 0 return ""
} }
if a.indexAssigned { sum := sha256.Sum256([]byte(seed))
return hex.EncodeToString(sum[:8])
}
// EnsureIndex returns a stable index derived from the auth file name or API key.
func (a *Auth) EnsureIndex() string {
if a == nil {
return ""
}
if a.indexAssigned && a.Index != "" {
return a.Index return a.Index
} }
idx := nextAuthIndex()
seed := strings.TrimSpace(a.FileName)
if seed != "" {
seed = "file:" + seed
} else if a.Attributes != nil {
if apiKey := strings.TrimSpace(a.Attributes["api_key"]); apiKey != "" {
seed = "api_key:" + apiKey
}
}
if seed == "" {
if id := strings.TrimSpace(a.ID); id != "" {
seed = "id:" + id
} else {
return ""
}
}
idx := stableAuthIndex(seed)
a.Index = idx a.Index = idx
a.indexAssigned = true a.indexAssigned = true
return idx return idx

View File

@@ -215,6 +215,7 @@ func (b *Builder) Build() (*Service, error) {
} }
// Attach a default RoundTripper provider so providers can opt-in per-auth transports. // Attach a default RoundTripper provider so providers can opt-in per-auth transports.
coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider()) coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider())
coreManager.SetOAuthModelMappings(b.cfg.OAuthModelMappings)
service := &Service{ service := &Service{
cfg: b.cfg, cfg: b.cfg,

View File

@@ -556,6 +556,9 @@ func (s *Service) Run(ctx context.Context) error {
s.cfgMu.Lock() s.cfgMu.Lock()
s.cfg = newCfg s.cfg = newCfg
s.cfgMu.Unlock() s.cfgMu.Unlock()
if s.coreManager != nil {
s.coreManager.SetOAuthModelMappings(newCfg.OAuthModelMappings)
}
s.rebindExecutors() s.rebindExecutors()
} }
@@ -681,6 +684,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
return return
} }
authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"])) authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"]))
if authKind == "" {
if kind, _ := a.AccountInfo(); strings.EqualFold(kind, "api_key") {
authKind = "apikey"
}
}
if a.Attributes != nil { if a.Attributes != nil {
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") { if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
GlobalModelRegistry().UnregisterClient(a.ID) GlobalModelRegistry().UnregisterClient(a.ID)
@@ -745,6 +753,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "codex": case "codex":
models = registry.GetOpenAIModels() models = registry.GetOpenAIModels()
if entry := s.resolveConfigCodexKey(a); entry != nil { if entry := s.resolveConfigCodexKey(a); entry != nil {
if len(entry.Models) > 0 {
models = buildCodexConfigModels(entry)
}
if authKind == "apikey" { if authKind == "apikey" {
excluded = entry.ExcludedModels excluded = entry.ExcludedModels
} }
@@ -842,6 +853,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
} }
} }
} }
models = applyOAuthModelMappings(s.cfg, provider, authKind, models)
if len(models) > 0 { if len(models) > 0 {
key := provider key := provider
if key == "" { if key == "" {
@@ -1151,6 +1163,93 @@ func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
return out return out
} }
func rewriteModelInfoName(name, oldID, newID string) string {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return name
}
oldID = strings.TrimSpace(oldID)
newID = strings.TrimSpace(newID)
if oldID == "" || newID == "" {
return name
}
if strings.EqualFold(oldID, newID) {
return name
}
if strings.HasSuffix(trimmed, "/"+oldID) {
prefix := strings.TrimSuffix(trimmed, oldID)
return prefix + newID
}
if trimmed == "models/"+oldID {
return "models/" + newID
}
return name
}
func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo {
if cfg == nil || len(models) == 0 {
return models
}
channel := coreauth.OAuthModelMappingChannel(provider, authKind)
if channel == "" || len(cfg.OAuthModelMappings) == 0 {
return models
}
mappings := cfg.OAuthModelMappings[channel]
if len(mappings) == 0 {
return models
}
forward := make(map[string]string, len(mappings))
for i := range mappings {
name := strings.TrimSpace(mappings[i].Name)
alias := strings.TrimSpace(mappings[i].Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
key := strings.ToLower(name)
if _, exists := forward[key]; exists {
continue
}
forward[key] = alias
}
if len(forward) == 0 {
return models
}
out := make([]*ModelInfo, 0, len(models))
seen := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil {
continue
}
id := strings.TrimSpace(model.ID)
if id == "" {
continue
}
mappedID := id
if to, ok := forward[strings.ToLower(id)]; ok && strings.TrimSpace(to) != "" {
mappedID = strings.TrimSpace(to)
}
uniqueKey := strings.ToLower(mappedID)
if _, exists := seen[uniqueKey]; exists {
continue
}
seen[uniqueKey] = struct{}{}
if mappedID == id {
out = append(out, model)
continue
}
clone := *model
clone.ID = mappedID
if clone.Name != "" {
clone.Name = rewriteModelInfoName(clone.Name, id, mappedID)
}
out = append(out, &clone)
}
return out
}
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo { func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
if entry == nil || len(entry.Models) == 0 { if entry == nil || len(entry.Models) == 0 {
return nil return nil
@@ -1188,3 +1287,41 @@ func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
} }
return out return out
} }
func buildCodexConfigModels(entry *config.CodexKey) []*ModelInfo {
if entry == nil || len(entry.Models) == 0 {
return nil
}
now := time.Now().Unix()
out := make([]*ModelInfo, 0, len(entry.Models))
seen := make(map[string]struct{}, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if alias == "" {
alias = name
}
if alias == "" {
continue
}
key := strings.ToLower(alias)
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
display := name
if display == "" {
display = alias
}
out = append(out, &ModelInfo{
ID: alias,
Object: "model",
Created: now,
OwnedBy: "openai",
Type: "openai",
DisplayName: display,
})
}
return out
}

View File

@@ -14,7 +14,7 @@ type Record struct {
Model string Model string
APIKey string APIKey string
AuthID string AuthID string
AuthIndex uint64 AuthIndex string
Source string Source string
RequestedAt time.Time RequestedAt time.Time
Failed bool Failed bool

View File

@@ -16,6 +16,7 @@ type StreamingConfig = internalconfig.StreamingConfig
type TLSConfig = internalconfig.TLSConfig type TLSConfig = internalconfig.TLSConfig
type RemoteManagement = internalconfig.RemoteManagement type RemoteManagement = internalconfig.RemoteManagement
type AmpCode = internalconfig.AmpCode type AmpCode = internalconfig.AmpCode
type ModelNameMapping = internalconfig.ModelNameMapping
type PayloadConfig = internalconfig.PayloadConfig type PayloadConfig = internalconfig.PayloadConfig
type PayloadRule = internalconfig.PayloadRule type PayloadRule = internalconfig.PayloadRule
type PayloadModelRule = internalconfig.PayloadModelRule type PayloadModelRule = internalconfig.PayloadModelRule

View File

@@ -56,6 +56,10 @@ func setupAmpRouter(h *management.Handler) *gin.Engine {
mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey) mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey)
mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey) mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey)
mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey) mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey)
mgmt.GET("/ampcode/upstream-api-keys", h.GetAmpUpstreamAPIKeys)
mgmt.PUT("/ampcode/upstream-api-keys", h.PutAmpUpstreamAPIKeys)
mgmt.PATCH("/ampcode/upstream-api-keys", h.PatchAmpUpstreamAPIKeys)
mgmt.DELETE("/ampcode/upstream-api-keys", h.DeleteAmpUpstreamAPIKeys)
mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost) mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost)
mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost) mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost)
mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings) mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings)
@@ -188,6 +192,90 @@ func TestPutAmpUpstreamAPIKey(t *testing.T) {
} }
} }
func TestPutAmpUpstreamAPIKeys_PersistsAndReturns(t *testing.T) {
h, configPath := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value":[{"upstream-api-key":" u1 ","api-keys":[" k1 ","","k2"]}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
// Verify it was persisted to disk
loaded, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("failed to load config from disk: %v", err)
}
if len(loaded.AmpCode.UpstreamAPIKeys) != 1 {
t.Fatalf("expected 1 upstream-api-keys entry, got %d", len(loaded.AmpCode.UpstreamAPIKeys))
}
entry := loaded.AmpCode.UpstreamAPIKeys[0]
if entry.UpstreamAPIKey != "u1" {
t.Fatalf("expected upstream-api-key u1, got %q", entry.UpstreamAPIKey)
}
if len(entry.APIKeys) != 2 || entry.APIKeys[0] != "k1" || entry.APIKeys[1] != "k2" {
t.Fatalf("expected api-keys [k1 k2], got %#v", entry.APIKeys)
}
// Verify it is returned by GET /ampcode
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]config.AmpCode
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if got := resp["ampcode"].UpstreamAPIKeys; len(got) != 1 || got[0].UpstreamAPIKey != "u1" {
t.Fatalf("expected upstream-api-keys to be present after update, got %#v", got)
}
}
func TestDeleteAmpUpstreamAPIKeys_ClearsAll(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
// Seed with one entry
putBody := `{"value":[{"upstream-api-key":"u1","api-keys":["k1"]}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(putBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
deleteBody := `{"value":[]}`
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(deleteBody))
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-keys", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string][]config.AmpUpstreamAPIKeyEntry
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["upstream-api-keys"] != nil && len(resp["upstream-api-keys"]) != 0 {
t.Fatalf("expected cleared list, got %#v", resp["upstream-api-keys"])
}
}
// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key. // TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key.
func TestDeleteAmpUpstreamAPIKey(t *testing.T) { func TestDeleteAmpUpstreamAPIKey(t *testing.T) {
h, _ := newAmpTestHandler(t) h, _ := newAmpTestHandler(t)