diff --git a/cmd/fetch_antigravity_models/main.go b/cmd/fetch_antigravity_models/main.go index 54ec16ca..d4328eb3 100644 --- a/cmd/fetch_antigravity_models/main.go +++ b/cmd/fetch_antigravity_models/main.go @@ -26,6 +26,7 @@ import ( "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" @@ -188,7 +189,7 @@ func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry { httpReq.Close = true httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+accessToken) - httpReq.Header.Set("User-Agent", "antigravity/1.21.9 darwin/arm64") + httpReq.Header.Set("User-Agent", misc.AntigravityUserAgent()) httpClient := &http.Client{Timeout: 30 * time.Second} if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil { diff --git a/cmd/server/main.go b/cmd/server/main.go index 4bb90dc7..986c61df 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -598,6 +598,7 @@ func main() { if standalone { // Standalone mode: start an embedded local server and connect TUI client to it. managementasset.StartAutoUpdater(context.Background(), configFilePath) + misc.StartAntigravityVersionUpdater(context.Background()) if !localModel { registry.StartModelsUpdater(context.Background()) } @@ -673,6 +674,7 @@ func main() { } else { // Start the main proxy service managementasset.StartAutoUpdater(context.Background(), configFilePath) + misc.StartAntigravityVersionUpdater(context.Background()) if !localModel { registry.StartModelsUpdater(context.Background()) } diff --git a/internal/api/handlers/management/api_tools.go b/internal/api/handlers/management/api_tools.go index c9aa55ed..46ea9060 100644 --- a/internal/api/handlers/management/api_tools.go +++ b/internal/api/handlers/management/api_tools.go @@ -13,6 +13,7 @@ import ( "github.com/fxamacker/cbor/v2" "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" @@ -700,6 +701,11 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper { if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" { proxyCandidates = append(proxyCandidates, proxyStr) } + if h != nil && h.cfg != nil { + if proxyStr := strings.TrimSpace(proxyURLFromAPIKeyConfig(h.cfg, auth)); proxyStr != "" { + proxyCandidates = append(proxyCandidates, proxyStr) + } + } } if h != nil && h.cfg != nil { if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" { @@ -722,6 +728,123 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper { return clone } +type apiKeyConfigEntry interface { + GetAPIKey() string + GetBaseURL() string +} + +func resolveAPIKeyConfig[T apiKeyConfigEntry](entries []T, auth *coreauth.Auth) *T { + if auth == nil || len(entries) == 0 { + return nil + } + attrKey, attrBase := "", "" + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range entries { + entry := &entries[i] + cfgKey := strings.TrimSpace((*entry).GetAPIKey()) + cfgBase := strings.TrimSpace((*entry).GetBaseURL()) + if attrKey != "" && attrBase != "" { + if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey != "" { + for i := range entries { + entry := &entries[i] + if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) { + return entry + } + } + } + return nil +} + +func proxyURLFromAPIKeyConfig(cfg *config.Config, auth *coreauth.Auth) string { + if cfg == nil || auth == nil { + return "" + } + authKind, authAccount := auth.AccountInfo() + if !strings.EqualFold(strings.TrimSpace(authKind), "api_key") { + return "" + } + + attrs := auth.Attributes + compatName := "" + providerKey := "" + if len(attrs) > 0 { + compatName = strings.TrimSpace(attrs["compat_name"]) + providerKey = strings.TrimSpace(attrs["provider_key"]) + } + if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") { + return resolveOpenAICompatAPIKeyProxyURL(cfg, auth, strings.TrimSpace(authAccount), providerKey, compatName) + } + + switch strings.ToLower(strings.TrimSpace(auth.Provider)) { + case "gemini": + if entry := resolveAPIKeyConfig(cfg.GeminiKey, auth); entry != nil { + return strings.TrimSpace(entry.ProxyURL) + } + case "claude": + if entry := resolveAPIKeyConfig(cfg.ClaudeKey, auth); entry != nil { + return strings.TrimSpace(entry.ProxyURL) + } + case "codex": + if entry := resolveAPIKeyConfig(cfg.CodexKey, auth); entry != nil { + return strings.TrimSpace(entry.ProxyURL) + } + } + return "" +} + +func resolveOpenAICompatAPIKeyProxyURL(cfg *config.Config, auth *coreauth.Auth, apiKey, providerKey, compatName string) string { + if cfg == nil || auth == nil { + return "" + } + apiKey = strings.TrimSpace(apiKey) + if apiKey == "" { + return "" + } + candidates := make([]string, 0, 3) + if v := strings.TrimSpace(compatName); v != "" { + candidates = append(candidates, v) + } + if v := strings.TrimSpace(providerKey); v != "" { + candidates = append(candidates, v) + } + if v := strings.TrimSpace(auth.Provider); v != "" { + candidates = append(candidates, v) + } + + for i := range cfg.OpenAICompatibility { + compat := &cfg.OpenAICompatibility[i] + for _, candidate := range candidates { + if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) { + for j := range compat.APIKeyEntries { + entry := &compat.APIKeyEntries[j] + if strings.EqualFold(strings.TrimSpace(entry.APIKey), apiKey) { + return strings.TrimSpace(entry.ProxyURL) + } + } + return "" + } + } + } + return "" +} + func buildProxyTransport(proxyStr string) *http.Transport { transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr) if errBuild != nil { diff --git a/internal/api/handlers/management/api_tools_test.go b/internal/api/handlers/management/api_tools_test.go index 6ed98c6e..b27fe639 100644 --- a/internal/api/handlers/management/api_tools_test.go +++ b/internal/api/handlers/management/api_tools_test.go @@ -58,6 +58,105 @@ func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) { } } +func TestAPICallTransportAPIKeyAuthFallsBackToConfigProxyURL(t *testing.T) { + t.Parallel() + + h := &Handler{ + cfg: &config.Config{ + SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}, + GeminiKey: []config.GeminiKey{{ + APIKey: "gemini-key", + ProxyURL: "http://gemini-proxy.example.com:8080", + }}, + ClaudeKey: []config.ClaudeKey{{ + APIKey: "claude-key", + ProxyURL: "http://claude-proxy.example.com:8080", + }}, + CodexKey: []config.CodexKey{{ + APIKey: "codex-key", + ProxyURL: "http://codex-proxy.example.com:8080", + }}, + OpenAICompatibility: []config.OpenAICompatibility{{ + Name: "bohe", + BaseURL: "https://bohe.example.com", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{{ + APIKey: "compat-key", + ProxyURL: "http://compat-proxy.example.com:8080", + }}, + }}, + }, + } + + cases := []struct { + name string + auth *coreauth.Auth + wantProxy string + }{ + { + name: "gemini", + auth: &coreauth.Auth{ + Provider: "gemini", + Attributes: map[string]string{"api_key": "gemini-key"}, + }, + wantProxy: "http://gemini-proxy.example.com:8080", + }, + { + name: "claude", + auth: &coreauth.Auth{ + Provider: "claude", + Attributes: map[string]string{"api_key": "claude-key"}, + }, + wantProxy: "http://claude-proxy.example.com:8080", + }, + { + name: "codex", + auth: &coreauth.Auth{ + Provider: "codex", + Attributes: map[string]string{"api_key": "codex-key"}, + }, + wantProxy: "http://codex-proxy.example.com:8080", + }, + { + name: "openai-compatibility", + auth: &coreauth.Auth{ + Provider: "bohe", + Attributes: map[string]string{ + "api_key": "compat-key", + "compat_name": "bohe", + "provider_key": "bohe", + }, + }, + wantProxy: "http://compat-proxy.example.com:8080", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + transport := h.apiCallTransport(tc.auth) + httpTransport, ok := transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", transport) + } + + req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil) + if errRequest != nil { + t.Fatalf("http.NewRequest returned error: %v", errRequest) + } + + proxyURL, errProxy := httpTransport.Proxy(req) + if errProxy != nil { + t.Fatalf("httpTransport.Proxy returned error: %v", errProxy) + } + if proxyURL == nil || proxyURL.String() != tc.wantProxy { + t.Fatalf("proxy URL = %v, want %s", proxyURL, tc.wantProxy) + } + }) + } +} + func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) { t.Parallel() diff --git a/internal/misc/antigravity_version.go b/internal/misc/antigravity_version.go new file mode 100644 index 00000000..595cfefd --- /dev/null +++ b/internal/misc/antigravity_version.go @@ -0,0 +1,151 @@ +// Package misc provides miscellaneous utility functions for the CLI Proxy API server. +package misc + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + antigravityReleasesURL = "https://antigravity-auto-updater-974169037036.us-central1.run.app/releases" + antigravityFallbackVersion = "1.21.9" + antigravityVersionCacheTTL = 6 * time.Hour + antigravityFetchTimeout = 10 * time.Second +) + +type antigravityRelease struct { + Version string `json:"version"` + ExecutionID string `json:"execution_id"` +} + +var ( + cachedAntigravityVersion = antigravityFallbackVersion + antigravityVersionMu sync.RWMutex + antigravityVersionExpiry time.Time + antigravityUpdaterOnce sync.Once +) + +// StartAntigravityVersionUpdater starts a background goroutine that periodically refreshes the cached antigravity version. +// This is intentionally decoupled from request execution to avoid blocking executors on version lookups. +func StartAntigravityVersionUpdater(ctx context.Context) { + antigravityUpdaterOnce.Do(func() { + go runAntigravityVersionUpdater(ctx) + }) +} + +func runAntigravityVersionUpdater(ctx context.Context) { + if ctx == nil { + ctx = context.Background() + } + + ticker := time.NewTicker(antigravityVersionCacheTTL / 2) + defer ticker.Stop() + + log.Infof("periodic antigravity version refresh started (interval=%s)", antigravityVersionCacheTTL/2) + + refreshAntigravityVersion(ctx) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + refreshAntigravityVersion(ctx) + } + } +} + +func refreshAntigravityVersion(ctx context.Context) { + version, errFetch := fetchAntigravityLatestVersion(ctx) + + antigravityVersionMu.Lock() + defer antigravityVersionMu.Unlock() + + now := time.Now() + + if errFetch == nil { + cachedAntigravityVersion = version + antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL) + log.WithField("version", version).Info("fetched latest antigravity version") + return + } + + if cachedAntigravityVersion == "" || now.After(antigravityVersionExpiry) { + cachedAntigravityVersion = antigravityFallbackVersion + antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL) + log.WithError(errFetch).Warn("failed to refresh antigravity version, using fallback version") + return + } + + log.WithError(errFetch).Debug("failed to refresh antigravity version, keeping cached value") +} + +// AntigravityLatestVersion returns the cached antigravity version refreshed by StartAntigravityVersionUpdater. +// It falls back to antigravityFallbackVersion if the cache is empty or stale. +func AntigravityLatestVersion() string { + antigravityVersionMu.RLock() + if cachedAntigravityVersion != "" && time.Now().Before(antigravityVersionExpiry) { + v := cachedAntigravityVersion + antigravityVersionMu.RUnlock() + return v + } + antigravityVersionMu.RUnlock() + + return antigravityFallbackVersion +} + +// AntigravityUserAgent returns the User-Agent string for antigravity requests +// using the latest version fetched from the releases API. +func AntigravityUserAgent() string { + return fmt.Sprintf("antigravity/%s darwin/arm64", AntigravityLatestVersion()) +} + +func fetchAntigravityLatestVersion(ctx context.Context) (string, error) { + if ctx == nil { + ctx = context.Background() + } + + client := &http.Client{Timeout: antigravityFetchTimeout} + + httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodGet, antigravityReleasesURL, nil) + if errReq != nil { + return "", fmt.Errorf("build antigravity releases request: %w", errReq) + } + + resp, errDo := client.Do(httpReq) + if errDo != nil { + return "", fmt.Errorf("fetch antigravity releases: %w", errDo) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.WithError(errClose).Warn("antigravity releases response body close error") + } + }() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("antigravity releases API returned status %d", resp.StatusCode) + } + + var releases []antigravityRelease + if errDecode := json.NewDecoder(resp.Body).Decode(&releases); errDecode != nil { + return "", fmt.Errorf("decode antigravity releases response: %w", errDecode) + } + + if len(releases) == 0 { + return "", errors.New("antigravity releases API returned empty list") + } + + version := releases[0].Version + if version == "" { + return "", errors.New("antigravity releases API returned empty version") + } + + return version, nil +} diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index b9bf4842..ecab3c87 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -24,6 +24,7 @@ import ( "github.com/google/uuid" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" @@ -45,7 +46,7 @@ const ( antigravityGeneratePath = "/v1internal:generateContent" antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" antigravityClientSecret = "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" - defaultAntigravityAgent = "antigravity/1.21.9 darwin/arm64" + defaultAntigravityAgent = "antigravity/1.21.9 darwin/arm64" // fallback only; overridden at runtime by misc.AntigravityUserAgent() antigravityAuthType = "antigravity" refreshSkew = 3000 * time.Second antigravityCreditsRetryTTL = 5 * time.Hour @@ -1739,7 +1740,7 @@ func resolveUserAgent(auth *cliproxyauth.Auth) string { } } } - return defaultAntigravityAgent + return misc.AntigravityUserAgent() } func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int { diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index f771099c..d8eec537 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -172,32 +172,101 @@ func timeUntilNextDay() time.Duration { return tomorrow.Sub(now) } -// ensureQwenSystemMessage prepends a default system message if none exists in "messages". +// ensureQwenSystemMessage ensures the request has a single system message at the beginning. +// It always injects the default system prompt and merges any user-provided system messages +// into the injected system message content to satisfy Qwen's strict message ordering rules. func ensureQwenSystemMessage(payload []byte) ([]byte, error) { - messages := gjson.GetBytes(payload, "messages") - if messages.Exists() && messages.IsArray() { - var buf bytes.Buffer - buf.WriteByte('[') - buf.Write(qwenDefaultSystemMessage) - for _, msg := range messages.Array() { - buf.WriteByte(',') - buf.WriteString(msg.Raw) + isInjectedSystemPart := func(part gjson.Result) bool { + if !part.Exists() || !part.IsObject() { + return false } - buf.WriteByte(']') - updated, errSet := sjson.SetRawBytes(payload, "messages", buf.Bytes()) - if errSet != nil { - return nil, fmt.Errorf("qwen executor: set default system message failed: %w", errSet) + if !strings.EqualFold(part.Get("type").String(), "text") { + return false } - return updated, nil + if !strings.EqualFold(part.Get("cache_control.type").String(), "ephemeral") { + return false + } + text := part.Get("text").String() + return text == "" || text == "You are Qwen Code." } - var buf bytes.Buffer - buf.WriteByte('[') - buf.Write(qwenDefaultSystemMessage) - buf.WriteByte(']') - updated, errSet := sjson.SetRawBytes(payload, "messages", buf.Bytes()) + defaultParts := gjson.ParseBytes(qwenDefaultSystemMessage).Get("content") + var systemParts []any + if defaultParts.Exists() && defaultParts.IsArray() { + for _, part := range defaultParts.Array() { + systemParts = append(systemParts, part.Value()) + } + } + if len(systemParts) == 0 { + systemParts = append(systemParts, map[string]any{ + "type": "text", + "text": "You are Qwen Code.", + "cache_control": map[string]any{ + "type": "ephemeral", + }, + }) + } + + appendSystemContent := func(content gjson.Result) { + makeTextPart := func(text string) map[string]any { + return map[string]any{ + "type": "text", + "text": text, + } + } + + if !content.Exists() || content.Type == gjson.Null { + return + } + if content.IsArray() { + for _, part := range content.Array() { + if part.Type == gjson.String { + systemParts = append(systemParts, makeTextPart(part.String())) + continue + } + if isInjectedSystemPart(part) { + continue + } + systemParts = append(systemParts, part.Value()) + } + return + } + if content.Type == gjson.String { + systemParts = append(systemParts, makeTextPart(content.String())) + return + } + if content.IsObject() { + if isInjectedSystemPart(content) { + return + } + systemParts = append(systemParts, content.Value()) + return + } + systemParts = append(systemParts, makeTextPart(content.String())) + } + + messages := gjson.GetBytes(payload, "messages") + var nonSystemMessages []any + if messages.Exists() && messages.IsArray() { + for _, msg := range messages.Array() { + if strings.EqualFold(msg.Get("role").String(), "system") { + appendSystemContent(msg.Get("content")) + continue + } + nonSystemMessages = append(nonSystemMessages, msg.Value()) + } + } + + newMessages := make([]any, 0, 1+len(nonSystemMessages)) + newMessages = append(newMessages, map[string]any{ + "role": "system", + "content": systemParts, + }) + newMessages = append(newMessages, nonSystemMessages...) + + updated, errSet := sjson.SetBytes(payload, "messages", newMessages) if errSet != nil { - return nil, fmt.Errorf("qwen executor: set default system message failed: %w", errSet) + return nil, fmt.Errorf("qwen executor: set system message failed: %w", errSet) } return updated, nil } diff --git a/internal/runtime/executor/qwen_executor_test.go b/internal/runtime/executor/qwen_executor_test.go index 6a777c53..627cf453 100644 --- a/internal/runtime/executor/qwen_executor_test.go +++ b/internal/runtime/executor/qwen_executor_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + "github.com/tidwall/gjson" ) func TestQwenExecutorParseSuffix(t *testing.T) { @@ -28,3 +29,123 @@ func TestQwenExecutorParseSuffix(t *testing.T) { }) } } + +func TestEnsureQwenSystemMessage_MergeStringSystem(t *testing.T) { + payload := []byte(`{ + "model": "qwen3.6-plus", + "stream": true, + "messages": [ + { "role": "system", "content": "ABCDEFG" }, + { "role": "user", "content": [ { "type": "text", "text": "你好" } ] } + ] + }`) + + out, err := ensureQwenSystemMessage(payload) + if err != nil { + t.Fatalf("ensureQwenSystemMessage() error = %v", err) + } + + msgs := gjson.GetBytes(out, "messages").Array() + if len(msgs) != 2 { + t.Fatalf("messages length = %d, want 2", len(msgs)) + } + if msgs[0].Get("role").String() != "system" { + t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system") + } + parts := msgs[0].Get("content").Array() + if len(parts) != 2 { + t.Fatalf("messages[0].content length = %d, want 2", len(parts)) + } + if parts[0].Get("text").String() != "You are Qwen Code." || parts[0].Get("cache_control.type").String() != "ephemeral" { + t.Fatalf("messages[0].content[0] = %s, want injected system part", parts[0].Raw) + } + if parts[1].Get("type").String() != "text" || parts[1].Get("text").String() != "ABCDEFG" { + t.Fatalf("messages[0].content[1] = %s, want text part with ABCDEFG", parts[1].Raw) + } + if msgs[1].Get("role").String() != "user" { + t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user") + } +} + +func TestEnsureQwenSystemMessage_MergeObjectSystem(t *testing.T) { + payload := []byte(`{ + "messages": [ + { "role": "system", "content": { "type": "text", "text": "ABCDEFG" } }, + { "role": "user", "content": [ { "type": "text", "text": "你好" } ] } + ] + }`) + + out, err := ensureQwenSystemMessage(payload) + if err != nil { + t.Fatalf("ensureQwenSystemMessage() error = %v", err) + } + + msgs := gjson.GetBytes(out, "messages").Array() + if len(msgs) != 2 { + t.Fatalf("messages length = %d, want 2", len(msgs)) + } + parts := msgs[0].Get("content").Array() + if len(parts) != 2 { + t.Fatalf("messages[0].content length = %d, want 2", len(parts)) + } + if parts[1].Get("text").String() != "ABCDEFG" { + t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "ABCDEFG") + } +} + +func TestEnsureQwenSystemMessage_PrependsWhenMissing(t *testing.T) { + payload := []byte(`{ + "messages": [ + { "role": "user", "content": [ { "type": "text", "text": "你好" } ] } + ] + }`) + + out, err := ensureQwenSystemMessage(payload) + if err != nil { + t.Fatalf("ensureQwenSystemMessage() error = %v", err) + } + + msgs := gjson.GetBytes(out, "messages").Array() + if len(msgs) != 2 { + t.Fatalf("messages length = %d, want 2", len(msgs)) + } + if msgs[0].Get("role").String() != "system" { + t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system") + } + if !msgs[0].Get("content").IsArray() || len(msgs[0].Get("content").Array()) == 0 { + t.Fatalf("messages[0].content = %s, want non-empty array", msgs[0].Get("content").Raw) + } + if msgs[1].Get("role").String() != "user" { + t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user") + } +} + +func TestEnsureQwenSystemMessage_MergesMultipleSystemMessages(t *testing.T) { + payload := []byte(`{ + "messages": [ + { "role": "system", "content": "A" }, + { "role": "user", "content": [ { "type": "text", "text": "hi" } ] }, + { "role": "system", "content": "B" } + ] + }`) + + out, err := ensureQwenSystemMessage(payload) + if err != nil { + t.Fatalf("ensureQwenSystemMessage() error = %v", err) + } + + msgs := gjson.GetBytes(out, "messages").Array() + if len(msgs) != 2 { + t.Fatalf("messages length = %d, want 2", len(msgs)) + } + parts := msgs[0].Get("content").Array() + if len(parts) != 3 { + t.Fatalf("messages[0].content length = %d, want 3", len(parts)) + } + if parts[1].Get("text").String() != "A" { + t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "A") + } + if parts[2].Get("text").String() != "B" { + t.Fatalf("messages[0].content[2].text = %q, want %q", parts[2].Get("text").String(), "B") + } +}