mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-16 11:27:28 +00:00
Merge PR #525 (v6.9.27)
This commit is contained in:
@@ -7,13 +7,13 @@ import (
|
||||
|
||||
func main() {
|
||||
ecm := cursorproto.NewMsg("ExecClientMessage")
|
||||
|
||||
|
||||
// Try different field names
|
||||
names := []string{
|
||||
"mcp_result", "mcpResult", "McpResult", "MCP_RESULT",
|
||||
"shell_result", "shellResult",
|
||||
}
|
||||
|
||||
|
||||
for _, name := range names {
|
||||
fd := ecm.Descriptor().Fields().ByName(name)
|
||||
if fd != nil {
|
||||
@@ -22,7 +22,7 @@ func main() {
|
||||
fmt.Printf("Field %q NOT FOUND\n", name)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// List all fields
|
||||
fmt.Println("\nAll fields in ExecClientMessage:")
|
||||
for i := 0; i < ecm.Descriptor().Fields().Len(); i++ {
|
||||
|
||||
@@ -75,7 +75,6 @@ func main() {
|
||||
var codexLogin bool
|
||||
var codexDeviceLogin bool
|
||||
var claudeLogin bool
|
||||
var qwenLogin bool
|
||||
var kiloLogin bool
|
||||
var iflowLogin bool
|
||||
var iflowCookie bool
|
||||
@@ -113,7 +112,6 @@ func main() {
|
||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||
@@ -538,8 +536,6 @@ func main() {
|
||||
} else if claudeLogin {
|
||||
// Handle Claude login
|
||||
cmd.DoClaudeLogin(cfg, options)
|
||||
} else if qwenLogin {
|
||||
cmd.DoQwenLogin(cfg, options)
|
||||
} else if kiloLogin {
|
||||
cmd.DoKiloLogin(cfg, options)
|
||||
} else if iflowLogin {
|
||||
|
||||
@@ -95,6 +95,10 @@ max-retry-interval: 30
|
||||
# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states).
|
||||
disable-cooling: false
|
||||
|
||||
# Core auth auto-refresh worker pool size (OAuth/file-based auth token refresh).
|
||||
# When > 0, overrides the default worker count (16).
|
||||
# auth-auto-refresh-workers: 16
|
||||
|
||||
# Quota exceeded behavior
|
||||
quota-exceeded:
|
||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||
@@ -103,7 +107,14 @@ quota-exceeded:
|
||||
|
||||
# Routing strategy for selecting credentials when multiple match.
|
||||
routing:
|
||||
strategy: 'round-robin' # round-robin (default), fill-first
|
||||
strategy: "round-robin" # round-robin (default), fill-first
|
||||
# Enable universal session-sticky routing for all clients.
|
||||
# Session IDs are extracted from: X-Session-ID header, Idempotency-Key,
|
||||
# metadata.user_id, conversation_id, or first few messages hash.
|
||||
# Automatic failover is always enabled when bound auth becomes unavailable.
|
||||
session-affinity: false # default: false
|
||||
# How long session-to-auth bindings are retained. Default: 1h
|
||||
session-affinity-ttl: "1h"
|
||||
|
||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||
ws-auth: false
|
||||
@@ -269,7 +280,7 @@ nonstream-keepalive-interval: 0
|
||||
# # Requests to that alias will round-robin across the upstream names below,
|
||||
# # and if the chosen upstream fails before producing output, the request will
|
||||
# # continue with the next upstream model in the same alias pool.
|
||||
# - name: "qwen3.5-plus"
|
||||
# - name: "deepseek-v3.1"
|
||||
# alias: "claude-opus-4.66"
|
||||
# - name: "glm-5"
|
||||
# alias: "claude-opus-4.66"
|
||||
@@ -330,7 +341,7 @@ nonstream-keepalive-interval: 0
|
||||
|
||||
# Global OAuth model name aliases (per channel)
|
||||
# These aliases rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot, kimi.
|
||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
|
||||
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
|
||||
@@ -369,9 +380,6 @@ nonstream-keepalive-interval: 0
|
||||
# codex:
|
||||
# - name: "gpt-5"
|
||||
# alias: "g5"
|
||||
# qwen:
|
||||
# - name: "qwen3-coder-plus"
|
||||
# alias: "qwen-plus"
|
||||
# iflow:
|
||||
# - name: "glm-4.7"
|
||||
# alias: "glm-god"
|
||||
@@ -403,8 +411,6 @@ nonstream-keepalive-interval: 0
|
||||
# - "claude-3-5-haiku-20241022"
|
||||
# codex:
|
||||
# - "gpt-5-codex-mini"
|
||||
# qwen:
|
||||
# - "vision-model"
|
||||
# iflow:
|
||||
# - "tstars2.0"
|
||||
# kimi:
|
||||
|
||||
@@ -36,7 +36,6 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
@@ -2526,62 +2525,6 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Qwen authentication...")
|
||||
|
||||
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
|
||||
// Initialize Qwen auth service
|
||||
qwenAuth := qwen.NewQwenAuth(h.cfg)
|
||||
|
||||
// Generate authorization URL
|
||||
deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate authorization URL: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||
return
|
||||
}
|
||||
authURL := deviceFlow.VerificationURIComplete
|
||||
|
||||
RegisterOAuthSession(state, "qwen")
|
||||
|
||||
go func() {
|
||||
fmt.Println("Waiting for authentication...")
|
||||
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
||||
if errPollForToken != nil {
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
||||
return
|
||||
}
|
||||
|
||||
// Create token storage
|
||||
tokenStorage := qwenAuth.CreateTokenStorage(tokenData)
|
||||
|
||||
tokenStorage.Email = fmt.Sprintf("%d", time.Now().UnixMilli())
|
||||
record := &coreauth.Auth{
|
||||
ID: fmt.Sprintf("qwen-%s.json", tokenStorage.Email),
|
||||
Provider: "qwen",
|
||||
FileName: fmt.Sprintf("qwen-%s.json", tokenStorage.Email),
|
||||
Storage: tokenStorage,
|
||||
Metadata: map[string]any{"email": tokenStorage.Email},
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use Qwen services through this CLI")
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
@@ -236,8 +236,6 @@ func NormalizeOAuthProvider(provider string) (string, error) {
|
||||
return "iflow", nil
|
||||
case "antigravity", "anti-gravity":
|
||||
return "antigravity", nil
|
||||
case "qwen":
|
||||
return "qwen", nil
|
||||
case "kiro":
|
||||
return "kiro", nil
|
||||
case "github":
|
||||
|
||||
@@ -24,8 +24,8 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
||||
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||
@@ -684,7 +684,6 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.POST("/gitlab-auth-url", s.mgmt.RequestGitLabPATToken)
|
||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
|
||||
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
@@ -1122,20 +1121,17 @@ func applySignatureCacheConfig(oldCfg, cfg *config.Config) {
|
||||
if oldCfg == nil {
|
||||
cache.SetSignatureCacheEnabled(newVal)
|
||||
cache.SetSignatureBypassStrictMode(newStrict)
|
||||
log.Debugf("antigravity_signature_cache_enabled toggled to %t", newVal)
|
||||
return
|
||||
}
|
||||
|
||||
oldVal := configuredSignatureCacheEnabled(oldCfg)
|
||||
if oldVal != newVal {
|
||||
cache.SetSignatureCacheEnabled(newVal)
|
||||
log.Debugf("antigravity_signature_cache_enabled updated from %t to %t", oldVal, newVal)
|
||||
}
|
||||
|
||||
oldStrict := configuredSignatureBypassStrict(oldCfg)
|
||||
if oldStrict != newStrict {
|
||||
cache.SetSignatureBypassStrictMode(newStrict)
|
||||
log.Debugf("antigravity_signature_bypass_strict updated from %t to %t", oldStrict, newStrict)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ func (a *CodeBuddyAuth) FetchAuthState(ctx context.Context) (*AuthState, error)
|
||||
return nil, fmt.Errorf("codebuddy: failed to create auth state request: %w", err)
|
||||
}
|
||||
|
||||
requestID := uuid.NewString()
|
||||
requestID := uuid.NewString()
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
|
||||
@@ -19,4 +19,3 @@ func TestDecodeUserID_ValidJWT(t *testing.T) {
|
||||
t.Errorf("expected 'test-user-id-123', got '%s'", userID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,11 +24,11 @@ const (
|
||||
copilotAPIEndpoint = "https://api.githubcopilot.com"
|
||||
|
||||
// Common HTTP header values for Copilot API requests.
|
||||
copilotUserAgent = "GithubCopilot/1.0"
|
||||
copilotEditorVersion = "vscode/1.100.0"
|
||||
copilotPluginVersion = "copilot/1.300.0"
|
||||
copilotIntegrationID = "vscode-chat"
|
||||
copilotOpenAIIntent = "conversation-panel"
|
||||
copilotUserAgent = "GithubCopilot/1.0"
|
||||
copilotEditorVersion = "vscode/1.100.0"
|
||||
copilotPluginVersion = "copilot/1.300.0"
|
||||
copilotIntegrationID = "vscode-chat"
|
||||
copilotOpenAIIntent = "conversation-panel"
|
||||
)
|
||||
|
||||
// CopilotAPIToken represents the Copilot API token response.
|
||||
@@ -314,9 +314,9 @@ const maxModelsResponseSize = 2 * 1024 * 1024
|
||||
|
||||
// allowedCopilotAPIHosts is the set of hosts that are considered safe for Copilot API requests.
|
||||
var allowedCopilotAPIHosts = map[string]bool{
|
||||
"api.githubcopilot.com": true,
|
||||
"api.individual.githubcopilot.com": true,
|
||||
"api.business.githubcopilot.com": true,
|
||||
"api.githubcopilot.com": true,
|
||||
"api.individual.githubcopilot.com": true,
|
||||
"api.business.githubcopilot.com": true,
|
||||
"copilot-proxy.githubusercontent.com": true,
|
||||
}
|
||||
|
||||
|
||||
@@ -12,30 +12,30 @@ import (
|
||||
type ServerMessageType int
|
||||
|
||||
const (
|
||||
ServerMsgUnknown ServerMessageType = iota
|
||||
ServerMsgTextDelta // Text content delta
|
||||
ServerMsgThinkingDelta // Thinking/reasoning delta
|
||||
ServerMsgThinkingCompleted // Thinking completed
|
||||
ServerMsgKvGetBlob // Server wants a blob
|
||||
ServerMsgKvSetBlob // Server wants to store a blob
|
||||
ServerMsgExecRequestCtx // Server requests context (tools, etc.)
|
||||
ServerMsgExecMcpArgs // Server wants MCP tool execution
|
||||
ServerMsgExecShellArgs // Rejected: shell command
|
||||
ServerMsgExecReadArgs // Rejected: file read
|
||||
ServerMsgExecWriteArgs // Rejected: file write
|
||||
ServerMsgExecDeleteArgs // Rejected: file delete
|
||||
ServerMsgExecLsArgs // Rejected: directory listing
|
||||
ServerMsgExecGrepArgs // Rejected: grep search
|
||||
ServerMsgExecFetchArgs // Rejected: HTTP fetch
|
||||
ServerMsgExecDiagnostics // Respond with empty diagnostics
|
||||
ServerMsgExecShellStream // Rejected: shell stream
|
||||
ServerMsgExecBgShellSpawn // Rejected: background shell
|
||||
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
|
||||
ServerMsgExecOther // Other exec types (respond with empty)
|
||||
ServerMsgTurnEnded // Turn has ended (no more output)
|
||||
ServerMsgHeartbeat // Server heartbeat
|
||||
ServerMsgTokenDelta // Token usage delta
|
||||
ServerMsgCheckpoint // Conversation checkpoint update
|
||||
ServerMsgUnknown ServerMessageType = iota
|
||||
ServerMsgTextDelta // Text content delta
|
||||
ServerMsgThinkingDelta // Thinking/reasoning delta
|
||||
ServerMsgThinkingCompleted // Thinking completed
|
||||
ServerMsgKvGetBlob // Server wants a blob
|
||||
ServerMsgKvSetBlob // Server wants to store a blob
|
||||
ServerMsgExecRequestCtx // Server requests context (tools, etc.)
|
||||
ServerMsgExecMcpArgs // Server wants MCP tool execution
|
||||
ServerMsgExecShellArgs // Rejected: shell command
|
||||
ServerMsgExecReadArgs // Rejected: file read
|
||||
ServerMsgExecWriteArgs // Rejected: file write
|
||||
ServerMsgExecDeleteArgs // Rejected: file delete
|
||||
ServerMsgExecLsArgs // Rejected: directory listing
|
||||
ServerMsgExecGrepArgs // Rejected: grep search
|
||||
ServerMsgExecFetchArgs // Rejected: HTTP fetch
|
||||
ServerMsgExecDiagnostics // Respond with empty diagnostics
|
||||
ServerMsgExecShellStream // Rejected: shell stream
|
||||
ServerMsgExecBgShellSpawn // Rejected: background shell
|
||||
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
|
||||
ServerMsgExecOther // Other exec types (respond with empty)
|
||||
ServerMsgTurnEnded // Turn has ended (no more output)
|
||||
ServerMsgHeartbeat // Server heartbeat
|
||||
ServerMsgTokenDelta // Token usage delta
|
||||
ServerMsgCheckpoint // Conversation checkpoint update
|
||||
)
|
||||
|
||||
// DecodedServerMessage holds parsed data from an AgentServerMessage.
|
||||
@@ -561,4 +561,3 @@ func decodeVarintField(data []byte, targetField protowire.Number) int64 {
|
||||
func BlobIdHex(blobId []byte) string {
|
||||
return hex.EncodeToString(blobId)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,23 +4,23 @@ package proto
|
||||
|
||||
// AgentClientMessage (msg 118) oneof "message"
|
||||
const (
|
||||
ACM_RunRequest = 1 // AgentRunRequest
|
||||
ACM_ExecClientMessage = 2 // ExecClientMessage
|
||||
ACM_KvClientMessage = 3 // KvClientMessage
|
||||
ACM_ConversationAction = 4 // ConversationAction
|
||||
ACM_ExecClientControlMsg = 5 // ExecClientControlMessage
|
||||
ACM_InteractionResponse = 6 // InteractionResponse
|
||||
ACM_ClientHeartbeat = 7 // ClientHeartbeat
|
||||
ACM_RunRequest = 1 // AgentRunRequest
|
||||
ACM_ExecClientMessage = 2 // ExecClientMessage
|
||||
ACM_KvClientMessage = 3 // KvClientMessage
|
||||
ACM_ConversationAction = 4 // ConversationAction
|
||||
ACM_ExecClientControlMsg = 5 // ExecClientControlMessage
|
||||
ACM_InteractionResponse = 6 // InteractionResponse
|
||||
ACM_ClientHeartbeat = 7 // ClientHeartbeat
|
||||
)
|
||||
|
||||
// AgentServerMessage (msg 119) oneof "message"
|
||||
const (
|
||||
ASM_InteractionUpdate = 1 // InteractionUpdate
|
||||
ASM_ExecServerMessage = 2 // ExecServerMessage
|
||||
ASM_ConversationCheckpoint = 3 // ConversationStateStructure
|
||||
ASM_KvServerMessage = 4 // KvServerMessage
|
||||
ASM_ExecServerControlMessage = 5 // ExecServerControlMessage
|
||||
ASM_InteractionQuery = 7 // InteractionQuery
|
||||
ASM_InteractionUpdate = 1 // InteractionUpdate
|
||||
ASM_ExecServerMessage = 2 // ExecServerMessage
|
||||
ASM_ConversationCheckpoint = 3 // ConversationStateStructure
|
||||
ASM_KvServerMessage = 4 // KvServerMessage
|
||||
ASM_ExecServerControlMessage = 5 // ExecServerControlMessage
|
||||
ASM_InteractionQuery = 7 // InteractionQuery
|
||||
)
|
||||
|
||||
// AgentRunRequest (msg 91)
|
||||
@@ -77,10 +77,10 @@ const (
|
||||
|
||||
// ModelDetails (msg 88)
|
||||
const (
|
||||
MD_ModelId = 1 // string
|
||||
MD_ModelId = 1 // string
|
||||
MD_ThinkingDetails = 2 // ThinkingDetails (optional)
|
||||
MD_DisplayModelId = 3 // string
|
||||
MD_DisplayName = 4 // string
|
||||
MD_DisplayModelId = 3 // string
|
||||
MD_DisplayName = 4 // string
|
||||
)
|
||||
|
||||
// McpTools (msg 307)
|
||||
@@ -122,9 +122,9 @@ const (
|
||||
|
||||
// InteractionUpdate oneof "message"
|
||||
const (
|
||||
IU_TextDelta = 1 // TextDeltaUpdate
|
||||
IU_ThinkingDelta = 4 // ThinkingDeltaUpdate
|
||||
IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate
|
||||
IU_TextDelta = 1 // TextDeltaUpdate
|
||||
IU_ThinkingDelta = 4 // ThinkingDeltaUpdate
|
||||
IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate
|
||||
)
|
||||
|
||||
// TextDeltaUpdate (msg 92)
|
||||
@@ -169,22 +169,22 @@ const (
|
||||
|
||||
// ExecServerMessage
|
||||
const (
|
||||
ESM_Id = 1 // uint32
|
||||
ESM_ExecId = 15 // string
|
||||
ESM_Id = 1 // uint32
|
||||
ESM_ExecId = 15 // string
|
||||
// oneof message:
|
||||
ESM_ShellArgs = 2 // ShellArgs
|
||||
ESM_WriteArgs = 3 // WriteArgs
|
||||
ESM_DeleteArgs = 4 // DeleteArgs
|
||||
ESM_GrepArgs = 5 // GrepArgs
|
||||
ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped)
|
||||
ESM_LsArgs = 8 // LsArgs
|
||||
ESM_DiagnosticsArgs = 9 // DiagnosticsArgs
|
||||
ESM_RequestContextArgs = 10 // RequestContextArgs
|
||||
ESM_McpArgs = 11 // McpArgs
|
||||
ESM_ShellStreamArgs = 14 // ShellArgs (stream variant)
|
||||
ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs
|
||||
ESM_FetchArgs = 20 // FetchArgs
|
||||
ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs
|
||||
ESM_ShellArgs = 2 // ShellArgs
|
||||
ESM_WriteArgs = 3 // WriteArgs
|
||||
ESM_DeleteArgs = 4 // DeleteArgs
|
||||
ESM_GrepArgs = 5 // GrepArgs
|
||||
ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped)
|
||||
ESM_LsArgs = 8 // LsArgs
|
||||
ESM_DiagnosticsArgs = 9 // DiagnosticsArgs
|
||||
ESM_RequestContextArgs = 10 // RequestContextArgs
|
||||
ESM_McpArgs = 11 // McpArgs
|
||||
ESM_ShellStreamArgs = 14 // ShellArgs (stream variant)
|
||||
ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs
|
||||
ESM_FetchArgs = 20 // FetchArgs
|
||||
ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs
|
||||
)
|
||||
|
||||
// ExecClientMessage
|
||||
@@ -192,19 +192,19 @@ const (
|
||||
ECM_Id = 1 // uint32
|
||||
ECM_ExecId = 15 // string
|
||||
// oneof message (mirrors server fields):
|
||||
ECM_ShellResult = 2
|
||||
ECM_WriteResult = 3
|
||||
ECM_DeleteResult = 4
|
||||
ECM_GrepResult = 5
|
||||
ECM_ReadResult = 7
|
||||
ECM_LsResult = 8
|
||||
ECM_DiagnosticsResult = 9
|
||||
ECM_RequestContextResult = 10
|
||||
ECM_McpResult = 11
|
||||
ECM_ShellStream = 14
|
||||
ECM_BackgroundShellSpawnRes = 16
|
||||
ECM_FetchResult = 20
|
||||
ECM_WriteShellStdinResult = 23
|
||||
ECM_ShellResult = 2
|
||||
ECM_WriteResult = 3
|
||||
ECM_DeleteResult = 4
|
||||
ECM_GrepResult = 5
|
||||
ECM_ReadResult = 7
|
||||
ECM_LsResult = 8
|
||||
ECM_DiagnosticsResult = 9
|
||||
ECM_RequestContextResult = 10
|
||||
ECM_McpResult = 11
|
||||
ECM_ShellStream = 14
|
||||
ECM_BackgroundShellSpawnRes = 16
|
||||
ECM_FetchResult = 20
|
||||
ECM_WriteShellStdinResult = 23
|
||||
)
|
||||
|
||||
// McpArgs
|
||||
@@ -276,28 +276,28 @@ const (
|
||||
// ShellResult oneof: success=1 (+ various), rejected=?
|
||||
// The TS code uses specific result field numbers from the oneof:
|
||||
const (
|
||||
RR_Rejected = 3 // ReadResult.rejected
|
||||
SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected)
|
||||
WR_Rejected = 5 // WriteResult.rejected
|
||||
DR_Rejected = 3 // DeleteResult.rejected
|
||||
LR_Rejected = 3 // LsResult.rejected
|
||||
GR_Error = 2 // GrepResult.error
|
||||
FR_Error = 2 // FetchResult.error
|
||||
RR_Rejected = 3 // ReadResult.rejected
|
||||
SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected)
|
||||
WR_Rejected = 5 // WriteResult.rejected
|
||||
DR_Rejected = 3 // DeleteResult.rejected
|
||||
LR_Rejected = 3 // LsResult.rejected
|
||||
GR_Error = 2 // GrepResult.error
|
||||
FR_Error = 2 // FetchResult.error
|
||||
BSSR_Rejected = 2 // BackgroundShellSpawnResult.rejected (error field)
|
||||
WSSR_Error = 2 // WriteShellStdinResult.error
|
||||
)
|
||||
|
||||
// --- Rejection struct fields ---
|
||||
const (
|
||||
REJ_Path = 1
|
||||
REJ_Reason = 2
|
||||
SREJ_Command = 1
|
||||
SREJ_WorkingDir = 2
|
||||
SREJ_Reason = 3
|
||||
SREJ_IsReadonly = 4
|
||||
GERR_Error = 1
|
||||
FERR_Url = 1
|
||||
FERR_Error = 2
|
||||
REJ_Path = 1
|
||||
REJ_Reason = 2
|
||||
SREJ_Command = 1
|
||||
SREJ_WorkingDir = 2
|
||||
SREJ_Reason = 3
|
||||
SREJ_IsReadonly = 4
|
||||
GERR_Error = 1
|
||||
FERR_Url = 1
|
||||
FERR_Error = 2
|
||||
)
|
||||
|
||||
// ReadArgs
|
||||
|
||||
@@ -33,10 +33,10 @@ type H2Stream struct {
|
||||
err error
|
||||
|
||||
// Send-side flow control
|
||||
sendWindow int32 // available bytes we can send on this stream
|
||||
connWindow int32 // available bytes on the connection level
|
||||
windowCond *sync.Cond // signaled when window is updated
|
||||
windowMu sync.Mutex // protects sendWindow, connWindow
|
||||
sendWindow int32 // available bytes we can send on this stream
|
||||
connWindow int32 // available bytes on the connection level
|
||||
windowCond *sync.Cond // signaled when window is updated
|
||||
windowMu sync.Mutex // protects sendWindow, connWindow
|
||||
}
|
||||
|
||||
// ID returns the unique identifier for this stream (for logging).
|
||||
|
||||
@@ -748,4 +748,3 @@ func TestExtractRegionFromMetadata(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
CooldownReason429 = "rate_limit_exceeded"
|
||||
CooldownReasonSuspended = "account_suspended"
|
||||
CooldownReason429 = "rate_limit_exceeded"
|
||||
CooldownReasonSuspended = "account_suspended"
|
||||
CooldownReasonQuotaExhausted = "quota_exhausted"
|
||||
|
||||
DefaultShortCooldown = 1 * time.Minute
|
||||
|
||||
@@ -26,9 +26,9 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
jitterRand *rand.Rand
|
||||
jitterRandOnce sync.Once
|
||||
jitterMu sync.Mutex
|
||||
jitterRand *rand.Rand
|
||||
jitterRandOnce sync.Once
|
||||
jitterMu sync.Mutex
|
||||
lastRequestTime time.Time
|
||||
)
|
||||
|
||||
|
||||
@@ -24,10 +24,10 @@ type TokenScorer struct {
|
||||
metrics map[string]*TokenMetrics
|
||||
|
||||
// Scoring weights
|
||||
successRateWeight float64
|
||||
quotaWeight float64
|
||||
latencyWeight float64
|
||||
lastUsedWeight float64
|
||||
successRateWeight float64
|
||||
quotaWeight float64
|
||||
latencyWeight float64
|
||||
lastUsedWeight float64
|
||||
failPenaltyMultiplier float64
|
||||
}
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ func (h *ProtocolHandler) Start(ctx context.Context) (int, error) {
|
||||
var listener net.Listener
|
||||
var err error
|
||||
portRange := []int{DefaultHandlerPort, DefaultHandlerPort + 1, DefaultHandlerPort + 2, DefaultHandlerPort + 3, DefaultHandlerPort + 4}
|
||||
|
||||
|
||||
for _, port := range portRange {
|
||||
listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
if err == nil {
|
||||
@@ -105,7 +105,7 @@ func (h *ProtocolHandler) Start(ctx context.Context) (int, error) {
|
||||
}
|
||||
log.Debugf("kiro protocol handler: port %d busy, trying next", port)
|
||||
}
|
||||
|
||||
|
||||
if listener == nil {
|
||||
return 0, fmt.Errorf("failed to start callback server: all ports %d-%d are busy", DefaultHandlerPort, DefaultHandlerPort+4)
|
||||
}
|
||||
|
||||
@@ -1,359 +0,0 @@
|
||||
package qwen
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// QwenOAuthDeviceCodeEndpoint is the URL for initiating the OAuth 2.0 device authorization flow.
|
||||
QwenOAuthDeviceCodeEndpoint = "https://chat.qwen.ai/api/v1/oauth2/device/code"
|
||||
// QwenOAuthTokenEndpoint is the URL for exchanging device codes or refresh tokens for access tokens.
|
||||
QwenOAuthTokenEndpoint = "https://chat.qwen.ai/api/v1/oauth2/token"
|
||||
// QwenOAuthClientID is the client identifier for the Qwen OAuth 2.0 application.
|
||||
QwenOAuthClientID = "f0304373b74a44d2b584a3fb70ca9e56"
|
||||
// QwenOAuthScope defines the permissions requested by the application.
|
||||
QwenOAuthScope = "openid profile email model.completion"
|
||||
// QwenOAuthGrantType specifies the grant type for the device code flow.
|
||||
QwenOAuthGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
)
|
||||
|
||||
// QwenTokenData represents the OAuth credentials, including access and refresh tokens.
|
||||
type QwenTokenData struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain a new access token when the current one expires.
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
// TokenType indicates the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// ResourceURL specifies the base URL of the resource server.
|
||||
ResourceURL string `json:"resource_url,omitempty"`
|
||||
// Expire indicates the expiration date and time of the access token.
|
||||
Expire string `json:"expiry_date,omitempty"`
|
||||
}
|
||||
|
||||
// DeviceFlow represents the response from the device authorization endpoint.
|
||||
type DeviceFlow struct {
|
||||
// DeviceCode is the code that the client uses to poll for an access token.
|
||||
DeviceCode string `json:"device_code"`
|
||||
// UserCode is the code that the user enters at the verification URI.
|
||||
UserCode string `json:"user_code"`
|
||||
// VerificationURI is the URL where the user can enter the user code to authorize the device.
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
// VerificationURIComplete is a URI that includes the user_code, which can be used to automatically
|
||||
// fill in the code on the verification page.
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
// ExpiresIn is the time in seconds until the device_code and user_code expire.
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
// Interval is the minimum time in seconds that the client should wait between polling requests.
|
||||
Interval int `json:"interval"`
|
||||
// CodeVerifier is the cryptographically random string used in the PKCE flow.
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
}
|
||||
|
||||
// QwenTokenResponse represents the successful token response from the token endpoint.
|
||||
type QwenTokenResponse struct {
|
||||
// AccessToken is the token used to access protected resources.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain a new access token.
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
// TokenType indicates the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// ResourceURL specifies the base URL of the resource server.
|
||||
ResourceURL string `json:"resource_url,omitempty"`
|
||||
// ExpiresIn is the time in seconds until the access token expires.
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// QwenAuth manages authentication and token handling for the Qwen API.
|
||||
type QwenAuth struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewQwenAuth creates a new QwenAuth instance with a proxy-configured HTTP client.
|
||||
func NewQwenAuth(cfg *config.Config) *QwenAuth {
|
||||
return &QwenAuth{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
||||
}
|
||||
}
|
||||
|
||||
// generateCodeVerifier generates a cryptographically random string for the PKCE code verifier.
|
||||
func (qa *QwenAuth) generateCodeVerifier() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// generateCodeChallenge creates a SHA-256 hash of the code verifier, used as the PKCE code challenge.
|
||||
func (qa *QwenAuth) generateCodeChallenge(codeVerifier string) string {
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
return base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// generatePKCEPair creates a new code verifier and its corresponding code challenge for PKCE.
|
||||
func (qa *QwenAuth) generatePKCEPair() (string, string, error) {
|
||||
codeVerifier, err := qa.generateCodeVerifier()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
codeChallenge := qa.generateCodeChallenge(codeVerifier)
|
||||
return codeVerifier, codeChallenge, nil
|
||||
}
|
||||
|
||||
// RefreshTokens exchanges a refresh token for a new access token.
|
||||
func (qa *QwenAuth) RefreshTokens(ctx context.Context, refreshToken string) (*QwenTokenData, error) {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "refresh_token")
|
||||
data.Set("refresh_token", refreshToken)
|
||||
data.Set("client_id", QwenOAuthClientID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthTokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := qa.httpClient.Do(req)
|
||||
|
||||
// resp, err := qa.httpClient.PostForm(QwenOAuthTokenEndpoint, data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errorData map[string]interface{}
|
||||
if err = json.Unmarshal(body, &errorData); err == nil {
|
||||
return nil, fmt.Errorf("token refresh failed: %v - %v", errorData["error"], errorData["error_description"])
|
||||
}
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
|
||||
var tokenData QwenTokenResponse
|
||||
if err = json.Unmarshal(body, &tokenData); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
return &QwenTokenData{
|
||||
AccessToken: tokenData.AccessToken,
|
||||
TokenType: tokenData.TokenType,
|
||||
RefreshToken: tokenData.RefreshToken,
|
||||
ResourceURL: tokenData.ResourceURL,
|
||||
Expire: time.Now().Add(time.Duration(tokenData.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InitiateDeviceFlow starts the OAuth 2.0 device authorization flow and returns the device flow details.
|
||||
func (qa *QwenAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceFlow, error) {
|
||||
// Generate PKCE code verifier and challenge
|
||||
codeVerifier, codeChallenge, err := qa.generatePKCEPair()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate PKCE pair: %w", err)
|
||||
}
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("client_id", QwenOAuthClientID)
|
||||
data.Set("scope", QwenOAuthScope)
|
||||
data.Set("code_challenge", codeChallenge)
|
||||
data.Set("code_challenge_method", "S256")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthDeviceCodeEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := qa.httpClient.Do(req)
|
||||
|
||||
// resp, err := qa.httpClient.PostForm(QwenOAuthDeviceCodeEndpoint, data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("device authorization request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("device authorization failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body))
|
||||
}
|
||||
|
||||
var result DeviceFlow
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse device flow response: %w", err)
|
||||
}
|
||||
|
||||
// Check if the response indicates success
|
||||
if result.DeviceCode == "" {
|
||||
return nil, fmt.Errorf("device authorization failed: device_code not found in response")
|
||||
}
|
||||
|
||||
// Add the code_verifier to the result so it can be used later for polling
|
||||
result.CodeVerifier = codeVerifier
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// PollForToken polls the token endpoint with the device code to obtain an access token.
|
||||
func (qa *QwenAuth) PollForToken(deviceCode, codeVerifier string) (*QwenTokenData, error) {
|
||||
pollInterval := 5 * time.Second
|
||||
maxAttempts := 60 // 5 minutes max
|
||||
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", QwenOAuthGrantType)
|
||||
data.Set("client_id", QwenOAuthClientID)
|
||||
data.Set("device_code", deviceCode)
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
|
||||
resp, err := http.PostForm(QwenOAuthTokenEndpoint, data)
|
||||
if err != nil {
|
||||
fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Parse the response as JSON to check for OAuth RFC 8628 standard errors
|
||||
var errorData map[string]interface{}
|
||||
if err = json.Unmarshal(body, &errorData); err == nil {
|
||||
// According to OAuth RFC 8628, handle standard polling responses
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
errorType, _ := errorData["error"].(string)
|
||||
switch errorType {
|
||||
case "authorization_pending":
|
||||
// User has not yet approved the authorization request. Continue polling.
|
||||
fmt.Printf("Polling attempt %d/%d...\n\n", attempt+1, maxAttempts)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
case "slow_down":
|
||||
// Client is polling too frequently. Increase poll interval.
|
||||
pollInterval = time.Duration(float64(pollInterval) * 1.5)
|
||||
if pollInterval > 10*time.Second {
|
||||
pollInterval = 10 * time.Second
|
||||
}
|
||||
fmt.Printf("Server requested to slow down, increasing poll interval to %v\n\n", pollInterval)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
case "expired_token":
|
||||
return nil, fmt.Errorf("device code expired. Please restart the authentication process")
|
||||
case "access_denied":
|
||||
return nil, fmt.Errorf("authorization denied by user. Please restart the authentication process")
|
||||
}
|
||||
}
|
||||
|
||||
// For other errors, return with proper error information
|
||||
errorType, _ := errorData["error"].(string)
|
||||
errorDesc, _ := errorData["error_description"].(string)
|
||||
return nil, fmt.Errorf("device token poll failed: %s - %s", errorType, errorDesc)
|
||||
}
|
||||
|
||||
// If JSON parsing fails, fall back to text response
|
||||
return nil, fmt.Errorf("device token poll failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body))
|
||||
}
|
||||
// log.Debugf("%s", string(body))
|
||||
// Success - parse token data
|
||||
var response QwenTokenResponse
|
||||
if err = json.Unmarshal(body, &response); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
// Convert to QwenTokenData format and save
|
||||
tokenData := &QwenTokenData{
|
||||
AccessToken: response.AccessToken,
|
||||
RefreshToken: response.RefreshToken,
|
||||
TokenType: response.TokenType,
|
||||
ResourceURL: response.ResourceURL,
|
||||
Expire: time.Now().Add(time.Duration(response.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}
|
||||
|
||||
return tokenData, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("authentication timeout. Please restart the authentication process")
|
||||
}
|
||||
|
||||
// RefreshTokensWithRetry attempts to refresh tokens with a specified number of retries upon failure.
|
||||
func (o *QwenAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*QwenTokenData, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Wait before retry
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Duration(attempt) * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
tokenData, err := o.RefreshTokens(ctx, refreshToken)
|
||||
if err == nil {
|
||||
return tokenData, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a QwenTokenStorage object from a QwenTokenData object.
|
||||
func (o *QwenAuth) CreateTokenStorage(tokenData *QwenTokenData) *QwenTokenStorage {
|
||||
storage := &QwenTokenStorage{
|
||||
AccessToken: tokenData.AccessToken,
|
||||
RefreshToken: tokenData.RefreshToken,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
ResourceURL: tokenData.ResourceURL,
|
||||
Expire: tokenData.Expire,
|
||||
}
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
// UpdateTokenStorage updates an existing token storage with new token data
|
||||
func (o *QwenAuth) UpdateTokenStorage(storage *QwenTokenStorage, tokenData *QwenTokenData) {
|
||||
storage.AccessToken = tokenData.AccessToken
|
||||
storage.RefreshToken = tokenData.RefreshToken
|
||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
storage.ResourceURL = tokenData.ResourceURL
|
||||
storage.Expire = tokenData.Expire
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
// Package qwen provides authentication and token management functionality
|
||||
// for Alibaba's Qwen AI services. It handles OAuth2 token storage, serialization,
|
||||
// and retrieval for maintaining authenticated sessions with the Qwen API.
|
||||
package qwen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
)
|
||||
|
||||
// QwenTokenStorage stores OAuth2 token information for Alibaba Qwen API authentication.
|
||||
// It maintains compatibility with the existing auth system while adding Qwen-specific fields
|
||||
// for managing access tokens, refresh tokens, and user account information.
|
||||
type QwenTokenStorage struct {
|
||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain new access tokens when the current one expires.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// LastRefresh is the timestamp of the last token refresh operation.
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
// ResourceURL is the base URL for API requests.
|
||||
ResourceURL string `json:"resource_url"`
|
||||
// Email is the Qwen account email address associated with this token.
|
||||
Email string `json:"email"`
|
||||
// Type indicates the authentication provider type, always "qwen" for this storage.
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the operation fails, nil otherwise
|
||||
func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "qwen"
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func CloseBrowser() error {
|
||||
if lastBrowserProcess == nil || lastBrowserProcess.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
err := lastBrowserProcess.Process.Kill()
|
||||
lastBrowserProcess = nil
|
||||
return err
|
||||
|
||||
16
internal/cache/signature_cache.go
vendored
16
internal/cache/signature_cache.go
vendored
@@ -207,9 +207,12 @@ func init() {
|
||||
|
||||
// SetSignatureCacheEnabled switches Antigravity signature handling between cache mode and bypass mode.
|
||||
func SetSignatureCacheEnabled(enabled bool) {
|
||||
signatureCacheEnabled.Store(enabled)
|
||||
previous := signatureCacheEnabled.Swap(enabled)
|
||||
if previous == enabled {
|
||||
return
|
||||
}
|
||||
if !enabled {
|
||||
log.Warn("antigravity signature cache DISABLED - bypass mode active, cached signatures will not be used for request translation")
|
||||
log.Info("antigravity signature cache DISABLED - bypass mode active, cached signatures will not be used for request translation")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,11 +223,14 @@ func SignatureCacheEnabled() bool {
|
||||
|
||||
// SetSignatureBypassStrictMode controls whether bypass mode uses strict protobuf-tree validation.
|
||||
func SetSignatureBypassStrictMode(strict bool) {
|
||||
signatureBypassStrictMode.Store(strict)
|
||||
previous := signatureBypassStrictMode.Swap(strict)
|
||||
if previous == strict {
|
||||
return
|
||||
}
|
||||
if strict {
|
||||
log.Info("antigravity bypass signature validation: strict mode (protobuf tree)")
|
||||
log.Debug("antigravity bypass signature validation: strict mode (protobuf tree)")
|
||||
} else {
|
||||
log.Info("antigravity bypass signature validation: basic mode (R/E + 0x12)")
|
||||
log.Debug("antigravity bypass signature validation: basic mode (R/E + 0x12)")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
91
internal/cache/signature_cache_test.go
vendored
91
internal/cache/signature_cache_test.go
vendored
@@ -1,8 +1,12 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const testModelName = "claude-sonnet-4-5"
|
||||
@@ -208,3 +212,90 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
|
||||
// but the logic is verified by the implementation
|
||||
_ = time.Now() // Acknowledge we're not testing time passage
|
||||
}
|
||||
|
||||
func TestSignatureModeSetters_LogAtInfoLevel(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
previousOutput := logger.Out
|
||||
previousLevel := logger.Level
|
||||
previousCache := SignatureCacheEnabled()
|
||||
previousStrict := SignatureBypassStrictMode()
|
||||
SetSignatureCacheEnabled(true)
|
||||
SetSignatureBypassStrictMode(false)
|
||||
buffer := &bytes.Buffer{}
|
||||
log.SetOutput(buffer)
|
||||
log.SetLevel(log.InfoLevel)
|
||||
t.Cleanup(func() {
|
||||
log.SetOutput(previousOutput)
|
||||
log.SetLevel(previousLevel)
|
||||
SetSignatureCacheEnabled(previousCache)
|
||||
SetSignatureBypassStrictMode(previousStrict)
|
||||
})
|
||||
|
||||
SetSignatureCacheEnabled(false)
|
||||
SetSignatureBypassStrictMode(true)
|
||||
SetSignatureBypassStrictMode(false)
|
||||
|
||||
output := buffer.String()
|
||||
if !strings.Contains(output, "antigravity signature cache DISABLED") {
|
||||
t.Fatalf("expected info output for disabling signature cache, got: %q", output)
|
||||
}
|
||||
if strings.Contains(output, "strict mode (protobuf tree)") {
|
||||
t.Fatalf("expected strict bypass mode log to stay below info level, got: %q", output)
|
||||
}
|
||||
if strings.Contains(output, "basic mode (R/E + 0x12)") {
|
||||
t.Fatalf("expected basic bypass mode log to stay below info level, got: %q", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureModeSetters_DoNotRepeatSameStateLogs(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
previousOutput := logger.Out
|
||||
previousLevel := logger.Level
|
||||
previousCache := SignatureCacheEnabled()
|
||||
previousStrict := SignatureBypassStrictMode()
|
||||
SetSignatureCacheEnabled(false)
|
||||
SetSignatureBypassStrictMode(true)
|
||||
buffer := &bytes.Buffer{}
|
||||
log.SetOutput(buffer)
|
||||
log.SetLevel(log.InfoLevel)
|
||||
t.Cleanup(func() {
|
||||
log.SetOutput(previousOutput)
|
||||
log.SetLevel(previousLevel)
|
||||
SetSignatureCacheEnabled(previousCache)
|
||||
SetSignatureBypassStrictMode(previousStrict)
|
||||
})
|
||||
|
||||
SetSignatureCacheEnabled(false)
|
||||
SetSignatureBypassStrictMode(true)
|
||||
|
||||
if buffer.Len() != 0 {
|
||||
t.Fatalf("expected repeated setter calls with unchanged state to stay silent, got: %q", buffer.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureBypassStrictMode_LogsAtDebugLevel(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
previousOutput := logger.Out
|
||||
previousLevel := logger.Level
|
||||
previousStrict := SignatureBypassStrictMode()
|
||||
SetSignatureBypassStrictMode(false)
|
||||
buffer := &bytes.Buffer{}
|
||||
log.SetOutput(buffer)
|
||||
log.SetLevel(log.DebugLevel)
|
||||
t.Cleanup(func() {
|
||||
log.SetOutput(previousOutput)
|
||||
log.SetLevel(previousLevel)
|
||||
SetSignatureBypassStrictMode(previousStrict)
|
||||
})
|
||||
|
||||
SetSignatureBypassStrictMode(true)
|
||||
SetSignatureBypassStrictMode(false)
|
||||
|
||||
output := buffer.String()
|
||||
if !strings.Contains(output, "strict mode (protobuf tree)") {
|
||||
t.Fatalf("expected debug output for strict bypass mode, got: %q", output)
|
||||
}
|
||||
if !strings.Contains(output, "basic mode (R/E + 0x12)") {
|
||||
t.Fatalf("expected debug output for basic bypass mode, got: %q", output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewGeminiAuthenticator(),
|
||||
sdkAuth.NewCodexAuthenticator(),
|
||||
sdkAuth.NewClaudeAuthenticator(),
|
||||
sdkAuth.NewQwenAuthenticator(),
|
||||
sdkAuth.NewIFlowAuthenticator(),
|
||||
sdkAuth.NewAntigravityAuthenticator(),
|
||||
sdkAuth.NewKimiAuthenticator(),
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoQwenLogin handles the Qwen device flow using the shared authentication manager.
|
||||
// It initiates the device-based authentication process for Qwen services and saves
|
||||
// the authentication tokens to the configured auth directory.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration
|
||||
// - options: Login options including browser behavior and prompts
|
||||
func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = func(prompt string) (string, error) {
|
||||
fmt.Println()
|
||||
fmt.Println(prompt)
|
||||
var value string
|
||||
_, err := fmt.Scanln(&value)
|
||||
return value, err
|
||||
}
|
||||
}
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
||||
if err != nil {
|
||||
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
|
||||
log.Error(emailErr.Error())
|
||||
return
|
||||
}
|
||||
fmt.Printf("Qwen authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
|
||||
fmt.Println("Qwen authentication successful!")
|
||||
}
|
||||
@@ -68,6 +68,10 @@ type Config struct {
|
||||
// DisableCooling disables quota cooldown scheduling when true.
|
||||
DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"`
|
||||
|
||||
// AuthAutoRefreshWorkers overrides the size of the core auth auto-refresh worker pool.
|
||||
// When <= 0, the default worker count is used.
|
||||
AuthAutoRefreshWorkers int `yaml:"auth-auto-refresh-workers" json:"auth-auto-refresh-workers"`
|
||||
|
||||
// RequestRetry defines the retry times when the request failed.
|
||||
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
||||
// MaxRetryCredentials defines the maximum number of credentials to try for a failed request.
|
||||
@@ -131,12 +135,12 @@ type Config struct {
|
||||
AmpCode AmpCode `yaml:"ampcode" json:"ampcode"`
|
||||
|
||||
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot.
|
||||
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
||||
|
||||
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
|
||||
// These aliases affect both model listing and model routing for supported channels:
|
||||
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
||||
// gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot.
|
||||
//
|
||||
// NOTE: This does not apply to existing per-credential model alias features under:
|
||||
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
||||
@@ -229,6 +233,22 @@ type RoutingConfig struct {
|
||||
// Strategy selects the credential selection strategy.
|
||||
// Supported values: "round-robin" (default), "fill-first".
|
||||
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
||||
|
||||
// ClaudeCodeSessionAffinity enables session-sticky routing for Claude Code clients.
|
||||
// When enabled, requests with the same session ID (extracted from metadata.user_id)
|
||||
// are routed to the same auth credential when available.
|
||||
// Deprecated: Use SessionAffinity instead for universal session support.
|
||||
ClaudeCodeSessionAffinity bool `yaml:"claude-code-session-affinity,omitempty" json:"claude-code-session-affinity,omitempty"`
|
||||
|
||||
// SessionAffinity enables universal session-sticky routing for all clients.
|
||||
// Session IDs are extracted from multiple sources:
|
||||
// X-Session-ID header, Idempotency-Key, metadata.user_id, conversation_id, or message hash.
|
||||
// Automatic failover is always enabled when bound auth becomes unavailable.
|
||||
SessionAffinity bool `yaml:"session-affinity,omitempty" json:"session-affinity,omitempty"`
|
||||
|
||||
// SessionAffinityTTL specifies how long session-to-auth bindings are retained.
|
||||
// Default: 1h. Accepts duration strings like "30m", "1h", "2h30m".
|
||||
SessionAffinityTTL string `yaml:"session-affinity-ttl,omitempty" json:"session-affinity-ttl,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthModelAlias defines a model ID alias for a specific channel.
|
||||
|
||||
@@ -17,7 +17,6 @@ type staticModelsJSON struct {
|
||||
CodexTeam []*ModelInfo `json:"codex-team"`
|
||||
CodexPlus []*ModelInfo `json:"codex-plus"`
|
||||
CodexPro []*ModelInfo `json:"codex-pro"`
|
||||
Qwen []*ModelInfo `json:"qwen"`
|
||||
IFlow []*ModelInfo `json:"iflow"`
|
||||
Kimi []*ModelInfo `json:"kimi"`
|
||||
Antigravity []*ModelInfo `json:"antigravity"`
|
||||
@@ -68,11 +67,6 @@ func GetCodexProModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().CodexPro)
|
||||
}
|
||||
|
||||
// GetQwenModels returns the standard Qwen model definitions.
|
||||
func GetQwenModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Qwen)
|
||||
}
|
||||
|
||||
// GetIFlowModels returns the standard iFlow model definitions.
|
||||
func GetIFlowModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().IFlow)
|
||||
@@ -239,7 +233,6 @@ func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
|
||||
// - gemini-cli
|
||||
// - aistudio
|
||||
// - codex
|
||||
// - qwen
|
||||
// - iflow
|
||||
// - kimi
|
||||
// - kilo
|
||||
@@ -261,8 +254,6 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
return GetAIStudioModels()
|
||||
case "codex":
|
||||
return GetCodexProModels()
|
||||
case "qwen":
|
||||
return GetQwenModels()
|
||||
case "iflow":
|
||||
return GetIFlowModels()
|
||||
case "kimi":
|
||||
@@ -313,7 +304,6 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
data.GeminiCLI,
|
||||
data.AIStudio,
|
||||
data.CodexPro,
|
||||
data.Qwen,
|
||||
data.IFlow,
|
||||
data.Kimi,
|
||||
data.Antigravity,
|
||||
|
||||
@@ -213,7 +213,6 @@ func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
|
||||
{"codex", oldData.CodexTeam, newData.CodexTeam},
|
||||
{"codex", oldData.CodexPlus, newData.CodexPlus},
|
||||
{"codex", oldData.CodexPro, newData.CodexPro},
|
||||
{"qwen", oldData.Qwen, newData.Qwen},
|
||||
{"iflow", oldData.IFlow, newData.IFlow},
|
||||
{"kimi", oldData.Kimi, newData.Kimi},
|
||||
{"antigravity", oldData.Antigravity, newData.Antigravity},
|
||||
@@ -335,7 +334,6 @@ func validateModelsCatalog(data *staticModelsJSON) error {
|
||||
{name: "codex-team", models: data.CodexTeam},
|
||||
{name: "codex-plus", models: data.CodexPlus},
|
||||
{name: "codex-pro", models: data.CodexPro},
|
||||
{name: "qwen", models: data.Qwen},
|
||||
{name: "iflow", models: data.IFlow},
|
||||
{name: "kimi", models: data.Kimi},
|
||||
{name: "antigravity", models: data.Antigravity},
|
||||
|
||||
@@ -1177,163 +1177,6 @@
|
||||
}
|
||||
],
|
||||
"codex-free": [
|
||||
{
|
||||
"id": "gpt-5",
|
||||
"object": "model",
|
||||
"created": 1754524800,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5",
|
||||
"version": "gpt-5-2025-08-07",
|
||||
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"minimal",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5-codex",
|
||||
"object": "model",
|
||||
"created": 1757894400,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5 Codex",
|
||||
"version": "gpt-5-2025-09-15",
|
||||
"description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5-codex-mini",
|
||||
"object": "model",
|
||||
"created": 1762473600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5 Codex Mini",
|
||||
"version": "gpt-5-2025-11-07",
|
||||
"description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1",
|
||||
"object": "model",
|
||||
"created": 1762905600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5",
|
||||
"version": "gpt-5.1-2025-11-12",
|
||||
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"none",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1-codex",
|
||||
"object": "model",
|
||||
"created": 1762905600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.1 Codex",
|
||||
"version": "gpt-5.1-2025-11-12",
|
||||
"description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1-codex-mini",
|
||||
"object": "model",
|
||||
"created": 1762905600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.1 Codex Mini",
|
||||
"version": "gpt-5.1-2025-11-12",
|
||||
"description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1-codex-max",
|
||||
"object": "model",
|
||||
"created": 1763424000,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.1 Codex Max",
|
||||
"version": "gpt-5.1-max",
|
||||
"description": "Stable version of GPT 5.1 Codex Max",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
"xhigh"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.2",
|
||||
"object": "model",
|
||||
@@ -1358,29 +1201,6 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.2-codex",
|
||||
"object": "model",
|
||||
"created": 1765440000,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.2 Codex",
|
||||
"version": "gpt-5.2",
|
||||
"description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
"xhigh"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.3-codex",
|
||||
"object": "model",
|
||||
@@ -1452,163 +1272,6 @@
|
||||
}
|
||||
],
|
||||
"codex-team": [
|
||||
{
|
||||
"id": "gpt-5",
|
||||
"object": "model",
|
||||
"created": 1754524800,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5",
|
||||
"version": "gpt-5-2025-08-07",
|
||||
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"minimal",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5-codex",
|
||||
"object": "model",
|
||||
"created": 1757894400,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5 Codex",
|
||||
"version": "gpt-5-2025-09-15",
|
||||
"description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5-codex-mini",
|
||||
"object": "model",
|
||||
"created": 1762473600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5 Codex Mini",
|
||||
"version": "gpt-5-2025-11-07",
|
||||
"description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1",
|
||||
"object": "model",
|
||||
"created": 1762905600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5",
|
||||
"version": "gpt-5.1-2025-11-12",
|
||||
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"none",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1-codex",
|
||||
"object": "model",
|
||||
"created": 1762905600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.1 Codex",
|
||||
"version": "gpt-5.1-2025-11-12",
|
||||
"description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1-codex-mini",
|
||||
"object": "model",
|
||||
"created": 1762905600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.1 Codex Mini",
|
||||
"version": "gpt-5.1-2025-11-12",
|
||||
"description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1-codex-max",
|
||||
"object": "model",
|
||||
"created": 1763424000,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.1 Codex Max",
|
||||
"version": "gpt-5.1-max",
|
||||
"description": "Stable version of GPT 5.1 Codex Max",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
"xhigh"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.2",
|
||||
"object": "model",
|
||||
@@ -1633,29 +1296,6 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.2-codex",
|
||||
"object": "model",
|
||||
"created": 1765440000,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.2 Codex",
|
||||
"version": "gpt-5.2",
|
||||
"description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
"xhigh"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.3-codex",
|
||||
"object": "model",
|
||||
@@ -1727,163 +1367,6 @@
|
||||
}
|
||||
],
|
||||
"codex-plus": [
|
||||
{
|
||||
"id": "gpt-5",
|
||||
"object": "model",
|
||||
"created": 1754524800,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5",
|
||||
"version": "gpt-5-2025-08-07",
|
||||
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"minimal",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5-codex",
|
||||
"object": "model",
|
||||
"created": 1757894400,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5 Codex",
|
||||
"version": "gpt-5-2025-09-15",
|
||||
"description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5-codex-mini",
|
||||
"object": "model",
|
||||
"created": 1762473600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5 Codex Mini",
|
||||
"version": "gpt-5-2025-11-07",
|
||||
"description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1",
|
||||
"object": "model",
|
||||
"created": 1762905600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5",
|
||||
"version": "gpt-5.1-2025-11-12",
|
||||
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"none",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1-codex",
|
||||
"object": "model",
|
||||
"created": 1762905600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.1 Codex",
|
||||
"version": "gpt-5.1-2025-11-12",
|
||||
"description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1-codex-mini",
|
||||
"object": "model",
|
||||
"created": 1762905600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.1 Codex Mini",
|
||||
"version": "gpt-5.1-2025-11-12",
|
||||
"description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1-codex-max",
|
||||
"object": "model",
|
||||
"created": 1763424000,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.1 Codex Max",
|
||||
"version": "gpt-5.1-max",
|
||||
"description": "Stable version of GPT 5.1 Codex Max",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
"xhigh"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.2",
|
||||
"object": "model",
|
||||
@@ -1908,29 +1391,6 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.2-codex",
|
||||
"object": "model",
|
||||
"created": 1765440000,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.2 Codex",
|
||||
"version": "gpt-5.2",
|
||||
"description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
"xhigh"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.3-codex",
|
||||
"object": "model",
|
||||
@@ -2025,163 +1485,6 @@
|
||||
}
|
||||
],
|
||||
"codex-pro": [
|
||||
{
|
||||
"id": "gpt-5",
|
||||
"object": "model",
|
||||
"created": 1754524800,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5",
|
||||
"version": "gpt-5-2025-08-07",
|
||||
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"minimal",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5-codex",
|
||||
"object": "model",
|
||||
"created": 1757894400,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5 Codex",
|
||||
"version": "gpt-5-2025-09-15",
|
||||
"description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5-codex-mini",
|
||||
"object": "model",
|
||||
"created": 1762473600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5 Codex Mini",
|
||||
"version": "gpt-5-2025-11-07",
|
||||
"description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1",
|
||||
"object": "model",
|
||||
"created": 1762905600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5",
|
||||
"version": "gpt-5.1-2025-11-12",
|
||||
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"none",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1-codex",
|
||||
"object": "model",
|
||||
"created": 1762905600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.1 Codex",
|
||||
"version": "gpt-5.1-2025-11-12",
|
||||
"description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1-codex-mini",
|
||||
"object": "model",
|
||||
"created": 1762905600,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.1 Codex Mini",
|
||||
"version": "gpt-5.1-2025-11-12",
|
||||
"description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.1-codex-max",
|
||||
"object": "model",
|
||||
"created": 1763424000,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.1 Codex Max",
|
||||
"version": "gpt-5.1-max",
|
||||
"description": "Stable version of GPT 5.1 Codex Max",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
"xhigh"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.2",
|
||||
"object": "model",
|
||||
@@ -2206,29 +1509,6 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.2-codex",
|
||||
"object": "model",
|
||||
"created": 1765440000,
|
||||
"owned_by": "openai",
|
||||
"type": "openai",
|
||||
"display_name": "GPT 5.2 Codex",
|
||||
"version": "gpt-5.2",
|
||||
"description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
|
||||
"context_length": 400000,
|
||||
"max_completion_tokens": 128000,
|
||||
"supported_parameters": [
|
||||
"tools"
|
||||
],
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
"xhigh"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.3-codex",
|
||||
"object": "model",
|
||||
@@ -2322,27 +1602,6 @@
|
||||
}
|
||||
}
|
||||
],
|
||||
"qwen": [
|
||||
{
|
||||
"id": "coder-model",
|
||||
"object": "model",
|
||||
"created": 1771171200,
|
||||
"owned_by": "qwen",
|
||||
"type": "qwen",
|
||||
"display_name": "Qwen 3.6 Plus",
|
||||
"version": "3.6",
|
||||
"description": "efficient hybrid model with leading coding performance",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65536,
|
||||
"supported_parameters": [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"stop"
|
||||
]
|
||||
}
|
||||
],
|
||||
"iflow": [
|
||||
{
|
||||
"id": "qwen3-coder-plus",
|
||||
@@ -2606,38 +1865,6 @@
|
||||
"dynamic_allowed": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gemini-2.5-flash",
|
||||
"object": "model",
|
||||
"owned_by": "antigravity",
|
||||
"type": "antigravity",
|
||||
"display_name": "Gemini 2.5 Flash",
|
||||
"name": "gemini-2.5-flash",
|
||||
"description": "Gemini 2.5 Flash",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65535,
|
||||
"thinking": {
|
||||
"max": 24576,
|
||||
"zero_allowed": true,
|
||||
"dynamic_allowed": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gemini-2.5-flash-lite",
|
||||
"object": "model",
|
||||
"owned_by": "antigravity",
|
||||
"type": "antigravity",
|
||||
"display_name": "Gemini 2.5 Flash Lite",
|
||||
"name": "gemini-2.5-flash-lite",
|
||||
"description": "Gemini 2.5 Flash Lite",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65535,
|
||||
"thinking": {
|
||||
"max": 24576,
|
||||
"zero_allowed": true,
|
||||
"dynamic_allowed": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gemini-3-flash",
|
||||
"object": "model",
|
||||
@@ -2770,6 +1997,29 @@
|
||||
"description": "GPT-OSS 120B (Medium)",
|
||||
"context_length": 114000,
|
||||
"max_completion_tokens": 32768
|
||||
},
|
||||
{
|
||||
"id": "gemini-3.1-flash-lite",
|
||||
"object": "model",
|
||||
"owned_by": "antigravity",
|
||||
"type": "antigravity",
|
||||
"display_name": "Gemini 3.1 Flash Lite",
|
||||
"name": "gemini-3.1-flash-lite",
|
||||
"description": "Gemini 3.1 Flash Lite",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65535,
|
||||
"thinking": {
|
||||
"min": 1,
|
||||
"max": 65535,
|
||||
"zero_allowed": true,
|
||||
"dynamic_allowed": true,
|
||||
"levels": [
|
||||
"minimal",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
antigravityclaude "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/claude"
|
||||
@@ -184,22 +185,24 @@ func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cli
|
||||
return client
|
||||
}
|
||||
|
||||
func validateAntigravityRequestSignatures(from sdktranslator.Format, rawJSON []byte) error {
|
||||
func validateAntigravityRequestSignatures(from sdktranslator.Format, rawJSON []byte) ([]byte, error) {
|
||||
if from.String() != "claude" {
|
||||
return nil
|
||||
return rawJSON, nil
|
||||
}
|
||||
// Always strip thinking blocks with empty signatures (proxy-generated).
|
||||
rawJSON = antigravityclaude.StripEmptySignatureThinkingBlocks(rawJSON)
|
||||
if cache.SignatureCacheEnabled() {
|
||||
return nil
|
||||
return rawJSON, nil
|
||||
}
|
||||
if !cache.SignatureBypassStrictMode() {
|
||||
// Non-strict bypass: let the translator handle invalid signatures
|
||||
// by dropping unsigned thinking blocks silently (no 400).
|
||||
return nil
|
||||
return rawJSON, nil
|
||||
}
|
||||
if err := antigravityclaude.ValidateClaudeBypassSignatures(rawJSON); err != nil {
|
||||
return statusErr{code: http.StatusBadRequest, msg: err.Error()}
|
||||
return rawJSON, statusErr{code: http.StatusBadRequest, msg: err.Error()}
|
||||
}
|
||||
return nil
|
||||
return rawJSON, nil
|
||||
}
|
||||
|
||||
// Identifier returns the executor identifier.
|
||||
@@ -695,9 +698,11 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
if errValidate := validateAntigravityRequestSignatures(from, originalPayload); errValidate != nil {
|
||||
originalPayload, errValidate := validateAntigravityRequestSignatures(from, originalPayload)
|
||||
if errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
req.Payload = originalPayload
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return resp, errToken
|
||||
@@ -907,9 +912,11 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
if errValidate := validateAntigravityRequestSignatures(from, originalPayload); errValidate != nil {
|
||||
originalPayload, errValidate := validateAntigravityRequestSignatures(from, originalPayload)
|
||||
if errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
req.Payload = originalPayload
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return resp, errToken
|
||||
@@ -1370,9 +1377,11 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
if errValidate := validateAntigravityRequestSignatures(from, originalPayload); errValidate != nil {
|
||||
originalPayload, errValidate := validateAntigravityRequestSignatures(from, originalPayload)
|
||||
if errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
req.Payload = originalPayload
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return nil, errToken
|
||||
@@ -1626,9 +1635,11 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
if errValidate := validateAntigravityRequestSignatures(from, originalPayloadSource); errValidate != nil {
|
||||
originalPayloadSource, errValidate := validateAntigravityRequestSignatures(from, originalPayloadSource)
|
||||
if errValidate != nil {
|
||||
return cliproxyexecutor.Response{}, errValidate
|
||||
}
|
||||
req.Payload = originalPayloadSource
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return cliproxyexecutor.Response{}, errToken
|
||||
@@ -1945,18 +1956,56 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
||||
|
||||
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro") || strings.Contains(modelName, "gemini-3.1-pro")
|
||||
payloadStr := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
payloadStr, _ = util.RenameKey(payloadStr, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
// Cap maxOutputTokens to model's max_completion_tokens from registry
|
||||
if maxOut := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxOut.Exists() && maxOut.Type == gjson.Number {
|
||||
if modelInfo := registry.LookupModelInfo(modelName, "antigravity"); modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
|
||||
if int(maxOut.Int()) > modelInfo.MaxCompletionTokens {
|
||||
payload, _ = sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", modelInfo.MaxCompletionTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if useAntigravitySchema {
|
||||
payloadStr = util.CleanJSONSchemaForAntigravity(payloadStr)
|
||||
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro") || strings.Contains(modelName, "gemini-3.1-pro")
|
||||
var (
|
||||
bodyReader io.Reader
|
||||
payloadLog []byte
|
||||
)
|
||||
if antigravityRequestNeedsSchemaSanitization(payload) {
|
||||
payloadStr := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
payloadStr, _ = util.RenameKey(payloadStr, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
|
||||
if useAntigravitySchema {
|
||||
payloadStr = util.CleanJSONSchemaForAntigravity(payloadStr)
|
||||
} else {
|
||||
payloadStr = util.CleanJSONSchemaForGemini(payloadStr)
|
||||
}
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
updated, _ := sjson.SetBytes([]byte(payloadStr), "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
payloadStr = string(updated)
|
||||
} else {
|
||||
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
|
||||
}
|
||||
|
||||
bodyReader = strings.NewReader(payloadStr)
|
||||
if e.cfg != nil && e.cfg.RequestLog {
|
||||
payloadLog = []byte(payloadStr)
|
||||
}
|
||||
} else {
|
||||
payloadStr = util.CleanJSONSchemaForGemini(payloadStr)
|
||||
if strings.Contains(modelName, "claude") {
|
||||
payload, _ = sjson.SetBytes(payload, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
} else {
|
||||
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.maxOutputTokens")
|
||||
}
|
||||
|
||||
bodyReader = bytes.NewReader(payload)
|
||||
if e.cfg != nil && e.cfg.RequestLog {
|
||||
payloadLog = append([]byte(nil), payload...)
|
||||
}
|
||||
}
|
||||
|
||||
// if useAntigravitySchema {
|
||||
@@ -1972,14 +2021,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
// }
|
||||
// }
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
updated, _ := sjson.SetBytes([]byte(payloadStr), "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
payloadStr = string(updated)
|
||||
} else {
|
||||
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), strings.NewReader(payloadStr))
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bodyReader)
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
}
|
||||
@@ -2002,10 +2044,6 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
var payloadLog []byte
|
||||
if e.cfg != nil && e.cfg.RequestLog {
|
||||
payloadLog = []byte(payloadStr)
|
||||
}
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: requestURL.String(),
|
||||
Method: http.MethodPost,
|
||||
@@ -2021,6 +2059,19 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
return httpReq, nil
|
||||
}
|
||||
|
||||
func antigravityRequestNeedsSchemaSanitization(payload []byte) bool {
|
||||
if gjson.GetBytes(payload, "request.tools.0").Exists() {
|
||||
return true
|
||||
}
|
||||
if gjson.GetBytes(payload, "request.generationConfig.responseJsonSchema").Exists() {
|
||||
return true
|
||||
}
|
||||
if gjson.GetBytes(payload, "request.generationConfig.responseSchema").Exists() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func tokenExpiry(metadata map[string]any) time.Time {
|
||||
if metadata == nil {
|
||||
return time.Time{}
|
||||
|
||||
@@ -35,12 +35,102 @@ func TestAntigravityBuildRequest_SanitizesAntigravityToolSchema(t *testing.T) {
|
||||
assertSchemaSanitizedAndPropertyPreserved(t, params)
|
||||
}
|
||||
|
||||
func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any {
|
||||
func TestAntigravityBuildRequest_SkipsSchemaSanitizationWithoutToolsField(t *testing.T) {
|
||||
body := buildRequestBodyFromRawPayload(t, "gemini-3.1-flash-image", []byte(`{
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"x-debug": "keep-me",
|
||||
"parts": [
|
||||
{
|
||||
"text": "hello"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"nonSchema": {
|
||||
"nullable": true,
|
||||
"x-extra": "keep-me"
|
||||
},
|
||||
"generationConfig": {
|
||||
"maxOutputTokens": 128
|
||||
}
|
||||
}
|
||||
}`))
|
||||
|
||||
assertNonSchemaRequestPreserved(t, body)
|
||||
}
|
||||
|
||||
func TestAntigravityBuildRequest_SkipsSchemaSanitizationWithEmptyToolsArray(t *testing.T) {
|
||||
body := buildRequestBodyFromRawPayload(t, "gemini-3.1-flash-image", []byte(`{
|
||||
"request": {
|
||||
"tools": [],
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"x-debug": "keep-me",
|
||||
"parts": [
|
||||
{
|
||||
"text": "hello"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"nonSchema": {
|
||||
"nullable": true,
|
||||
"x-extra": "keep-me"
|
||||
},
|
||||
"generationConfig": {
|
||||
"maxOutputTokens": 128
|
||||
}
|
||||
}
|
||||
}`))
|
||||
|
||||
assertNonSchemaRequestPreserved(t, body)
|
||||
}
|
||||
|
||||
func assertNonSchemaRequestPreserved(t *testing.T, body map[string]any) {
|
||||
t.Helper()
|
||||
|
||||
executor := &AntigravityExecutor{}
|
||||
auth := &cliproxyauth.Auth{}
|
||||
payload := []byte(`{
|
||||
request, ok := body["request"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("request missing or invalid type")
|
||||
}
|
||||
|
||||
contents, ok := request["contents"].([]any)
|
||||
if !ok || len(contents) == 0 {
|
||||
t.Fatalf("contents missing or empty")
|
||||
}
|
||||
content, ok := contents[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("content missing or invalid type")
|
||||
}
|
||||
if got, ok := content["x-debug"].(string); !ok || got != "keep-me" {
|
||||
t.Fatalf("x-debug should be preserved when no tool schema exists, got=%v", content["x-debug"])
|
||||
}
|
||||
|
||||
nonSchema, ok := request["nonSchema"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("nonSchema missing or invalid type")
|
||||
}
|
||||
if _, ok := nonSchema["nullable"]; !ok {
|
||||
t.Fatalf("nullable should be preserved outside schema cleanup path")
|
||||
}
|
||||
if got, ok := nonSchema["x-extra"].(string); !ok || got != "keep-me" {
|
||||
t.Fatalf("x-extra should be preserved outside schema cleanup path, got=%v", nonSchema["x-extra"])
|
||||
}
|
||||
|
||||
if generationConfig, ok := request["generationConfig"].(map[string]any); ok {
|
||||
if _, ok := generationConfig["maxOutputTokens"]; ok {
|
||||
t.Fatalf("maxOutputTokens should still be removed for non-Claude requests")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any {
|
||||
t.Helper()
|
||||
return buildRequestBodyFromRawPayload(t, modelName, []byte(`{
|
||||
"request": {
|
||||
"tools": [
|
||||
{
|
||||
@@ -75,7 +165,14 @@ func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any
|
||||
}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
}`))
|
||||
}
|
||||
|
||||
func buildRequestBodyFromRawPayload(t *testing.T, modelName string, payload []byte) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
executor := &AntigravityExecutor{}
|
||||
auth := &cliproxyauth.Auth{}
|
||||
|
||||
req, err := executor.buildRequest(context.Background(), auth, "token", modelName, payload, false, "", "https://example.com")
|
||||
if err != nil {
|
||||
|
||||
@@ -21,6 +21,14 @@ func testGeminiSignaturePayload() string {
|
||||
return base64.StdEncoding.EncodeToString(payload)
|
||||
}
|
||||
|
||||
// testFakeClaudeSignature returns a base64 string starting with 'E' that passes
|
||||
// the lightweight hasValidClaudeSignature check but has invalid protobuf content
|
||||
// (first decoded byte 0x12 is correct, but no valid protobuf field 2 follows),
|
||||
// so it fails deep validation in strict mode.
|
||||
func testFakeClaudeSignature() string {
|
||||
return base64.StdEncoding.EncodeToString([]byte{0x12, 0xFF, 0xFE, 0xFD})
|
||||
}
|
||||
|
||||
func testAntigravityAuth(baseURL string) *cliproxyauth.Auth {
|
||||
return &cliproxyauth.Auth{
|
||||
Attributes: map[string]string{
|
||||
@@ -40,7 +48,7 @@ func invalidClaudeThinkingPayload() []byte {
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "bad", "signature": "` + testGeminiSignaturePayload() + `"},
|
||||
{"type": "thinking", "thinking": "bad", "signature": "` + testFakeClaudeSignature() + `"},
|
||||
{"type": "text", "text": "hello"}
|
||||
]
|
||||
}
|
||||
@@ -134,7 +142,7 @@ func TestAntigravityExecutor_NonStrictBypassSkipsPrecheck(t *testing.T) {
|
||||
payload := invalidClaudeThinkingPayload()
|
||||
from := sdktranslator.FromString("claude")
|
||||
|
||||
err := validateAntigravityRequestSignatures(from, payload)
|
||||
_, err := validateAntigravityRequestSignatures(from, payload)
|
||||
if err != nil {
|
||||
t.Fatalf("non-strict bypass should skip precheck, got: %v", err)
|
||||
}
|
||||
@@ -150,7 +158,7 @@ func TestAntigravityExecutor_CacheModeSkipsPrecheck(t *testing.T) {
|
||||
payload := invalidClaudeThinkingPayload()
|
||||
from := sdktranslator.FromString("claude")
|
||||
|
||||
err := validateAntigravityRequestSignatures(from, payload)
|
||||
_, err := validateAntigravityRequestSignatures(from, payload)
|
||||
if err != nil {
|
||||
t.Fatalf("cache mode should skip precheck, got: %v", err)
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -30,14 +30,14 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
cursorAPIURL = "https://api2.cursor.sh"
|
||||
cursorRunPath = "/agent.v1.AgentService/Run"
|
||||
cursorModelsPath = "/agent.v1.AgentService/GetUsableModels"
|
||||
cursorClientVersion = "cli-2026.02.13-41ac335"
|
||||
cursorAuthType = "cursor"
|
||||
cursorAPIURL = "https://api2.cursor.sh"
|
||||
cursorRunPath = "/agent.v1.AgentService/Run"
|
||||
cursorModelsPath = "/agent.v1.AgentService/GetUsableModels"
|
||||
cursorClientVersion = "cli-2026.02.13-41ac335"
|
||||
cursorAuthType = "cursor"
|
||||
cursorHeartbeatInterval = 5 * time.Second
|
||||
cursorSessionTTL = 5 * time.Minute
|
||||
cursorCheckpointTTL = 30 * time.Minute
|
||||
cursorSessionTTL = 5 * time.Minute
|
||||
cursorCheckpointTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// CursorExecutor handles requests to the Cursor API via Connect+Protobuf protocol.
|
||||
@@ -63,9 +63,9 @@ type cursorSession struct {
|
||||
pending []pendingMcpExec
|
||||
cancel context.CancelFunc // cancels the session-scoped heartbeat (NOT tied to HTTP request)
|
||||
createdAt time.Time
|
||||
authID string // auth file ID that created this session (for multi-account isolation)
|
||||
toolResultCh chan []toolResultInfo // receives tool results from the next HTTP request
|
||||
resumeOutCh chan cliproxyexecutor.StreamChunk // output channel for resumed response
|
||||
authID string // auth file ID that created this session (for multi-account isolation)
|
||||
toolResultCh chan []toolResultInfo // receives tool results from the next HTTP request
|
||||
resumeOutCh chan cliproxyexecutor.StreamChunk // output channel for resumed response
|
||||
switchOutput func(ch chan cliproxyexecutor.StreamChunk) // callback to switch output channel
|
||||
}
|
||||
|
||||
@@ -148,7 +148,7 @@ type cursorStatusErr struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e cursorStatusErr) Error() string { return e.msg }
|
||||
func (e cursorStatusErr) Error() string { return e.msg }
|
||||
func (e cursorStatusErr) StatusCode() int { return e.code }
|
||||
func (e cursorStatusErr) RetryAfter() *time.Duration { return nil } // no retry-after info from Cursor; conductor uses exponential backoff
|
||||
|
||||
@@ -786,7 +786,7 @@ func (e *CursorExecutor) resumeWithToolResults(
|
||||
func openCursorH2Stream(accessToken string) (*cursorproto.H2Stream, error) {
|
||||
headers := map[string]string{
|
||||
":path": cursorRunPath,
|
||||
"content-type": "application/connect+proto",
|
||||
"content-type": "application/connect+proto",
|
||||
"connect-protocol-version": "1",
|
||||
"te": "trailers",
|
||||
"authorization": "Bearer " + accessToken,
|
||||
@@ -876,21 +876,21 @@ func processH2SessionFrames(
|
||||
buf.Write(data)
|
||||
log.Debugf("cursor: processH2SessionFrames[%s]: buf total=%d", stream.ID(), buf.Len())
|
||||
|
||||
// Process all complete frames
|
||||
for {
|
||||
currentBuf := buf.Bytes()
|
||||
if len(currentBuf) == 0 {
|
||||
break
|
||||
}
|
||||
flags, payload, consumed, ok := cursorproto.ParseConnectFrame(currentBuf)
|
||||
if !ok {
|
||||
// Log detailed info about why parsing failed
|
||||
previewLen := min(20, len(currentBuf))
|
||||
log.Debugf("cursor: incomplete frame in buffer, waiting for more data (buf=%d bytes, first bytes: %x = %q)", len(currentBuf), currentBuf[:previewLen], string(currentBuf[:previewLen]))
|
||||
break
|
||||
}
|
||||
buf.Next(consumed)
|
||||
log.Debugf("cursor: parsed Connect frame flags=0x%02x payload=%d bytes consumed=%d", flags, len(payload), consumed)
|
||||
// Process all complete frames
|
||||
for {
|
||||
currentBuf := buf.Bytes()
|
||||
if len(currentBuf) == 0 {
|
||||
break
|
||||
}
|
||||
flags, payload, consumed, ok := cursorproto.ParseConnectFrame(currentBuf)
|
||||
if !ok {
|
||||
// Log detailed info about why parsing failed
|
||||
previewLen := min(20, len(currentBuf))
|
||||
log.Debugf("cursor: incomplete frame in buffer, waiting for more data (buf=%d bytes, first bytes: %x = %q)", len(currentBuf), currentBuf[:previewLen], string(currentBuf[:previewLen]))
|
||||
break
|
||||
}
|
||||
buf.Next(consumed)
|
||||
log.Debugf("cursor: parsed Connect frame flags=0x%02x payload=%d bytes consumed=%d", flags, len(payload), consumed)
|
||||
|
||||
if flags&cursorproto.ConnectEndStreamFlag != 0 {
|
||||
if err := cursorproto.ParseConnectEndStream(payload); err != nil {
|
||||
@@ -1080,15 +1080,15 @@ func processH2SessionFrames(
|
||||
// --- OpenAI request parsing ---
|
||||
|
||||
type parsedOpenAIRequest struct {
|
||||
Model string
|
||||
Messages []gjson.Result
|
||||
Tools []gjson.Result
|
||||
Stream bool
|
||||
Model string
|
||||
Messages []gjson.Result
|
||||
Tools []gjson.Result
|
||||
Stream bool
|
||||
SystemPrompt string
|
||||
UserText string
|
||||
Images []cursorproto.ImageData
|
||||
Turns []cursorproto.TurnData
|
||||
ToolResults []toolResultInfo
|
||||
UserText string
|
||||
Images []cursorproto.ImageData
|
||||
Turns []cursorproto.TurnData
|
||||
ToolResults []toolResultInfo
|
||||
}
|
||||
|
||||
type toolResultInfo struct {
|
||||
|
||||
@@ -16,9 +16,9 @@ import (
|
||||
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -75,7 +75,7 @@ var gitLabAgenticCatalog = []gitLabCatalogModel{
|
||||
}
|
||||
|
||||
var gitLabModelAliases = map[string]string{
|
||||
"duo-chat-haiku-4-6": "duo-chat-haiku-4-5",
|
||||
"duo-chat-haiku-4-6": "duo-chat-haiku-4-5",
|
||||
}
|
||||
|
||||
func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor {
|
||||
|
||||
@@ -215,7 +215,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
|
||||
body = preserveReasoningContentInMessages(body)
|
||||
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
|
||||
// Ensure tools array exists to avoid provider quirks observed in some upstreams.
|
||||
toolsResult := gjson.GetBytes(body, "tools")
|
||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||
body = ensureToolsArray(body)
|
||||
|
||||
@@ -281,8 +281,8 @@ func TestGetAuthValue(t *testing.T) {
|
||||
expected: "attribute_value",
|
||||
},
|
||||
{
|
||||
name: "Both nil",
|
||||
auth: &cliproxyauth.Auth{},
|
||||
name: "Both nil",
|
||||
auth: &cliproxyauth.Auth{},
|
||||
key: "test_key",
|
||||
expected: "",
|
||||
},
|
||||
@@ -326,9 +326,9 @@ func TestGetAuthValue(t *testing.T) {
|
||||
|
||||
func TestGetAccountKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
auth *cliproxyauth.Auth
|
||||
checkFn func(t *testing.T, result string)
|
||||
name string
|
||||
auth *cliproxyauth.Auth
|
||||
checkFn func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "From client_id",
|
||||
|
||||
@@ -1,739 +0,0 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
qwenUserAgent = "QwenCode/0.14.2 (darwin; arm64)"
|
||||
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
|
||||
qwenRateLimitWindow = time.Minute // sliding window duration
|
||||
)
|
||||
|
||||
var qwenDefaultSystemMessage = []byte(`{"role":"system","content":[{"type":"text","text":"","cache_control":{"type":"ephemeral"}}]}`)
|
||||
|
||||
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
|
||||
var qwenQuotaCodes = map[string]struct{}{
|
||||
"insufficient_quota": {},
|
||||
"quota_exceeded": {},
|
||||
}
|
||||
|
||||
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
|
||||
// Qwen has a limit of 60 requests per minute per account.
|
||||
var qwenRateLimiter = struct {
|
||||
sync.Mutex
|
||||
requests map[string][]time.Time // authID -> request timestamps
|
||||
}{
|
||||
requests: make(map[string][]time.Time),
|
||||
}
|
||||
|
||||
// redactAuthID returns a redacted version of the auth ID for safe logging.
|
||||
// Keeps a small prefix/suffix to allow correlation across events.
|
||||
func redactAuthID(id string) string {
|
||||
if id == "" {
|
||||
return ""
|
||||
}
|
||||
if len(id) <= 8 {
|
||||
return id
|
||||
}
|
||||
return id[:4] + "..." + id[len(id)-4:]
|
||||
}
|
||||
|
||||
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
|
||||
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
|
||||
func checkQwenRateLimit(authID string) error {
|
||||
if authID == "" {
|
||||
// Empty authID should not bypass rate limiting in production
|
||||
// Use debug level to avoid log spam for certain auth flows
|
||||
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-qwenRateLimitWindow)
|
||||
|
||||
qwenRateLimiter.Lock()
|
||||
defer qwenRateLimiter.Unlock()
|
||||
|
||||
// Get and filter timestamps within the window
|
||||
timestamps := qwenRateLimiter.requests[authID]
|
||||
var validTimestamps []time.Time
|
||||
for _, ts := range timestamps {
|
||||
if ts.After(windowStart) {
|
||||
validTimestamps = append(validTimestamps, ts)
|
||||
}
|
||||
}
|
||||
|
||||
// Always prune expired entries to prevent memory leak
|
||||
// Delete empty entries, otherwise update with pruned slice
|
||||
if len(validTimestamps) == 0 {
|
||||
delete(qwenRateLimiter.requests, authID)
|
||||
}
|
||||
|
||||
// Check if rate limit exceeded
|
||||
if len(validTimestamps) >= qwenRateLimitPerMin {
|
||||
// Calculate when the oldest request will expire
|
||||
oldestInWindow := validTimestamps[0]
|
||||
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
|
||||
if retryAfter < time.Second {
|
||||
retryAfter = time.Second
|
||||
}
|
||||
retryAfterSec := int(retryAfter.Seconds())
|
||||
return statusErr{
|
||||
code: http.StatusTooManyRequests,
|
||||
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
|
||||
retryAfter: &retryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
// Record this request and update the map with pruned timestamps
|
||||
validTimestamps = append(validTimestamps, now)
|
||||
qwenRateLimiter.requests[authID] = validTimestamps
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
|
||||
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
|
||||
func isQwenQuotaError(body []byte) bool {
|
||||
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
|
||||
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
|
||||
|
||||
// Primary check: exact match on error.code or error.type (most reliable)
|
||||
if _, ok := qwenQuotaCodes[code]; ok {
|
||||
return true
|
||||
}
|
||||
if _, ok := qwenQuotaCodes[errType]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
// Fallback: check message only if code/type don't match (less reliable)
|
||||
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
|
||||
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
|
||||
strings.Contains(msg, "free allocated quota exceeded") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
|
||||
// Returns the appropriate status code and retryAfter duration for statusErr.
|
||||
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
|
||||
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
|
||||
errCode = httpCode
|
||||
// Only check quota errors for expected status codes to avoid false positives
|
||||
// Qwen returns 403 for quota errors, 429 for rate limits
|
||||
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
|
||||
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
|
||||
// Do not force an excessively long retry-after (e.g. until tomorrow), otherwise
|
||||
// the global request-retry scheduler may skip retries due to max-retry-interval.
|
||||
helps.LogWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d)", httpCode, errCode)
|
||||
}
|
||||
return errCode, retryAfter
|
||||
}
|
||||
|
||||
func qwenDisableCooling(cfg *config.Config, auth *cliproxyauth.Auth) bool {
|
||||
if auth != nil {
|
||||
if override, ok := auth.DisableCoolingOverride(); ok {
|
||||
return override
|
||||
}
|
||||
}
|
||||
if cfg == nil {
|
||||
return false
|
||||
}
|
||||
return cfg.DisableCooling
|
||||
}
|
||||
|
||||
func parseRetryAfterHeader(header http.Header, now time.Time) *time.Duration {
|
||||
raw := strings.TrimSpace(header.Get("Retry-After"))
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if seconds, err := strconv.Atoi(raw); err == nil {
|
||||
if seconds <= 0 {
|
||||
return nil
|
||||
}
|
||||
d := time.Duration(seconds) * time.Second
|
||||
return &d
|
||||
}
|
||||
if at, err := http.ParseTime(raw); err == nil {
|
||||
if !at.After(now) {
|
||||
return nil
|
||||
}
|
||||
d := at.Sub(now)
|
||||
return &d
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureQwenSystemMessage ensures the request has a single system message at the beginning.
|
||||
// It always injects the default system prompt and merges any user-provided system messages
|
||||
// into the injected system message content to satisfy Qwen's strict message ordering rules.
|
||||
func ensureQwenSystemMessage(payload []byte) ([]byte, error) {
|
||||
isInjectedSystemPart := func(part gjson.Result) bool {
|
||||
if !part.Exists() || !part.IsObject() {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(part.Get("type").String(), "text") {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(part.Get("cache_control.type").String(), "ephemeral") {
|
||||
return false
|
||||
}
|
||||
text := part.Get("text").String()
|
||||
return text == "" || text == "You are Qwen Code."
|
||||
}
|
||||
|
||||
defaultParts := gjson.ParseBytes(qwenDefaultSystemMessage).Get("content")
|
||||
var systemParts []any
|
||||
if defaultParts.Exists() && defaultParts.IsArray() {
|
||||
for _, part := range defaultParts.Array() {
|
||||
systemParts = append(systemParts, part.Value())
|
||||
}
|
||||
}
|
||||
if len(systemParts) == 0 {
|
||||
systemParts = append(systemParts, map[string]any{
|
||||
"type": "text",
|
||||
"text": "You are Qwen Code.",
|
||||
"cache_control": map[string]any{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
appendSystemContent := func(content gjson.Result) {
|
||||
makeTextPart := func(text string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "text",
|
||||
"text": text,
|
||||
}
|
||||
}
|
||||
|
||||
if !content.Exists() || content.Type == gjson.Null {
|
||||
return
|
||||
}
|
||||
if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
if part.Type == gjson.String {
|
||||
systemParts = append(systemParts, makeTextPart(part.String()))
|
||||
continue
|
||||
}
|
||||
if isInjectedSystemPart(part) {
|
||||
continue
|
||||
}
|
||||
systemParts = append(systemParts, part.Value())
|
||||
}
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.String {
|
||||
systemParts = append(systemParts, makeTextPart(content.String()))
|
||||
return
|
||||
}
|
||||
if content.IsObject() {
|
||||
if isInjectedSystemPart(content) {
|
||||
return
|
||||
}
|
||||
systemParts = append(systemParts, content.Value())
|
||||
return
|
||||
}
|
||||
systemParts = append(systemParts, makeTextPart(content.String()))
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(payload, "messages")
|
||||
var nonSystemMessages []any
|
||||
if messages.Exists() && messages.IsArray() {
|
||||
for _, msg := range messages.Array() {
|
||||
if strings.EqualFold(msg.Get("role").String(), "system") {
|
||||
appendSystemContent(msg.Get("content"))
|
||||
continue
|
||||
}
|
||||
nonSystemMessages = append(nonSystemMessages, msg.Value())
|
||||
}
|
||||
}
|
||||
|
||||
newMessages := make([]any, 0, 1+len(nonSystemMessages))
|
||||
newMessages = append(newMessages, map[string]any{
|
||||
"role": "system",
|
||||
"content": systemParts,
|
||||
})
|
||||
newMessages = append(newMessages, nonSystemMessages...)
|
||||
|
||||
updated, errSet := sjson.SetBytes(payload, "messages", newMessages)
|
||||
if errSet != nil {
|
||||
return nil, fmt.Errorf("qwen executor: set system message failed: %w", errSet)
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
||||
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
||||
type QwenExecutor struct {
|
||||
cfg *config.Config
|
||||
refreshForImmediateRetry func(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error)
|
||||
}
|
||||
|
||||
func NewQwenExecutor(cfg *config.Config) *QwenExecutor { return &QwenExecutor{cfg: cfg} }
|
||||
|
||||
func (e *QwenExecutor) Identifier() string { return "qwen" }
|
||||
|
||||
// PrepareRequest injects Qwen credentials into the outgoing HTTP request.
|
||||
func (e *QwenExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
token, _ := qwenCreds(auth)
|
||||
if strings.TrimSpace(token) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects Qwen credentials into the request and executes it.
|
||||
func (e *QwenExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("qwen executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
|
||||
var authID string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, err = ensureQwenSystemMessage(body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
for {
|
||||
if errRate := checkQwenRateLimit(authID); errRate != nil {
|
||||
helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||
return resp, errRate
|
||||
}
|
||||
|
||||
token, baseURL := qwenCreds(auth)
|
||||
if baseURL == "" {
|
||||
baseURL = "https://portal.qwen.ai/v1"
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if errReq != nil {
|
||||
return resp, errReq
|
||||
}
|
||||
applyQwenHeaders(httpReq, token, false)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return resp, errDo
|
||||
}
|
||||
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||
}
|
||||
|
||||
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||
if errCode == http.StatusTooManyRequests && retryAfter == nil {
|
||||
retryAfter = parseRetryAfterHeader(httpResp.Header, time.Now())
|
||||
}
|
||||
if errCode == http.StatusTooManyRequests && retryAfter == nil && qwenDisableCooling(e.cfg, auth) && isQwenQuotaError(b) {
|
||||
defaultRetryAfter := time.Second
|
||||
retryAfter = &defaultRetryAfter
|
||||
}
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
|
||||
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return resp, errRead
|
||||
}
|
||||
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
|
||||
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
|
||||
var authID string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// toolsResult := gjson.GetBytes(body, "tools")
|
||||
// I'm addressing the Qwen3 "poisoning" issue, which is caused by the model needing a tool to be defined. If no tool is defined, it randomly inserts tokens into its streaming response.
|
||||
// This will have no real consequences. It's just to scare Qwen3.
|
||||
// if (toolsResult.IsArray() && len(toolsResult.Array()) == 0) || !toolsResult.Exists() {
|
||||
// body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
|
||||
// }
|
||||
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, err = ensureQwenSystemMessage(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
if errRate := checkQwenRateLimit(authID); errRate != nil {
|
||||
helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||
return nil, errRate
|
||||
}
|
||||
|
||||
token, baseURL := qwenCreds(auth)
|
||||
if baseURL == "" {
|
||||
baseURL = "https://portal.qwen.ai/v1"
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
}
|
||||
applyQwenHeaders(httpReq, token, true)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return nil, errDo
|
||||
}
|
||||
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||
}
|
||||
|
||||
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||
if errCode == http.StatusTooManyRequests && retryAfter == nil {
|
||||
retryAfter = parseRetryAfterHeader(httpResp.Header, time.Now())
|
||||
}
|
||||
if errCode == http.StatusTooManyRequests && retryAfter == nil && qwenDisableCooling(e.cfg, auth) && isQwenQuotaError(b) {
|
||||
defaultRetryAfter := time.Second
|
||||
retryAfter = &defaultRetryAfter
|
||||
}
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
|
||||
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 52_428_800) // 50MB
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if strings.TrimSpace(modelName) == "" {
|
||||
modelName = baseModel
|
||||
}
|
||||
|
||||
enc, err := helps.TokenizerForModel(modelName)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
|
||||
count, err := helps.CountOpenAIChatTokens(enc, body)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: token counting failed: %w", err)
|
||||
}
|
||||
|
||||
usageJSON := helps.BuildOpenAIUsageJSON(count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
log.Debugf("qwen executor: refresh called")
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("qwen executor: auth is nil")
|
||||
}
|
||||
// Expect refresh_token in metadata for OAuth-based accounts
|
||||
var refreshToken string
|
||||
if auth.Metadata != nil {
|
||||
if v, ok := auth.Metadata["refresh_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
refreshToken = v
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(refreshToken) == "" {
|
||||
// Nothing to refresh
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
svc := qwenauth.NewQwenAuth(e.cfg)
|
||||
td, err := svc.RefreshTokens(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
auth.Metadata["access_token"] = td.AccessToken
|
||||
if td.RefreshToken != "" {
|
||||
auth.Metadata["refresh_token"] = td.RefreshToken
|
||||
}
|
||||
if td.ResourceURL != "" {
|
||||
auth.Metadata["resource_url"] = td.ResourceURL
|
||||
}
|
||||
// Use "expired" for consistency with existing file format
|
||||
auth.Metadata["expired"] = td.Expire
|
||||
auth.Metadata["type"] = "qwen"
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
auth.Metadata["last_refresh"] = now
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func applyQwenHeaders(r *http.Request, token string, stream bool) {
|
||||
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
|
||||
r.Header.Set("User-Agent", qwenUserAgent)
|
||||
r.Header.Set("X-Stainless-Lang", "js")
|
||||
r.Header.Set("Accept-Language", "*")
|
||||
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
|
||||
r.Header.Set("X-Stainless-Os", "MacOS")
|
||||
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
|
||||
r.Header.Set("X-Stainless-Arch", "arm64")
|
||||
r.Header.Set("X-Stainless-Runtime", "node")
|
||||
r.Header.Set("X-Stainless-Retry-Count", "0")
|
||||
r.Header.Set("Accept-Encoding", "gzip, deflate")
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
|
||||
r.Header.Set("Sec-Fetch-Mode", "cors")
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Connection", "keep-alive")
|
||||
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
return
|
||||
}
|
||||
r.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
func normaliseQwenBaseURL(resourceURL string) string {
|
||||
raw := strings.TrimSpace(resourceURL)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
normalized := raw
|
||||
lower := strings.ToLower(normalized)
|
||||
if !strings.HasPrefix(lower, "http://") && !strings.HasPrefix(lower, "https://") {
|
||||
normalized = "https://" + normalized
|
||||
}
|
||||
|
||||
normalized = strings.TrimRight(normalized, "/")
|
||||
if !strings.HasSuffix(strings.ToLower(normalized), "/v1") {
|
||||
normalized += "/v1"
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
func qwenCreds(a *cliproxyauth.Auth) (token, baseURL string) {
|
||||
if a == nil {
|
||||
return "", ""
|
||||
}
|
||||
if a.Attributes != nil {
|
||||
if v := a.Attributes["api_key"]; v != "" {
|
||||
token = v
|
||||
}
|
||||
if v := a.Attributes["base_url"]; v != "" {
|
||||
baseURL = v
|
||||
}
|
||||
}
|
||||
if token == "" && a.Metadata != nil {
|
||||
if v, ok := a.Metadata["access_token"].(string); ok {
|
||||
token = v
|
||||
}
|
||||
if v, ok := a.Metadata["resource_url"].(string); ok {
|
||||
baseURL = normaliseQwenBaseURL(v)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1,614 +0,0 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestQwenExecutorParseSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantBase string
|
||||
wantLevel string
|
||||
}{
|
||||
{"no suffix", "qwen-max", "qwen-max", ""},
|
||||
{"with level suffix", "qwen-max(high)", "qwen-max", "high"},
|
||||
{"with budget suffix", "qwen-max(16384)", "qwen-max", "16384"},
|
||||
{"complex model name", "qwen-plus-latest(medium)", "qwen-plus-latest", "medium"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := thinking.ParseSuffix(tt.model)
|
||||
if result.ModelName != tt.wantBase {
|
||||
t.Errorf("ParseSuffix(%q).ModelName = %q, want %q", tt.model, result.ModelName, tt.wantBase)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureQwenSystemMessage_MergeStringSystem(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"model": "qwen3.6-plus",
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{ "role": "system", "content": "ABCDEFG" },
|
||||
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := ensureQwenSystemMessage(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||
}
|
||||
|
||||
msgs := gjson.GetBytes(out, "messages").Array()
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||
}
|
||||
if msgs[0].Get("role").String() != "system" {
|
||||
t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system")
|
||||
}
|
||||
parts := msgs[0].Get("content").Array()
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("messages[0].content length = %d, want 2", len(parts))
|
||||
}
|
||||
if parts[0].Get("type").String() != "text" || parts[0].Get("cache_control.type").String() != "ephemeral" {
|
||||
t.Fatalf("messages[0].content[0] = %s, want injected system part", parts[0].Raw)
|
||||
}
|
||||
if text := parts[0].Get("text").String(); text != "" && text != "You are Qwen Code." {
|
||||
t.Fatalf("messages[0].content[0].text = %q, want empty string or default prompt", text)
|
||||
}
|
||||
if parts[1].Get("type").String() != "text" || parts[1].Get("text").String() != "ABCDEFG" {
|
||||
t.Fatalf("messages[0].content[1] = %s, want text part with ABCDEFG", parts[1].Raw)
|
||||
}
|
||||
if msgs[1].Get("role").String() != "user" {
|
||||
t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureQwenSystemMessage_MergeObjectSystem(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"messages": [
|
||||
{ "role": "system", "content": { "type": "text", "text": "ABCDEFG" } },
|
||||
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := ensureQwenSystemMessage(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||
}
|
||||
|
||||
msgs := gjson.GetBytes(out, "messages").Array()
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||
}
|
||||
parts := msgs[0].Get("content").Array()
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("messages[0].content length = %d, want 2", len(parts))
|
||||
}
|
||||
if parts[1].Get("text").String() != "ABCDEFG" {
|
||||
t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "ABCDEFG")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureQwenSystemMessage_PrependsWhenMissing(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"messages": [
|
||||
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := ensureQwenSystemMessage(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||
}
|
||||
|
||||
msgs := gjson.GetBytes(out, "messages").Array()
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||
}
|
||||
if msgs[0].Get("role").String() != "system" {
|
||||
t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system")
|
||||
}
|
||||
if !msgs[0].Get("content").IsArray() || len(msgs[0].Get("content").Array()) == 0 {
|
||||
t.Fatalf("messages[0].content = %s, want non-empty array", msgs[0].Get("content").Raw)
|
||||
}
|
||||
if msgs[1].Get("role").String() != "user" {
|
||||
t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureQwenSystemMessage_MergesMultipleSystemMessages(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"messages": [
|
||||
{ "role": "system", "content": "A" },
|
||||
{ "role": "user", "content": [ { "type": "text", "text": "hi" } ] },
|
||||
{ "role": "system", "content": "B" }
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := ensureQwenSystemMessage(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||
}
|
||||
|
||||
msgs := gjson.GetBytes(out, "messages").Array()
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||
}
|
||||
parts := msgs[0].Get("content").Array()
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("messages[0].content length = %d, want 3", len(parts))
|
||||
}
|
||||
if parts[1].Get("text").String() != "A" {
|
||||
t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "A")
|
||||
}
|
||||
if parts[2].Get("text").String() != "B" {
|
||||
t.Fatalf("messages[0].content[2].text = %q, want %q", parts[2].Get("text").String(), "B")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapQwenError_InsufficientQuotaDoesNotSetRetryAfter(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":"insufficient_quota","message":"You exceeded your current quota","type":"insufficient_quota"}}`)
|
||||
code, retryAfter := wrapQwenError(context.Background(), http.StatusTooManyRequests, body)
|
||||
if code != http.StatusTooManyRequests {
|
||||
t.Fatalf("wrapQwenError status = %d, want %d", code, http.StatusTooManyRequests)
|
||||
}
|
||||
if retryAfter != nil {
|
||||
t.Fatalf("wrapQwenError retryAfter = %v, want nil", *retryAfter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapQwenError_Maps403QuotaTo429WithoutRetryAfter(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":"insufficient_quota","message":"You exceeded your current quota","type":"insufficient_quota"}}`)
|
||||
code, retryAfter := wrapQwenError(context.Background(), http.StatusForbidden, body)
|
||||
if code != http.StatusTooManyRequests {
|
||||
t.Fatalf("wrapQwenError status = %d, want %d", code, http.StatusTooManyRequests)
|
||||
}
|
||||
if retryAfter != nil {
|
||||
t.Fatalf("wrapQwenError retryAfter = %v, want nil", *retryAfter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenCreds_NormalizesResourceURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
resourceURL string
|
||||
wantBaseURL string
|
||||
}{
|
||||
{"host only", "portal.qwen.ai", "https://portal.qwen.ai/v1"},
|
||||
{"scheme no v1", "https://portal.qwen.ai", "https://portal.qwen.ai/v1"},
|
||||
{"scheme with v1", "https://portal.qwen.ai/v1", "https://portal.qwen.ai/v1"},
|
||||
{"scheme with v1 slash", "https://portal.qwen.ai/v1/", "https://portal.qwen.ai/v1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
auth := &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{
|
||||
"access_token": "test-token",
|
||||
"resource_url": tt.resourceURL,
|
||||
},
|
||||
}
|
||||
|
||||
token, baseURL := qwenCreds(auth)
|
||||
if token != "test-token" {
|
||||
t.Fatalf("qwenCreds token = %q, want %q", token, "test-token")
|
||||
}
|
||||
if baseURL != tt.wantBaseURL {
|
||||
t.Fatalf("qwenCreds baseURL = %q, want %q", baseURL, tt.wantBaseURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenExecutorExecute_429DoesNotRefreshOrRetry(t *testing.T) {
|
||||
qwenRateLimiter.Lock()
|
||||
qwenRateLimiter.requests = make(map[string][]time.Time)
|
||||
qwenRateLimiter.Unlock()
|
||||
|
||||
var calls int32
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
if r.URL.Path != "/v1/chat/completions" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
switch r.Header.Get("Authorization") {
|
||||
case "Bearer old-token":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"code":"quota_exceeded","message":"quota exceeded","type":"quota_exceeded"}}`))
|
||||
return
|
||||
case "Bearer new-token":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"id":"chatcmpl-test","object":"chat.completion","created":1,"model":"qwen-max","choices":[{"index":0,"message":{"role":"assistant","content":"hi"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`))
|
||||
return
|
||||
default:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewQwenExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-test",
|
||||
Provider: "qwen",
|
||||
Attributes: map[string]string{
|
||||
"base_url": srv.URL + "/v1",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "refresh-token",
|
||||
},
|
||||
}
|
||||
|
||||
var refresherCalls int32
|
||||
exec.refreshForImmediateRetry = func(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
atomic.AddInt32(&refresherCalls, 1)
|
||||
refreshed := auth.Clone()
|
||||
if refreshed.Metadata == nil {
|
||||
refreshed.Metadata = make(map[string]any)
|
||||
}
|
||||
refreshed.Metadata["access_token"] = "new-token"
|
||||
refreshed.Metadata["refresh_token"] = "refresh-token-2"
|
||||
return refreshed, nil
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := exec.Execute(ctx, auth, cliproxyexecutor.Request{
|
||||
Model: "qwen-max",
|
||||
Payload: []byte(`{"model":"qwen-max","messages":[{"role":"user","content":"hi"}]}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("Execute() expected error, got nil")
|
||||
}
|
||||
status, ok := err.(statusErr)
|
||||
if !ok {
|
||||
t.Fatalf("Execute() error type = %T, want statusErr", err)
|
||||
}
|
||||
if status.StatusCode() != http.StatusTooManyRequests {
|
||||
t.Fatalf("Execute() status code = %d, want %d", status.StatusCode(), http.StatusTooManyRequests)
|
||||
}
|
||||
if atomic.LoadInt32(&calls) != 1 {
|
||||
t.Fatalf("upstream calls = %d, want 1", atomic.LoadInt32(&calls))
|
||||
}
|
||||
if atomic.LoadInt32(&refresherCalls) != 0 {
|
||||
t.Fatalf("refresher calls = %d, want 0", atomic.LoadInt32(&refresherCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenExecutorExecuteStream_429DoesNotRefreshOrRetry(t *testing.T) {
|
||||
qwenRateLimiter.Lock()
|
||||
qwenRateLimiter.requests = make(map[string][]time.Time)
|
||||
qwenRateLimiter.Unlock()
|
||||
|
||||
var calls int32
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
if r.URL.Path != "/v1/chat/completions" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
switch r.Header.Get("Authorization") {
|
||||
case "Bearer old-token":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"code":"quota_exceeded","message":"quota exceeded","type":"quota_exceeded"}}`))
|
||||
return
|
||||
case "Bearer new-token":
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-test\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"qwen-max\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"},\"finish_reason\":null}]}\n"))
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
default:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewQwenExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-test",
|
||||
Provider: "qwen",
|
||||
Attributes: map[string]string{
|
||||
"base_url": srv.URL + "/v1",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "refresh-token",
|
||||
},
|
||||
}
|
||||
|
||||
var refresherCalls int32
|
||||
exec.refreshForImmediateRetry = func(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
atomic.AddInt32(&refresherCalls, 1)
|
||||
refreshed := auth.Clone()
|
||||
if refreshed.Metadata == nil {
|
||||
refreshed.Metadata = make(map[string]any)
|
||||
}
|
||||
refreshed.Metadata["access_token"] = "new-token"
|
||||
refreshed.Metadata["refresh_token"] = "refresh-token-2"
|
||||
return refreshed, nil
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := exec.ExecuteStream(ctx, auth, cliproxyexecutor.Request{
|
||||
Model: "qwen-max",
|
||||
Payload: []byte(`{"model":"qwen-max","stream":true,"messages":[{"role":"user","content":"hi"}]}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("ExecuteStream() expected error, got nil")
|
||||
}
|
||||
status, ok := err.(statusErr)
|
||||
if !ok {
|
||||
t.Fatalf("ExecuteStream() error type = %T, want statusErr", err)
|
||||
}
|
||||
if status.StatusCode() != http.StatusTooManyRequests {
|
||||
t.Fatalf("ExecuteStream() status code = %d, want %d", status.StatusCode(), http.StatusTooManyRequests)
|
||||
}
|
||||
if atomic.LoadInt32(&calls) != 1 {
|
||||
t.Fatalf("upstream calls = %d, want 1", atomic.LoadInt32(&calls))
|
||||
}
|
||||
if atomic.LoadInt32(&refresherCalls) != 0 {
|
||||
t.Fatalf("refresher calls = %d, want 0", atomic.LoadInt32(&refresherCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenExecutorExecute_429RetryAfterHeaderPropagatesToStatusErr(t *testing.T) {
|
||||
qwenRateLimiter.Lock()
|
||||
qwenRateLimiter.requests = make(map[string][]time.Time)
|
||||
qwenRateLimiter.Unlock()
|
||||
|
||||
var calls int32
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
if r.URL.Path != "/v1/chat/completions" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Retry-After", "2")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"code":"rate_limit_exceeded","message":"rate limited","type":"rate_limit_exceeded"}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewQwenExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-test",
|
||||
Provider: "qwen",
|
||||
Attributes: map[string]string{
|
||||
"base_url": srv.URL + "/v1",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"access_token": "test-token",
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := exec.Execute(ctx, auth, cliproxyexecutor.Request{
|
||||
Model: "qwen-max",
|
||||
Payload: []byte(`{"model":"qwen-max","messages":[{"role":"user","content":"hi"}]}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("Execute() expected error, got nil")
|
||||
}
|
||||
status, ok := err.(statusErr)
|
||||
if !ok {
|
||||
t.Fatalf("Execute() error type = %T, want statusErr", err)
|
||||
}
|
||||
if status.StatusCode() != http.StatusTooManyRequests {
|
||||
t.Fatalf("Execute() status code = %d, want %d", status.StatusCode(), http.StatusTooManyRequests)
|
||||
}
|
||||
if status.RetryAfter() == nil {
|
||||
t.Fatalf("Execute() RetryAfter is nil, want non-nil")
|
||||
}
|
||||
if got := *status.RetryAfter(); got != 2*time.Second {
|
||||
t.Fatalf("Execute() RetryAfter = %v, want %v", got, 2*time.Second)
|
||||
}
|
||||
if atomic.LoadInt32(&calls) != 1 {
|
||||
t.Fatalf("upstream calls = %d, want 1", atomic.LoadInt32(&calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenExecutorExecuteStream_429RetryAfterHeaderPropagatesToStatusErr(t *testing.T) {
|
||||
qwenRateLimiter.Lock()
|
||||
qwenRateLimiter.requests = make(map[string][]time.Time)
|
||||
qwenRateLimiter.Unlock()
|
||||
|
||||
var calls int32
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
if r.URL.Path != "/v1/chat/completions" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Retry-After", "2")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"code":"rate_limit_exceeded","message":"rate limited","type":"rate_limit_exceeded"}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewQwenExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-test",
|
||||
Provider: "qwen",
|
||||
Attributes: map[string]string{
|
||||
"base_url": srv.URL + "/v1",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"access_token": "test-token",
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := exec.ExecuteStream(ctx, auth, cliproxyexecutor.Request{
|
||||
Model: "qwen-max",
|
||||
Payload: []byte(`{"model":"qwen-max","stream":true,"messages":[{"role":"user","content":"hi"}]}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("ExecuteStream() expected error, got nil")
|
||||
}
|
||||
status, ok := err.(statusErr)
|
||||
if !ok {
|
||||
t.Fatalf("ExecuteStream() error type = %T, want statusErr", err)
|
||||
}
|
||||
if status.StatusCode() != http.StatusTooManyRequests {
|
||||
t.Fatalf("ExecuteStream() status code = %d, want %d", status.StatusCode(), http.StatusTooManyRequests)
|
||||
}
|
||||
if status.RetryAfter() == nil {
|
||||
t.Fatalf("ExecuteStream() RetryAfter is nil, want non-nil")
|
||||
}
|
||||
if got := *status.RetryAfter(); got != 2*time.Second {
|
||||
t.Fatalf("ExecuteStream() RetryAfter = %v, want %v", got, 2*time.Second)
|
||||
}
|
||||
if atomic.LoadInt32(&calls) != 1 {
|
||||
t.Fatalf("upstream calls = %d, want 1", atomic.LoadInt32(&calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenExecutorExecute_429QuotaExhausted_DisableCoolingSetsDefaultRetryAfter(t *testing.T) {
|
||||
qwenRateLimiter.Lock()
|
||||
qwenRateLimiter.requests = make(map[string][]time.Time)
|
||||
qwenRateLimiter.Unlock()
|
||||
|
||||
var calls int32
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
if r.URL.Path != "/v1/chat/completions" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"code":"quota_exceeded","message":"quota exceeded","type":"quota_exceeded"}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewQwenExecutor(&config.Config{DisableCooling: true})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-test",
|
||||
Provider: "qwen",
|
||||
Attributes: map[string]string{
|
||||
"base_url": srv.URL + "/v1",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"access_token": "test-token",
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := exec.Execute(ctx, auth, cliproxyexecutor.Request{
|
||||
Model: "qwen-max",
|
||||
Payload: []byte(`{"model":"qwen-max","messages":[{"role":"user","content":"hi"}]}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("Execute() expected error, got nil")
|
||||
}
|
||||
status, ok := err.(statusErr)
|
||||
if !ok {
|
||||
t.Fatalf("Execute() error type = %T, want statusErr", err)
|
||||
}
|
||||
if status.StatusCode() != http.StatusTooManyRequests {
|
||||
t.Fatalf("Execute() status code = %d, want %d", status.StatusCode(), http.StatusTooManyRequests)
|
||||
}
|
||||
if status.RetryAfter() == nil {
|
||||
t.Fatalf("Execute() RetryAfter is nil, want non-nil")
|
||||
}
|
||||
if got := *status.RetryAfter(); got != time.Second {
|
||||
t.Fatalf("Execute() RetryAfter = %v, want %v", got, time.Second)
|
||||
}
|
||||
if atomic.LoadInt32(&calls) != 1 {
|
||||
t.Fatalf("upstream calls = %d, want 1", atomic.LoadInt32(&calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenExecutorExecuteStream_429QuotaExhausted_DisableCoolingSetsDefaultRetryAfter(t *testing.T) {
|
||||
qwenRateLimiter.Lock()
|
||||
qwenRateLimiter.requests = make(map[string][]time.Time)
|
||||
qwenRateLimiter.Unlock()
|
||||
|
||||
var calls int32
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
if r.URL.Path != "/v1/chat/completions" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"code":"quota_exceeded","message":"quota exceeded","type":"quota_exceeded"}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewQwenExecutor(&config.Config{DisableCooling: true})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-test",
|
||||
Provider: "qwen",
|
||||
Attributes: map[string]string{
|
||||
"base_url": srv.URL + "/v1",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"access_token": "test-token",
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := exec.ExecuteStream(ctx, auth, cliproxyexecutor.Request{
|
||||
Model: "qwen-max",
|
||||
Payload: []byte(`{"model":"qwen-max","stream":true,"messages":[{"role":"user","content":"hi"}]}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("ExecuteStream() expected error, got nil")
|
||||
}
|
||||
status, ok := err.(statusErr)
|
||||
if !ok {
|
||||
t.Fatalf("ExecuteStream() error type = %T, want statusErr", err)
|
||||
}
|
||||
if status.StatusCode() != http.StatusTooManyRequests {
|
||||
t.Fatalf("ExecuteStream() status code = %d, want %d", status.StatusCode(), http.StatusTooManyRequests)
|
||||
}
|
||||
if status.RetryAfter() == nil {
|
||||
t.Fatalf("ExecuteStream() RetryAfter is nil, want non-nil")
|
||||
}
|
||||
if got := *status.RetryAfter(); got != time.Second {
|
||||
t.Fatalf("ExecuteStream() RetryAfter = %v, want %v", got, time.Second)
|
||||
}
|
||||
if atomic.LoadInt32(&calls) != 1 {
|
||||
t.Fatalf("upstream calls = %d, want 1", atomic.LoadInt32(&calls))
|
||||
}
|
||||
}
|
||||
@@ -154,7 +154,7 @@ func isEnableThinkingModel(modelID string) bool {
|
||||
}
|
||||
id := strings.ToLower(modelID)
|
||||
switch id {
|
||||
case "qwen3-max-preview", "deepseek-v3.2", "deepseek-v3.1":
|
||||
case "deepseek-v3.2", "deepseek-v3.1":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@@ -101,6 +101,9 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
systemTypePromptResult := systemPromptResult.Get("type")
|
||||
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||
systemPrompt := systemPromptResult.Get("text").String()
|
||||
if strings.HasPrefix(systemPrompt, "x-anthropic-billing-header:") {
|
||||
continue
|
||||
}
|
||||
partJSON := []byte(`{}`)
|
||||
if systemPrompt != "" {
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", systemPrompt)
|
||||
@@ -170,9 +173,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
continue
|
||||
}
|
||||
|
||||
// Valid signature, send as thought block
|
||||
// Always include "text" field — Google Antigravity API requires it
|
||||
// even for redacted thinking where the text is empty.
|
||||
// Drop empty-text thinking blocks (redacted thinking from Claude Max).
|
||||
// Antigravity wraps empty text into a prompt-caching-scope object that
|
||||
// omits the required inner "thinking" field, causing:
|
||||
// 400 "messages.N.content.0.thinking.thinking: Field required"
|
||||
if thinkingText == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Valid signature with content, send as thought block.
|
||||
partJSON := []byte(`{}`)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thought", true)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText)
|
||||
|
||||
@@ -468,11 +468,7 @@ func TestValidateBypassMode_HandlesWhitespace(t *testing.T) {
|
||||
|
||||
func TestValidateBypassMode_RejectsOversizedSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
payload := append([]byte{0x12}, bytes.Repeat([]byte{0x34}, maxBypassSignatureLen)...)
|
||||
sig := base64.StdEncoding.EncodeToString(payload)
|
||||
if len(sig) <= maxBypassSignatureLen {
|
||||
t.Fatalf("test setup: signature should exceed max length, got %d", len(sig))
|
||||
}
|
||||
sig := strings.Repeat("A", maxBypassSignatureLen+1)
|
||||
|
||||
inputJSON := []byte(`{
|
||||
"messages": [{"role": "assistant", "content": [
|
||||
@@ -489,6 +485,33 @@ func TestValidateBypassMode_RejectsOversizedSignature(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateBypassMode_StrictAcceptsSignatureBetween16KiBAnd32MiB(t *testing.T) {
|
||||
previous := cache.SignatureBypassStrictMode()
|
||||
cache.SetSignatureBypassStrictMode(true)
|
||||
t.Cleanup(func() {
|
||||
cache.SetSignatureBypassStrictMode(previous)
|
||||
})
|
||||
|
||||
payload := buildClaudeSignaturePayload(t, 12, uint64Ptr(2), strings.Repeat("m", 20000), true)
|
||||
sig := base64.StdEncoding.EncodeToString(payload)
|
||||
if len(sig) <= 1<<14 {
|
||||
t.Fatalf("test setup: signature should exceed previous 16KiB guardrail, got %d", len(sig))
|
||||
}
|
||||
if len(sig) > maxBypassSignatureLen {
|
||||
t.Fatalf("test setup: signature should remain within new max length, got %d", len(sig))
|
||||
}
|
||||
|
||||
inputJSON := []byte(`{
|
||||
"messages": [{"role": "assistant", "content": [
|
||||
{"type": "thinking", "thinking": "t", "signature": "` + sig + `"}
|
||||
]}]
|
||||
}`)
|
||||
|
||||
if err := ValidateClaudeBypassSignatures(inputJSON); err != nil {
|
||||
t.Fatalf("expected strict mode to accept signature below 32MiB max, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveBypassModeSignature_TrimsWhitespace(t *testing.T) {
|
||||
previous := cache.SignatureCacheEnabled()
|
||||
cache.SetSignatureCacheEnabled(false)
|
||||
@@ -2158,6 +2181,225 @@ func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingMediaType(t *te
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_BypassMode_DropsRedactedThinkingBlocks(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
previous := cache.SignatureCacheEnabled()
|
||||
cache.SetSignatureCacheEnabled(false)
|
||||
t.Cleanup(func() {
|
||||
cache.SetSignatureCacheEnabled(previous)
|
||||
cache.ClearSignatureCache("")
|
||||
})
|
||||
|
||||
validSignature := testAnthropicNativeSignature(t)
|
||||
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-opus-4-6",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "", "signature": "` + validSignature + `"},
|
||||
{"type": "text", "text": "I can help with that."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Follow up question"}]
|
||||
}
|
||||
],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 10000}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6", inputJSON, false)
|
||||
|
||||
assistantParts := gjson.GetBytes(output, "request.contents.1.parts").Array()
|
||||
if len(assistantParts) != 1 {
|
||||
t.Fatalf("Expected 1 part (redacted thinking dropped), got %d: %s",
|
||||
len(assistantParts), gjson.GetBytes(output, "request.contents.1.parts").Raw)
|
||||
}
|
||||
if assistantParts[0].Get("thought").Bool() {
|
||||
t.Fatal("Redacted thinking block with empty text should be dropped")
|
||||
}
|
||||
if assistantParts[0].Get("text").String() != "I can help with that." {
|
||||
t.Fatalf("Expected text part preserved, got: %s", assistantParts[0].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_BypassMode_DropsWrappedRedactedThinking(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
previous := cache.SignatureCacheEnabled()
|
||||
cache.SetSignatureCacheEnabled(false)
|
||||
t.Cleanup(func() {
|
||||
cache.SetSignatureCacheEnabled(previous)
|
||||
cache.ClearSignatureCache("")
|
||||
})
|
||||
|
||||
validSignature := testAnthropicNativeSignature(t)
|
||||
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-6",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Test user message"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": {"cache_control": {"type": "ephemeral"}}, "signature": "` + validSignature + `"},
|
||||
{"type": "text", "text": "Answer"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Follow up"}]
|
||||
}
|
||||
],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 8000}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-6", inputJSON, false)
|
||||
|
||||
assistantParts := gjson.GetBytes(output, "request.contents.1.parts").Array()
|
||||
if len(assistantParts) != 1 {
|
||||
t.Fatalf("Expected 1 part (wrapped redacted thinking dropped), got %d: %s",
|
||||
len(assistantParts), gjson.GetBytes(output, "request.contents.1.parts").Raw)
|
||||
}
|
||||
if assistantParts[0].Get("text").String() != "Answer" {
|
||||
t.Fatalf("Expected text part preserved, got: %s", assistantParts[0].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_BypassMode_KeepsNonEmptyThinking(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
previous := cache.SignatureCacheEnabled()
|
||||
cache.SetSignatureCacheEnabled(false)
|
||||
t.Cleanup(func() {
|
||||
cache.SetSignatureCacheEnabled(previous)
|
||||
cache.ClearSignatureCache("")
|
||||
})
|
||||
|
||||
validSignature := testAnthropicNativeSignature(t)
|
||||
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-opus-4-6",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Let me reason about this carefully...", "signature": "` + validSignature + `"},
|
||||
{"type": "text", "text": "Here is my answer."}
|
||||
]
|
||||
}
|
||||
],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 10000}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6", inputJSON, false)
|
||||
|
||||
assistantParts := gjson.GetBytes(output, "request.contents.1.parts").Array()
|
||||
if len(assistantParts) != 2 {
|
||||
t.Fatalf("Expected 2 parts (thinking + text), got %d", len(assistantParts))
|
||||
}
|
||||
if !assistantParts[0].Get("thought").Bool() {
|
||||
t.Fatal("First part should be a thought block")
|
||||
}
|
||||
if assistantParts[0].Get("text").String() != "Let me reason about this carefully..." {
|
||||
t.Fatalf("Thinking text mismatch, got: %s", assistantParts[0].Get("text").String())
|
||||
}
|
||||
if assistantParts[1].Get("text").String() != "Here is my answer." {
|
||||
t.Fatalf("Text part mismatch, got: %s", assistantParts[1].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_BypassMode_MultiTurnRedactedThinking(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
previous := cache.SignatureCacheEnabled()
|
||||
cache.SetSignatureCacheEnabled(false)
|
||||
t.Cleanup(func() {
|
||||
cache.SetSignatureCacheEnabled(previous)
|
||||
cache.ClearSignatureCache("")
|
||||
})
|
||||
|
||||
sig := testAnthropicNativeSignature(t)
|
||||
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-opus-4-6",
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "First question"}]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "", "signature": "` + sig + `"},
|
||||
{"type": "text", "text": "First answer"},
|
||||
{"type": "tool_use", "id": "Bash-123-456", "name": "Bash", "input": {"command": "ls"}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "Bash-123-456", "content": "file1.txt\nfile2.txt"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "", "signature": "` + sig + `"},
|
||||
{"type": "text", "text": "Here are the files."}
|
||||
]
|
||||
},
|
||||
{"role": "user", "content": [{"type": "text", "text": "Thanks"}]}
|
||||
],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 10000}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6", inputJSON, false)
|
||||
|
||||
if !gjson.ValidBytes(output) {
|
||||
t.Fatalf("Output is not valid JSON: %s", string(output))
|
||||
}
|
||||
|
||||
firstAssistantParts := gjson.GetBytes(output, "request.contents.1.parts").Array()
|
||||
for _, p := range firstAssistantParts {
|
||||
if p.Get("thought").Bool() {
|
||||
t.Fatal("Redacted thinking should be dropped from first assistant message")
|
||||
}
|
||||
}
|
||||
hasText := false
|
||||
hasFC := false
|
||||
for _, p := range firstAssistantParts {
|
||||
if p.Get("text").String() == "First answer" {
|
||||
hasText = true
|
||||
}
|
||||
if p.Get("functionCall").Exists() {
|
||||
hasFC = true
|
||||
}
|
||||
}
|
||||
if !hasText || !hasFC {
|
||||
t.Fatalf("First assistant should have text + functionCall, got: %s",
|
||||
gjson.GetBytes(output, "request.contents.1.parts").Raw)
|
||||
}
|
||||
|
||||
secondAssistantParts := gjson.GetBytes(output, "request.contents.3.parts").Array()
|
||||
for _, p := range secondAssistantParts {
|
||||
if p.Get("thought").Bool() {
|
||||
t.Fatal("Redacted thinking should be dropped from second assistant message")
|
||||
}
|
||||
}
|
||||
if len(secondAssistantParts) != 1 || secondAssistantParts[0].Get("text").String() != "Here are the files." {
|
||||
t.Fatalf("Second assistant should have only text part, got: %s",
|
||||
gjson.GetBytes(output, "request.contents.3.parts").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
||||
// When tools + thinking but no system instruction, should create one with hint
|
||||
inputJSON := []byte(`{
|
||||
|
||||
@@ -55,10 +55,11 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
)
|
||||
|
||||
const maxBypassSignatureLen = 8192
|
||||
const maxBypassSignatureLen = 32 * 1024 * 1024
|
||||
|
||||
type claudeSignatureTree struct {
|
||||
EncodingLayers int
|
||||
@@ -72,6 +73,62 @@ type claudeSignatureTree struct {
|
||||
HasField7 bool
|
||||
}
|
||||
|
||||
// StripInvalidSignatureThinkingBlocks removes thinking blocks whose signatures
|
||||
// are empty or not valid Claude format (must start with 'E' or 'R' after
|
||||
// stripping any cache prefix). These come from proxy-generated responses
|
||||
// (Antigravity/Gemini) where no real Claude signature exists.
|
||||
func StripEmptySignatureThinkingBlocks(payload []byte) []byte {
|
||||
messages := gjson.GetBytes(payload, "messages")
|
||||
if !messages.IsArray() {
|
||||
return payload
|
||||
}
|
||||
modified := false
|
||||
for i, msg := range messages.Array() {
|
||||
content := msg.Get("content")
|
||||
if !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
var kept []string
|
||||
stripped := false
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "thinking" && !hasValidClaudeSignature(part.Get("signature").String()) {
|
||||
stripped = true
|
||||
continue
|
||||
}
|
||||
kept = append(kept, part.Raw)
|
||||
}
|
||||
if stripped {
|
||||
modified = true
|
||||
if len(kept) == 0 {
|
||||
payload, _ = sjson.SetRawBytes(payload, fmt.Sprintf("messages.%d.content", i), []byte("[]"))
|
||||
} else {
|
||||
payload, _ = sjson.SetRawBytes(payload, fmt.Sprintf("messages.%d.content", i), []byte("["+strings.Join(kept, ",")+"]"))
|
||||
}
|
||||
}
|
||||
}
|
||||
if !modified {
|
||||
return payload
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
// hasValidClaudeSignature returns true if sig looks like a real Claude thinking
|
||||
// signature: non-empty and starts with 'E' or 'R' (after stripping optional
|
||||
// cache prefix like "modelGroup#").
|
||||
func hasValidClaudeSignature(sig string) bool {
|
||||
sig = strings.TrimSpace(sig)
|
||||
if sig == "" {
|
||||
return false
|
||||
}
|
||||
if idx := strings.IndexByte(sig, '#'); idx >= 0 {
|
||||
sig = strings.TrimSpace(sig[idx+1:])
|
||||
}
|
||||
if sig == "" {
|
||||
return false
|
||||
}
|
||||
return sig[0] == 'E' || sig[0] == 'R'
|
||||
}
|
||||
|
||||
func ValidateClaudeBypassSignatures(inputRawJSON []byte) error {
|
||||
messages := gjson.GetBytes(inputRawJSON, "messages")
|
||||
if !messages.IsArray() {
|
||||
|
||||
@@ -13,4 +13,4 @@ func GetString(m map[string]interface{}, key string) string {
|
||||
// GetStringValue is an alias for GetString for backward compatibility.
|
||||
func GetStringValue(m map[string]interface{}, key string) string {
|
||||
return GetString(m, key)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,4 +17,4 @@ func init() {
|
||||
NonStream: ConvertKiroNonStreamToOpenAI,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -274,4 +274,4 @@ func min(a, b int) int {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,4 +209,4 @@ func NewThinkingTagState() *ThinkingTagState {
|
||||
PendingStartChars: 0,
|
||||
PendingEndChars: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ var oauthProviders = []oauthProvider{
|
||||
{"Claude (Anthropic)", "anthropic-auth-url", "🟧"},
|
||||
{"Codex (OpenAI)", "codex-auth-url", "🟩"},
|
||||
{"Antigravity", "antigravity-auth-url", "🟪"},
|
||||
{"Qwen", "qwen-auth-url", "🟨"},
|
||||
{"Kimi", "kimi-auth-url", "🟫"},
|
||||
{"IFlow", "iflow-auth-url", "⬜"},
|
||||
}
|
||||
@@ -280,8 +279,6 @@ func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd {
|
||||
providerKey = "codex"
|
||||
case "antigravity-auth-url":
|
||||
providerKey = "antigravity"
|
||||
case "qwen-auth-url":
|
||||
providerKey = "qwen"
|
||||
case "kimi-auth-url":
|
||||
providerKey = "kimi"
|
||||
case "iflow-auth-url":
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
// - "gemini" for Google's Gemini family
|
||||
// - "codex" for OpenAI GPT-compatible providers
|
||||
// - "claude" for Anthropic models
|
||||
// - "qwen" for Alibaba's Qwen models
|
||||
// - "openai-compatibility" for external OpenAI-compatible providers
|
||||
//
|
||||
// Parameters:
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -85,14 +84,22 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
||||
} else if resolvedAuthDir != "" {
|
||||
_ = filepath.Walk(resolvedAuthDir, func(path string, info fs.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
|
||||
if data, errReadFile := os.ReadFile(path); errReadFile == nil && len(data) > 0 {
|
||||
entries, errReadDir := os.ReadDir(resolvedAuthDir)
|
||||
if errReadDir != nil {
|
||||
log.Errorf("failed to read auth directory for hash cache: %v", errReadDir)
|
||||
} else {
|
||||
for _, entry := range entries {
|
||||
if entry == nil || entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
||||
continue
|
||||
}
|
||||
fullPath := filepath.Join(resolvedAuthDir, name)
|
||||
if data, errReadFile := os.ReadFile(fullPath); errReadFile == nil && len(data) > 0 {
|
||||
sum := sha256.Sum256(data)
|
||||
normalizedPath := w.normalizeAuthPath(path)
|
||||
normalizedPath := w.normalizeAuthPath(fullPath)
|
||||
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
|
||||
// Parse and cache auth content for future diff comparisons (debug only).
|
||||
if cacheAuthContents {
|
||||
@@ -107,15 +114,14 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
Now: time.Now(),
|
||||
IDGenerator: synthesizer.NewStableIDGenerator(),
|
||||
}
|
||||
if generated := synthesizer.SynthesizeAuthFile(ctx, path, data); len(generated) > 0 {
|
||||
if generated := synthesizer.SynthesizeAuthFile(ctx, fullPath, data); len(generated) > 0 {
|
||||
if pathAuths := authSliceToMap(generated); len(pathAuths) > 0 {
|
||||
w.fileAuthsByPath[normalizedPath] = authIDSet(pathAuths)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
}
|
||||
@@ -306,23 +312,25 @@ func (w *Watcher) loadFileClients(cfg *config.Config) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
errWalk := filepath.Walk(authDir, func(path string, info fs.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
log.Debugf("error accessing path %s: %v", path, err)
|
||||
return err
|
||||
entries, errReadDir := os.ReadDir(authDir)
|
||||
if errReadDir != nil {
|
||||
log.Errorf("error reading auth directory: %v", errReadDir)
|
||||
return 0
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry == nil || entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
|
||||
authFileCount++
|
||||
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
|
||||
if data, errCreate := os.ReadFile(path); errCreate == nil && len(data) > 0 {
|
||||
successfulAuthCount++
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
||||
continue
|
||||
}
|
||||
authFileCount++
|
||||
log.Debugf("processing auth file %d: %s", authFileCount, name)
|
||||
fullPath := filepath.Join(authDir, name)
|
||||
if data, errReadFile := os.ReadFile(fullPath); errReadFile == nil && len(data) > 0 {
|
||||
successfulAuthCount++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if errWalk != nil {
|
||||
log.Errorf("error walking auth directory: %v", errWalk)
|
||||
}
|
||||
log.Debugf("auth directory scan complete - found %d .json files, %d readable", authFileCount, successfulAuthCount)
|
||||
return authFileCount
|
||||
|
||||
@@ -96,7 +96,7 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
normalizedAuthDir := w.normalizeAuthPath(w.authDir)
|
||||
isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0
|
||||
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
|
||||
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
|
||||
isAuthJSON := filepath.Dir(normalizedName) == normalizedAuthDir && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
|
||||
isKiroIDEToken := w.isKiroIDETokenFile(event.Name) && event.Op&authOps != 0
|
||||
if !isConfigEvent && !isAuthJSON && !isKiroIDEToken {
|
||||
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
@@ -188,7 +187,7 @@ func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool {
|
||||
|
||||
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
||||
// It is forwarded as execution metadata; when absent we generate a UUID.
|
||||
// Only include it if the client explicitly provides it.
|
||||
key := ""
|
||||
if ctx != nil {
|
||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
@@ -196,7 +195,7 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||
}
|
||||
}
|
||||
if key == "" {
|
||||
key = uuid.NewString()
|
||||
return make(map[string]any)
|
||||
}
|
||||
|
||||
meta := map[string]any{idempotencyKeyMetadataKey: key}
|
||||
|
||||
@@ -17,7 +17,6 @@ type ManagementTokenRequester interface {
|
||||
RequestGeminiCLIToken(*gin.Context)
|
||||
RequestCodexToken(*gin.Context)
|
||||
RequestAntigravityToken(*gin.Context)
|
||||
RequestQwenToken(*gin.Context)
|
||||
RequestKimiToken(*gin.Context)
|
||||
RequestIFlowToken(*gin.Context)
|
||||
RequestIFlowCookieToken(*gin.Context)
|
||||
@@ -52,10 +51,6 @@ func (m *managementTokenRequester) RequestAntigravityToken(c *gin.Context) {
|
||||
m.handler.RequestAntigravityToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) RequestQwenToken(c *gin.Context) {
|
||||
m.handler.RequestQwenToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) RequestKimiToken(c *gin.Context) {
|
||||
m.handler.RequestKimiToken(c)
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
}
|
||||
|
||||
kilocodeAuth := kilo.NewKiloAuth()
|
||||
|
||||
|
||||
fmt.Println("Initiating Kilo device authentication...")
|
||||
resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
@@ -48,7 +48,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
|
||||
fmt.Printf("Please visit: %s\n", resp.VerificationURL)
|
||||
fmt.Printf("And enter code: %s\n", resp.Code)
|
||||
|
||||
|
||||
fmt.Println("Waiting for authorization...")
|
||||
status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
|
||||
if err != nil {
|
||||
@@ -68,7 +68,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
for i, org := range profile.Orgs {
|
||||
fmt.Printf("[%d] %s (%s)\n", i+1, org.Name, org.ID)
|
||||
}
|
||||
|
||||
|
||||
if opts.Prompt != nil {
|
||||
input, err := opts.Prompt("Enter the number of the organization: ")
|
||||
if err != nil {
|
||||
@@ -108,7 +108,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
metadata := map[string]any{
|
||||
"email": status.UserEmail,
|
||||
"organization_id": orgID,
|
||||
"model": defaults.Model,
|
||||
"model": defaults.Model,
|
||||
}
|
||||
|
||||
return &coreauth.Auth{
|
||||
|
||||
113
sdk/auth/qwen.go
113
sdk/auth/qwen.go
@@ -1,113 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
// legacy client removed
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// QwenAuthenticator implements the device flow login for Qwen accounts.
|
||||
type QwenAuthenticator struct{}
|
||||
|
||||
// NewQwenAuthenticator constructs a Qwen authenticator.
|
||||
func NewQwenAuthenticator() *QwenAuthenticator {
|
||||
return &QwenAuthenticator{}
|
||||
}
|
||||
|
||||
func (a *QwenAuthenticator) Provider() string {
|
||||
return "qwen"
|
||||
}
|
||||
|
||||
func (a *QwenAuthenticator) RefreshLead() *time.Duration {
|
||||
return new(20 * time.Minute)
|
||||
}
|
||||
|
||||
func (a *QwenAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
authSvc := qwen.NewQwenAuth(cfg)
|
||||
|
||||
deviceFlow, err := authSvc.InitiateDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qwen device flow initiation failed: %w", err)
|
||||
}
|
||||
|
||||
authURL := deviceFlow.VerificationURIComplete
|
||||
|
||||
if !opts.NoBrowser {
|
||||
fmt.Println("Opening browser for Qwen authentication")
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available; please open the URL manually")
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
} else if err = browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", err)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
|
||||
fmt.Println("Waiting for Qwen authentication...")
|
||||
|
||||
tokenData, err := authSvc.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qwen authentication failed: %w", err)
|
||||
}
|
||||
|
||||
tokenStorage := authSvc.CreateTokenStorage(tokenData)
|
||||
|
||||
email := ""
|
||||
if opts.Metadata != nil {
|
||||
email = opts.Metadata["email"]
|
||||
if email == "" {
|
||||
email = opts.Metadata["alias"]
|
||||
}
|
||||
}
|
||||
|
||||
if email == "" && opts.Prompt != nil {
|
||||
email, err = opts.Prompt("Please input your email address or alias for Qwen:")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" {
|
||||
return nil, &EmailRequiredError{Prompt: "Please provide an email address or alias for Qwen."}
|
||||
}
|
||||
|
||||
tokenStorage.Email = email
|
||||
|
||||
// no legacy client construction
|
||||
|
||||
fileName := fmt.Sprintf("qwen-%s.json", tokenStorage.Email)
|
||||
metadata := map[string]any{
|
||||
"email": tokenStorage.Email,
|
||||
}
|
||||
|
||||
fmt.Println("Qwen authentication successful")
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestQwenAuthenticator_RefreshLeadIsSane(t *testing.T) {
|
||||
lead := NewQwenAuthenticator().RefreshLead()
|
||||
if lead == nil {
|
||||
t.Fatal("RefreshLead() = nil, want non-nil")
|
||||
}
|
||||
if *lead <= 0 {
|
||||
t.Fatalf("RefreshLead() = %s, want > 0", *lead)
|
||||
}
|
||||
if *lead > 30*time.Minute {
|
||||
t.Fatalf("RefreshLead() = %s, want <= %s", *lead, 30*time.Minute)
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
func init() {
|
||||
registerRefreshLead("codex", func() Authenticator { return NewCodexAuthenticator() })
|
||||
registerRefreshLead("claude", func() Authenticator { return NewClaudeAuthenticator() })
|
||||
registerRefreshLead("qwen", func() Authenticator { return NewQwenAuthenticator() })
|
||||
registerRefreshLead("iflow", func() Authenticator { return NewIFlowAuthenticator() })
|
||||
registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() })
|
||||
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
||||
|
||||
453
sdk/cliproxy/auth/auto_refresh_loop.go
Normal file
453
sdk/cliproxy/auth/auto_refresh_loop.go
Normal file
@@ -0,0 +1,453 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type authAutoRefreshLoop struct {
|
||||
manager *Manager
|
||||
interval time.Duration
|
||||
concurrency int
|
||||
|
||||
mu sync.Mutex
|
||||
queue refreshMinHeap
|
||||
index map[string]*refreshHeapItem
|
||||
dirty map[string]struct{}
|
||||
|
||||
wakeCh chan struct{}
|
||||
jobs chan string
|
||||
}
|
||||
|
||||
func newAuthAutoRefreshLoop(manager *Manager, interval time.Duration, concurrency int) *authAutoRefreshLoop {
|
||||
if interval <= 0 {
|
||||
interval = refreshCheckInterval
|
||||
}
|
||||
if concurrency <= 0 {
|
||||
concurrency = refreshMaxConcurrency
|
||||
}
|
||||
jobBuffer := concurrency * 4
|
||||
if jobBuffer < 64 {
|
||||
jobBuffer = 64
|
||||
}
|
||||
return &authAutoRefreshLoop{
|
||||
manager: manager,
|
||||
interval: interval,
|
||||
concurrency: concurrency,
|
||||
index: make(map[string]*refreshHeapItem),
|
||||
dirty: make(map[string]struct{}),
|
||||
wakeCh: make(chan struct{}, 1),
|
||||
jobs: make(chan string, jobBuffer),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) queueReschedule(authID string) {
|
||||
if l == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
l.dirty[authID] = struct{}{}
|
||||
l.mu.Unlock()
|
||||
select {
|
||||
case l.wakeCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) run(ctx context.Context) {
|
||||
if l == nil || l.manager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
workers := l.concurrency
|
||||
if workers <= 0 {
|
||||
workers = refreshMaxConcurrency
|
||||
}
|
||||
for i := 0; i < workers; i++ {
|
||||
go l.worker(ctx)
|
||||
}
|
||||
|
||||
l.loop(ctx)
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) worker(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case authID := <-l.jobs:
|
||||
if authID == "" {
|
||||
continue
|
||||
}
|
||||
l.manager.refreshAuth(ctx, authID)
|
||||
l.queueReschedule(authID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) rebuild(now time.Time) {
|
||||
type entry struct {
|
||||
id string
|
||||
next time.Time
|
||||
}
|
||||
|
||||
entries := make([]entry, 0)
|
||||
|
||||
l.manager.mu.RLock()
|
||||
for id, auth := range l.manager.auths {
|
||||
next, ok := nextRefreshCheckAt(now, auth, l.interval)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, entry{id: id, next: next})
|
||||
}
|
||||
l.manager.mu.RUnlock()
|
||||
|
||||
l.mu.Lock()
|
||||
l.queue = l.queue[:0]
|
||||
l.index = make(map[string]*refreshHeapItem, len(entries))
|
||||
for _, e := range entries {
|
||||
item := &refreshHeapItem{id: e.id, next: e.next}
|
||||
heap.Push(&l.queue, item)
|
||||
l.index[e.id] = item
|
||||
}
|
||||
l.mu.Unlock()
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) loop(ctx context.Context) {
|
||||
timer := time.NewTimer(time.Hour)
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
defer timer.Stop()
|
||||
|
||||
var timerCh <-chan time.Time
|
||||
l.resetTimer(timer, &timerCh, time.Now())
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-l.wakeCh:
|
||||
now := time.Now()
|
||||
l.applyDirty(now)
|
||||
l.resetTimer(timer, &timerCh, now)
|
||||
case <-timerCh:
|
||||
now := time.Now()
|
||||
l.handleDue(ctx, now)
|
||||
l.applyDirty(now)
|
||||
l.resetTimer(timer, &timerCh, now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) resetTimer(timer *time.Timer, timerCh *<-chan time.Time, now time.Time) {
|
||||
next, ok := l.peek()
|
||||
if !ok {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
*timerCh = nil
|
||||
return
|
||||
}
|
||||
|
||||
wait := next.Sub(now)
|
||||
if wait < 0 {
|
||||
wait = 0
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(wait)
|
||||
*timerCh = timer.C
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) peek() (time.Time, bool) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if len(l.queue) == 0 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return l.queue[0].next, true
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) handleDue(ctx context.Context, now time.Time) {
|
||||
due := l.popDue(now)
|
||||
if len(due) == 0 {
|
||||
return
|
||||
}
|
||||
if log.IsLevelEnabled(log.DebugLevel) {
|
||||
log.Debugf("auto-refresh scheduler due auths: %d", len(due))
|
||||
}
|
||||
for _, authID := range due {
|
||||
l.handleDueAuth(ctx, now, authID)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) popDue(now time.Time) []string {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
var due []string
|
||||
for len(l.queue) > 0 {
|
||||
item := l.queue[0]
|
||||
if item == nil || item.next.After(now) {
|
||||
break
|
||||
}
|
||||
popped := heap.Pop(&l.queue).(*refreshHeapItem)
|
||||
if popped == nil {
|
||||
continue
|
||||
}
|
||||
delete(l.index, popped.id)
|
||||
due = append(due, popped.id)
|
||||
}
|
||||
return due
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) handleDueAuth(ctx context.Context, now time.Time, authID string) {
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
manager := l.manager
|
||||
|
||||
manager.mu.RLock()
|
||||
auth := manager.auths[authID]
|
||||
if auth == nil {
|
||||
manager.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
next, shouldSchedule := nextRefreshCheckAt(now, auth, l.interval)
|
||||
shouldRefresh := manager.shouldRefresh(auth, now)
|
||||
exec := manager.executors[auth.Provider]
|
||||
manager.mu.RUnlock()
|
||||
|
||||
if !shouldSchedule {
|
||||
l.remove(authID)
|
||||
return
|
||||
}
|
||||
|
||||
if !shouldRefresh {
|
||||
l.upsert(authID, next)
|
||||
return
|
||||
}
|
||||
|
||||
if exec == nil {
|
||||
l.upsert(authID, now.Add(l.interval))
|
||||
return
|
||||
}
|
||||
|
||||
if !manager.markRefreshPending(authID, now) {
|
||||
manager.mu.RLock()
|
||||
auth = manager.auths[authID]
|
||||
next, shouldSchedule = nextRefreshCheckAt(now, auth, l.interval)
|
||||
manager.mu.RUnlock()
|
||||
if shouldSchedule {
|
||||
l.upsert(authID, next)
|
||||
} else {
|
||||
l.remove(authID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case l.jobs <- authID:
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) applyDirty(now time.Time) {
|
||||
dirty := l.drainDirty()
|
||||
if len(dirty) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, authID := range dirty {
|
||||
l.manager.mu.RLock()
|
||||
auth := l.manager.auths[authID]
|
||||
next, ok := nextRefreshCheckAt(now, auth, l.interval)
|
||||
l.manager.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
l.remove(authID)
|
||||
continue
|
||||
}
|
||||
l.upsert(authID, next)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) drainDirty() []string {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if len(l.dirty) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(l.dirty))
|
||||
for authID := range l.dirty {
|
||||
out = append(out, authID)
|
||||
delete(l.dirty, authID)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) upsert(authID string, next time.Time) {
|
||||
if authID == "" || next.IsZero() {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if item, ok := l.index[authID]; ok && item != nil {
|
||||
item.next = next
|
||||
heap.Fix(&l.queue, item.index)
|
||||
return
|
||||
}
|
||||
item := &refreshHeapItem{id: authID, next: next}
|
||||
heap.Push(&l.queue, item)
|
||||
l.index[authID] = item
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) remove(authID string) {
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
item, ok := l.index[authID]
|
||||
if !ok || item == nil {
|
||||
return
|
||||
}
|
||||
heap.Remove(&l.queue, item.index)
|
||||
delete(l.index, authID)
|
||||
}
|
||||
|
||||
func nextRefreshCheckAt(now time.Time, auth *Auth, interval time.Duration) (time.Time, bool) {
|
||||
if auth == nil || auth.Disabled {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
accountType, _ := auth.AccountInfo()
|
||||
if accountType == "api_key" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
|
||||
return auth.NextRefreshAfter, true
|
||||
}
|
||||
|
||||
if evaluator, ok := auth.Runtime.(RefreshEvaluator); ok && evaluator != nil {
|
||||
if interval <= 0 {
|
||||
interval = refreshCheckInterval
|
||||
}
|
||||
return now.Add(interval), true
|
||||
}
|
||||
|
||||
lastRefresh := auth.LastRefreshedAt
|
||||
if lastRefresh.IsZero() {
|
||||
if ts, ok := authLastRefreshTimestamp(auth); ok {
|
||||
lastRefresh = ts
|
||||
}
|
||||
}
|
||||
|
||||
expiry, hasExpiry := auth.ExpirationTime()
|
||||
|
||||
if pref := authPreferredInterval(auth); pref > 0 {
|
||||
candidates := make([]time.Time, 0, 2)
|
||||
if hasExpiry && !expiry.IsZero() {
|
||||
if !expiry.After(now) || expiry.Sub(now) <= pref {
|
||||
return now, true
|
||||
}
|
||||
candidates = append(candidates, expiry.Add(-pref))
|
||||
}
|
||||
if lastRefresh.IsZero() {
|
||||
return now, true
|
||||
}
|
||||
candidates = append(candidates, lastRefresh.Add(pref))
|
||||
next := candidates[0]
|
||||
for _, candidate := range candidates[1:] {
|
||||
if candidate.Before(next) {
|
||||
next = candidate
|
||||
}
|
||||
}
|
||||
if !next.After(now) {
|
||||
return now, true
|
||||
}
|
||||
return next, true
|
||||
}
|
||||
|
||||
provider := strings.ToLower(auth.Provider)
|
||||
lead := ProviderRefreshLead(provider, auth.Runtime)
|
||||
if lead == nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
if hasExpiry && !expiry.IsZero() {
|
||||
dueAt := expiry.Add(-*lead)
|
||||
if !dueAt.After(now) {
|
||||
return now, true
|
||||
}
|
||||
return dueAt, true
|
||||
}
|
||||
if !lastRefresh.IsZero() {
|
||||
dueAt := lastRefresh.Add(*lead)
|
||||
if !dueAt.After(now) {
|
||||
return now, true
|
||||
}
|
||||
return dueAt, true
|
||||
}
|
||||
return now, true
|
||||
}
|
||||
|
||||
type refreshHeapItem struct {
|
||||
id string
|
||||
next time.Time
|
||||
index int
|
||||
}
|
||||
|
||||
type refreshMinHeap []*refreshHeapItem
|
||||
|
||||
func (h refreshMinHeap) Len() int { return len(h) }
|
||||
|
||||
func (h refreshMinHeap) Less(i, j int) bool {
|
||||
return h[i].next.Before(h[j].next)
|
||||
}
|
||||
|
||||
func (h refreshMinHeap) Swap(i, j int) {
|
||||
h[i], h[j] = h[j], h[i]
|
||||
h[i].index = i
|
||||
h[j].index = j
|
||||
}
|
||||
|
||||
func (h *refreshMinHeap) Push(x any) {
|
||||
item, ok := x.(*refreshHeapItem)
|
||||
if !ok || item == nil {
|
||||
return
|
||||
}
|
||||
item.index = len(*h)
|
||||
*h = append(*h, item)
|
||||
}
|
||||
|
||||
func (h *refreshMinHeap) Pop() any {
|
||||
old := *h
|
||||
n := len(old)
|
||||
if n == 0 {
|
||||
return (*refreshHeapItem)(nil)
|
||||
}
|
||||
item := old[n-1]
|
||||
item.index = -1
|
||||
*h = old[:n-1]
|
||||
return item
|
||||
}
|
||||
137
sdk/cliproxy/auth/auto_refresh_loop_test.go
Normal file
137
sdk/cliproxy/auth/auto_refresh_loop_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testRefreshEvaluator struct{}
|
||||
|
||||
func (testRefreshEvaluator) ShouldRefresh(time.Time, *Auth) bool { return false }
|
||||
|
||||
func setRefreshLeadFactory(t *testing.T, provider string, factory func() *time.Duration) {
|
||||
t.Helper()
|
||||
key := strings.ToLower(strings.TrimSpace(provider))
|
||||
refreshLeadMu.Lock()
|
||||
prev, hadPrev := refreshLeadFactories[key]
|
||||
if factory == nil {
|
||||
delete(refreshLeadFactories, key)
|
||||
} else {
|
||||
refreshLeadFactories[key] = factory
|
||||
}
|
||||
refreshLeadMu.Unlock()
|
||||
t.Cleanup(func() {
|
||||
refreshLeadMu.Lock()
|
||||
if hadPrev {
|
||||
refreshLeadFactories[key] = prev
|
||||
} else {
|
||||
delete(refreshLeadFactories, key)
|
||||
}
|
||||
refreshLeadMu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func TestNextRefreshCheckAt_DisabledUnschedule(t *testing.T) {
|
||||
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||
auth := &Auth{ID: "a1", Provider: "test", Disabled: true}
|
||||
if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok {
|
||||
t.Fatalf("nextRefreshCheckAt() ok = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextRefreshCheckAt_APIKeyUnschedule(t *testing.T) {
|
||||
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||
auth := &Auth{ID: "a1", Provider: "test", Attributes: map[string]string{"api_key": "k"}}
|
||||
if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok {
|
||||
t.Fatalf("nextRefreshCheckAt() ok = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextRefreshCheckAt_NextRefreshAfterGate(t *testing.T) {
|
||||
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||
nextAfter := now.Add(30 * time.Minute)
|
||||
auth := &Auth{
|
||||
ID: "a1",
|
||||
Provider: "test",
|
||||
NextRefreshAfter: nextAfter,
|
||||
Metadata: map[string]any{"email": "x@example.com"},
|
||||
}
|
||||
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
|
||||
if !ok {
|
||||
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
|
||||
}
|
||||
if !got.Equal(nextAfter) {
|
||||
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, nextAfter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextRefreshCheckAt_PreferredInterval_PicksEarliestCandidate(t *testing.T) {
|
||||
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||
expiry := now.Add(20 * time.Minute)
|
||||
auth := &Auth{
|
||||
ID: "a1",
|
||||
Provider: "test",
|
||||
LastRefreshedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"email": "x@example.com",
|
||||
"expires_at": expiry.Format(time.RFC3339),
|
||||
"refresh_interval_seconds": 900, // 15m
|
||||
},
|
||||
}
|
||||
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
|
||||
if !ok {
|
||||
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
|
||||
}
|
||||
want := expiry.Add(-15 * time.Minute)
|
||||
if !got.Equal(want) {
|
||||
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextRefreshCheckAt_ProviderLead_Expiry(t *testing.T) {
|
||||
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||
expiry := now.Add(time.Hour)
|
||||
lead := 10 * time.Minute
|
||||
setRefreshLeadFactory(t, "provider-lead-expiry", func() *time.Duration {
|
||||
d := lead
|
||||
return &d
|
||||
})
|
||||
|
||||
auth := &Auth{
|
||||
ID: "a1",
|
||||
Provider: "provider-lead-expiry",
|
||||
Metadata: map[string]any{
|
||||
"email": "x@example.com",
|
||||
"expires_at": expiry.Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
|
||||
if !ok {
|
||||
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
|
||||
}
|
||||
want := expiry.Add(-lead)
|
||||
if !got.Equal(want) {
|
||||
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextRefreshCheckAt_RefreshEvaluatorFallback(t *testing.T) {
|
||||
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||
interval := 15 * time.Minute
|
||||
auth := &Auth{
|
||||
ID: "a1",
|
||||
Provider: "test",
|
||||
Metadata: map[string]any{"email": "x@example.com"},
|
||||
Runtime: testRefreshEvaluator{},
|
||||
}
|
||||
got, ok := nextRefreshCheckAt(now, auth, interval)
|
||||
if !ok {
|
||||
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
|
||||
}
|
||||
want := now.Add(interval)
|
||||
if !got.Equal(want) {
|
||||
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
@@ -105,6 +105,13 @@ type Selector interface {
|
||||
Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error)
|
||||
}
|
||||
|
||||
// StoppableSelector is an optional interface for selectors that hold resources.
|
||||
// Selectors that implement this interface will have Stop called during shutdown.
|
||||
type StoppableSelector interface {
|
||||
Selector
|
||||
Stop()
|
||||
}
|
||||
|
||||
// Hook captures lifecycle callbacks for observing auth changes.
|
||||
type Hook interface {
|
||||
// OnAuthRegistered fires when a new auth is registered.
|
||||
@@ -162,8 +169,8 @@ type Manager struct {
|
||||
rtProvider RoundTripperProvider
|
||||
|
||||
// Auto refresh state
|
||||
refreshCancel context.CancelFunc
|
||||
refreshSemaphore chan struct{}
|
||||
refreshCancel context.CancelFunc
|
||||
refreshLoop *authAutoRefreshLoop
|
||||
}
|
||||
|
||||
// NewManager constructs a manager with optional custom selector and hook.
|
||||
@@ -182,7 +189,6 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
||||
auths: make(map[string]*Auth),
|
||||
providerOffsets: make(map[string]int),
|
||||
modelPoolOffsets: make(map[string]int),
|
||||
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
|
||||
}
|
||||
// atomic.Value requires non-nil initial value.
|
||||
manager.runtimeConfig.Store(&internalconfig.Config{})
|
||||
@@ -214,6 +220,16 @@ func (m *Manager) syncScheduler() {
|
||||
m.syncSchedulerFromSnapshot(m.snapshotAuths())
|
||||
}
|
||||
|
||||
func (m *Manager) snapshotAuths() []*Auth {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
out := make([]*Auth, 0, len(m.auths))
|
||||
for _, a := range m.auths {
|
||||
out = append(out, a.Clone())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
|
||||
// supportedModelSet is rebuilt from the current global model registry state.
|
||||
// This must be called after models have been registered for a newly added auth,
|
||||
@@ -1088,6 +1104,7 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.upsertAuth(authClone)
|
||||
}
|
||||
m.queueRefreshReschedule(auth.ID)
|
||||
_ = m.persist(ctx, auth)
|
||||
m.hook.OnAuthRegistered(ctx, auth.Clone())
|
||||
return auth.Clone(), nil
|
||||
@@ -1118,6 +1135,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.upsertAuth(authClone)
|
||||
}
|
||||
m.queueRefreshReschedule(auth.ID)
|
||||
_ = m.persist(ctx, auth)
|
||||
m.hook.OnAuthUpdated(ctx, auth.Clone())
|
||||
return auth.Clone(), nil
|
||||
@@ -2890,80 +2908,60 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio
|
||||
if interval <= 0 {
|
||||
interval = refreshCheckInterval
|
||||
}
|
||||
if m.refreshCancel != nil {
|
||||
m.refreshCancel()
|
||||
m.refreshCancel = nil
|
||||
|
||||
m.mu.Lock()
|
||||
cancelPrev := m.refreshCancel
|
||||
m.refreshCancel = nil
|
||||
m.refreshLoop = nil
|
||||
m.mu.Unlock()
|
||||
if cancelPrev != nil {
|
||||
cancelPrev()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(parent)
|
||||
m.refreshCancel = cancel
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
m.checkRefreshes(ctx)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.checkRefreshes(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancelCtx := context.WithCancel(parent)
|
||||
workers := refreshMaxConcurrency
|
||||
if cfg, ok := m.runtimeConfig.Load().(*internalconfig.Config); ok && cfg != nil && cfg.AuthAutoRefreshWorkers > 0 {
|
||||
workers = cfg.AuthAutoRefreshWorkers
|
||||
}
|
||||
loop := newAuthAutoRefreshLoop(m, interval, workers)
|
||||
|
||||
m.mu.Lock()
|
||||
m.refreshCancel = cancelCtx
|
||||
m.refreshLoop = loop
|
||||
m.mu.Unlock()
|
||||
|
||||
loop.rebuild(time.Now())
|
||||
go loop.run(ctx)
|
||||
}
|
||||
|
||||
// StopAutoRefresh cancels the background refresh loop, if running.
|
||||
// It also stops the selector if it implements StoppableSelector.
|
||||
func (m *Manager) StopAutoRefresh() {
|
||||
if m.refreshCancel != nil {
|
||||
m.refreshCancel()
|
||||
m.refreshCancel = nil
|
||||
m.mu.Lock()
|
||||
cancel := m.refreshCancel
|
||||
m.refreshCancel = nil
|
||||
m.refreshLoop = nil
|
||||
m.mu.Unlock()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
// Stop selector if it implements StoppableSelector (e.g., SessionAffinitySelector)
|
||||
if stoppable, ok := m.selector.(StoppableSelector); ok {
|
||||
stoppable.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) checkRefreshes(ctx context.Context) {
|
||||
// log.Debugf("checking refreshes")
|
||||
now := time.Now()
|
||||
snapshot := m.snapshotAuths()
|
||||
for _, a := range snapshot {
|
||||
typ, _ := a.AccountInfo()
|
||||
if typ != "api_key" {
|
||||
if !m.shouldRefresh(a, now) {
|
||||
continue
|
||||
}
|
||||
log.Debugf("checking refresh for %s, %s, %s", a.Provider, a.ID, typ)
|
||||
|
||||
if exec := m.executorFor(a.Provider); exec == nil {
|
||||
continue
|
||||
}
|
||||
if !m.markRefreshPending(a.ID, now) {
|
||||
continue
|
||||
}
|
||||
go m.refreshAuthWithLimit(ctx, a.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) refreshAuthWithLimit(ctx context.Context, id string) {
|
||||
if m.refreshSemaphore == nil {
|
||||
m.refreshAuth(ctx, id)
|
||||
func (m *Manager) queueRefreshReschedule(authID string) {
|
||||
if m == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case m.refreshSemaphore <- struct{}{}:
|
||||
defer func() { <-m.refreshSemaphore }()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
m.refreshAuth(ctx, id)
|
||||
}
|
||||
|
||||
func (m *Manager) snapshotAuths() []*Auth {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
out := make([]*Auth, 0, len(m.auths))
|
||||
for _, a := range m.auths {
|
||||
out = append(out, a.Clone())
|
||||
loop := m.refreshLoop
|
||||
m.mu.RUnlock()
|
||||
if loop == nil {
|
||||
return
|
||||
}
|
||||
return out
|
||||
loop.queueReschedule(authID)
|
||||
}
|
||||
|
||||
func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool {
|
||||
@@ -3173,16 +3171,20 @@ func lookupMetadataTime(meta map[string]any, keys ...string) (time.Time, bool) {
|
||||
|
||||
func (m *Manager) markRefreshPending(id string, now time.Time) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
auth, ok := m.auths[id]
|
||||
if !ok || auth == nil || auth.Disabled {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
auth.NextRefreshAfter = now.Add(refreshPendingBackoff)
|
||||
m.auths[id] = auth
|
||||
m.mu.Unlock()
|
||||
|
||||
m.queueRefreshReschedule(id)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -3209,16 +3211,21 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err)
|
||||
now := time.Now()
|
||||
if err != nil {
|
||||
shouldReschedule := false
|
||||
m.mu.Lock()
|
||||
if current := m.auths[id]; current != nil {
|
||||
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
|
||||
current.LastError = &Error{Message: err.Error()}
|
||||
m.auths[id] = current
|
||||
shouldReschedule = true
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.upsertAuth(current.Clone())
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
if shouldReschedule {
|
||||
m.queueRefreshReschedule(id)
|
||||
}
|
||||
return
|
||||
}
|
||||
if updated == nil {
|
||||
|
||||
@@ -69,18 +69,18 @@ func TestManager_ShouldRetryAfterError_UsesOAuthModelAliasForCooldown(t *testing
|
||||
m := NewManager(nil, nil, nil)
|
||||
m.SetRetryConfig(3, 30*time.Second, 0)
|
||||
m.SetOAuthModelAlias(map[string][]internalconfig.OAuthModelAlias{
|
||||
"qwen": {
|
||||
{Name: "qwen3.6-plus", Alias: "coder-model"},
|
||||
"iflow": {
|
||||
{Name: "deepseek-v3.1", Alias: "pool-model"},
|
||||
},
|
||||
})
|
||||
|
||||
routeModel := "coder-model"
|
||||
upstreamModel := "qwen3.6-plus"
|
||||
routeModel := "pool-model"
|
||||
upstreamModel := "deepseek-v3.1"
|
||||
next := time.Now().Add(5 * time.Second)
|
||||
|
||||
auth := &Auth{
|
||||
ID: "auth-1",
|
||||
Provider: "qwen",
|
||||
Provider: "iflow",
|
||||
ModelStates: map[string]*ModelState{
|
||||
upstreamModel: {
|
||||
Unavailable: true,
|
||||
@@ -99,7 +99,7 @@ func TestManager_ShouldRetryAfterError_UsesOAuthModelAliasForCooldown(t *testing
|
||||
}
|
||||
|
||||
_, _, maxWait := m.retrySettings()
|
||||
wait, shouldRetry := m.shouldRetryAfterError(&Error{HTTPStatus: 429, Message: "quota"}, 0, []string{"qwen"}, routeModel, maxWait)
|
||||
wait, shouldRetry := m.shouldRetryAfterError(&Error{HTTPStatus: 429, Message: "quota"}, 0, []string{"iflow"}, routeModel, maxWait)
|
||||
if !shouldRetry {
|
||||
t.Fatalf("expected shouldRetry=true, got false (wait=%v)", wait)
|
||||
}
|
||||
|
||||
@@ -265,7 +265,7 @@ func modelAliasChannel(auth *Auth) string {
|
||||
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
|
||||
// OAuth model alias (e.g., API key authentication).
|
||||
//
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot, kimi.
|
||||
func OAuthModelAliasChannel(provider, authKind string) string {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
authKind = strings.ToLower(strings.TrimSpace(authKind))
|
||||
@@ -289,7 +289,7 @@ func OAuthModelAliasChannel(provider, authKind string) string {
|
||||
return ""
|
||||
}
|
||||
return "codex"
|
||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kiro", "github-copilot", "kimi":
|
||||
case "gemini-cli", "aistudio", "antigravity", "iflow", "kiro", "github-copilot", "kimi":
|
||||
return provider
|
||||
default:
|
||||
return ""
|
||||
|
||||
@@ -184,8 +184,6 @@ func createAuthForChannel(channel string) *Auth {
|
||||
return &Auth{Provider: "aistudio"}
|
||||
case "antigravity":
|
||||
return &Auth{Provider: "antigravity"}
|
||||
case "qwen":
|
||||
return &Auth{Provider: "qwen"}
|
||||
case "iflow":
|
||||
return &Auth{Provider: "iflow"}
|
||||
case "kimi":
|
||||
|
||||
@@ -215,10 +215,10 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testi
|
||||
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
countErrors: map[string]error{"qwen3.5-plus": invalidErr},
|
||||
countErrors: map[string]error{"deepseek-v3.1": invalidErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -227,18 +227,18 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testi
|
||||
t.Fatalf("execute count error = %v, want %v", err, invalidErr)
|
||||
}
|
||||
got := executor.CountModels()
|
||||
if len(got) != 1 || got[0] != "qwen3.5-plus" {
|
||||
if len(got) != 1 || got[0] != "deepseek-v3.1" {
|
||||
t.Fatalf("count calls = %v, want only first invalid model", got)
|
||||
}
|
||||
}
|
||||
func TestResolveModelAliasPoolFromConfigModels(t *testing.T) {
|
||||
models := []modelAliasEntry{
|
||||
internalconfig.OpenAICompatibilityModel{Name: "qwen3.5-plus", Alias: "claude-opus-4.66"},
|
||||
internalconfig.OpenAICompatibilityModel{Name: "deepseek-v3.1", Alias: "claude-opus-4.66"},
|
||||
internalconfig.OpenAICompatibilityModel{Name: "glm-5", Alias: "claude-opus-4.66"},
|
||||
internalconfig.OpenAICompatibilityModel{Name: "kimi-k2.5", Alias: "claude-opus-4.66"},
|
||||
}
|
||||
got := resolveModelAliasPoolFromConfigModels("claude-opus-4.66(8192)", models)
|
||||
want := []string{"qwen3.5-plus(8192)", "glm-5(8192)", "kimi-k2.5(8192)"}
|
||||
want := []string{"deepseek-v3.1(8192)", "glm-5(8192)", "kimi-k2.5(8192)"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("pool len = %d, want %d (%v)", len(got), len(want), got)
|
||||
}
|
||||
@@ -253,7 +253,7 @@ func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
executor := &openAICompatPoolExecutor{id: "pool"}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -268,7 +268,7 @@ func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
|
||||
}
|
||||
|
||||
got := executor.ExecuteModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5", "qwen3.5-plus"}
|
||||
want := []string{"deepseek-v3.1", "glm-5", "deepseek-v3.1"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
@@ -284,10 +284,10 @@ func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) {
|
||||
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
executeErrors: map[string]error{"qwen3.5-plus": invalidErr},
|
||||
executeErrors: map[string]error{"deepseek-v3.1": invalidErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -296,7 +296,7 @@ func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) {
|
||||
t.Fatalf("execute error = %v, want %v", err, invalidErr)
|
||||
}
|
||||
got := executor.ExecuteModels()
|
||||
if len(got) != 1 || got[0] != "qwen3.5-plus" {
|
||||
if len(got) != 1 || got[0] != "deepseek-v3.1" {
|
||||
t.Fatalf("execute calls = %v, want only first invalid model", got)
|
||||
}
|
||||
}
|
||||
@@ -309,10 +309,10 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
executeErrors: map[string]error{"deepseek-v3.1": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -324,7 +324,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t
|
||||
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
|
||||
}
|
||||
got := executor.ExecuteModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
want := []string{"deepseek-v3.1", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
@@ -338,7 +338,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to remain registered")
|
||||
}
|
||||
state := updated.ModelStates["qwen3.5-plus"]
|
||||
state := updated.ModelStates["deepseek-v3.1"]
|
||||
if state == nil {
|
||||
t.Fatalf("expected suspended upstream model state")
|
||||
}
|
||||
@@ -355,10 +355,10 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportUnprocessabl
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
executeErrors: map[string]error{"deepseek-v3.1": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -370,7 +370,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportUnprocessabl
|
||||
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
|
||||
}
|
||||
got := executor.ExecuteModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
want := []string{"deepseek-v3.1", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
@@ -385,10 +385,10 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.
|
||||
alias := "claude-opus-4.66"
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
executeErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
|
||||
executeErrors: map[string]error{"deepseek-v3.1": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -400,7 +400,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.
|
||||
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
|
||||
}
|
||||
got := executor.ExecuteModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
want := []string{"deepseek-v3.1", "glm-5"}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
|
||||
@@ -413,11 +413,11 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *te
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
streamPayloads: map[string][]cliproxyexecutor.StreamChunk{
|
||||
"qwen3.5-plus": {},
|
||||
"deepseek-v3.1": {},
|
||||
},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -436,7 +436,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *te
|
||||
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
|
||||
}
|
||||
got := executor.StreamModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
want := []string{"deepseek-v3.1", "glm-5"}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
|
||||
@@ -448,10 +448,10 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *t
|
||||
alias := "claude-opus-4.66"
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
streamFirstErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
|
||||
streamFirstErrors: map[string]error{"deepseek-v3.1": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -470,7 +470,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *t
|
||||
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
|
||||
}
|
||||
got := executor.StreamModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
want := []string{"deepseek-v3.1", "glm-5"}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
|
||||
@@ -486,10 +486,10 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *test
|
||||
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
|
||||
streamFirstErrors: map[string]error{"deepseek-v3.1": invalidErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -498,7 +498,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *test
|
||||
t.Fatalf("execute stream error = %v, want %v", err, invalidErr)
|
||||
}
|
||||
got := executor.StreamModels()
|
||||
if len(got) != 1 || got[0] != "qwen3.5-plus" {
|
||||
if len(got) != 1 || got[0] != "deepseek-v3.1" {
|
||||
t.Fatalf("stream calls = %v, want only first invalid model", got)
|
||||
}
|
||||
}
|
||||
@@ -511,10 +511,10 @@ func TestManagerExecute_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterReques
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
executeErrors: map[string]error{"deepseek-v3.1": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -529,7 +529,7 @@ func TestManagerExecute_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterReques
|
||||
}
|
||||
|
||||
got := executor.ExecuteModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"}
|
||||
want := []string{"deepseek-v3.1", "glm-5", "glm-5", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
@@ -548,10 +548,10 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLater
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
streamFirstErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
streamFirstErrors: map[string]error{"deepseek-v3.1": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -569,7 +569,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLater
|
||||
}
|
||||
|
||||
got := executor.StreamModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"}
|
||||
want := []string{"deepseek-v3.1", "glm-5", "glm-5", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("stream calls = %v, want %v", got, want)
|
||||
}
|
||||
@@ -584,7 +584,7 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T
|
||||
alias := "claude-opus-4.66"
|
||||
executor := &openAICompatPoolExecutor{id: "pool"}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -599,7 +599,7 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T
|
||||
}
|
||||
|
||||
got := executor.CountModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
want := []string{"deepseek-v3.1", "glm-5"}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i])
|
||||
@@ -615,10 +615,10 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterR
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
countErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
countErrors: map[string]error{"deepseek-v3.1": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -633,7 +633,7 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterR
|
||||
}
|
||||
|
||||
got := executor.CountModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"}
|
||||
want := []string{"deepseek-v3.1", "glm-5", "glm-5", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("count calls = %v, want %v", got, want)
|
||||
}
|
||||
@@ -650,7 +650,7 @@ func TestManagerExecute_OpenAICompatAliasPoolBlockedAuthDoesNotConsumeRetryBudge
|
||||
OpenAICompatibility: []internalconfig.OpenAICompatibility{{
|
||||
Name: "pool",
|
||||
Models: []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
},
|
||||
}},
|
||||
@@ -701,7 +701,7 @@ func TestManagerExecute_OpenAICompatAliasPoolBlockedAuthDoesNotConsumeRetryBudge
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "invalid_request_error: The requested model is not supported.",
|
||||
}
|
||||
for _, upstreamModel := range []string{"qwen3.5-plus", "glm-5"} {
|
||||
for _, upstreamModel := range []string{"deepseek-v3.1", "glm-5"} {
|
||||
m.MarkResult(context.Background(), Result{
|
||||
AuthID: badAuth.ID,
|
||||
Provider: "pool",
|
||||
@@ -733,10 +733,10 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *te
|
||||
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
|
||||
streamFirstErrors: map[string]error{"deepseek-v3.1": invalidErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "deepseek-v3.1", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
@@ -750,7 +750,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *te
|
||||
if streamResult != nil {
|
||||
t.Fatalf("streamResult = %#v, want nil on invalid bootstrap", streamResult)
|
||||
}
|
||||
if got := executor.StreamModels(); len(got) != 1 || got[0] != "qwen3.5-plus" {
|
||||
if got := executor.StreamModels(); len(got) != 1 || got[0] != "deepseek-v3.1" {
|
||||
t.Fatalf("stream calls = %v, want only first upstream model", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,15 +4,21 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
@@ -420,3 +426,448 @@ func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, block
|
||||
}
|
||||
return false, blockReasonNone, time.Time{}
|
||||
}
|
||||
|
||||
// sessionPattern matches Claude Code user_id format:
|
||||
// user_{hash}_account__session_{uuid}
|
||||
var sessionPattern = regexp.MustCompile(`_session_([a-f0-9-]+)$`)
|
||||
|
||||
// SessionAffinitySelector wraps another selector with session-sticky behavior.
|
||||
// It extracts session ID from multiple sources and maintains session-to-auth
|
||||
// mappings with automatic failover when the bound auth becomes unavailable.
|
||||
type SessionAffinitySelector struct {
|
||||
fallback Selector
|
||||
cache *SessionCache
|
||||
}
|
||||
|
||||
// SessionAffinityConfig configures the session affinity selector.
|
||||
type SessionAffinityConfig struct {
|
||||
Fallback Selector
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// NewSessionAffinitySelector creates a new session-aware selector.
|
||||
func NewSessionAffinitySelector(fallback Selector) *SessionAffinitySelector {
|
||||
return NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||
Fallback: fallback,
|
||||
TTL: time.Hour,
|
||||
})
|
||||
}
|
||||
|
||||
// NewSessionAffinitySelectorWithConfig creates a selector with custom configuration.
|
||||
func NewSessionAffinitySelectorWithConfig(cfg SessionAffinityConfig) *SessionAffinitySelector {
|
||||
if cfg.Fallback == nil {
|
||||
cfg.Fallback = &RoundRobinSelector{}
|
||||
}
|
||||
if cfg.TTL <= 0 {
|
||||
cfg.TTL = time.Hour
|
||||
}
|
||||
return &SessionAffinitySelector{
|
||||
fallback: cfg.Fallback,
|
||||
cache: NewSessionCache(cfg.TTL),
|
||||
}
|
||||
}
|
||||
|
||||
// Pick selects an auth with session affinity when possible.
|
||||
// Priority for session ID extraction:
|
||||
// 1. metadata.user_id (Claude Code format) - highest priority
|
||||
// 2. X-Session-ID header
|
||||
// 3. metadata.user_id (non-Claude Code format)
|
||||
// 4. conversation_id field
|
||||
// 5. Hash-based fallback from messages
|
||||
//
|
||||
// Note: The cache key includes provider, session ID, and model to handle cases where
|
||||
// a session uses multiple models (e.g., gemini-2.5-pro and gemini-3-flash-preview)
|
||||
// that may be supported by different auth credentials, and to avoid cross-provider conflicts.
|
||||
func (s *SessionAffinitySelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
entry := selectorLogEntry(ctx)
|
||||
primaryID, fallbackID := extractSessionIDs(opts.Headers, opts.OriginalRequest, opts.Metadata)
|
||||
if primaryID == "" {
|
||||
entry.Debugf("session-affinity: no session ID extracted, falling back to default selector | provider=%s model=%s", provider, model)
|
||||
return s.fallback.Pick(ctx, provider, model, opts, auths)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
available, err := getAvailableAuths(auths, provider, model, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cacheKey := provider + "::" + primaryID + "::" + model
|
||||
|
||||
if cachedAuthID, ok := s.cache.GetAndRefresh(cacheKey); ok {
|
||||
for _, auth := range available {
|
||||
if auth.ID == cachedAuthID {
|
||||
entry.Infof("session-affinity: cache hit | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
|
||||
return auth, nil
|
||||
}
|
||||
}
|
||||
// Cached auth not available, reselect via fallback selector for even distribution
|
||||
auth, err := s.fallback.Pick(ctx, provider, model, opts, auths)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.cache.Set(cacheKey, auth.ID)
|
||||
entry.Infof("session-affinity: cache hit but auth unavailable, reselected | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
if fallbackID != "" && fallbackID != primaryID {
|
||||
fallbackKey := provider + "::" + fallbackID + "::" + model
|
||||
if cachedAuthID, ok := s.cache.Get(fallbackKey); ok {
|
||||
for _, auth := range available {
|
||||
if auth.ID == cachedAuthID {
|
||||
s.cache.Set(cacheKey, auth.ID)
|
||||
entry.Infof("session-affinity: fallback cache hit | session=%s fallback=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), truncateSessionID(fallbackID), auth.ID, provider, model)
|
||||
return auth, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auth, err := s.fallback.Pick(ctx, provider, model, opts, auths)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.cache.Set(cacheKey, auth.ID)
|
||||
entry.Infof("session-affinity: cache miss, new binding | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func selectorLogEntry(ctx context.Context) *log.Entry {
|
||||
if ctx == nil {
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
if reqID := logging.GetRequestID(ctx); reqID != "" {
|
||||
return log.WithField("request_id", reqID)
|
||||
}
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
|
||||
// truncateSessionID shortens session ID for logging (first 8 chars + "...")
|
||||
func truncateSessionID(id string) string {
|
||||
if len(id) <= 20 {
|
||||
return id
|
||||
}
|
||||
return id[:8] + "..."
|
||||
}
|
||||
|
||||
// Stop releases resources held by the selector.
|
||||
func (s *SessionAffinitySelector) Stop() {
|
||||
if s.cache != nil {
|
||||
s.cache.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateAuth removes all session bindings for a specific auth.
|
||||
// Called when an auth becomes rate-limited or unavailable.
|
||||
func (s *SessionAffinitySelector) InvalidateAuth(authID string) {
|
||||
if s.cache != nil {
|
||||
s.cache.InvalidateAuth(authID)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractSessionID extracts session identifier from multiple sources.
|
||||
// Priority order:
|
||||
// 1. metadata.user_id (Claude Code format with _session_{uuid}) - highest priority for Claude Code clients
|
||||
// 2. X-Session-ID header
|
||||
// 3. metadata.user_id (non-Claude Code format)
|
||||
// 4. conversation_id field in request body
|
||||
// 5. Stable hash from first few messages content (fallback)
|
||||
func ExtractSessionID(headers http.Header, payload []byte, metadata map[string]any) string {
|
||||
primary, _ := extractSessionIDs(headers, payload, metadata)
|
||||
return primary
|
||||
}
|
||||
|
||||
// extractSessionIDs returns (primaryID, fallbackID) for session affinity.
|
||||
// primaryID: full hash including assistant response (stable after first turn)
|
||||
// fallbackID: short hash without assistant (used to inherit binding from first turn)
|
||||
func extractSessionIDs(headers http.Header, payload []byte, metadata map[string]any) (string, string) {
|
||||
// 1. metadata.user_id with Claude Code session format (highest priority)
|
||||
if len(payload) > 0 {
|
||||
userID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||
if userID != "" {
|
||||
// Old format: user_{hash}_account__session_{uuid}
|
||||
if matches := sessionPattern.FindStringSubmatch(userID); len(matches) >= 2 {
|
||||
id := "claude:" + matches[1]
|
||||
return id, ""
|
||||
}
|
||||
// New format: JSON object with session_id field
|
||||
// e.g. {"device_id":"...","account_uuid":"...","session_id":"uuid"}
|
||||
if len(userID) > 0 && userID[0] == '{' {
|
||||
if sid := gjson.Get(userID, "session_id").String(); sid != "" {
|
||||
return "claude:" + sid, ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. X-Session-ID header
|
||||
if headers != nil {
|
||||
if sid := headers.Get("X-Session-ID"); sid != "" {
|
||||
return "header:" + sid, ""
|
||||
}
|
||||
}
|
||||
|
||||
if len(payload) == 0 {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// 3. metadata.user_id (non-Claude Code format)
|
||||
userID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||
if userID != "" {
|
||||
return "user:" + userID, ""
|
||||
}
|
||||
|
||||
// 4. conversation_id field
|
||||
if convID := gjson.GetBytes(payload, "conversation_id").String(); convID != "" {
|
||||
return "conv:" + convID, ""
|
||||
}
|
||||
|
||||
// 5. Hash-based fallback from message content
|
||||
return extractMessageHashIDs(payload)
|
||||
}
|
||||
|
||||
func extractMessageHashIDs(payload []byte) (primaryID, fallbackID string) {
|
||||
var systemPrompt, firstUserMsg, firstAssistantMsg string
|
||||
|
||||
// OpenAI/Claude messages format
|
||||
messages := gjson.GetBytes(payload, "messages")
|
||||
if messages.Exists() && messages.IsArray() {
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
role := msg.Get("role").String()
|
||||
content := extractMessageContent(msg.Get("content"))
|
||||
if content == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
switch role {
|
||||
case "system":
|
||||
if systemPrompt == "" {
|
||||
systemPrompt = truncateString(content, 100)
|
||||
}
|
||||
case "user":
|
||||
if firstUserMsg == "" {
|
||||
firstUserMsg = truncateString(content, 100)
|
||||
}
|
||||
case "assistant":
|
||||
if firstAssistantMsg == "" {
|
||||
firstAssistantMsg = truncateString(content, 100)
|
||||
}
|
||||
}
|
||||
|
||||
if systemPrompt != "" && firstUserMsg != "" && firstAssistantMsg != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Claude API: top-level "system" field (array or string)
|
||||
if systemPrompt == "" {
|
||||
topSystem := gjson.GetBytes(payload, "system")
|
||||
if topSystem.Exists() {
|
||||
if topSystem.IsArray() {
|
||||
topSystem.ForEach(func(_, part gjson.Result) bool {
|
||||
if text := part.Get("text").String(); text != "" && systemPrompt == "" {
|
||||
systemPrompt = truncateString(text, 100)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
} else if topSystem.Type == gjson.String {
|
||||
systemPrompt = truncateString(topSystem.String(), 100)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Gemini format
|
||||
if systemPrompt == "" && firstUserMsg == "" {
|
||||
sysInstr := gjson.GetBytes(payload, "systemInstruction.parts")
|
||||
if sysInstr.Exists() && sysInstr.IsArray() {
|
||||
sysInstr.ForEach(func(_, part gjson.Result) bool {
|
||||
if text := part.Get("text").String(); text != "" && systemPrompt == "" {
|
||||
systemPrompt = truncateString(text, 100)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
contents := gjson.GetBytes(payload, "contents")
|
||||
if contents.Exists() && contents.IsArray() {
|
||||
contents.ForEach(func(_, msg gjson.Result) bool {
|
||||
role := msg.Get("role").String()
|
||||
msg.Get("parts").ForEach(func(_, part gjson.Result) bool {
|
||||
text := part.Get("text").String()
|
||||
if text == "" {
|
||||
return true
|
||||
}
|
||||
switch role {
|
||||
case "user":
|
||||
if firstUserMsg == "" {
|
||||
firstUserMsg = truncateString(text, 100)
|
||||
}
|
||||
case "model":
|
||||
if firstAssistantMsg == "" {
|
||||
firstAssistantMsg = truncateString(text, 100)
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
if firstUserMsg != "" && firstAssistantMsg != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI Responses API format (v1/responses)
|
||||
if systemPrompt == "" && firstUserMsg == "" {
|
||||
if instr := gjson.GetBytes(payload, "instructions").String(); instr != "" {
|
||||
systemPrompt = truncateString(instr, 100)
|
||||
}
|
||||
|
||||
input := gjson.GetBytes(payload, "input")
|
||||
if input.Exists() && input.IsArray() {
|
||||
input.ForEach(func(_, item gjson.Result) bool {
|
||||
itemType := item.Get("type").String()
|
||||
if itemType == "reasoning" {
|
||||
return true
|
||||
}
|
||||
// Skip non-message typed items (function_call, function_call_output, etc.)
|
||||
// but allow items with no type that have a role (inline message format).
|
||||
if itemType != "" && itemType != "message" {
|
||||
return true
|
||||
}
|
||||
|
||||
role := item.Get("role").String()
|
||||
if itemType == "" && role == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle both string content and array content (multimodal).
|
||||
content := item.Get("content")
|
||||
var text string
|
||||
if content.Type == gjson.String {
|
||||
text = content.String()
|
||||
} else {
|
||||
text = extractResponsesAPIContent(content)
|
||||
}
|
||||
if text == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
switch role {
|
||||
case "developer", "system":
|
||||
if systemPrompt == "" {
|
||||
systemPrompt = truncateString(text, 100)
|
||||
}
|
||||
case "user":
|
||||
if firstUserMsg == "" {
|
||||
firstUserMsg = truncateString(text, 100)
|
||||
}
|
||||
case "assistant":
|
||||
if firstAssistantMsg == "" {
|
||||
firstAssistantMsg = truncateString(text, 100)
|
||||
}
|
||||
}
|
||||
|
||||
if firstUserMsg != "" && firstAssistantMsg != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if systemPrompt == "" && firstUserMsg == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
shortHash := computeSessionHash(systemPrompt, firstUserMsg, "")
|
||||
if firstAssistantMsg == "" {
|
||||
return shortHash, ""
|
||||
}
|
||||
|
||||
fullHash := computeSessionHash(systemPrompt, firstUserMsg, firstAssistantMsg)
|
||||
return fullHash, shortHash
|
||||
}
|
||||
|
||||
func computeSessionHash(systemPrompt, userMsg, assistantMsg string) string {
|
||||
h := fnv.New64a()
|
||||
if systemPrompt != "" {
|
||||
h.Write([]byte("sys:" + systemPrompt + "\n"))
|
||||
}
|
||||
if userMsg != "" {
|
||||
h.Write([]byte("usr:" + userMsg + "\n"))
|
||||
}
|
||||
if assistantMsg != "" {
|
||||
h.Write([]byte("ast:" + assistantMsg + "\n"))
|
||||
}
|
||||
return fmt.Sprintf("msg:%016x", h.Sum64())
|
||||
}
|
||||
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) > maxLen {
|
||||
return s[:maxLen]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// extractMessageContent extracts text content from a message content field.
|
||||
// Handles both string content and array content (multimodal messages).
|
||||
// For array content, extracts text from all text-type elements.
|
||||
func extractMessageContent(content gjson.Result) string {
|
||||
// String content: "Hello world"
|
||||
if content.Type == gjson.String {
|
||||
return content.String()
|
||||
}
|
||||
|
||||
// Array content: [{"type":"text","text":"Hello"},{"type":"image",...}]
|
||||
if content.IsArray() {
|
||||
var texts []string
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
// Handle Claude format: {"type":"text","text":"content"}
|
||||
if part.Get("type").String() == "text" {
|
||||
if text := part.Get("text").String(); text != "" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
}
|
||||
// Handle OpenAI format: {"type":"text","text":"content"}
|
||||
// Same structure as Claude, already handled above
|
||||
return true
|
||||
})
|
||||
if len(texts) > 0 {
|
||||
return strings.Join(texts, " ")
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractResponsesAPIContent(content gjson.Result) string {
|
||||
if !content.IsArray() {
|
||||
return ""
|
||||
}
|
||||
var texts []string
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
if partType == "input_text" || partType == "output_text" || partType == "text" {
|
||||
if text := part.Get("text").String(); text != "" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if len(texts) > 0 {
|
||||
return strings.Join(texts, " ")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractSessionID is kept for backward compatibility.
|
||||
// Deprecated: Use ExtractSessionID instead.
|
||||
func extractSessionID(payload []byte) string {
|
||||
return ExtractSessionID(nil, payload, nil)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -458,6 +460,159 @@ func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSessionID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
payload string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "valid_claude_code_format",
|
||||
payload: `{"metadata":{"user_id":"user_3f221fe75652cf9a89a31647f16274bb8036a9b85ac4dc226a4df0efec8dc04d_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`,
|
||||
want: "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344",
|
||||
},
|
||||
{
|
||||
name: "json_user_id_with_session_id",
|
||||
payload: `{"metadata":{"user_id":"{\"device_id\":\"be82c3aee1e0c2d74535bacc85f9f559228f02dd8a17298cf522b71e6c375714\",\"account_uuid\":\"\",\"session_id\":\"e26d4046-0f88-4b09-bb5b-f863ab5fb24e\"}"}}`,
|
||||
want: "claude:e26d4046-0f88-4b09-bb5b-f863ab5fb24e",
|
||||
},
|
||||
{
|
||||
name: "json_user_id_without_session_id",
|
||||
payload: `{"metadata":{"user_id":"{\"device_id\":\"abc123\"}"}}`,
|
||||
want: `user:{"device_id":"abc123"}`,
|
||||
},
|
||||
{
|
||||
name: "no_session_but_user_id",
|
||||
payload: `{"metadata":{"user_id":"user_abc123"}}`,
|
||||
want: "user:user_abc123",
|
||||
},
|
||||
{
|
||||
name: "conversation_id",
|
||||
payload: `{"conversation_id":"conv-12345"}`,
|
||||
want: "conv:conv-12345",
|
||||
},
|
||||
{
|
||||
name: "no_metadata",
|
||||
payload: `{"model":"claude-3"}`,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty_payload",
|
||||
payload: ``,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractSessionID([]byte(tt.payload))
|
||||
if got != tt.want {
|
||||
t.Errorf("extractSessionID() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionAffinitySelector_SameSessionSameAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fallback := &RoundRobinSelector{}
|
||||
selector := NewSessionAffinitySelector(fallback)
|
||||
|
||||
auths := []*Auth{
|
||||
{ID: "auth-a"},
|
||||
{ID: "auth-b"},
|
||||
{ID: "auth-c"},
|
||||
}
|
||||
|
||||
// Use valid UUID format for session ID
|
||||
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`)
|
||||
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||
|
||||
// Same session should always pick the same auth
|
||||
first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() error = %v", err)
|
||||
}
|
||||
if first == nil {
|
||||
t.Fatalf("Pick() returned nil")
|
||||
}
|
||||
|
||||
// Verify consistency: same session, same auths -> same result
|
||||
for i := 0; i < 10; i++ {
|
||||
got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got.ID != first.ID {
|
||||
t.Fatalf("Pick() #%d auth.ID = %q, want %q (same session should pick same auth)", i, got.ID, first.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionAffinitySelector_NoSessionFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fallback := &FillFirstSelector{}
|
||||
selector := NewSessionAffinitySelector(fallback)
|
||||
|
||||
auths := []*Auth{
|
||||
{ID: "auth-b"},
|
||||
{ID: "auth-a"},
|
||||
{ID: "auth-c"},
|
||||
}
|
||||
|
||||
// No session in payload, should fallback to FillFirstSelector (picks "auth-a" after sorting)
|
||||
payload := []byte(`{"model":"claude-3"}`)
|
||||
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||
|
||||
got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() error = %v", err)
|
||||
}
|
||||
if got.ID != "auth-a" {
|
||||
t.Fatalf("Pick() auth.ID = %q, want %q (should fallback to FillFirst)", got.ID, "auth-a")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionAffinitySelector_DifferentSessionsDifferentAuths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fallback := &RoundRobinSelector{}
|
||||
selector := NewSessionAffinitySelector(fallback)
|
||||
|
||||
auths := []*Auth{
|
||||
{ID: "auth-a"},
|
||||
{ID: "auth-b"},
|
||||
{ID: "auth-c"},
|
||||
}
|
||||
|
||||
// Use valid UUID format for session IDs
|
||||
session1 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_11111111-1111-1111-1111-111111111111"}}`)
|
||||
session2 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_22222222-2222-2222-2222-222222222222"}}`)
|
||||
|
||||
opts1 := cliproxyexecutor.Options{OriginalRequest: session1}
|
||||
opts2 := cliproxyexecutor.Options{OriginalRequest: session2}
|
||||
|
||||
auth1, _ := selector.Pick(context.Background(), "claude", "claude-3", opts1, auths)
|
||||
auth2, _ := selector.Pick(context.Background(), "claude", "claude-3", opts2, auths)
|
||||
|
||||
// Different sessions may or may not pick different auths (depends on hash collision)
|
||||
// But each session should be consistent
|
||||
for i := 0; i < 5; i++ {
|
||||
got1, _ := selector.Pick(context.Background(), "claude", "claude-3", opts1, auths)
|
||||
got2, _ := selector.Pick(context.Background(), "claude", "claude-3", opts2, auths)
|
||||
if got1.ID != auth1.ID {
|
||||
t.Fatalf("session1 Pick() #%d inconsistent: got %q, want %q", i, got1.ID, auth1.ID)
|
||||
}
|
||||
if got2.ID != auth2.ID {
|
||||
t.Fatalf("session2 Pick() #%d inconsistent: got %q, want %q", i, got2.ID, auth2.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -494,6 +649,57 @@ func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionAffinitySelector_FailoverWhenAuthUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fallback := &RoundRobinSelector{}
|
||||
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||
Fallback: fallback,
|
||||
TTL: time.Minute,
|
||||
})
|
||||
defer selector.Stop()
|
||||
|
||||
auths := []*Auth{
|
||||
{ID: "auth-a"},
|
||||
{ID: "auth-b"},
|
||||
{ID: "auth-c"},
|
||||
}
|
||||
|
||||
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_failover-test-uuid"}}`)
|
||||
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||
|
||||
// First pick establishes binding
|
||||
first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() error = %v", err)
|
||||
}
|
||||
|
||||
// Remove the bound auth from available list (simulating rate limit)
|
||||
availableWithoutFirst := make([]*Auth, 0, len(auths)-1)
|
||||
for _, a := range auths {
|
||||
if a.ID != first.ID {
|
||||
availableWithoutFirst = append(availableWithoutFirst, a)
|
||||
}
|
||||
}
|
||||
|
||||
// With failover enabled, should pick a new auth
|
||||
second, err := selector.Pick(context.Background(), "claude", "claude-3", opts, availableWithoutFirst)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() after failover error = %v", err)
|
||||
}
|
||||
if second.ID == first.ID {
|
||||
t.Fatalf("Pick() after failover returned same auth %q, expected different", first.ID)
|
||||
}
|
||||
|
||||
// Subsequent picks should consistently return the new binding
|
||||
for i := 0; i < 5; i++ {
|
||||
got, _ := selector.Pick(context.Background(), "claude", "claude-3", opts, availableWithoutFirst)
|
||||
if got.ID != second.ID {
|
||||
t.Fatalf("Pick() #%d after failover inconsistent: got %q, want %q", i, got.ID, second.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -527,3 +733,629 @@ func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *test
|
||||
}
|
||||
}
|
||||
}
|
||||
func TestExtractSessionID_ClaudeCodePriorityOverHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Claude Code metadata.user_id should have highest priority, even when X-Session-ID header is present
|
||||
headers := make(http.Header)
|
||||
headers.Set("X-Session-ID", "header-session-id")
|
||||
|
||||
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`)
|
||||
|
||||
got := ExtractSessionID(headers, payload, nil)
|
||||
want := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344"
|
||||
if got != want {
|
||||
t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should have highest priority over header)", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSessionID_ClaudeCodePriorityOverIdempotencyKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Claude Code metadata.user_id should have highest priority, even when idempotency_key is present
|
||||
metadata := map[string]any{"idempotency_key": "idem-12345"}
|
||||
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`)
|
||||
|
||||
got := ExtractSessionID(nil, payload, metadata)
|
||||
want := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344"
|
||||
if got != want {
|
||||
t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should have highest priority over idempotency_key)", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSessionID_Headers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("X-Session-ID", "my-explicit-session")
|
||||
|
||||
got := ExtractSessionID(headers, nil, nil)
|
||||
want := "header:my-explicit-session"
|
||||
if got != want {
|
||||
t.Errorf("ExtractSessionID() with header = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractSessionID_IdempotencyKey verifies that idempotency_key is intentionally
|
||||
// ignored for session affinity (it's auto-generated per-request, causing cache misses).
|
||||
func TestExtractSessionID_IdempotencyKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
metadata := map[string]any{"idempotency_key": "idem-12345"}
|
||||
|
||||
got := ExtractSessionID(nil, nil, metadata)
|
||||
// idempotency_key is disabled - should return empty (no payload to hash)
|
||||
if got != "" {
|
||||
t.Errorf("ExtractSessionID() with idempotency_key = %q, want empty (idempotency_key is disabled)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSessionID_MessageHashFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// First request (user only) generates short hash
|
||||
firstRequestPayload := []byte(`{"messages":[{"role":"user","content":"Hello world"}]}`)
|
||||
shortHash := ExtractSessionID(nil, firstRequestPayload, nil)
|
||||
if shortHash == "" {
|
||||
t.Error("ExtractSessionID() first request should return short hash")
|
||||
}
|
||||
if !strings.HasPrefix(shortHash, "msg:") {
|
||||
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", shortHash)
|
||||
}
|
||||
|
||||
// Multi-turn with assistant generates full hash (different from short hash)
|
||||
multiTurnPayload := []byte(`{"messages":[
|
||||
{"role":"user","content":"Hello world"},
|
||||
{"role":"assistant","content":"Hi! How can I help?"},
|
||||
{"role":"user","content":"Tell me a joke"}
|
||||
]}`)
|
||||
fullHash := ExtractSessionID(nil, multiTurnPayload, nil)
|
||||
if fullHash == "" {
|
||||
t.Error("ExtractSessionID() multi-turn should return full hash")
|
||||
}
|
||||
if fullHash == shortHash {
|
||||
t.Error("Full hash should differ from short hash (includes assistant)")
|
||||
}
|
||||
|
||||
// Same multi-turn payload should produce same hash
|
||||
fullHash2 := ExtractSessionID(nil, multiTurnPayload, nil)
|
||||
if fullHash != fullHash2 {
|
||||
t.Errorf("ExtractSessionID() not stable: got %q then %q", fullHash, fullHash2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSessionID_ClaudeAPITopLevelSystem(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Claude API: system prompt in top-level "system" field (array format)
|
||||
arraySystem := []byte(`{
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"system": [{"type": "text", "text": "You are Claude Code"}]
|
||||
}`)
|
||||
got1 := ExtractSessionID(nil, arraySystem, nil)
|
||||
if got1 == "" || !strings.HasPrefix(got1, "msg:") {
|
||||
t.Errorf("ExtractSessionID() with array system = %q, want msg:* prefix", got1)
|
||||
}
|
||||
|
||||
// Claude API: system prompt in top-level "system" field (string format)
|
||||
stringSystem := []byte(`{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"system": "You are Claude Code"
|
||||
}`)
|
||||
got2 := ExtractSessionID(nil, stringSystem, nil)
|
||||
if got2 == "" || !strings.HasPrefix(got2, "msg:") {
|
||||
t.Errorf("ExtractSessionID() with string system = %q, want msg:* prefix", got2)
|
||||
}
|
||||
|
||||
// Multi-turn with top-level system should produce stable hash
|
||||
multiTurn := []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
{"role": "user", "content": "Help me"}
|
||||
],
|
||||
"system": "You are Claude Code"
|
||||
}`)
|
||||
got3 := ExtractSessionID(nil, multiTurn, nil)
|
||||
if got3 == "" {
|
||||
t.Error("ExtractSessionID() multi-turn with top-level system should return hash")
|
||||
}
|
||||
if got3 == got2 {
|
||||
t.Error("Multi-turn hash should differ from first-turn hash (includes assistant)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSessionID_GeminiFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Gemini format with systemInstruction and contents
|
||||
payload := []byte(`{
|
||||
"systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]},
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": "Hello Gemini"}]},
|
||||
{"role": "model", "parts": [{"text": "Hi there!"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
got := ExtractSessionID(nil, payload, nil)
|
||||
if got == "" {
|
||||
t.Error("ExtractSessionID() with Gemini format should return hash-based session ID")
|
||||
}
|
||||
if !strings.HasPrefix(got, "msg:") {
|
||||
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", got)
|
||||
}
|
||||
|
||||
// Same payload should produce same hash
|
||||
got2 := ExtractSessionID(nil, payload, nil)
|
||||
if got != got2 {
|
||||
t.Errorf("ExtractSessionID() not stable: got %q then %q", got, got2)
|
||||
}
|
||||
|
||||
// Different user message should produce different hash
|
||||
differentPayload := []byte(`{
|
||||
"systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]},
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": "Hello different"}]},
|
||||
{"role": "model", "parts": [{"text": "Hi there!"}]}
|
||||
]
|
||||
}`)
|
||||
got3 := ExtractSessionID(nil, differentPayload, nil)
|
||||
if got == got3 {
|
||||
t.Errorf("ExtractSessionID() should produce different hash for different user message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSessionID_OpenAIResponsesAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
firstTurn := []byte(`{
|
||||
"instructions": "You are Codex, based on GPT-5.",
|
||||
"input": [
|
||||
{"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]},
|
||||
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
got1 := ExtractSessionID(nil, firstTurn, nil)
|
||||
if got1 == "" {
|
||||
t.Error("ExtractSessionID() should return hash for OpenAI Responses API format")
|
||||
}
|
||||
if !strings.HasPrefix(got1, "msg:") {
|
||||
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", got1)
|
||||
}
|
||||
|
||||
secondTurn := []byte(`{
|
||||
"instructions": "You are Codex, based on GPT-5.",
|
||||
"input": [
|
||||
{"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]},
|
||||
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]},
|
||||
{"type": "reasoning", "summary": [{"type": "summary_text", "text": "thinking..."}], "encrypted_content": "xxx"},
|
||||
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hello!"}]},
|
||||
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what can you do"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
got2 := ExtractSessionID(nil, secondTurn, nil)
|
||||
if got2 == "" {
|
||||
t.Error("ExtractSessionID() should return hash for second turn")
|
||||
}
|
||||
|
||||
if got1 == got2 {
|
||||
t.Log("First turn and second turn have different hashes (expected: second includes assistant)")
|
||||
}
|
||||
|
||||
thirdTurn := []byte(`{
|
||||
"instructions": "You are Codex, based on GPT-5.",
|
||||
"input": [
|
||||
{"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]},
|
||||
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]},
|
||||
{"type": "reasoning", "summary": [{"type": "summary_text", "text": "thinking..."}], "encrypted_content": "xxx"},
|
||||
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hello!"}]},
|
||||
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what can you do"}]},
|
||||
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "I can help with..."}]},
|
||||
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "thanks"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
got3 := ExtractSessionID(nil, thirdTurn, nil)
|
||||
if got2 != got3 {
|
||||
t.Errorf("Second and third turn should have same hash (same first assistant): got %q vs %q", got2, got3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionAffinitySelector_ThreeScenarios(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fallback := &RoundRobinSelector{}
|
||||
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||
Fallback: fallback,
|
||||
TTL: time.Minute,
|
||||
})
|
||||
defer selector.Stop()
|
||||
|
||||
auths := []*Auth{{ID: "auth-a"}, {ID: "auth-b"}, {ID: "auth-c"}}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
scenario string
|
||||
payload []byte
|
||||
}{
|
||||
{
|
||||
name: "OpenAI_Scenario1_NewRequest",
|
||||
scenario: "new",
|
||||
payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"}]}`),
|
||||
},
|
||||
{
|
||||
name: "OpenAI_Scenario2_SecondTurn",
|
||||
scenario: "second",
|
||||
payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"Help me"}]}`),
|
||||
},
|
||||
{
|
||||
name: "OpenAI_Scenario3_ManyTurns",
|
||||
scenario: "many",
|
||||
payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"Help me"},{"role":"assistant","content":"Sure!"},{"role":"user","content":"Thanks"}]}`),
|
||||
},
|
||||
{
|
||||
name: "Gemini_Scenario1_NewRequest",
|
||||
scenario: "new",
|
||||
payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]}]}`),
|
||||
},
|
||||
{
|
||||
name: "Gemini_Scenario2_SecondTurn",
|
||||
scenario: "second",
|
||||
payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]},{"role":"model","parts":[{"text":"Hi!"}]},{"role":"user","parts":[{"text":"Help"}]}]}`),
|
||||
},
|
||||
{
|
||||
name: "Gemini_Scenario3_ManyTurns",
|
||||
scenario: "many",
|
||||
payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]},{"role":"model","parts":[{"text":"Hi!"}]},{"role":"user","parts":[{"text":"Help"}]},{"role":"model","parts":[{"text":"Sure!"}]},{"role":"user","parts":[{"text":"Thanks"}]}]}`),
|
||||
},
|
||||
{
|
||||
name: "Claude_Scenario1_NewRequest",
|
||||
scenario: "new",
|
||||
payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"}]}`),
|
||||
},
|
||||
{
|
||||
name: "Claude_Scenario2_SecondTurn",
|
||||
scenario: "second",
|
||||
payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"},{"role":"assistant","content":"Hello!"},{"role":"user","content":"Help me"}]}`),
|
||||
},
|
||||
{
|
||||
name: "Claude_Scenario3_ManyTurns",
|
||||
scenario: "many",
|
||||
payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"},{"role":"assistant","content":"Hello!"},{"role":"user","content":"Help"},{"role":"assistant","content":"Sure!"},{"role":"user","content":"Thanks"}]}`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
opts := cliproxyexecutor.Options{OriginalRequest: tc.payload}
|
||||
picked, err := selector.Pick(context.Background(), "provider", "model", opts, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() error = %v", err)
|
||||
}
|
||||
if picked == nil {
|
||||
t.Fatal("Pick() returned nil")
|
||||
}
|
||||
t.Logf("%s: picked %s", tc.name, picked.ID)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Scenario2And3_SameAuth", func(t *testing.T) {
|
||||
openaiS2 := []byte(`{"messages":[{"role":"system","content":"Stable test"},{"role":"user","content":"First msg"},{"role":"assistant","content":"Response"},{"role":"user","content":"Second"}]}`)
|
||||
openaiS3 := []byte(`{"messages":[{"role":"system","content":"Stable test"},{"role":"user","content":"First msg"},{"role":"assistant","content":"Response"},{"role":"user","content":"Second"},{"role":"assistant","content":"More"},{"role":"user","content":"Third"}]}`)
|
||||
|
||||
opts2 := cliproxyexecutor.Options{OriginalRequest: openaiS2}
|
||||
opts3 := cliproxyexecutor.Options{OriginalRequest: openaiS3}
|
||||
|
||||
picked2, _ := selector.Pick(context.Background(), "test", "model", opts2, auths)
|
||||
picked3, _ := selector.Pick(context.Background(), "test", "model", opts3, auths)
|
||||
|
||||
if picked2.ID != picked3.ID {
|
||||
t.Errorf("Scenario2 and Scenario3 should pick same auth: got %s vs %s", picked2.ID, picked3.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Scenario1To2_InheritBinding", func(t *testing.T) {
|
||||
s1 := []byte(`{"messages":[{"role":"system","content":"Inherit test"},{"role":"user","content":"Initial"}]}`)
|
||||
s2 := []byte(`{"messages":[{"role":"system","content":"Inherit test"},{"role":"user","content":"Initial"},{"role":"assistant","content":"Reply"},{"role":"user","content":"Continue"}]}`)
|
||||
|
||||
opts1 := cliproxyexecutor.Options{OriginalRequest: s1}
|
||||
opts2 := cliproxyexecutor.Options{OriginalRequest: s2}
|
||||
|
||||
picked1, _ := selector.Pick(context.Background(), "inherit", "model", opts1, auths)
|
||||
picked2, _ := selector.Pick(context.Background(), "inherit", "model", opts2, auths)
|
||||
|
||||
if picked1.ID != picked2.ID {
|
||||
t.Errorf("Scenario2 should inherit Scenario1 binding: got %s vs %s", picked1.ID, picked2.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionAffinitySelector_MultiModelSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fallback := &RoundRobinSelector{}
|
||||
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||
Fallback: fallback,
|
||||
TTL: time.Minute,
|
||||
})
|
||||
defer selector.Stop()
|
||||
|
||||
// auth-a supports only model-a, auth-b supports only model-b
|
||||
authA := &Auth{ID: "auth-a"}
|
||||
authB := &Auth{ID: "auth-b"}
|
||||
|
||||
// Same session ID for all requests
|
||||
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_multi-model-test"}}`)
|
||||
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||
|
||||
// Request model-a with only auth-a available for that model
|
||||
authsForModelA := []*Auth{authA}
|
||||
pickedA, err := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() for model-a error = %v", err)
|
||||
}
|
||||
if pickedA.ID != "auth-a" {
|
||||
t.Fatalf("Pick() for model-a = %q, want auth-a", pickedA.ID)
|
||||
}
|
||||
|
||||
// Request model-b with only auth-b available for that model
|
||||
authsForModelB := []*Auth{authB}
|
||||
pickedB, err := selector.Pick(context.Background(), "provider", "model-b", opts, authsForModelB)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() for model-b error = %v", err)
|
||||
}
|
||||
if pickedB.ID != "auth-b" {
|
||||
t.Fatalf("Pick() for model-b = %q, want auth-b", pickedB.ID)
|
||||
}
|
||||
|
||||
// Switch back to model-a - should still get auth-a (separate binding per model)
|
||||
pickedA2, err := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() for model-a (2nd) error = %v", err)
|
||||
}
|
||||
if pickedA2.ID != "auth-a" {
|
||||
t.Fatalf("Pick() for model-a (2nd) = %q, want auth-a", pickedA2.ID)
|
||||
}
|
||||
|
||||
// Verify bindings are stable for multiple calls
|
||||
for i := 0; i < 5; i++ {
|
||||
gotA, _ := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA)
|
||||
gotB, _ := selector.Pick(context.Background(), "provider", "model-b", opts, authsForModelB)
|
||||
if gotA.ID != "auth-a" {
|
||||
t.Fatalf("Pick() #%d for model-a = %q, want auth-a", i, gotA.ID)
|
||||
}
|
||||
if gotB.ID != "auth-b" {
|
||||
t.Fatalf("Pick() #%d for model-b = %q, want auth-b", i, gotB.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSessionID_MultimodalContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// First request generates short hash
|
||||
firstRequestPayload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"Hello world"},{"type":"image","source":{"data":"..."}}]}]}`)
|
||||
shortHash := ExtractSessionID(nil, firstRequestPayload, nil)
|
||||
if shortHash == "" {
|
||||
t.Error("ExtractSessionID() first request should return short hash")
|
||||
}
|
||||
if !strings.HasPrefix(shortHash, "msg:") {
|
||||
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", shortHash)
|
||||
}
|
||||
|
||||
// Multi-turn generates full hash
|
||||
multiTurnPayload := []byte(`{"messages":[
|
||||
{"role":"user","content":[{"type":"text","text":"Hello world"},{"type":"image","source":{"data":"..."}}]},
|
||||
{"role":"assistant","content":"I see an image!"},
|
||||
{"role":"user","content":"What is it?"}
|
||||
]}`)
|
||||
fullHash := ExtractSessionID(nil, multiTurnPayload, nil)
|
||||
if fullHash == "" {
|
||||
t.Error("ExtractSessionID() multimodal multi-turn should return full hash")
|
||||
}
|
||||
if fullHash == shortHash {
|
||||
t.Error("Full hash should differ from short hash")
|
||||
}
|
||||
|
||||
// Different user content produces different hash
|
||||
differentPayload := []byte(`{"messages":[
|
||||
{"role":"user","content":[{"type":"text","text":"Different content"}]},
|
||||
{"role":"assistant","content":"I see something different!"}
|
||||
]}`)
|
||||
differentHash := ExtractSessionID(nil, differentPayload, nil)
|
||||
if fullHash == differentHash {
|
||||
t.Errorf("ExtractSessionID() should produce different hash for different content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionAffinitySelector_CrossProviderIsolation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fallback := &RoundRobinSelector{}
|
||||
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||
Fallback: fallback,
|
||||
TTL: time.Minute,
|
||||
})
|
||||
defer selector.Stop()
|
||||
|
||||
authClaude := &Auth{ID: "auth-claude"}
|
||||
authGemini := &Auth{ID: "auth-gemini"}
|
||||
|
||||
// Same session ID for both providers
|
||||
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_cross-provider-test"}}`)
|
||||
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||
|
||||
// Request via claude provider
|
||||
pickedClaude, err := selector.Pick(context.Background(), "claude", "claude-3", opts, []*Auth{authClaude})
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() for claude error = %v", err)
|
||||
}
|
||||
if pickedClaude.ID != "auth-claude" {
|
||||
t.Fatalf("Pick() for claude = %q, want auth-claude", pickedClaude.ID)
|
||||
}
|
||||
|
||||
// Same session but via gemini provider should get different auth
|
||||
pickedGemini, err := selector.Pick(context.Background(), "gemini", "gemini-2.5-pro", opts, []*Auth{authGemini})
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() for gemini error = %v", err)
|
||||
}
|
||||
if pickedGemini.ID != "auth-gemini" {
|
||||
t.Fatalf("Pick() for gemini = %q, want auth-gemini", pickedGemini.ID)
|
||||
}
|
||||
|
||||
// Verify both bindings remain stable
|
||||
for i := 0; i < 5; i++ {
|
||||
gotC, _ := selector.Pick(context.Background(), "claude", "claude-3", opts, []*Auth{authClaude})
|
||||
gotG, _ := selector.Pick(context.Background(), "gemini", "gemini-2.5-pro", opts, []*Auth{authGemini})
|
||||
if gotC.ID != "auth-claude" {
|
||||
t.Fatalf("Pick() #%d for claude = %q, want auth-claude", i, gotC.ID)
|
||||
}
|
||||
if gotG.ID != "auth-gemini" {
|
||||
t.Fatalf("Pick() #%d for gemini = %q, want auth-gemini", i, gotG.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCache_GetAndRefresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := NewSessionCache(100 * time.Millisecond)
|
||||
defer cache.Stop()
|
||||
|
||||
cache.Set("session1", "auth1")
|
||||
|
||||
// Verify initial value
|
||||
got, ok := cache.GetAndRefresh("session1")
|
||||
if !ok || got != "auth1" {
|
||||
t.Fatalf("GetAndRefresh() = %q, %v, want auth1, true", got, ok)
|
||||
}
|
||||
|
||||
// Wait half TTL and access again (should refresh)
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
got, ok = cache.GetAndRefresh("session1")
|
||||
if !ok || got != "auth1" {
|
||||
t.Fatalf("GetAndRefresh() after 60ms = %q, %v, want auth1, true", got, ok)
|
||||
}
|
||||
|
||||
// Wait another 60ms (total 120ms from original, but TTL refreshed at 60ms)
|
||||
// Entry should still be valid because TTL was refreshed
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
got, ok = cache.GetAndRefresh("session1")
|
||||
if !ok || got != "auth1" {
|
||||
t.Fatalf("GetAndRefresh() after refresh = %q, %v, want auth1, true (TTL should have been refreshed)", got, ok)
|
||||
}
|
||||
|
||||
// Now wait full TTL without access
|
||||
time.Sleep(110 * time.Millisecond)
|
||||
got, ok = cache.GetAndRefresh("session1")
|
||||
if ok {
|
||||
t.Fatalf("GetAndRefresh() after expiry = %q, %v, want '', false", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionAffinitySelector_RoundRobinDistribution(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fallback := &RoundRobinSelector{}
|
||||
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||
Fallback: fallback,
|
||||
TTL: time.Minute,
|
||||
})
|
||||
defer selector.Stop()
|
||||
|
||||
auths := []*Auth{
|
||||
{ID: "auth-a"},
|
||||
{ID: "auth-b"},
|
||||
{ID: "auth-c"},
|
||||
}
|
||||
|
||||
sessionCount := 12
|
||||
counts := make(map[string]int)
|
||||
for i := 0; i < sessionCount; i++ {
|
||||
payload := []byte(fmt.Sprintf(`{"metadata":{"user_id":"user_xxx_account__session_%08d-0000-0000-0000-000000000000"}}`, i))
|
||||
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||
got, err := selector.Pick(context.Background(), "provider", "model", opts, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() session %d error = %v", i, err)
|
||||
}
|
||||
counts[got.ID]++
|
||||
}
|
||||
|
||||
expected := sessionCount / len(auths)
|
||||
for _, auth := range auths {
|
||||
got := counts[auth.ID]
|
||||
if got != expected {
|
||||
t.Errorf("auth %s got %d sessions, want %d (round-robin should distribute evenly)", auth.ID, got, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionAffinitySelector_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fallback := &RoundRobinSelector{}
|
||||
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||
Fallback: fallback,
|
||||
TTL: time.Minute,
|
||||
})
|
||||
defer selector.Stop()
|
||||
|
||||
auths := []*Auth{
|
||||
{ID: "auth-a"},
|
||||
{ID: "auth-b"},
|
||||
{ID: "auth-c"},
|
||||
}
|
||||
|
||||
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_concurrent-test"}}`)
|
||||
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||
|
||||
// First pick to establish binding
|
||||
first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Initial Pick() error = %v", err)
|
||||
}
|
||||
expectedID := first.ID
|
||||
|
||||
start := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
goroutines := 32
|
||||
iterations := 50
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
for j := 0; j < iterations; j++ {
|
||||
got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
|
||||
if err != nil {
|
||||
select {
|
||||
case errCh <- err:
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
if got.ID != expectedID {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("concurrent Pick() returned %q, want %q", got.ID, expectedID):
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatalf("concurrent Pick() error = %v", err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
152
sdk/cliproxy/auth/session_cache.go
Normal file
152
sdk/cliproxy/auth/session_cache.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// sessionEntry stores auth binding with expiration.
|
||||
type sessionEntry struct {
|
||||
authID string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// SessionCache provides TTL-based session to auth mapping with automatic cleanup.
|
||||
type SessionCache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]sessionEntry
|
||||
ttl time.Duration
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionCache creates a cache with the specified TTL.
|
||||
// A background goroutine periodically cleans expired entries.
|
||||
func NewSessionCache(ttl time.Duration) *SessionCache {
|
||||
if ttl <= 0 {
|
||||
ttl = 30 * time.Minute
|
||||
}
|
||||
c := &SessionCache{
|
||||
entries: make(map[string]sessionEntry),
|
||||
ttl: ttl,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
go c.cleanupLoop()
|
||||
return c
|
||||
}
|
||||
|
||||
// Get retrieves the auth ID bound to a session, if still valid.
|
||||
// Does NOT refresh the TTL on access.
|
||||
func (c *SessionCache) Get(sessionID string) (string, bool) {
|
||||
if sessionID == "" {
|
||||
return "", false
|
||||
}
|
||||
c.mu.RLock()
|
||||
entry, ok := c.entries[sessionID]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
c.mu.Lock()
|
||||
delete(c.entries, sessionID)
|
||||
c.mu.Unlock()
|
||||
return "", false
|
||||
}
|
||||
return entry.authID, true
|
||||
}
|
||||
|
||||
// GetAndRefresh retrieves the auth ID bound to a session and refreshes TTL on hit.
|
||||
// This extends the binding lifetime for active sessions.
|
||||
func (c *SessionCache) GetAndRefresh(sessionID string) (string, bool) {
|
||||
if sessionID == "" {
|
||||
return "", false
|
||||
}
|
||||
now := time.Now()
|
||||
c.mu.Lock()
|
||||
entry, ok := c.entries[sessionID]
|
||||
if !ok {
|
||||
c.mu.Unlock()
|
||||
return "", false
|
||||
}
|
||||
if now.After(entry.expiresAt) {
|
||||
delete(c.entries, sessionID)
|
||||
c.mu.Unlock()
|
||||
return "", false
|
||||
}
|
||||
// Refresh TTL on successful access
|
||||
entry.expiresAt = now.Add(c.ttl)
|
||||
c.entries[sessionID] = entry
|
||||
c.mu.Unlock()
|
||||
return entry.authID, true
|
||||
}
|
||||
|
||||
// Set binds a session to an auth ID with TTL refresh.
|
||||
func (c *SessionCache) Set(sessionID, authID string) {
|
||||
if sessionID == "" || authID == "" {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.entries[sessionID] = sessionEntry{
|
||||
authID: authID,
|
||||
expiresAt: time.Now().Add(c.ttl),
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Invalidate removes a specific session binding.
|
||||
func (c *SessionCache) Invalidate(sessionID string) {
|
||||
if sessionID == "" {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
delete(c.entries, sessionID)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// InvalidateAuth removes all sessions bound to a specific auth ID.
|
||||
// Used when an auth becomes unavailable.
|
||||
func (c *SessionCache) InvalidateAuth(authID string) {
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
for sid, entry := range c.entries {
|
||||
if entry.authID == authID {
|
||||
delete(c.entries, sid)
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Stop terminates the background cleanup goroutine.
|
||||
func (c *SessionCache) Stop() {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
default:
|
||||
close(c.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SessionCache) cleanupLoop() {
|
||||
ticker := time.NewTicker(c.ttl / 2)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SessionCache) cleanup() {
|
||||
now := time.Now()
|
||||
c.mu.Lock()
|
||||
for sid, entry := range c.entries {
|
||||
if now.After(entry.expiresAt) {
|
||||
delete(c.entries, sid)
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
@@ -6,6 +6,7 @@ package cliproxy
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
@@ -208,8 +209,17 @@ func (b *Builder) Build() (*Service, error) {
|
||||
}
|
||||
|
||||
strategy := ""
|
||||
sessionAffinity := false
|
||||
sessionAffinityTTL := time.Hour
|
||||
if b.cfg != nil {
|
||||
strategy = strings.ToLower(strings.TrimSpace(b.cfg.Routing.Strategy))
|
||||
// Support both legacy ClaudeCodeSessionAffinity and new universal SessionAffinity
|
||||
sessionAffinity = b.cfg.Routing.ClaudeCodeSessionAffinity || b.cfg.Routing.SessionAffinity
|
||||
if ttlStr := strings.TrimSpace(b.cfg.Routing.SessionAffinityTTL); ttlStr != "" {
|
||||
if parsed, err := time.ParseDuration(ttlStr); err == nil && parsed > 0 {
|
||||
sessionAffinityTTL = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
var selector coreauth.Selector
|
||||
switch strategy {
|
||||
@@ -219,6 +229,14 @@ func (b *Builder) Build() (*Service, error) {
|
||||
selector = &coreauth.RoundRobinSelector{}
|
||||
}
|
||||
|
||||
// Wrap with session affinity if enabled (failover is always on)
|
||||
if sessionAffinity {
|
||||
selector = coreauth.NewSessionAffinitySelectorWithConfig(coreauth.SessionAffinityConfig{
|
||||
Fallback: selector,
|
||||
TTL: sessionAffinityTTL,
|
||||
})
|
||||
}
|
||||
|
||||
coreManager = coreauth.NewManager(tokenStore, selector, nil)
|
||||
}
|
||||
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
|
||||
|
||||
@@ -118,7 +118,6 @@ func newDefaultAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewGeminiAuthenticator(),
|
||||
sdkAuth.NewCodexAuthenticator(),
|
||||
sdkAuth.NewClaudeAuthenticator(),
|
||||
sdkAuth.NewQwenAuthenticator(),
|
||||
sdkAuth.NewGitLabAuthenticator(),
|
||||
)
|
||||
}
|
||||
@@ -435,8 +434,6 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
|
||||
s.coreManager.RegisterExecutor(executor.NewAntigravityExecutor(s.cfg))
|
||||
case "claude":
|
||||
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
|
||||
case "qwen":
|
||||
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
||||
case "iflow":
|
||||
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
|
||||
case "kimi":
|
||||
@@ -639,9 +636,13 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
var watcherWrapper *WatcherWrapper
|
||||
reloadCallback := func(newCfg *config.Config) {
|
||||
previousStrategy := ""
|
||||
var previousSessionAffinity bool
|
||||
var previousSessionAffinityTTL string
|
||||
s.cfgMu.RLock()
|
||||
if s.cfg != nil {
|
||||
previousStrategy = strings.ToLower(strings.TrimSpace(s.cfg.Routing.Strategy))
|
||||
previousSessionAffinity = s.cfg.Routing.ClaudeCodeSessionAffinity || s.cfg.Routing.SessionAffinity
|
||||
previousSessionAffinityTTL = s.cfg.Routing.SessionAffinityTTL
|
||||
}
|
||||
s.cfgMu.RUnlock()
|
||||
|
||||
@@ -665,7 +666,15 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
}
|
||||
previousStrategy = normalizeStrategy(previousStrategy)
|
||||
nextStrategy = normalizeStrategy(nextStrategy)
|
||||
if s.coreManager != nil && previousStrategy != nextStrategy {
|
||||
|
||||
nextSessionAffinity := newCfg.Routing.ClaudeCodeSessionAffinity || newCfg.Routing.SessionAffinity
|
||||
nextSessionAffinityTTL := newCfg.Routing.SessionAffinityTTL
|
||||
|
||||
selectorChanged := previousStrategy != nextStrategy ||
|
||||
previousSessionAffinity != nextSessionAffinity ||
|
||||
previousSessionAffinityTTL != nextSessionAffinityTTL
|
||||
|
||||
if s.coreManager != nil && selectorChanged {
|
||||
var selector coreauth.Selector
|
||||
switch nextStrategy {
|
||||
case "fill-first":
|
||||
@@ -673,6 +682,20 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
default:
|
||||
selector = &coreauth.RoundRobinSelector{}
|
||||
}
|
||||
|
||||
if nextSessionAffinity {
|
||||
ttl := time.Hour
|
||||
if ttlStr := strings.TrimSpace(nextSessionAffinityTTL); ttlStr != "" {
|
||||
if parsed, err := time.ParseDuration(ttlStr); err == nil && parsed > 0 {
|
||||
ttl = parsed
|
||||
}
|
||||
}
|
||||
selector = coreauth.NewSessionAffinitySelectorWithConfig(coreauth.SessionAffinityConfig{
|
||||
Fallback: selector,
|
||||
TTL: ttl,
|
||||
})
|
||||
}
|
||||
|
||||
s.coreManager.SetSelector(selector)
|
||||
}
|
||||
|
||||
@@ -939,9 +962,6 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
}
|
||||
}
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "qwen":
|
||||
models = registry.GetQwenModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "iflow":
|
||||
models = registry.GetIFlowModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
|
||||
Reference in New Issue
Block a user