mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-08 05:47:16 +00:00
Compare commits
55 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4607356333 | ||
|
|
9a9ed99072 | ||
|
|
5ae38584b8 | ||
|
|
c8b7e2b8d6 | ||
|
|
cad45ffa33 | ||
|
|
6a27bceec0 | ||
|
|
163d68318f | ||
|
|
0ea768011b | ||
|
|
341b4beea1 | ||
|
|
bea13f9724 | ||
|
|
9f5bdfaa31 | ||
|
|
9eabdd09db | ||
|
|
c3f8dc362e | ||
|
|
b85120873b | ||
|
|
6f58518c69 | ||
|
|
000fcb15fa | ||
|
|
ea43361492 | ||
|
|
c1818f197b | ||
|
|
b0653cec7b | ||
|
|
22a1a24cf5 | ||
|
|
7223fee2de | ||
|
|
ada8e2905e | ||
|
|
4ba10531da | ||
|
|
3774b56e9f | ||
|
|
c2d4137fb9 | ||
|
|
2ee938acaf | ||
|
|
8d5e470e1f | ||
|
|
65e9e892a4 | ||
|
|
3882494878 | ||
|
|
088c1d07f4 | ||
|
|
8430b28cfa | ||
|
|
f3ab8f4bc5 | ||
|
|
0e4f189c2e | ||
|
|
98509f615c | ||
|
|
87bf0b73d5 | ||
|
|
b849bf79d6 | ||
|
|
59af2c57b1 | ||
|
|
9b5ce8c64f | ||
|
|
058793c73a | ||
|
|
da3a498a28 | ||
|
|
5fc2bd393e | ||
|
|
66eb12294a | ||
|
|
73b22ec29b | ||
|
|
c31ae2f3b5 | ||
|
|
76b53d6b5b | ||
|
|
a34dfed378 | ||
|
|
36efcc6e28 | ||
|
|
a337ecf35c | ||
|
|
e08f68ed7c | ||
|
|
f09ed25fd3 | ||
|
|
e166e56249 | ||
|
|
5f58248016 | ||
|
|
07d6689d87 | ||
|
|
14cb2b95c6 | ||
|
|
fdeef48498 |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -54,4 +54,10 @@ _bmad-output/*
|
||||
# macOS
|
||||
.DS_Store
|
||||
._*
|
||||
|
||||
# Opencode
|
||||
.beads/
|
||||
.opencode/
|
||||
.cli-proxy-api/
|
||||
.venv/
|
||||
*.bak
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
@@ -188,7 +189,7 @@ func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry {
|
||||
httpReq.Close = true
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
httpReq.Header.Set("User-Agent", "antigravity/1.19.6 darwin/arm64")
|
||||
httpReq.Header.Set("User-Agent", misc.AntigravityUserAgent())
|
||||
|
||||
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {
|
||||
|
||||
@@ -99,6 +99,7 @@ func main() {
|
||||
var codeBuddyLogin bool
|
||||
var projectID string
|
||||
var vertexImport string
|
||||
var vertexImportPrefix string
|
||||
var configPath string
|
||||
var password string
|
||||
var tuiMode bool
|
||||
@@ -139,6 +140,7 @@ func main() {
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)")
|
||||
flag.StringVar(&password, "password", "", "")
|
||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||
@@ -188,6 +190,7 @@ func main() {
|
||||
gitStoreRemoteURL string
|
||||
gitStoreUser string
|
||||
gitStorePassword string
|
||||
gitStoreBranch string
|
||||
gitStoreLocalPath string
|
||||
gitStoreInst *store.GitTokenStore
|
||||
gitStoreRoot string
|
||||
@@ -257,6 +260,9 @@ func main() {
|
||||
if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
|
||||
gitStoreLocalPath = value
|
||||
}
|
||||
if value, ok := lookupEnv("GITSTORE_GIT_BRANCH", "gitstore_git_branch"); ok {
|
||||
gitStoreBranch = value
|
||||
}
|
||||
if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
|
||||
useObjectStore = true
|
||||
objectStoreEndpoint = value
|
||||
@@ -391,7 +397,7 @@ func main() {
|
||||
}
|
||||
gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
|
||||
authDir := filepath.Join(gitStoreRoot, "auths")
|
||||
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
|
||||
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword, gitStoreBranch)
|
||||
gitStoreInst.SetBaseDir(authDir)
|
||||
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
|
||||
log.Errorf("failed to prepare git token store: %v", errRepo)
|
||||
@@ -510,7 +516,7 @@ func main() {
|
||||
|
||||
if vertexImport != "" {
|
||||
// Handle Vertex service account import
|
||||
cmd.DoVertexImport(cfg, vertexImport)
|
||||
cmd.DoVertexImport(cfg, vertexImport, vertexImportPrefix)
|
||||
} else if login {
|
||||
// Handle Google/Gemini login
|
||||
cmd.DoLogin(cfg, projectID, options)
|
||||
@@ -596,6 +602,7 @@ func main() {
|
||||
if standalone {
|
||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
misc.StartAntigravityVersionUpdater(context.Background())
|
||||
if !localModel {
|
||||
registry.StartModelsUpdater(context.Background())
|
||||
}
|
||||
@@ -671,6 +678,7 @@ func main() {
|
||||
} else {
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
misc.StartAntigravityVersionUpdater(context.Background())
|
||||
if !localModel {
|
||||
registry.StartModelsUpdater(context.Background())
|
||||
}
|
||||
|
||||
@@ -92,6 +92,9 @@ max-retry-credentials: 0
|
||||
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
||||
max-retry-interval: 30
|
||||
|
||||
# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states).
|
||||
disable-cooling: false
|
||||
|
||||
# Quota exceeded behavior
|
||||
quota-exceeded:
|
||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
@@ -700,6 +701,11 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
||||
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
|
||||
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||
}
|
||||
if h != nil && h.cfg != nil {
|
||||
if proxyStr := strings.TrimSpace(proxyURLFromAPIKeyConfig(h.cfg, auth)); proxyStr != "" {
|
||||
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if h != nil && h.cfg != nil {
|
||||
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
|
||||
@@ -722,6 +728,123 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
||||
return clone
|
||||
}
|
||||
|
||||
type apiKeyConfigEntry interface {
|
||||
GetAPIKey() string
|
||||
GetBaseURL() string
|
||||
}
|
||||
|
||||
func resolveAPIKeyConfig[T apiKeyConfigEntry](entries []T, auth *coreauth.Auth) *T {
|
||||
if auth == nil || len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
attrKey, attrBase := "", ""
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range entries {
|
||||
entry := &entries[i]
|
||||
cfgKey := strings.TrimSpace((*entry).GetAPIKey())
|
||||
cfgBase := strings.TrimSpace((*entry).GetBaseURL())
|
||||
if attrKey != "" && attrBase != "" {
|
||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey != "" {
|
||||
for i := range entries {
|
||||
entry := &entries[i]
|
||||
if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func proxyURLFromAPIKeyConfig(cfg *config.Config, auth *coreauth.Auth) string {
|
||||
if cfg == nil || auth == nil {
|
||||
return ""
|
||||
}
|
||||
authKind, authAccount := auth.AccountInfo()
|
||||
if !strings.EqualFold(strings.TrimSpace(authKind), "api_key") {
|
||||
return ""
|
||||
}
|
||||
|
||||
attrs := auth.Attributes
|
||||
compatName := ""
|
||||
providerKey := ""
|
||||
if len(attrs) > 0 {
|
||||
compatName = strings.TrimSpace(attrs["compat_name"])
|
||||
providerKey = strings.TrimSpace(attrs["provider_key"])
|
||||
}
|
||||
if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
|
||||
return resolveOpenAICompatAPIKeyProxyURL(cfg, auth, strings.TrimSpace(authAccount), providerKey, compatName)
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(auth.Provider)) {
|
||||
case "gemini":
|
||||
if entry := resolveAPIKeyConfig(cfg.GeminiKey, auth); entry != nil {
|
||||
return strings.TrimSpace(entry.ProxyURL)
|
||||
}
|
||||
case "claude":
|
||||
if entry := resolveAPIKeyConfig(cfg.ClaudeKey, auth); entry != nil {
|
||||
return strings.TrimSpace(entry.ProxyURL)
|
||||
}
|
||||
case "codex":
|
||||
if entry := resolveAPIKeyConfig(cfg.CodexKey, auth); entry != nil {
|
||||
return strings.TrimSpace(entry.ProxyURL)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func resolveOpenAICompatAPIKeyProxyURL(cfg *config.Config, auth *coreauth.Auth, apiKey, providerKey, compatName string) string {
|
||||
if cfg == nil || auth == nil {
|
||||
return ""
|
||||
}
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
if apiKey == "" {
|
||||
return ""
|
||||
}
|
||||
candidates := make([]string, 0, 3)
|
||||
if v := strings.TrimSpace(compatName); v != "" {
|
||||
candidates = append(candidates, v)
|
||||
}
|
||||
if v := strings.TrimSpace(providerKey); v != "" {
|
||||
candidates = append(candidates, v)
|
||||
}
|
||||
if v := strings.TrimSpace(auth.Provider); v != "" {
|
||||
candidates = append(candidates, v)
|
||||
}
|
||||
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
compat := &cfg.OpenAICompatibility[i]
|
||||
for _, candidate := range candidates {
|
||||
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
|
||||
for j := range compat.APIKeyEntries {
|
||||
entry := &compat.APIKeyEntries[j]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), apiKey) {
|
||||
return strings.TrimSpace(entry.ProxyURL)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildProxyTransport(proxyStr string) *http.Transport {
|
||||
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
|
||||
if errBuild != nil {
|
||||
|
||||
@@ -58,6 +58,105 @@ func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPICallTransportAPIKeyAuthFallsBackToConfigProxyURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := &Handler{
|
||||
cfg: &config.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||
GeminiKey: []config.GeminiKey{{
|
||||
APIKey: "gemini-key",
|
||||
ProxyURL: "http://gemini-proxy.example.com:8080",
|
||||
}},
|
||||
ClaudeKey: []config.ClaudeKey{{
|
||||
APIKey: "claude-key",
|
||||
ProxyURL: "http://claude-proxy.example.com:8080",
|
||||
}},
|
||||
CodexKey: []config.CodexKey{{
|
||||
APIKey: "codex-key",
|
||||
ProxyURL: "http://codex-proxy.example.com:8080",
|
||||
}},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{{
|
||||
Name: "bohe",
|
||||
BaseURL: "https://bohe.example.com",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{
|
||||
APIKey: "compat-key",
|
||||
ProxyURL: "http://compat-proxy.example.com:8080",
|
||||
}},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
auth *coreauth.Auth
|
||||
wantProxy string
|
||||
}{
|
||||
{
|
||||
name: "gemini",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "gemini",
|
||||
Attributes: map[string]string{"api_key": "gemini-key"},
|
||||
},
|
||||
wantProxy: "http://gemini-proxy.example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "claude",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "claude",
|
||||
Attributes: map[string]string{"api_key": "claude-key"},
|
||||
},
|
||||
wantProxy: "http://claude-proxy.example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "codex",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "codex",
|
||||
Attributes: map[string]string{"api_key": "codex-key"},
|
||||
},
|
||||
wantProxy: "http://codex-proxy.example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "openai-compatibility",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "bohe",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "compat-key",
|
||||
"compat_name": "bohe",
|
||||
"provider_key": "bohe",
|
||||
},
|
||||
},
|
||||
wantProxy: "http://compat-proxy.example.com:8080",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transport := h.apiCallTransport(tc.auth)
|
||||
httpTransport, ok := transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||
}
|
||||
|
||||
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||
if errRequest != nil {
|
||||
t.Fatalf("http.NewRequest returned error: %v", errRequest)
|
||||
}
|
||||
|
||||
proxyURL, errProxy := httpTransport.Proxy(req)
|
||||
if errProxy != nil {
|
||||
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
|
||||
}
|
||||
if proxyURL == nil || proxyURL.String() != tc.wantProxy {
|
||||
t.Fatalf("proxy URL = %v, want %s", proxyURL, tc.wantProxy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -298,8 +299,10 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
|
||||
}
|
||||
|
||||
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
|
||||
// from the messages array in a request body before forwarding to the upstream API.
|
||||
// This prevents 400 errors from the API which requires valid signatures on thinking blocks.
|
||||
// and strips the proxy-injected "signature" field from tool_use blocks in the messages
|
||||
// array before forwarding to the upstream API.
|
||||
// This prevents 400 errors from the API which requires valid signatures on thinking
|
||||
// blocks and does not accept a signature field on tool_use blocks.
|
||||
func SanitizeAmpRequestBody(body []byte) []byte {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
@@ -317,21 +320,30 @@ func SanitizeAmpRequestBody(body []byte) []byte {
|
||||
}
|
||||
|
||||
var keepBlocks []interface{}
|
||||
removedCount := 0
|
||||
contentModified := false
|
||||
|
||||
for _, block := range content.Array() {
|
||||
blockType := block.Get("type").String()
|
||||
if blockType == "thinking" {
|
||||
sig := block.Get("signature")
|
||||
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
|
||||
removedCount++
|
||||
contentModified = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
keepBlocks = append(keepBlocks, block.Value())
|
||||
|
||||
// Use raw JSON to prevent float64 rounding of large integers in tool_use inputs
|
||||
blockRaw := []byte(block.Raw)
|
||||
if blockType == "tool_use" && block.Get("signature").Exists() {
|
||||
blockRaw, _ = sjson.DeleteBytes(blockRaw, "signature")
|
||||
contentModified = true
|
||||
}
|
||||
|
||||
// sjson.SetBytes supports raw JSON strings if wrapped in gjson.Raw
|
||||
keepBlocks = append(keepBlocks, json.RawMessage(blockRaw))
|
||||
}
|
||||
|
||||
if removedCount > 0 {
|
||||
if contentModified {
|
||||
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
|
||||
var err error
|
||||
if len(keepBlocks) == 0 {
|
||||
@@ -340,11 +352,10 @@ func SanitizeAmpRequestBody(body []byte) []byte {
|
||||
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
|
||||
}
|
||||
if err != nil {
|
||||
log.Warnf("Amp RequestSanitizer: failed to remove thinking blocks from message %d: %v", msgIdx, err)
|
||||
log.Warnf("Amp RequestSanitizer: failed to sanitize message %d: %v", msgIdx, err)
|
||||
continue
|
||||
}
|
||||
modified = true
|
||||
log.Debugf("Amp RequestSanitizer: removed %d thinking blocks with invalid signatures from message %d", removedCount, msgIdx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -145,6 +145,36 @@ func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testi
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeAmpRequestBody_StripsSignatureFromToolUseBlocks(t *testing.T) {
|
||||
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"thought","signature":"valid-sig"},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
|
||||
result := SanitizeAmpRequestBody(input)
|
||||
|
||||
if contains(result, []byte(`"signature":""`)) {
|
||||
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte(`"valid-sig"`)) {
|
||||
t.Fatalf("expected thinking signature to remain, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte(`"tool_use"`)) {
|
||||
t.Fatalf("expected tool_use block to remain, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testing.T) {
|
||||
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-me","signature":""},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
|
||||
result := SanitizeAmpRequestBody(input)
|
||||
|
||||
if contains(result, []byte("drop-me")) {
|
||||
t.Fatalf("expected invalid thinking block to be removed, got %s", string(result))
|
||||
}
|
||||
if contains(result, []byte(`"signature"`)) {
|
||||
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte(`"tool_use"`)) {
|
||||
t.Fatalf("expected tool_use block to remain, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func contains(data, substr []byte) bool {
|
||||
for i := 0; i <= len(data)-len(substr); i++ {
|
||||
if string(data[i:i+len(substr)]) == string(substr) {
|
||||
|
||||
@@ -573,6 +573,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
||||
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
||||
|
||||
mgmt.GET("/copilot-quota", s.mgmt.GetCopilotQuota)
|
||||
|
||||
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
|
||||
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
|
||||
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
|
||||
|
||||
@@ -235,6 +235,74 @@ type CopilotModelEntry struct {
|
||||
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||
}
|
||||
|
||||
// CopilotModelLimits holds the token limits returned by the Copilot /models API
|
||||
// under capabilities.limits. These limits vary by account type (individual vs
|
||||
// business) and are the authoritative source for enforcing prompt size.
|
||||
type CopilotModelLimits struct {
|
||||
// MaxContextWindowTokens is the total context window (prompt + output).
|
||||
MaxContextWindowTokens int
|
||||
// MaxPromptTokens is the hard limit on input/prompt tokens.
|
||||
// Exceeding this triggers a 400 error from the Copilot API.
|
||||
MaxPromptTokens int
|
||||
// MaxOutputTokens is the maximum number of output/completion tokens.
|
||||
MaxOutputTokens int
|
||||
}
|
||||
|
||||
// Limits extracts the token limits from the model's capabilities map.
|
||||
// Returns nil if no limits are available or the structure is unexpected.
|
||||
//
|
||||
// Expected Copilot API shape:
|
||||
//
|
||||
// "capabilities": {
|
||||
// "limits": {
|
||||
// "max_context_window_tokens": 200000,
|
||||
// "max_prompt_tokens": 168000,
|
||||
// "max_output_tokens": 32000
|
||||
// }
|
||||
// }
|
||||
func (e *CopilotModelEntry) Limits() *CopilotModelLimits {
|
||||
if e.Capabilities == nil {
|
||||
return nil
|
||||
}
|
||||
limitsRaw, ok := e.Capabilities["limits"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
limitsMap, ok := limitsRaw.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := &CopilotModelLimits{
|
||||
MaxContextWindowTokens: anyToInt(limitsMap["max_context_window_tokens"]),
|
||||
MaxPromptTokens: anyToInt(limitsMap["max_prompt_tokens"]),
|
||||
MaxOutputTokens: anyToInt(limitsMap["max_output_tokens"]),
|
||||
}
|
||||
|
||||
// Only return if at least one field is populated.
|
||||
if result.MaxContextWindowTokens == 0 && result.MaxPromptTokens == 0 && result.MaxOutputTokens == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// anyToInt converts a JSON-decoded numeric value to int.
|
||||
// Go's encoding/json decodes numbers into float64 when the target is any/interface{}.
|
||||
func anyToInt(v any) int {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n)
|
||||
case float32:
|
||||
return int(n)
|
||||
case int:
|
||||
return n
|
||||
case int64:
|
||||
return int(n)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// CopilotModelsResponse represents the response from the Copilot /models endpoint.
|
||||
type CopilotModelsResponse struct {
|
||||
Data []CopilotModelEntry `json:"data"`
|
||||
|
||||
@@ -30,6 +30,10 @@ type VertexCredentialStorage struct {
|
||||
|
||||
// Type is the provider identifier stored alongside credentials. Always "vertex".
|
||||
Type string `json:"type"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA").
|
||||
// This results in model names like "teamA/gemini-2.0-flash".
|
||||
Prefix string `json:"prefix,omitempty"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile writes the credential payload to the given file path in JSON format.
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
// DoVertexImport imports a Google Cloud service account key JSON and persists
|
||||
// it as a "vertex" provider credential. The file content is embedded in the auth
|
||||
// file to allow portable deployment across stores.
|
||||
func DoVertexImport(cfg *config.Config, keyPath string) {
|
||||
func DoVertexImport(cfg *config.Config, keyPath string, prefix string) {
|
||||
if cfg == nil {
|
||||
cfg = &config.Config{}
|
||||
}
|
||||
@@ -62,13 +62,28 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
|
||||
// Default location if not provided by user. Can be edited in the saved file later.
|
||||
location := "us-central1"
|
||||
|
||||
fileName := fmt.Sprintf("vertex-%s.json", sanitizeFilePart(projectID))
|
||||
// Normalize and validate prefix: must be a single segment (no "/" allowed).
|
||||
prefix = strings.TrimSpace(prefix)
|
||||
prefix = strings.Trim(prefix, "/")
|
||||
if prefix != "" && strings.Contains(prefix, "/") {
|
||||
log.Errorf("vertex-import: prefix must be a single segment (no '/' allowed): %q", prefix)
|
||||
return
|
||||
}
|
||||
|
||||
// Include prefix in filename so importing the same project with different
|
||||
// prefixes creates separate credential files instead of overwriting.
|
||||
baseName := sanitizeFilePart(projectID)
|
||||
if prefix != "" {
|
||||
baseName = sanitizeFilePart(prefix) + "-" + baseName
|
||||
}
|
||||
fileName := fmt.Sprintf("vertex-%s.json", baseName)
|
||||
// Build auth record
|
||||
storage := &vertex.VertexCredentialStorage{
|
||||
ServiceAccount: sa,
|
||||
ProjectID: projectID,
|
||||
Email: email,
|
||||
Location: location,
|
||||
Prefix: prefix,
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"service_account": sa,
|
||||
@@ -76,6 +91,7 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
|
||||
"email": email,
|
||||
"location": location,
|
||||
"type": "vertex",
|
||||
"prefix": prefix,
|
||||
"label": labelForVertex(projectID, email),
|
||||
}
|
||||
record := &coreauth.Auth{
|
||||
|
||||
151
internal/misc/antigravity_version.go
Normal file
151
internal/misc/antigravity_version.go
Normal file
@@ -0,0 +1,151 @@
|
||||
// Package misc provides miscellaneous utility functions for the CLI Proxy API server.
|
||||
package misc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityReleasesURL = "https://antigravity-auto-updater-974169037036.us-central1.run.app/releases"
|
||||
antigravityFallbackVersion = "1.21.9"
|
||||
antigravityVersionCacheTTL = 6 * time.Hour
|
||||
antigravityFetchTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type antigravityRelease struct {
|
||||
Version string `json:"version"`
|
||||
ExecutionID string `json:"execution_id"`
|
||||
}
|
||||
|
||||
var (
|
||||
cachedAntigravityVersion = antigravityFallbackVersion
|
||||
antigravityVersionMu sync.RWMutex
|
||||
antigravityVersionExpiry time.Time
|
||||
antigravityUpdaterOnce sync.Once
|
||||
)
|
||||
|
||||
// StartAntigravityVersionUpdater starts a background goroutine that periodically refreshes the cached antigravity version.
|
||||
// This is intentionally decoupled from request execution to avoid blocking executors on version lookups.
|
||||
func StartAntigravityVersionUpdater(ctx context.Context) {
|
||||
antigravityUpdaterOnce.Do(func() {
|
||||
go runAntigravityVersionUpdater(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func runAntigravityVersionUpdater(ctx context.Context) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(antigravityVersionCacheTTL / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
log.Infof("periodic antigravity version refresh started (interval=%s)", antigravityVersionCacheTTL/2)
|
||||
|
||||
refreshAntigravityVersion(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
refreshAntigravityVersion(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func refreshAntigravityVersion(ctx context.Context) {
|
||||
version, errFetch := fetchAntigravityLatestVersion(ctx)
|
||||
|
||||
antigravityVersionMu.Lock()
|
||||
defer antigravityVersionMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
if errFetch == nil {
|
||||
cachedAntigravityVersion = version
|
||||
antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL)
|
||||
log.WithField("version", version).Info("fetched latest antigravity version")
|
||||
return
|
||||
}
|
||||
|
||||
if cachedAntigravityVersion == "" || now.After(antigravityVersionExpiry) {
|
||||
cachedAntigravityVersion = antigravityFallbackVersion
|
||||
antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL)
|
||||
log.WithError(errFetch).Warn("failed to refresh antigravity version, using fallback version")
|
||||
return
|
||||
}
|
||||
|
||||
log.WithError(errFetch).Debug("failed to refresh antigravity version, keeping cached value")
|
||||
}
|
||||
|
||||
// AntigravityLatestVersion returns the cached antigravity version refreshed by StartAntigravityVersionUpdater.
|
||||
// It falls back to antigravityFallbackVersion if the cache is empty or stale.
|
||||
func AntigravityLatestVersion() string {
|
||||
antigravityVersionMu.RLock()
|
||||
if cachedAntigravityVersion != "" && time.Now().Before(antigravityVersionExpiry) {
|
||||
v := cachedAntigravityVersion
|
||||
antigravityVersionMu.RUnlock()
|
||||
return v
|
||||
}
|
||||
antigravityVersionMu.RUnlock()
|
||||
|
||||
return antigravityFallbackVersion
|
||||
}
|
||||
|
||||
// AntigravityUserAgent returns the User-Agent string for antigravity requests
|
||||
// using the latest version fetched from the releases API.
|
||||
func AntigravityUserAgent() string {
|
||||
return fmt.Sprintf("antigravity/%s darwin/arm64", AntigravityLatestVersion())
|
||||
}
|
||||
|
||||
func fetchAntigravityLatestVersion(ctx context.Context) (string, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: antigravityFetchTimeout}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodGet, antigravityReleasesURL, nil)
|
||||
if errReq != nil {
|
||||
return "", fmt.Errorf("build antigravity releases request: %w", errReq)
|
||||
}
|
||||
|
||||
resp, errDo := client.Do(httpReq)
|
||||
if errDo != nil {
|
||||
return "", fmt.Errorf("fetch antigravity releases: %w", errDo)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("antigravity releases response body close error")
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("antigravity releases API returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var releases []antigravityRelease
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&releases); errDecode != nil {
|
||||
return "", fmt.Errorf("decode antigravity releases response: %w", errDecode)
|
||||
}
|
||||
|
||||
if len(releases) == 0 {
|
||||
return "", errors.New("antigravity releases API returned empty list")
|
||||
}
|
||||
|
||||
version := releases[0].Version
|
||||
if version == "" {
|
||||
return "", errors.New("antigravity releases API returned empty version")
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
@@ -549,6 +549,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4.6",
|
||||
@@ -561,6 +562,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4",
|
||||
@@ -573,6 +575,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4.5",
|
||||
@@ -585,6 +588,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4.6",
|
||||
@@ -597,6 +601,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-2.5-pro",
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
@@ -45,7 +46,7 @@ const (
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
|
||||
defaultAntigravityAgent = "antigravity/1.21.9 darwin/arm64" // fallback only; overridden at runtime by misc.AntigravityUserAgent()
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
antigravityCreditsRetryTTL = 5 * time.Hour
|
||||
@@ -1739,7 +1740,7 @@ func resolveUserAgent(auth *cliproxyauth.Auth) string {
|
||||
}
|
||||
}
|
||||
}
|
||||
return defaultAntigravityAgent
|
||||
return misc.AntigravityUserAgent()
|
||||
}
|
||||
|
||||
func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -841,6 +840,9 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
baseBetas += ",oauth-2025-04-20"
|
||||
}
|
||||
}
|
||||
if !strings.Contains(baseBetas, "interleaved-thinking") {
|
||||
baseBetas += ",interleaved-thinking-2025-05-14"
|
||||
}
|
||||
|
||||
hasClaude1MHeader := false
|
||||
if ginHeaders != nil {
|
||||
@@ -848,6 +850,14 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
hasClaude1MHeader = true
|
||||
}
|
||||
}
|
||||
// Also check auth attributes — GitLab Duo sets gitlab_duo_force_context_1m
|
||||
// when routing through the Anthropic gateway, but the gin headers won't have
|
||||
// X-CPA-CLAUDE-1M because the request is internally constructed.
|
||||
if !hasClaude1MHeader && auth != nil && auth.Attributes != nil {
|
||||
if auth.Attributes["gitlab_duo_force_context_1m"] == "true" {
|
||||
hasClaude1MHeader = true
|
||||
}
|
||||
}
|
||||
|
||||
// Merge extra betas from request body and request flags.
|
||||
if len(extraBetas) > 0 || hasClaude1MHeader {
|
||||
@@ -949,12 +959,9 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
// Collect built-in tool names (those with a non-empty "type" field) so we can
|
||||
// skip them consistently in both tools and message history.
|
||||
builtinTools := map[string]bool{}
|
||||
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
|
||||
builtinTools[name] = true
|
||||
}
|
||||
// Collect built-in tool names from the authoritative fallback seed list and
|
||||
// augment it with any typed built-ins present in the current request body.
|
||||
builtinTools := helps.AugmentClaudeBuiltinToolRegistry(body, nil)
|
||||
|
||||
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
||||
tools.ForEach(func(index, tool gjson.Result) bool {
|
||||
@@ -1463,182 +1470,6 @@ func countCacheControls(payload []byte) int {
|
||||
return count
|
||||
}
|
||||
|
||||
func parsePayloadObject(payload []byte) (map[string]any, bool) {
|
||||
if len(payload) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
var root map[string]any
|
||||
if err := json.Unmarshal(payload, &root); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return root, true
|
||||
}
|
||||
|
||||
func marshalPayloadObject(original []byte, root map[string]any) []byte {
|
||||
if root == nil {
|
||||
return original
|
||||
}
|
||||
out, err := json.Marshal(root)
|
||||
if err != nil {
|
||||
return original
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func asObject(v any) (map[string]any, bool) {
|
||||
obj, ok := v.(map[string]any)
|
||||
return obj, ok
|
||||
}
|
||||
|
||||
func asArray(v any) ([]any, bool) {
|
||||
arr, ok := v.([]any)
|
||||
return arr, ok
|
||||
}
|
||||
|
||||
func countCacheControlsMap(root map[string]any) int {
|
||||
count := 0
|
||||
|
||||
if system, ok := asArray(root["system"]); ok {
|
||||
for _, item := range system {
|
||||
if obj, ok := asObject(item); ok {
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tools, ok := asArray(root["tools"]); ok {
|
||||
for _, item := range tools {
|
||||
if obj, ok := asObject(item); ok {
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if messages, ok := asArray(root["messages"]); ok {
|
||||
for _, msg := range messages {
|
||||
msgObj, ok := asObject(msg)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := asArray(msgObj["content"])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, item := range content {
|
||||
if obj, ok := asObject(item); ok {
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool {
|
||||
ccRaw, exists := obj["cache_control"]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
cc, ok := asObject(ccRaw)
|
||||
if !ok {
|
||||
*seen5m = true
|
||||
return false
|
||||
}
|
||||
ttlRaw, ttlExists := cc["ttl"]
|
||||
ttl, ttlIsString := ttlRaw.(string)
|
||||
if !ttlExists || !ttlIsString || ttl != "1h" {
|
||||
*seen5m = true
|
||||
return false
|
||||
}
|
||||
if *seen5m {
|
||||
delete(cc, "ttl")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func findLastCacheControlIndex(arr []any) int {
|
||||
last := -1
|
||||
for idx, item := range arr {
|
||||
obj, ok := asObject(item)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
last = idx
|
||||
}
|
||||
}
|
||||
return last
|
||||
}
|
||||
|
||||
func stripCacheControlExceptIndex(arr []any, preserveIdx int, excess *int) {
|
||||
for idx, item := range arr {
|
||||
if *excess <= 0 {
|
||||
return
|
||||
}
|
||||
obj, ok := asObject(item)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := obj["cache_control"]; exists && idx != preserveIdx {
|
||||
delete(obj, "cache_control")
|
||||
*excess--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stripAllCacheControl(arr []any, excess *int) {
|
||||
for _, item := range arr {
|
||||
if *excess <= 0 {
|
||||
return
|
||||
}
|
||||
obj, ok := asObject(item)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
delete(obj, "cache_control")
|
||||
*excess--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stripMessageCacheControl(messages []any, excess *int) {
|
||||
for _, msg := range messages {
|
||||
if *excess <= 0 {
|
||||
return
|
||||
}
|
||||
msgObj, ok := asObject(msg)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := asArray(msgObj["content"])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, item := range content {
|
||||
if *excess <= 0 {
|
||||
return
|
||||
}
|
||||
obj, ok := asObject(item)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
delete(obj, "cache_control")
|
||||
*excess--
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeCacheControlTTL ensures cache_control TTL values don't violate the
|
||||
// prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not
|
||||
// appear after a 5m-TTL block anywhere in the evaluation order.
|
||||
@@ -1651,58 +1482,75 @@ func stripMessageCacheControl(messages []any, excess *int) {
|
||||
// Strategy: walk all cache_control blocks in evaluation order. Once a 5m block
|
||||
// is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m).
|
||||
func normalizeCacheControlTTL(payload []byte) []byte {
|
||||
root, ok := parsePayloadObject(payload)
|
||||
if !ok {
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return payload
|
||||
}
|
||||
|
||||
original := payload
|
||||
seen5m := false
|
||||
modified := false
|
||||
|
||||
if tools, ok := asArray(root["tools"]); ok {
|
||||
for _, tool := range tools {
|
||||
if obj, ok := asObject(tool); ok {
|
||||
if normalizeTTLForBlock(obj, &seen5m) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
processBlock := func(path string, obj gjson.Result) {
|
||||
cc := obj.Get("cache_control")
|
||||
if !cc.Exists() {
|
||||
return
|
||||
}
|
||||
if !cc.IsObject() {
|
||||
seen5m = true
|
||||
return
|
||||
}
|
||||
ttl := cc.Get("ttl")
|
||||
if ttl.Type != gjson.String || ttl.String() != "1h" {
|
||||
seen5m = true
|
||||
return
|
||||
}
|
||||
if !seen5m {
|
||||
return
|
||||
}
|
||||
ttlPath := path + ".cache_control.ttl"
|
||||
updated, errDel := sjson.DeleteBytes(payload, ttlPath)
|
||||
if errDel != nil {
|
||||
return
|
||||
}
|
||||
payload = updated
|
||||
modified = true
|
||||
}
|
||||
|
||||
if system, ok := asArray(root["system"]); ok {
|
||||
for _, item := range system {
|
||||
if obj, ok := asObject(item); ok {
|
||||
if normalizeTTLForBlock(obj, &seen5m) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
tools := gjson.GetBytes(payload, "tools")
|
||||
if tools.IsArray() {
|
||||
tools.ForEach(func(idx, item gjson.Result) bool {
|
||||
processBlock(fmt.Sprintf("tools.%d", int(idx.Int())), item)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if messages, ok := asArray(root["messages"]); ok {
|
||||
for _, msg := range messages {
|
||||
msgObj, ok := asObject(msg)
|
||||
if !ok {
|
||||
continue
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(idx, item gjson.Result) bool {
|
||||
processBlock(fmt.Sprintf("system.%d", int(idx.Int())), item)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(payload, "messages")
|
||||
if messages.IsArray() {
|
||||
messages.ForEach(func(msgIdx, msg gjson.Result) bool {
|
||||
content := msg.Get("content")
|
||||
if !content.IsArray() {
|
||||
return true
|
||||
}
|
||||
content, ok := asArray(msgObj["content"])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, item := range content {
|
||||
if obj, ok := asObject(item); ok {
|
||||
if normalizeTTLForBlock(obj, &seen5m) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
content.ForEach(func(itemIdx, item gjson.Result) bool {
|
||||
processBlock(fmt.Sprintf("messages.%d.content.%d", int(msgIdx.Int()), int(itemIdx.Int())), item)
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return payload
|
||||
return original
|
||||
}
|
||||
return marshalPayloadObject(payload, root)
|
||||
return payload
|
||||
}
|
||||
|
||||
// enforceCacheControlLimit removes excess cache_control blocks from a payload
|
||||
@@ -1722,64 +1570,166 @@ func normalizeCacheControlTTL(payload []byte) []byte {
|
||||
// Phase 4: remaining system blocks (last system).
|
||||
// Phase 5: remaining tool blocks (last tool).
|
||||
func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte {
|
||||
root, ok := parsePayloadObject(payload)
|
||||
if !ok {
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return payload
|
||||
}
|
||||
|
||||
total := countCacheControlsMap(root)
|
||||
total := countCacheControls(payload)
|
||||
if total <= maxBlocks {
|
||||
return payload
|
||||
}
|
||||
|
||||
excess := total - maxBlocks
|
||||
|
||||
var system []any
|
||||
if arr, ok := asArray(root["system"]); ok {
|
||||
system = arr
|
||||
}
|
||||
var tools []any
|
||||
if arr, ok := asArray(root["tools"]); ok {
|
||||
tools = arr
|
||||
}
|
||||
var messages []any
|
||||
if arr, ok := asArray(root["messages"]); ok {
|
||||
messages = arr
|
||||
}
|
||||
|
||||
if len(system) > 0 {
|
||||
stripCacheControlExceptIndex(system, findLastCacheControlIndex(system), &excess)
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
if system.IsArray() {
|
||||
lastIdx := -1
|
||||
system.ForEach(func(idx, item gjson.Result) bool {
|
||||
if item.Get("cache_control").Exists() {
|
||||
lastIdx = int(idx.Int())
|
||||
}
|
||||
return true
|
||||
})
|
||||
if lastIdx >= 0 {
|
||||
system.ForEach(func(idx, item gjson.Result) bool {
|
||||
if excess <= 0 {
|
||||
return false
|
||||
}
|
||||
i := int(idx.Int())
|
||||
if i == lastIdx {
|
||||
return true
|
||||
}
|
||||
if !item.Get("cache_control").Exists() {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("system.%d.cache_control", i)
|
||||
updated, errDel := sjson.DeleteBytes(payload, path)
|
||||
if errDel != nil {
|
||||
return true
|
||||
}
|
||||
payload = updated
|
||||
excess--
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
if excess <= 0 {
|
||||
return marshalPayloadObject(payload, root)
|
||||
return payload
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
stripCacheControlExceptIndex(tools, findLastCacheControlIndex(tools), &excess)
|
||||
tools := gjson.GetBytes(payload, "tools")
|
||||
if tools.IsArray() {
|
||||
lastIdx := -1
|
||||
tools.ForEach(func(idx, item gjson.Result) bool {
|
||||
if item.Get("cache_control").Exists() {
|
||||
lastIdx = int(idx.Int())
|
||||
}
|
||||
return true
|
||||
})
|
||||
if lastIdx >= 0 {
|
||||
tools.ForEach(func(idx, item gjson.Result) bool {
|
||||
if excess <= 0 {
|
||||
return false
|
||||
}
|
||||
i := int(idx.Int())
|
||||
if i == lastIdx {
|
||||
return true
|
||||
}
|
||||
if !item.Get("cache_control").Exists() {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("tools.%d.cache_control", i)
|
||||
updated, errDel := sjson.DeleteBytes(payload, path)
|
||||
if errDel != nil {
|
||||
return true
|
||||
}
|
||||
payload = updated
|
||||
excess--
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
if excess <= 0 {
|
||||
return marshalPayloadObject(payload, root)
|
||||
return payload
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
stripMessageCacheControl(messages, &excess)
|
||||
messages := gjson.GetBytes(payload, "messages")
|
||||
if messages.IsArray() {
|
||||
messages.ForEach(func(msgIdx, msg gjson.Result) bool {
|
||||
if excess <= 0 {
|
||||
return false
|
||||
}
|
||||
content := msg.Get("content")
|
||||
if !content.IsArray() {
|
||||
return true
|
||||
}
|
||||
content.ForEach(func(itemIdx, item gjson.Result) bool {
|
||||
if excess <= 0 {
|
||||
return false
|
||||
}
|
||||
if !item.Get("cache_control").Exists() {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.cache_control", int(msgIdx.Int()), int(itemIdx.Int()))
|
||||
updated, errDel := sjson.DeleteBytes(payload, path)
|
||||
if errDel != nil {
|
||||
return true
|
||||
}
|
||||
payload = updated
|
||||
excess--
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
}
|
||||
if excess <= 0 {
|
||||
return marshalPayloadObject(payload, root)
|
||||
return payload
|
||||
}
|
||||
|
||||
if len(system) > 0 {
|
||||
stripAllCacheControl(system, &excess)
|
||||
system = gjson.GetBytes(payload, "system")
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(idx, item gjson.Result) bool {
|
||||
if excess <= 0 {
|
||||
return false
|
||||
}
|
||||
if !item.Get("cache_control").Exists() {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("system.%d.cache_control", int(idx.Int()))
|
||||
updated, errDel := sjson.DeleteBytes(payload, path)
|
||||
if errDel != nil {
|
||||
return true
|
||||
}
|
||||
payload = updated
|
||||
excess--
|
||||
return true
|
||||
})
|
||||
}
|
||||
if excess <= 0 {
|
||||
return marshalPayloadObject(payload, root)
|
||||
return payload
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
stripAllCacheControl(tools, &excess)
|
||||
tools = gjson.GetBytes(payload, "tools")
|
||||
if tools.IsArray() {
|
||||
tools.ForEach(func(idx, item gjson.Result) bool {
|
||||
if excess <= 0 {
|
||||
return false
|
||||
}
|
||||
if !item.Get("cache_control").Exists() {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("tools.%d.cache_control", int(idx.Int()))
|
||||
updated, errDel := sjson.DeleteBytes(payload, path)
|
||||
if errDel != nil {
|
||||
return true
|
||||
}
|
||||
payload = updated
|
||||
excess--
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return marshalPayloadObject(payload, root)
|
||||
return payload
|
||||
}
|
||||
|
||||
// injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.
|
||||
|
||||
@@ -739,6 +739,35 @@ func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_KnownFallbackBuiltinsRemainUnprefixed(t *testing.T) {
|
||||
for _, builtin := range []string{"web_search", "code_execution", "text_editor", "computer"} {
|
||||
t.Run(builtin, func(t *testing.T) {
|
||||
input := []byte(fmt.Sprintf(`{
|
||||
"tools":[{"name":"Read"}],
|
||||
"tool_choice":{"type":"tool","name":%q},
|
||||
"messages":[{"role":"assistant","content":[{"type":"tool_use","name":%q,"id":"toolu_1","input":{}},{"type":"tool_reference","tool_name":%q},{"type":"tool_result","tool_use_id":"toolu_1","content":[{"type":"tool_reference","tool_name":%q}]}]}]
|
||||
}`, builtin, builtin, builtin, builtin))
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != builtin {
|
||||
t.Fatalf("tool_choice.name = %q, want %q", got, builtin)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != builtin {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, builtin)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != builtin {
|
||||
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, builtin)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.2.content.0.tool_name").String(); got != builtin {
|
||||
t.Fatalf("messages.0.content.2.content.0.tool_name = %q, want %q", got, builtin)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
@@ -965,6 +994,28 @@ func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCacheControlTTL_PreservesKeyOrderWhenModified(t *testing.T) {
|
||||
payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`)
|
||||
|
||||
out := normalizeCacheControlTTL(payload)
|
||||
|
||||
if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() {
|
||||
t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block")
|
||||
}
|
||||
|
||||
outStr := string(out)
|
||||
idxModel := strings.Index(outStr, `"model"`)
|
||||
idxMessages := strings.Index(outStr, `"messages"`)
|
||||
idxTools := strings.Index(outStr, `"tools"`)
|
||||
idxSystem := strings.Index(outStr, `"system"`)
|
||||
if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 {
|
||||
t.Fatalf("failed to locate top-level keys in output: %s", outStr)
|
||||
}
|
||||
if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) {
|
||||
t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"tools": [
|
||||
@@ -994,6 +1045,31 @@ func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceCacheControlLimit_PreservesKeyOrderWhenModified(t *testing.T) {
|
||||
payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}},{"name":"t2","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`)
|
||||
|
||||
out := enforceCacheControlLimit(payload, 4)
|
||||
|
||||
if got := countCacheControls(out); got != 4 {
|
||||
t.Fatalf("cache_control count = %d, want 4", got)
|
||||
}
|
||||
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
|
||||
t.Fatalf("tools.0.cache_control should be removed first (non-last tool)")
|
||||
}
|
||||
|
||||
outStr := string(out)
|
||||
idxModel := strings.Index(outStr, `"model"`)
|
||||
idxMessages := strings.Index(outStr, `"messages"`)
|
||||
idxTools := strings.Index(outStr, `"tools"`)
|
||||
idxSystem := strings.Index(outStr, `"system"`)
|
||||
if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 {
|
||||
t.Fatalf("failed to locate top-level keys in output: %s", outStr)
|
||||
}
|
||||
if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) {
|
||||
t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"tools": [
|
||||
|
||||
@@ -4,9 +4,11 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
||||
@@ -14,8 +16,11 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -98,10 +103,12 @@ func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
translated, _ = sjson.SetBytes(translated, "stream", true)
|
||||
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -114,6 +121,8 @@ func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
return resp, err
|
||||
}
|
||||
e.applyHeaders(httpReq, accessToken, userID, domain)
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
httpReq.Header.Set("Cache-Control", "no-cache")
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -160,11 +169,16 @@ func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, body)
|
||||
reporter.publish(ctx, parseOpenAIUsage(body))
|
||||
aggregatedBody, usageDetail, err := aggregateOpenAIChatCompletionStream(body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
reporter.publish(ctx, usageDetail)
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, aggregatedBody, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -341,3 +355,197 @@ func (e *CodeBuddyExecutor) applyHeaders(req *http.Request, accessToken, userID,
|
||||
req.Header.Set("X-IDE-Version", "2.63.2")
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
}
|
||||
|
||||
type openAIChatStreamChoiceAccumulator struct {
|
||||
Role string
|
||||
ContentParts []string
|
||||
ReasoningParts []string
|
||||
FinishReason string
|
||||
ToolCalls map[int]*openAIChatStreamToolCallAccumulator
|
||||
ToolCallOrder []int
|
||||
NativeFinishReason any
|
||||
}
|
||||
|
||||
type openAIChatStreamToolCallAccumulator struct {
|
||||
ID string
|
||||
Type string
|
||||
Name string
|
||||
Arguments strings.Builder
|
||||
}
|
||||
|
||||
func aggregateOpenAIChatCompletionStream(raw []byte) ([]byte, usage.Detail, error) {
|
||||
lines := bytes.Split(raw, []byte("\n"))
|
||||
var (
|
||||
responseID string
|
||||
model string
|
||||
created int64
|
||||
serviceTier string
|
||||
systemFP string
|
||||
usageDetail usage.Detail
|
||||
choices = map[int]*openAIChatStreamChoiceAccumulator{}
|
||||
choiceOrder []int
|
||||
)
|
||||
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
payload := bytes.TrimSpace(line[5:])
|
||||
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if !gjson.ValidBytes(payload) {
|
||||
continue
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(payload)
|
||||
if responseID == "" {
|
||||
responseID = root.Get("id").String()
|
||||
}
|
||||
if model == "" {
|
||||
model = root.Get("model").String()
|
||||
}
|
||||
if created == 0 {
|
||||
created = root.Get("created").Int()
|
||||
}
|
||||
if serviceTier == "" {
|
||||
serviceTier = root.Get("service_tier").String()
|
||||
}
|
||||
if systemFP == "" {
|
||||
systemFP = root.Get("system_fingerprint").String()
|
||||
}
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
usageDetail = detail
|
||||
}
|
||||
|
||||
for _, choiceResult := range root.Get("choices").Array() {
|
||||
idx := int(choiceResult.Get("index").Int())
|
||||
choice := choices[idx]
|
||||
if choice == nil {
|
||||
choice = &openAIChatStreamChoiceAccumulator{ToolCalls: map[int]*openAIChatStreamToolCallAccumulator{}}
|
||||
choices[idx] = choice
|
||||
choiceOrder = append(choiceOrder, idx)
|
||||
}
|
||||
|
||||
delta := choiceResult.Get("delta")
|
||||
if role := delta.Get("role").String(); role != "" {
|
||||
choice.Role = role
|
||||
}
|
||||
if content := delta.Get("content").String(); content != "" {
|
||||
choice.ContentParts = append(choice.ContentParts, content)
|
||||
}
|
||||
if reasoning := delta.Get("reasoning_content").String(); reasoning != "" {
|
||||
choice.ReasoningParts = append(choice.ReasoningParts, reasoning)
|
||||
}
|
||||
if finishReason := choiceResult.Get("finish_reason").String(); finishReason != "" {
|
||||
choice.FinishReason = finishReason
|
||||
}
|
||||
if nativeFinishReason := choiceResult.Get("native_finish_reason"); nativeFinishReason.Exists() {
|
||||
choice.NativeFinishReason = nativeFinishReason.Value()
|
||||
}
|
||||
|
||||
for _, toolCallResult := range delta.Get("tool_calls").Array() {
|
||||
toolIdx := int(toolCallResult.Get("index").Int())
|
||||
toolCall := choice.ToolCalls[toolIdx]
|
||||
if toolCall == nil {
|
||||
toolCall = &openAIChatStreamToolCallAccumulator{}
|
||||
choice.ToolCalls[toolIdx] = toolCall
|
||||
choice.ToolCallOrder = append(choice.ToolCallOrder, toolIdx)
|
||||
}
|
||||
if id := toolCallResult.Get("id").String(); id != "" {
|
||||
toolCall.ID = id
|
||||
}
|
||||
if typ := toolCallResult.Get("type").String(); typ != "" {
|
||||
toolCall.Type = typ
|
||||
}
|
||||
if name := toolCallResult.Get("function.name").String(); name != "" {
|
||||
toolCall.Name = name
|
||||
}
|
||||
if args := toolCallResult.Get("function.arguments").String(); args != "" {
|
||||
toolCall.Arguments.WriteString(args)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if responseID == "" && model == "" && len(choiceOrder) == 0 {
|
||||
return nil, usageDetail, fmt.Errorf("codebuddy: streaming response did not contain any chat completion chunks")
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"id": responseID,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": make([]map[string]any, 0, len(choiceOrder)),
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": usageDetail.InputTokens,
|
||||
"completion_tokens": usageDetail.OutputTokens,
|
||||
"total_tokens": usageDetail.TotalTokens,
|
||||
},
|
||||
}
|
||||
if serviceTier != "" {
|
||||
response["service_tier"] = serviceTier
|
||||
}
|
||||
if systemFP != "" {
|
||||
response["system_fingerprint"] = systemFP
|
||||
}
|
||||
|
||||
for _, idx := range choiceOrder {
|
||||
choice := choices[idx]
|
||||
message := map[string]any{
|
||||
"role": choice.Role,
|
||||
"content": strings.Join(choice.ContentParts, ""),
|
||||
}
|
||||
if message["role"] == "" {
|
||||
message["role"] = "assistant"
|
||||
}
|
||||
if len(choice.ReasoningParts) > 0 {
|
||||
message["reasoning_content"] = strings.Join(choice.ReasoningParts, "")
|
||||
}
|
||||
if len(choice.ToolCallOrder) > 0 {
|
||||
toolCalls := make([]map[string]any, 0, len(choice.ToolCallOrder))
|
||||
for _, toolIdx := range choice.ToolCallOrder {
|
||||
toolCall := choice.ToolCalls[toolIdx]
|
||||
toolCallType := toolCall.Type
|
||||
if toolCallType == "" {
|
||||
toolCallType = "function"
|
||||
}
|
||||
arguments := toolCall.Arguments.String()
|
||||
if arguments == "" {
|
||||
arguments = "{}"
|
||||
}
|
||||
toolCalls = append(toolCalls, map[string]any{
|
||||
"id": toolCall.ID,
|
||||
"type": toolCallType,
|
||||
"function": map[string]any{
|
||||
"name": toolCall.Name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
message["tool_calls"] = toolCalls
|
||||
}
|
||||
|
||||
finishReason := choice.FinishReason
|
||||
if finishReason == "" {
|
||||
finishReason = "stop"
|
||||
}
|
||||
choicePayload := map[string]any{
|
||||
"index": idx,
|
||||
"message": message,
|
||||
"finish_reason": finishReason,
|
||||
}
|
||||
if choice.NativeFinishReason != nil {
|
||||
choicePayload["native_finish_reason"] = choice.NativeFinishReason
|
||||
}
|
||||
response["choices"] = append(response["choices"].([]map[string]any), choicePayload)
|
||||
}
|
||||
|
||||
out, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return nil, usageDetail, fmt.Errorf("codebuddy: failed to encode aggregated response: %w", err)
|
||||
}
|
||||
return out, usageDetail, nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -167,22 +168,63 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
outputItemsByIndex := make(map[int64][]byte)
|
||||
var outputItemsFallback [][]byte
|
||||
for _, line := range lines {
|
||||
if !bytes.HasPrefix(line, dataTag) {
|
||||
continue
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line[5:])
|
||||
if gjson.GetBytes(line, "type").String() != "response.completed" {
|
||||
eventData := bytes.TrimSpace(line[5:])
|
||||
eventType := gjson.GetBytes(eventData, "type").String()
|
||||
|
||||
if eventType == "response.output_item.done" {
|
||||
itemResult := gjson.GetBytes(eventData, "item")
|
||||
if !itemResult.Exists() || itemResult.Type != gjson.JSON {
|
||||
continue
|
||||
}
|
||||
outputIndexResult := gjson.GetBytes(eventData, "output_index")
|
||||
if outputIndexResult.Exists() {
|
||||
outputItemsByIndex[outputIndexResult.Int()] = []byte(itemResult.Raw)
|
||||
} else {
|
||||
outputItemsFallback = append(outputItemsFallback, []byte(itemResult.Raw))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if detail, ok := helps.ParseCodexUsage(line); ok {
|
||||
if eventType != "response.completed" {
|
||||
continue
|
||||
}
|
||||
|
||||
if detail, ok := helps.ParseCodexUsage(eventData); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
|
||||
completedData := eventData
|
||||
outputResult := gjson.GetBytes(completedData, "response.output")
|
||||
shouldPatchOutput := (!outputResult.Exists() || !outputResult.IsArray() || len(outputResult.Array()) == 0) && (len(outputItemsByIndex) > 0 || len(outputItemsFallback) > 0)
|
||||
if shouldPatchOutput {
|
||||
completedDataPatched := completedData
|
||||
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output", []byte(`[]`))
|
||||
|
||||
indexes := make([]int64, 0, len(outputItemsByIndex))
|
||||
for idx := range outputItemsByIndex {
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
sort.Slice(indexes, func(i, j int) bool {
|
||||
return indexes[i] < indexes[j]
|
||||
})
|
||||
for _, idx := range indexes {
|
||||
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output.-1", outputItemsByIndex[idx])
|
||||
}
|
||||
for _, item := range outputItemsFallback {
|
||||
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output.-1", item)
|
||||
}
|
||||
completedData = completedDataPatched
|
||||
}
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, completedData, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestCodexExecutorExecute_EmptyStreamCompletionOutputUsesOutputItemDone(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]},\"output_index\":0}\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":1775555723,\"status\":\"completed\",\"model\":\"gpt-5.4-mini-2026-03-17\",\"output\":[],\"usage\":{\"input_tokens\":8,\"output_tokens\":28,\"total_tokens\":36}}}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewCodexExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"api_key": "test",
|
||||
}}
|
||||
|
||||
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "gpt-5.4-mini",
|
||||
Payload: []byte(`{"model":"gpt-5.4-mini","messages":[{"role":"user","content":"Say ok"}]}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
Stream: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute error: %v", err)
|
||||
}
|
||||
|
||||
gotContent := gjson.GetBytes(resp.Payload, "choices.0.message.content").String()
|
||||
if gotContent != "ok" {
|
||||
t.Fatalf("choices.0.message.content = %q, want %q; payload=%s", gotContent, "ok", string(resp.Payload))
|
||||
}
|
||||
}
|
||||
@@ -734,7 +734,7 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
|
||||
}
|
||||
|
||||
switch setting.URL.Scheme {
|
||||
case "socks5":
|
||||
case "socks5", "socks5h":
|
||||
var proxyAuth *proxy.Auth
|
||||
if setting.URL.User != nil {
|
||||
username := setting.URL.User.Username()
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -40,7 +42,7 @@ const (
|
||||
copilotEditorVersion = "vscode/1.107.0"
|
||||
copilotPluginVersion = "copilot-chat/0.35.0"
|
||||
copilotIntegrationID = "vscode-chat"
|
||||
copilotOpenAIIntent = "conversation-panel"
|
||||
copilotOpenAIIntent = "conversation-edits"
|
||||
copilotGitHubAPIVer = "2025-04-01"
|
||||
)
|
||||
|
||||
@@ -126,6 +128,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = e.normalizeModel(req.Model, body)
|
||||
body = flattenAssistantContent(body)
|
||||
body = stripUnsupportedBetas(body)
|
||||
|
||||
// Detect vision content before input normalization removes messages
|
||||
hasVision := detectVisionContent(body)
|
||||
@@ -142,6 +145,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
if useResponses {
|
||||
body = normalizeGitHubCopilotResponsesInput(body)
|
||||
body = normalizeGitHubCopilotResponsesTools(body)
|
||||
body = applyGitHubCopilotResponsesDefaults(body)
|
||||
} else {
|
||||
body = normalizeGitHubCopilotChatTools(body)
|
||||
}
|
||||
@@ -225,9 +229,10 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
if useResponses && from.String() == "claude" {
|
||||
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
|
||||
} else {
|
||||
data = normalizeGitHubCopilotReasoningField(data)
|
||||
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
}
|
||||
resp = cliproxyexecutor.Response{Payload: converted}
|
||||
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
@@ -256,6 +261,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body = e.normalizeModel(req.Model, body)
|
||||
body = flattenAssistantContent(body)
|
||||
body = stripUnsupportedBetas(body)
|
||||
|
||||
// Detect vision content before input normalization removes messages
|
||||
hasVision := detectVisionContent(body)
|
||||
@@ -272,6 +278,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
if useResponses {
|
||||
body = normalizeGitHubCopilotResponsesInput(body)
|
||||
body = normalizeGitHubCopilotResponsesTools(body)
|
||||
body = applyGitHubCopilotResponsesDefaults(body)
|
||||
} else {
|
||||
body = normalizeGitHubCopilotChatTools(body)
|
||||
}
|
||||
@@ -378,7 +385,20 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
if useResponses && from.String() == "claude" {
|
||||
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m)
|
||||
} else {
|
||||
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
// Strip SSE "data: " prefix before reasoning field normalization,
|
||||
// since normalizeGitHubCopilotReasoningField expects pure JSON.
|
||||
// Re-wrap with the prefix afterward for the translator.
|
||||
normalizedLine := bytes.Clone(line)
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
sseData := bytes.TrimSpace(line[len(dataTag):])
|
||||
if !bytes.Equal(sseData, []byte("[DONE]")) && gjson.ValidBytes(sseData) {
|
||||
normalized := normalizeGitHubCopilotReasoningField(bytes.Clone(sseData))
|
||||
if !bytes.Equal(normalized, sseData) {
|
||||
normalizedLine = append(append([]byte(nil), dataTag...), normalized...)
|
||||
}
|
||||
}
|
||||
}
|
||||
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, normalizedLine, ¶m)
|
||||
}
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(chunks[i])}
|
||||
@@ -400,9 +420,28 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CountTokens is not supported for GitHub Copilot.
|
||||
func (e *GitHubCopilotExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for github-copilot"}
|
||||
// CountTokens estimates token count locally using tiktoken, since the GitHub
|
||||
// Copilot API does not expose a dedicated token counting endpoint.
|
||||
func (e *GitHubCopilotExecutor) CountTokens(ctx context.Context, _ *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
enc, err := helps.TokenizerForModel(baseModel)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("github copilot executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
|
||||
count, err := helps.CountOpenAIChatTokens(enc, translated)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("github copilot executor: token counting failed: %w", err)
|
||||
}
|
||||
|
||||
usageJSON := helps.BuildOpenAIUsageJSON(count)
|
||||
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
|
||||
}
|
||||
|
||||
// Refresh validates the GitHub token is still working.
|
||||
@@ -491,46 +530,127 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, b
|
||||
r.Header.Set("X-Request-Id", uuid.NewString())
|
||||
|
||||
initiator := "user"
|
||||
if role := detectLastConversationRole(body); role == "assistant" || role == "tool" {
|
||||
if isAgentInitiated(body) {
|
||||
initiator = "agent"
|
||||
}
|
||||
r.Header.Set("X-Initiator", initiator)
|
||||
}
|
||||
|
||||
func detectLastConversationRole(body []byte) string {
|
||||
// isAgentInitiated determines whether the current request is agent-initiated
|
||||
// (tool callbacks, continuations) rather than user-initiated (new user prompt).
|
||||
//
|
||||
// GitHub Copilot uses the X-Initiator header for billing:
|
||||
// - "user" → consumes premium request quota
|
||||
// - "agent" → free (tool loops, continuations)
|
||||
//
|
||||
// The challenge: Claude Code sends tool results as role:"user" messages with
|
||||
// content type "tool_result". After translation to OpenAI format, the tool_result
|
||||
// part becomes a separate role:"tool" message, but if the original Claude message
|
||||
// also contained text content (e.g. skill invocations, attachment descriptions),
|
||||
// a role:"user" message is emitted AFTER the tool message, making the last message
|
||||
// appear user-initiated when it's actually part of an agent tool loop.
|
||||
//
|
||||
// VSCode Copilot Chat solves this with explicit flags (iterationNumber,
|
||||
// isContinuation, subAgentInvocationId). Since CPA doesn't have these flags,
|
||||
// we infer agent status by checking whether the conversation contains prior
|
||||
// assistant/tool messages — if it does, the current request is a continuation.
|
||||
//
|
||||
// References:
|
||||
// - opencode#8030, opencode#15824: same root cause and fix approach
|
||||
// - vscode-copilot-chat: toolCallingLoop.ts (iterationNumber === 0)
|
||||
// - pi-ai: github-copilot-headers.ts (last message role check)
|
||||
func isAgentInitiated(body []byte) bool {
|
||||
if len(body) == 0 {
|
||||
return ""
|
||||
return false
|
||||
}
|
||||
|
||||
// Chat Completions API: check messages array
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
arr := messages.Array()
|
||||
if len(arr) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
lastRole := ""
|
||||
for i := len(arr) - 1; i >= 0; i-- {
|
||||
if role := arr[i].Get("role").String(); role != "" {
|
||||
return role
|
||||
if r := arr[i].Get("role").String(); r != "" {
|
||||
lastRole = r
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If last message is assistant or tool, clearly agent-initiated.
|
||||
if lastRole == "assistant" || lastRole == "tool" {
|
||||
return true
|
||||
}
|
||||
|
||||
// If last message is "user", check whether it contains tool results
|
||||
// (indicating a tool-loop continuation) or if the preceding message
|
||||
// is an assistant tool_use. This is more precise than checking for
|
||||
// any prior assistant message, which would false-positive on genuine
|
||||
// multi-turn follow-ups.
|
||||
if lastRole == "user" {
|
||||
// Check if the last user message contains tool_result content
|
||||
lastContent := arr[len(arr)-1].Get("content")
|
||||
if lastContent.Exists() && lastContent.IsArray() {
|
||||
for _, part := range lastContent.Array() {
|
||||
if part.Get("type").String() == "tool_result" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check if the second-to-last message is an assistant with tool_use
|
||||
if len(arr) >= 2 {
|
||||
prev := arr[len(arr)-2]
|
||||
if prev.Get("role").String() == "assistant" {
|
||||
prevContent := prev.Get("content")
|
||||
if prevContent.Exists() && prevContent.IsArray() {
|
||||
for _, part := range prevContent.Array() {
|
||||
if part.Get("type").String() == "tool_use" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Responses API: check input array
|
||||
if inputs := gjson.GetBytes(body, "input"); inputs.Exists() && inputs.IsArray() {
|
||||
arr := inputs.Array()
|
||||
for i := len(arr) - 1; i >= 0; i-- {
|
||||
item := arr[i]
|
||||
if len(arr) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Most Responses input items carry a top-level role.
|
||||
if role := item.Get("role").String(); role != "" {
|
||||
return role
|
||||
// Check last item
|
||||
last := arr[len(arr)-1]
|
||||
if role := last.Get("role").String(); role == "assistant" {
|
||||
return true
|
||||
}
|
||||
switch last.Get("type").String() {
|
||||
case "function_call", "function_call_arguments", "computer_call":
|
||||
return true
|
||||
case "function_call_output", "function_call_response", "tool_result", "computer_call_output":
|
||||
return true
|
||||
}
|
||||
|
||||
// If last item is user-role, check for prior non-user items
|
||||
for _, item := range arr {
|
||||
if role := item.Get("role").String(); role == "assistant" {
|
||||
return true
|
||||
}
|
||||
|
||||
switch item.Get("type").String() {
|
||||
case "function_call", "function_call_arguments", "computer_call":
|
||||
return "assistant"
|
||||
case "function_call_output", "function_call_response", "tool_result", "computer_call_output":
|
||||
return "tool"
|
||||
case "function_call", "function_call_output", "function_call_response",
|
||||
"function_call_arguments", "computer_call", "computer_call_output":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
return false
|
||||
}
|
||||
|
||||
// detectVisionContent checks if the request body contains vision/image content.
|
||||
@@ -572,6 +692,85 @@ func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte
|
||||
return body
|
||||
}
|
||||
|
||||
// copilotUnsupportedBetas lists beta headers that are Anthropic-specific and
|
||||
// must not be forwarded to GitHub Copilot. The context-1m beta enables 1M
|
||||
// context on Anthropic's API, but Copilot's Claude models are limited to
|
||||
// ~128K-200K. Passing it through would not enable 1M on Copilot, but stripping
|
||||
// it from the translated body avoids confusing downstream translators.
|
||||
var copilotUnsupportedBetas = []string{
|
||||
"context-1m-2025-08-07",
|
||||
}
|
||||
|
||||
// stripUnsupportedBetas removes Anthropic-specific beta entries from the
|
||||
// translated request body. In OpenAI format the betas may appear under
|
||||
// "metadata.betas" or a top-level "betas" array; in Claude format they sit at
|
||||
// "betas". This function checks all known locations.
|
||||
func stripUnsupportedBetas(body []byte) []byte {
|
||||
betaPaths := []string{"betas", "metadata.betas"}
|
||||
for _, path := range betaPaths {
|
||||
arr := gjson.GetBytes(body, path)
|
||||
if !arr.Exists() || !arr.IsArray() {
|
||||
continue
|
||||
}
|
||||
var filtered []string
|
||||
changed := false
|
||||
for _, item := range arr.Array() {
|
||||
beta := item.String()
|
||||
if isCopilotUnsupportedBeta(beta) {
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, beta)
|
||||
}
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
body, _ = sjson.DeleteBytes(body, path)
|
||||
} else {
|
||||
body, _ = sjson.SetBytes(body, path, filtered)
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func isCopilotUnsupportedBeta(beta string) bool {
|
||||
return slices.Contains(copilotUnsupportedBetas, beta)
|
||||
}
|
||||
|
||||
// normalizeGitHubCopilotReasoningField maps Copilot's non-standard
|
||||
// 'reasoning_text' field to the standard OpenAI 'reasoning_content' field
|
||||
// that the SDK translator expects. This handles both streaming deltas
|
||||
// (choices[].delta.reasoning_text) and non-streaming messages
|
||||
// (choices[].message.reasoning_text). The field is only renamed when
|
||||
// 'reasoning_content' is absent or null, preserving standard responses.
|
||||
// All choices are processed to support n>1 requests.
|
||||
func normalizeGitHubCopilotReasoningField(data []byte) []byte {
|
||||
choices := gjson.GetBytes(data, "choices")
|
||||
if !choices.Exists() || !choices.IsArray() {
|
||||
return data
|
||||
}
|
||||
for i := range choices.Array() {
|
||||
// Non-streaming: choices[i].message.reasoning_text
|
||||
msgRT := fmt.Sprintf("choices.%d.message.reasoning_text", i)
|
||||
msgRC := fmt.Sprintf("choices.%d.message.reasoning_content", i)
|
||||
if rt := gjson.GetBytes(data, msgRT); rt.Exists() && rt.String() != "" {
|
||||
if rc := gjson.GetBytes(data, msgRC); !rc.Exists() || rc.Type == gjson.Null || rc.String() == "" {
|
||||
data, _ = sjson.SetBytes(data, msgRC, rt.String())
|
||||
}
|
||||
}
|
||||
// Streaming: choices[i].delta.reasoning_text
|
||||
deltaRT := fmt.Sprintf("choices.%d.delta.reasoning_text", i)
|
||||
deltaRC := fmt.Sprintf("choices.%d.delta.reasoning_content", i)
|
||||
if rt := gjson.GetBytes(data, deltaRT); rt.Exists() && rt.String() != "" {
|
||||
if rc := gjson.GetBytes(data, deltaRC); !rc.Exists() || rc.Type == gjson.Null || rc.String() == "" {
|
||||
data, _ = sjson.SetBytes(data, deltaRC, rt.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool {
|
||||
if sourceFormat.String() == "openai-response" {
|
||||
return true
|
||||
@@ -596,12 +795,7 @@ func lookupGitHubCopilotStaticModelInfo(model string) *registry.ModelInfo {
|
||||
}
|
||||
|
||||
func containsEndpoint(endpoints []string, endpoint string) bool {
|
||||
for _, item := range endpoints {
|
||||
if item == endpoint {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return slices.Contains(endpoints, endpoint)
|
||||
}
|
||||
|
||||
// flattenAssistantContent converts assistant message content from array format
|
||||
@@ -856,6 +1050,32 @@ func stripGitHubCopilotResponsesUnsupportedFields(body []byte) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
// applyGitHubCopilotResponsesDefaults sets required fields for the Responses API
|
||||
// that both vscode-copilot-chat and pi-ai always include.
|
||||
//
|
||||
// References:
|
||||
// - vscode-copilot-chat: src/platform/endpoint/node/responsesApi.ts
|
||||
// - pi-ai (badlogic/pi-mono): packages/ai/src/providers/openai-responses.ts
|
||||
func applyGitHubCopilotResponsesDefaults(body []byte) []byte {
|
||||
// store: false — prevents request/response storage
|
||||
if !gjson.GetBytes(body, "store").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "store", false)
|
||||
}
|
||||
|
||||
// include: ["reasoning.encrypted_content"] — enables reasoning content
|
||||
// reuse across turns, avoiding redundant computation
|
||||
if !gjson.GetBytes(body, "include").Exists() {
|
||||
body, _ = sjson.SetRawBytes(body, "include", []byte(`["reasoning.encrypted_content"]`))
|
||||
}
|
||||
|
||||
// If reasoning.effort is set but reasoning.summary is not, default to "auto"
|
||||
if gjson.GetBytes(body, "reasoning.effort").Exists() && !gjson.GetBytes(body, "reasoning.summary").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "reasoning.summary", "auto")
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() {
|
||||
@@ -1406,6 +1626,21 @@ func FetchGitHubCopilotModels(ctx context.Context, auth *cliproxyauth.Auth, cfg
|
||||
m.MaxCompletionTokens = defaultCopilotMaxCompletionTokens
|
||||
}
|
||||
|
||||
// Override with real limits from the Copilot API when available.
|
||||
// The API returns per-account limits (individual vs business) under
|
||||
// capabilities.limits, which are more accurate than our static
|
||||
// fallback values. We use max_prompt_tokens as ContextLength because
|
||||
// that's the hard limit the Copilot API enforces on prompt size —
|
||||
// exceeding it triggers "prompt token count exceeds the limit" errors.
|
||||
if limits := entry.Limits(); limits != nil {
|
||||
if limits.MaxPromptTokens > 0 {
|
||||
m.ContextLength = limits.MaxPromptTokens
|
||||
}
|
||||
if limits.MaxOutputTokens > 0 {
|
||||
m.MaxCompletionTokens = limits.MaxOutputTokens
|
||||
}
|
||||
}
|
||||
|
||||
models = append(models, m)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -72,7 +75,7 @@ func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Not parallel: shares global model registry with DynamicRegistryWinsOverStatic.
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
||||
t.Fatal("expected responses-only registry model to use /responses")
|
||||
}
|
||||
@@ -82,7 +85,7 @@ func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Not parallel: mutates global model registry, conflicts with RegistryResponsesOnlyModel.
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "github-copilot-test-client"
|
||||
@@ -251,14 +254,14 @@ func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.Get(out, "type").String() != "message" {
|
||||
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
|
||||
if gjson.GetBytes(out, "type").String() != "message" {
|
||||
t.Fatalf("type = %q, want message", gjson.GetBytes(out, "type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.type").String() != "text" {
|
||||
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
|
||||
if gjson.GetBytes(out, "content.0.type").String() != "text" {
|
||||
t.Fatalf("content.0.type = %q, want text", gjson.GetBytes(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.text").String() != "hello" {
|
||||
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
|
||||
if gjson.GetBytes(out, "content.0.text").String() != "hello" {
|
||||
t.Fatalf("content.0.text = %q, want hello", gjson.GetBytes(out, "content.0.text").String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -266,14 +269,14 @@ func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *test
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.Get(out, "content.0.type").String() != "tool_use" {
|
||||
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
|
||||
if gjson.GetBytes(out, "content.0.type").String() != "tool_use" {
|
||||
t.Fatalf("content.0.type = %q, want tool_use", gjson.GetBytes(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.name").String() != "sum" {
|
||||
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
|
||||
if gjson.GetBytes(out, "content.0.name").String() != "sum" {
|
||||
t.Fatalf("content.0.name = %q, want sum", gjson.GetBytes(out, "content.0.name").String())
|
||||
}
|
||||
if gjson.Get(out, "stop_reason").String() != "tool_use" {
|
||||
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
|
||||
if gjson.GetBytes(out, "stop_reason").String() != "tool_use" {
|
||||
t.Fatalf("stop_reason = %q, want tool_use", gjson.GetBytes(out, "stop_reason").String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,18 +285,24 @@ func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.
|
||||
var param any
|
||||
|
||||
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), ¶m)
|
||||
if len(created) == 0 || !strings.Contains(created[0], "message_start") {
|
||||
if len(created) == 0 || !strings.Contains(string(created[0]), "message_start") {
|
||||
t.Fatalf("created events = %#v, want message_start", created)
|
||||
}
|
||||
|
||||
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), ¶m)
|
||||
joinedDelta := strings.Join(delta, "")
|
||||
var joinedDelta string
|
||||
for _, d := range delta {
|
||||
joinedDelta += string(d)
|
||||
}
|
||||
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
|
||||
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
|
||||
}
|
||||
|
||||
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), ¶m)
|
||||
joinedCompleted := strings.Join(completed, "")
|
||||
var joinedCompleted string
|
||||
for _, c := range completed {
|
||||
joinedCompleted += string(c)
|
||||
}
|
||||
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
|
||||
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
|
||||
}
|
||||
@@ -312,15 +321,17 @@ func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_UserWhenLastRoleIsUser(t *testing.T) {
|
||||
func TestApplyHeaders_XInitiator_AgentWhenLastUserButHistoryHasAssistant(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// Last role governs the initiator decision.
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
|
||||
// When the last role is "user" and the message contains tool_result content,
|
||||
// the request is a continuation (e.g. Claude tool result translated to a
|
||||
// synthetic user message). Should be "agent".
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":[{"type":"tool_result","tool_use_id":"tu1","content":"file contents..."}]}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user (last role is user)", got)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (last user contains tool_result)", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -328,10 +339,11 @@ func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// When the last message has role "tool", it's clearly agent-initiated.
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (tool role exists)", got)
|
||||
t.Fatalf("X-Initiator = %q, want agent (last role is tool)", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -346,14 +358,15 @@ func TestApplyHeaders_XInitiator_InputArrayLastAssistantMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_InputArrayLastUserMessage(t *testing.T) {
|
||||
func TestApplyHeaders_XInitiator_InputArrayAgentWhenLastUserButHistoryHasAssistant(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// Responses API: last item is user-role but history contains assistant → agent.
|
||||
body := []byte(`{"input":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"I can help"}]},{"type":"message","role":"user","content":[{"type":"input_text","text":"Do X"}]}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user (last role is user)", got)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (history has assistant)", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -368,6 +381,33 @@ func TestApplyHeaders_XInitiator_InputArrayLastFunctionCallOutput(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_UserInMultiTurnNoTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// Genuine multi-turn: user → assistant (plain text) → user follow-up.
|
||||
// No tool messages → should be "user" (not a false-positive).
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"what is 2+2?"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user (genuine multi-turn, no tools)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_UserFollowUpAfterToolHistory(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// User follow-up after a completed tool-use conversation.
|
||||
// The last message is a genuine user question — should be "user", not "agent".
|
||||
// This aligns with opencode's behavior: only active tool loops are agent-initiated.
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":[{"type":"tool_use","id":"tu1","name":"Read","input":{}}]},{"role":"tool","tool_call_id":"tu1","content":"file data"},{"role":"assistant","content":"I read the file."},{"role":"user","content":"What did we do so far?"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user (genuine follow-up after tool history)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for x-github-api-version header (Problem M) ---
|
||||
|
||||
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
|
||||
@@ -414,3 +454,364 @@ func TestDetectVisionContent_NoMessages(t *testing.T) {
|
||||
t.Fatal("expected no vision content when messages field is absent")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for applyGitHubCopilotResponsesDefaults ---
|
||||
|
||||
func TestApplyGitHubCopilotResponsesDefaults_SetsAllDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":"hello","reasoning":{"effort":"medium"}}`)
|
||||
got := applyGitHubCopilotResponsesDefaults(body)
|
||||
|
||||
if gjson.GetBytes(got, "store").Bool() != false {
|
||||
t.Fatalf("store = %v, want false", gjson.GetBytes(got, "store").Raw)
|
||||
}
|
||||
inc := gjson.GetBytes(got, "include")
|
||||
if !inc.IsArray() || inc.Array()[0].String() != "reasoning.encrypted_content" {
|
||||
t.Fatalf("include = %s, want [\"reasoning.encrypted_content\"]", inc.Raw)
|
||||
}
|
||||
if gjson.GetBytes(got, "reasoning.summary").String() != "auto" {
|
||||
t.Fatalf("reasoning.summary = %q, want auto", gjson.GetBytes(got, "reasoning.summary").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyGitHubCopilotResponsesDefaults_DoesNotOverrideExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":"hello","store":true,"include":["other"],"reasoning":{"effort":"high","summary":"concise"}}`)
|
||||
got := applyGitHubCopilotResponsesDefaults(body)
|
||||
|
||||
if gjson.GetBytes(got, "store").Bool() != true {
|
||||
t.Fatalf("store should not be overridden, got %s", gjson.GetBytes(got, "store").Raw)
|
||||
}
|
||||
if gjson.GetBytes(got, "include").Array()[0].String() != "other" {
|
||||
t.Fatalf("include should not be overridden, got %s", gjson.GetBytes(got, "include").Raw)
|
||||
}
|
||||
if gjson.GetBytes(got, "reasoning.summary").String() != "concise" {
|
||||
t.Fatalf("reasoning.summary should not be overridden, got %q", gjson.GetBytes(got, "reasoning.summary").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyGitHubCopilotResponsesDefaults_NoReasoningEffort(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":"hello"}`)
|
||||
got := applyGitHubCopilotResponsesDefaults(body)
|
||||
|
||||
if gjson.GetBytes(got, "store").Bool() != false {
|
||||
t.Fatalf("store = %v, want false", gjson.GetBytes(got, "store").Raw)
|
||||
}
|
||||
// reasoning.summary should NOT be set when reasoning.effort is absent
|
||||
if gjson.GetBytes(got, "reasoning.summary").Exists() {
|
||||
t.Fatalf("reasoning.summary should not be set when reasoning.effort is absent, got %q", gjson.GetBytes(got, "reasoning.summary").String())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for normalizeGitHubCopilotReasoningField ---
|
||||
|
||||
func TestNormalizeReasoningField_NonStreaming(t *testing.T) {
|
||||
t.Parallel()
|
||||
data := []byte(`{"choices":[{"message":{"content":"hello","reasoning_text":"I think..."}}]}`)
|
||||
got := normalizeGitHubCopilotReasoningField(data)
|
||||
rc := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
|
||||
if rc != "I think..." {
|
||||
t.Fatalf("reasoning_content = %q, want %q", rc, "I think...")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeReasoningField_Streaming(t *testing.T) {
|
||||
t.Parallel()
|
||||
data := []byte(`{"choices":[{"delta":{"reasoning_text":"thinking delta"}}]}`)
|
||||
got := normalizeGitHubCopilotReasoningField(data)
|
||||
rc := gjson.GetBytes(got, "choices.0.delta.reasoning_content").String()
|
||||
if rc != "thinking delta" {
|
||||
t.Fatalf("reasoning_content = %q, want %q", rc, "thinking delta")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeReasoningField_PreservesExistingReasoningContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
data := []byte(`{"choices":[{"message":{"reasoning_text":"old","reasoning_content":"existing"}}]}`)
|
||||
got := normalizeGitHubCopilotReasoningField(data)
|
||||
rc := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
|
||||
if rc != "existing" {
|
||||
t.Fatalf("reasoning_content = %q, want %q (should not overwrite)", rc, "existing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeReasoningField_MultiChoice(t *testing.T) {
|
||||
t.Parallel()
|
||||
data := []byte(`{"choices":[{"message":{"reasoning_text":"thought-0"}},{"message":{"reasoning_text":"thought-1"}}]}`)
|
||||
got := normalizeGitHubCopilotReasoningField(data)
|
||||
rc0 := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
|
||||
rc1 := gjson.GetBytes(got, "choices.1.message.reasoning_content").String()
|
||||
if rc0 != "thought-0" {
|
||||
t.Fatalf("choices[0].reasoning_content = %q, want %q", rc0, "thought-0")
|
||||
}
|
||||
if rc1 != "thought-1" {
|
||||
t.Fatalf("choices[1].reasoning_content = %q, want %q", rc1, "thought-1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeReasoningField_NoChoices(t *testing.T) {
|
||||
t.Parallel()
|
||||
data := []byte(`{"id":"chatcmpl-123"}`)
|
||||
got := normalizeGitHubCopilotReasoningField(data)
|
||||
if string(got) != string(data) {
|
||||
t.Fatalf("expected no change, got %s", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OpenAIIntentValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
e.applyHeaders(req, "token", nil)
|
||||
if got := req.Header.Get("Openai-Intent"); got != "conversation-edits" {
|
||||
t.Fatalf("Openai-Intent = %q, want conversation-edits", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for CountTokens (local tiktoken estimation) ---
|
||||
|
||||
func TestCountTokens_ReturnsPositiveCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
body := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"Hello, world!"}]}`)
|
||||
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
|
||||
Model: "gpt-4o",
|
||||
Payload: body,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CountTokens() error: %v", err)
|
||||
}
|
||||
if len(resp.Payload) == 0 {
|
||||
t.Fatal("CountTokens() returned empty payload")
|
||||
}
|
||||
// The response should contain a positive token count.
|
||||
tokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
|
||||
if tokens <= 0 {
|
||||
t.Fatalf("expected positive token count, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountTokens_ClaudeSourceFormatTranslates(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
body := []byte(`{"model":"claude-sonnet-4","messages":[{"role":"user","content":"Tell me a joke"}],"max_tokens":1024}`)
|
||||
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
|
||||
Model: "claude-sonnet-4",
|
||||
Payload: body,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CountTokens() error: %v", err)
|
||||
}
|
||||
// Claude source format → should get input_tokens in response
|
||||
inputTokens := gjson.GetBytes(resp.Payload, "input_tokens").Int()
|
||||
if inputTokens <= 0 {
|
||||
// Fallback: check usage.prompt_tokens (depends on translator registration)
|
||||
promptTokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
|
||||
if promptTokens <= 0 {
|
||||
t.Fatalf("expected positive token count, got payload: %s", resp.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountTokens_EmptyPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
|
||||
Model: "gpt-4o",
|
||||
Payload: []byte(`{"model":"gpt-4o","messages":[]}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CountTokens() error: %v", err)
|
||||
}
|
||||
tokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
|
||||
// Empty messages should return 0 tokens.
|
||||
if tokens != 0 {
|
||||
t.Fatalf("expected 0 tokens for empty messages, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripUnsupportedBetas_RemovesContext1M(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"claude-opus-4.6","betas":["interleaved-thinking-2025-05-14","context-1m-2025-08-07","claude-code-20250219"],"messages":[]}`)
|
||||
result := stripUnsupportedBetas(body)
|
||||
|
||||
betas := gjson.GetBytes(result, "betas")
|
||||
if !betas.Exists() {
|
||||
t.Fatal("betas field should still exist after stripping")
|
||||
}
|
||||
for _, item := range betas.Array() {
|
||||
if item.String() == "context-1m-2025-08-07" {
|
||||
t.Fatal("context-1m-2025-08-07 should have been stripped")
|
||||
}
|
||||
}
|
||||
// Other betas should be preserved
|
||||
found := false
|
||||
for _, item := range betas.Array() {
|
||||
if item.String() == "interleaved-thinking-2025-05-14" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("other betas should be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripUnsupportedBetas_NoBetasField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"gpt-4o","messages":[]}`)
|
||||
result := stripUnsupportedBetas(body)
|
||||
|
||||
// Should be unchanged
|
||||
if string(result) != string(body) {
|
||||
t.Fatalf("body should be unchanged when no betas field exists, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripUnsupportedBetas_MetadataBetas(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"claude-opus-4.6","metadata":{"betas":["context-1m-2025-08-07","other-beta"]},"messages":[]}`)
|
||||
result := stripUnsupportedBetas(body)
|
||||
|
||||
betas := gjson.GetBytes(result, "metadata.betas")
|
||||
if !betas.Exists() {
|
||||
t.Fatal("metadata.betas field should still exist after stripping")
|
||||
}
|
||||
for _, item := range betas.Array() {
|
||||
if item.String() == "context-1m-2025-08-07" {
|
||||
t.Fatal("context-1m-2025-08-07 should have been stripped from metadata.betas")
|
||||
}
|
||||
}
|
||||
if betas.Array()[0].String() != "other-beta" {
|
||||
t.Fatal("other betas in metadata.betas should be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripUnsupportedBetas_AllBetasStripped(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"claude-opus-4.6","betas":["context-1m-2025-08-07"],"messages":[]}`)
|
||||
result := stripUnsupportedBetas(body)
|
||||
|
||||
betas := gjson.GetBytes(result, "betas")
|
||||
if betas.Exists() {
|
||||
t.Fatal("betas field should be deleted when all betas are stripped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotModelEntry_Limits(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
capabilities map[string]any
|
||||
wantNil bool
|
||||
wantPrompt int
|
||||
wantOutput int
|
||||
wantContext int
|
||||
}{
|
||||
{
|
||||
name: "nil capabilities",
|
||||
capabilities: nil,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "no limits key",
|
||||
capabilities: map[string]any{"family": "claude-opus-4.6"},
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "limits is not a map",
|
||||
capabilities: map[string]any{"limits": "invalid"},
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "all zero values",
|
||||
capabilities: map[string]any{
|
||||
"limits": map[string]any{
|
||||
"max_context_window_tokens": float64(0),
|
||||
"max_prompt_tokens": float64(0),
|
||||
"max_output_tokens": float64(0),
|
||||
},
|
||||
},
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "individual account limits (128K prompt)",
|
||||
capabilities: map[string]any{
|
||||
"limits": map[string]any{
|
||||
"max_context_window_tokens": float64(144000),
|
||||
"max_prompt_tokens": float64(128000),
|
||||
"max_output_tokens": float64(64000),
|
||||
},
|
||||
},
|
||||
wantNil: false,
|
||||
wantPrompt: 128000,
|
||||
wantOutput: 64000,
|
||||
wantContext: 144000,
|
||||
},
|
||||
{
|
||||
name: "business account limits (168K prompt)",
|
||||
capabilities: map[string]any{
|
||||
"limits": map[string]any{
|
||||
"max_context_window_tokens": float64(200000),
|
||||
"max_prompt_tokens": float64(168000),
|
||||
"max_output_tokens": float64(32000),
|
||||
},
|
||||
},
|
||||
wantNil: false,
|
||||
wantPrompt: 168000,
|
||||
wantOutput: 32000,
|
||||
wantContext: 200000,
|
||||
},
|
||||
{
|
||||
name: "partial limits (only prompt)",
|
||||
capabilities: map[string]any{
|
||||
"limits": map[string]any{
|
||||
"max_prompt_tokens": float64(128000),
|
||||
},
|
||||
},
|
||||
wantNil: false,
|
||||
wantPrompt: 128000,
|
||||
wantOutput: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
entry := copilotauth.CopilotModelEntry{
|
||||
ID: "claude-opus-4.6",
|
||||
Capabilities: tt.capabilities,
|
||||
}
|
||||
limits := entry.Limits()
|
||||
if tt.wantNil {
|
||||
if limits != nil {
|
||||
t.Fatalf("expected nil limits, got %+v", limits)
|
||||
}
|
||||
return
|
||||
}
|
||||
if limits == nil {
|
||||
t.Fatal("expected non-nil limits, got nil")
|
||||
}
|
||||
if limits.MaxPromptTokens != tt.wantPrompt {
|
||||
t.Errorf("MaxPromptTokens = %d, want %d", limits.MaxPromptTokens, tt.wantPrompt)
|
||||
}
|
||||
if limits.MaxOutputTokens != tt.wantOutput {
|
||||
t.Errorf("MaxOutputTokens = %d, want %d", limits.MaxOutputTokens, tt.wantOutput)
|
||||
}
|
||||
if tt.wantContext > 0 && limits.MaxContextWindowTokens != tt.wantContext {
|
||||
t.Errorf("MaxContextWindowTokens = %d, want %d", limits.MaxContextWindowTokens, tt.wantContext)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
38
internal/runtime/executor/helps/claude_builtin_tools.go
Normal file
38
internal/runtime/executor/helps/claude_builtin_tools.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package helps
|
||||
|
||||
import "github.com/tidwall/gjson"
|
||||
|
||||
var defaultClaudeBuiltinToolNames = []string{
|
||||
"web_search",
|
||||
"code_execution",
|
||||
"text_editor",
|
||||
"computer",
|
||||
}
|
||||
|
||||
func newClaudeBuiltinToolRegistry() map[string]bool {
|
||||
registry := make(map[string]bool, len(defaultClaudeBuiltinToolNames))
|
||||
for _, name := range defaultClaudeBuiltinToolNames {
|
||||
registry[name] = true
|
||||
}
|
||||
return registry
|
||||
}
|
||||
|
||||
func AugmentClaudeBuiltinToolRegistry(body []byte, registry map[string]bool) map[string]bool {
|
||||
if registry == nil {
|
||||
registry = newClaudeBuiltinToolRegistry()
|
||||
}
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return registry
|
||||
}
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
if tool.Get("type").String() == "" {
|
||||
return true
|
||||
}
|
||||
if name := tool.Get("name").String(); name != "" {
|
||||
registry[name] = true
|
||||
}
|
||||
return true
|
||||
})
|
||||
return registry
|
||||
}
|
||||
32
internal/runtime/executor/helps/claude_builtin_tools_test.go
Normal file
32
internal/runtime/executor/helps/claude_builtin_tools_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package helps
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestClaudeBuiltinToolRegistry_DefaultSeedFallback(t *testing.T) {
|
||||
registry := AugmentClaudeBuiltinToolRegistry(nil, nil)
|
||||
for _, name := range defaultClaudeBuiltinToolNames {
|
||||
if !registry[name] {
|
||||
t.Fatalf("default builtin %q missing from fallback registry", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeBuiltinToolRegistry_AugmentsTypedBuiltinsFromBody(t *testing.T) {
|
||||
registry := AugmentClaudeBuiltinToolRegistry([]byte(`{
|
||||
"tools": [
|
||||
{"type": "web_search_20250305", "name": "web_search"},
|
||||
{"type": "custom_builtin_20250401", "name": "special_builtin"},
|
||||
{"name": "Read"}
|
||||
]
|
||||
}`), nil)
|
||||
|
||||
if !registry["web_search"] {
|
||||
t.Fatal("expected default typed builtin web_search in registry")
|
||||
}
|
||||
if !registry["special_builtin"] {
|
||||
t.Fatal("expected typed builtin from body to be added to registry")
|
||||
}
|
||||
if registry["Read"] {
|
||||
t.Fatal("expected untyped custom tool to stay out of builtin registry")
|
||||
}
|
||||
}
|
||||
@@ -298,6 +298,14 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
} else {
|
||||
// In case the upstream close the stream without a terminal [DONE] marker.
|
||||
// Feed a synthetic done marker through the translator so pending
|
||||
// response.completed events are still emitted exactly once.
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
// Ensure we record the request if no usage chunk was ever seen
|
||||
reporter.EnsurePublished(ctx)
|
||||
|
||||
@@ -172,32 +172,101 @@ func timeUntilNextDay() time.Duration {
|
||||
return tomorrow.Sub(now)
|
||||
}
|
||||
|
||||
// ensureQwenSystemMessage prepends a default system message if none exists in "messages".
|
||||
// ensureQwenSystemMessage ensures the request has a single system message at the beginning.
|
||||
// It always injects the default system prompt and merges any user-provided system messages
|
||||
// into the injected system message content to satisfy Qwen's strict message ordering rules.
|
||||
func ensureQwenSystemMessage(payload []byte) ([]byte, error) {
|
||||
messages := gjson.GetBytes(payload, "messages")
|
||||
if messages.Exists() && messages.IsArray() {
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('[')
|
||||
buf.Write(qwenDefaultSystemMessage)
|
||||
for _, msg := range messages.Array() {
|
||||
buf.WriteByte(',')
|
||||
buf.WriteString(msg.Raw)
|
||||
isInjectedSystemPart := func(part gjson.Result) bool {
|
||||
if !part.Exists() || !part.IsObject() {
|
||||
return false
|
||||
}
|
||||
buf.WriteByte(']')
|
||||
updated, errSet := sjson.SetRawBytes(payload, "messages", buf.Bytes())
|
||||
if errSet != nil {
|
||||
return nil, fmt.Errorf("qwen executor: set default system message failed: %w", errSet)
|
||||
if !strings.EqualFold(part.Get("type").String(), "text") {
|
||||
return false
|
||||
}
|
||||
return updated, nil
|
||||
if !strings.EqualFold(part.Get("cache_control.type").String(), "ephemeral") {
|
||||
return false
|
||||
}
|
||||
text := part.Get("text").String()
|
||||
return text == "" || text == "You are Qwen Code."
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('[')
|
||||
buf.Write(qwenDefaultSystemMessage)
|
||||
buf.WriteByte(']')
|
||||
updated, errSet := sjson.SetRawBytes(payload, "messages", buf.Bytes())
|
||||
defaultParts := gjson.ParseBytes(qwenDefaultSystemMessage).Get("content")
|
||||
var systemParts []any
|
||||
if defaultParts.Exists() && defaultParts.IsArray() {
|
||||
for _, part := range defaultParts.Array() {
|
||||
systemParts = append(systemParts, part.Value())
|
||||
}
|
||||
}
|
||||
if len(systemParts) == 0 {
|
||||
systemParts = append(systemParts, map[string]any{
|
||||
"type": "text",
|
||||
"text": "You are Qwen Code.",
|
||||
"cache_control": map[string]any{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
appendSystemContent := func(content gjson.Result) {
|
||||
makeTextPart := func(text string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "text",
|
||||
"text": text,
|
||||
}
|
||||
}
|
||||
|
||||
if !content.Exists() || content.Type == gjson.Null {
|
||||
return
|
||||
}
|
||||
if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
if part.Type == gjson.String {
|
||||
systemParts = append(systemParts, makeTextPart(part.String()))
|
||||
continue
|
||||
}
|
||||
if isInjectedSystemPart(part) {
|
||||
continue
|
||||
}
|
||||
systemParts = append(systemParts, part.Value())
|
||||
}
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.String {
|
||||
systemParts = append(systemParts, makeTextPart(content.String()))
|
||||
return
|
||||
}
|
||||
if content.IsObject() {
|
||||
if isInjectedSystemPart(content) {
|
||||
return
|
||||
}
|
||||
systemParts = append(systemParts, content.Value())
|
||||
return
|
||||
}
|
||||
systemParts = append(systemParts, makeTextPart(content.String()))
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(payload, "messages")
|
||||
var nonSystemMessages []any
|
||||
if messages.Exists() && messages.IsArray() {
|
||||
for _, msg := range messages.Array() {
|
||||
if strings.EqualFold(msg.Get("role").String(), "system") {
|
||||
appendSystemContent(msg.Get("content"))
|
||||
continue
|
||||
}
|
||||
nonSystemMessages = append(nonSystemMessages, msg.Value())
|
||||
}
|
||||
}
|
||||
|
||||
newMessages := make([]any, 0, 1+len(nonSystemMessages))
|
||||
newMessages = append(newMessages, map[string]any{
|
||||
"role": "system",
|
||||
"content": systemParts,
|
||||
})
|
||||
newMessages = append(newMessages, nonSystemMessages...)
|
||||
|
||||
updated, errSet := sjson.SetBytes(payload, "messages", newMessages)
|
||||
if errSet != nil {
|
||||
return nil, fmt.Errorf("qwen executor: set default system message failed: %w", errSet)
|
||||
return nil, fmt.Errorf("qwen executor: set system message failed: %w", errSet)
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestQwenExecutorParseSuffix(t *testing.T) {
|
||||
@@ -28,3 +29,123 @@ func TestQwenExecutorParseSuffix(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureQwenSystemMessage_MergeStringSystem(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"model": "qwen3.6-plus",
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{ "role": "system", "content": "ABCDEFG" },
|
||||
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := ensureQwenSystemMessage(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||
}
|
||||
|
||||
msgs := gjson.GetBytes(out, "messages").Array()
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||
}
|
||||
if msgs[0].Get("role").String() != "system" {
|
||||
t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system")
|
||||
}
|
||||
parts := msgs[0].Get("content").Array()
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("messages[0].content length = %d, want 2", len(parts))
|
||||
}
|
||||
if parts[0].Get("text").String() != "You are Qwen Code." || parts[0].Get("cache_control.type").String() != "ephemeral" {
|
||||
t.Fatalf("messages[0].content[0] = %s, want injected system part", parts[0].Raw)
|
||||
}
|
||||
if parts[1].Get("type").String() != "text" || parts[1].Get("text").String() != "ABCDEFG" {
|
||||
t.Fatalf("messages[0].content[1] = %s, want text part with ABCDEFG", parts[1].Raw)
|
||||
}
|
||||
if msgs[1].Get("role").String() != "user" {
|
||||
t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureQwenSystemMessage_MergeObjectSystem(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"messages": [
|
||||
{ "role": "system", "content": { "type": "text", "text": "ABCDEFG" } },
|
||||
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := ensureQwenSystemMessage(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||
}
|
||||
|
||||
msgs := gjson.GetBytes(out, "messages").Array()
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||
}
|
||||
parts := msgs[0].Get("content").Array()
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("messages[0].content length = %d, want 2", len(parts))
|
||||
}
|
||||
if parts[1].Get("text").String() != "ABCDEFG" {
|
||||
t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "ABCDEFG")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureQwenSystemMessage_PrependsWhenMissing(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"messages": [
|
||||
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := ensureQwenSystemMessage(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||
}
|
||||
|
||||
msgs := gjson.GetBytes(out, "messages").Array()
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||
}
|
||||
if msgs[0].Get("role").String() != "system" {
|
||||
t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system")
|
||||
}
|
||||
if !msgs[0].Get("content").IsArray() || len(msgs[0].Get("content").Array()) == 0 {
|
||||
t.Fatalf("messages[0].content = %s, want non-empty array", msgs[0].Get("content").Raw)
|
||||
}
|
||||
if msgs[1].Get("role").String() != "user" {
|
||||
t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureQwenSystemMessage_MergesMultipleSystemMessages(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"messages": [
|
||||
{ "role": "system", "content": "A" },
|
||||
{ "role": "user", "content": [ { "type": "text", "text": "hi" } ] },
|
||||
{ "role": "system", "content": "B" }
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := ensureQwenSystemMessage(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||
}
|
||||
|
||||
msgs := gjson.GetBytes(out, "messages").Array()
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||
}
|
||||
parts := msgs[0].Get("content").Array()
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("messages[0].content length = %d, want 3", len(parts))
|
||||
}
|
||||
if parts[1].Get("text").String() != "A" {
|
||||
t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "A")
|
||||
}
|
||||
if parts[2].Get("text").String() != "B" {
|
||||
t.Fatalf("messages[0].content[2].text = %q, want %q", parts[2].Get("text").String(), "B")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,16 +32,24 @@ type GitTokenStore struct {
|
||||
repoDir string
|
||||
configDir string
|
||||
remote string
|
||||
branch string
|
||||
username string
|
||||
password string
|
||||
lastGC time.Time
|
||||
}
|
||||
|
||||
type resolvedRemoteBranch struct {
|
||||
name plumbing.ReferenceName
|
||||
hash plumbing.Hash
|
||||
}
|
||||
|
||||
// NewGitTokenStore creates a token store that saves credentials to disk through the
|
||||
// TokenStorage implementation embedded in the token record.
|
||||
func NewGitTokenStore(remote, username, password string) *GitTokenStore {
|
||||
// When branch is non-empty, clone/pull/push operations target that branch instead of the remote default.
|
||||
func NewGitTokenStore(remote, username, password, branch string) *GitTokenStore {
|
||||
return &GitTokenStore{
|
||||
remote: remote,
|
||||
branch: strings.TrimSpace(branch),
|
||||
username: username,
|
||||
password: password,
|
||||
}
|
||||
@@ -120,7 +128,11 @@ func (s *GitTokenStore) EnsureRepository() error {
|
||||
s.dirLock.Unlock()
|
||||
return fmt.Errorf("git token store: create repo dir: %w", errMk)
|
||||
}
|
||||
if _, errClone := git.PlainClone(repoDir, &git.CloneOptions{Auth: authMethod, URL: s.remote}); errClone != nil {
|
||||
cloneOpts := &git.CloneOptions{Auth: authMethod, URL: s.remote}
|
||||
if s.branch != "" {
|
||||
cloneOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch)
|
||||
}
|
||||
if _, errClone := git.PlainClone(repoDir, cloneOpts); errClone != nil {
|
||||
if errors.Is(errClone, transport.ErrEmptyRemoteRepository) {
|
||||
_ = os.RemoveAll(gitDir)
|
||||
repo, errInit := git.PlainInit(repoDir, false)
|
||||
@@ -128,6 +140,13 @@ func (s *GitTokenStore) EnsureRepository() error {
|
||||
s.dirLock.Unlock()
|
||||
return fmt.Errorf("git token store: init empty repo: %w", errInit)
|
||||
}
|
||||
if s.branch != "" {
|
||||
headRef := plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(s.branch))
|
||||
if errHead := repo.Storer.SetReference(headRef); errHead != nil {
|
||||
s.dirLock.Unlock()
|
||||
return fmt.Errorf("git token store: set head to branch %s: %w", s.branch, errHead)
|
||||
}
|
||||
}
|
||||
if _, errRemote := repo.Remote("origin"); errRemote != nil {
|
||||
if _, errCreate := repo.CreateRemote(&config.RemoteConfig{
|
||||
Name: "origin",
|
||||
@@ -176,16 +195,39 @@ func (s *GitTokenStore) EnsureRepository() error {
|
||||
s.dirLock.Unlock()
|
||||
return fmt.Errorf("git token store: worktree: %w", errWorktree)
|
||||
}
|
||||
if errPull := worktree.Pull(&git.PullOptions{Auth: authMethod, RemoteName: "origin"}); errPull != nil {
|
||||
if s.branch != "" {
|
||||
if errCheckout := s.checkoutConfiguredBranch(repo, worktree, authMethod); errCheckout != nil {
|
||||
s.dirLock.Unlock()
|
||||
return errCheckout
|
||||
}
|
||||
} else {
|
||||
// When branch is unset, ensure the working tree follows the remote default branch
|
||||
if err := checkoutRemoteDefaultBranch(repo, worktree, authMethod); err != nil {
|
||||
if !shouldFallbackToCurrentBranch(repo, err) {
|
||||
s.dirLock.Unlock()
|
||||
return fmt.Errorf("git token store: checkout remote default: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
pullOpts := &git.PullOptions{Auth: authMethod, RemoteName: "origin"}
|
||||
if s.branch != "" {
|
||||
pullOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch)
|
||||
}
|
||||
if errPull := worktree.Pull(pullOpts); errPull != nil {
|
||||
switch {
|
||||
case errors.Is(errPull, git.NoErrAlreadyUpToDate),
|
||||
errors.Is(errPull, git.ErrUnstagedChanges),
|
||||
errors.Is(errPull, git.ErrNonFastForwardUpdate):
|
||||
// Ignore clean syncs, local edits, and remote divergence—local changes win.
|
||||
case errors.Is(errPull, transport.ErrAuthenticationRequired),
|
||||
errors.Is(errPull, plumbing.ErrReferenceNotFound),
|
||||
errors.Is(errPull, transport.ErrEmptyRemoteRepository):
|
||||
// Ignore authentication prompts and empty remote references on initial sync.
|
||||
case errors.Is(errPull, plumbing.ErrReferenceNotFound):
|
||||
if s.branch != "" {
|
||||
s.dirLock.Unlock()
|
||||
return fmt.Errorf("git token store: pull: %w", errPull)
|
||||
}
|
||||
// Ignore missing references only when following the remote default branch.
|
||||
default:
|
||||
s.dirLock.Unlock()
|
||||
return fmt.Errorf("git token store: pull: %w", errPull)
|
||||
@@ -554,6 +596,192 @@ func (s *GitTokenStore) relativeToRepo(path string) (string, error) {
|
||||
return rel, nil
|
||||
}
|
||||
|
||||
func (s *GitTokenStore) checkoutConfiguredBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error {
|
||||
branchRefName := plumbing.NewBranchReferenceName(s.branch)
|
||||
headRef, errHead := repo.Head()
|
||||
switch {
|
||||
case errHead == nil && headRef.Name() == branchRefName:
|
||||
return nil
|
||||
case errHead != nil && !errors.Is(errHead, plumbing.ErrReferenceNotFound):
|
||||
return fmt.Errorf("git token store: get head: %w", errHead)
|
||||
}
|
||||
|
||||
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err == nil {
|
||||
return nil
|
||||
} else if _, errRef := repo.Reference(branchRefName, true); errRef == nil {
|
||||
return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err)
|
||||
} else if !errors.Is(errRef, plumbing.ErrReferenceNotFound) {
|
||||
return fmt.Errorf("git token store: inspect branch %s: %w", s.branch, errRef)
|
||||
} else if err := s.checkoutConfiguredRemoteTrackingBranch(repo, worktree, branchRefName, authMethod); err != nil {
|
||||
return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GitTokenStore) checkoutConfiguredRemoteTrackingBranch(repo *git.Repository, worktree *git.Worktree, branchRefName plumbing.ReferenceName, authMethod transport.AuthMethod) error {
|
||||
remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + s.branch)
|
||||
remoteRef, err := repo.Reference(remoteRefName, true)
|
||||
if errors.Is(err, plumbing.ErrReferenceNotFound) {
|
||||
if errSync := syncRemoteReferences(repo, authMethod); errSync != nil {
|
||||
return fmt.Errorf("sync remote refs: %w", errSync)
|
||||
}
|
||||
remoteRef, err = repo.Reference(remoteRefName, true)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: remoteRef.Hash()}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg, err := repo.Config()
|
||||
if err != nil {
|
||||
return fmt.Errorf("git token store: repo config: %w", err)
|
||||
}
|
||||
if _, ok := cfg.Branches[s.branch]; !ok {
|
||||
cfg.Branches[s.branch] = &config.Branch{Name: s.branch}
|
||||
}
|
||||
cfg.Branches[s.branch].Remote = "origin"
|
||||
cfg.Branches[s.branch].Merge = branchRefName
|
||||
if err := repo.SetConfig(cfg); err != nil {
|
||||
return fmt.Errorf("git token store: set branch config: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func syncRemoteReferences(repo *git.Repository, authMethod transport.AuthMethod) error {
|
||||
if err := repo.Fetch(&git.FetchOptions{Auth: authMethod, RemoteName: "origin"}); err != nil && !errors.Is(err, git.NoErrAlreadyUpToDate) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveRemoteDefaultBranch queries the origin remote to determine the remote's default branch
|
||||
// (the target of HEAD) and returns the corresponding local branch reference name (e.g. refs/heads/master).
|
||||
func resolveRemoteDefaultBranch(repo *git.Repository, authMethod transport.AuthMethod) (resolvedRemoteBranch, error) {
|
||||
if err := syncRemoteReferences(repo, authMethod); err != nil {
|
||||
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: sync remote refs: %w", err)
|
||||
}
|
||||
remote, err := repo.Remote("origin")
|
||||
if err != nil {
|
||||
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: get remote: %w", err)
|
||||
}
|
||||
refs, err := remote.List(&git.ListOptions{Auth: authMethod})
|
||||
if err != nil {
|
||||
if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok {
|
||||
return resolved, nil
|
||||
}
|
||||
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: list remote refs: %w", err)
|
||||
}
|
||||
for _, r := range refs {
|
||||
if r.Name() == plumbing.HEAD {
|
||||
if r.Type() == plumbing.SymbolicReference {
|
||||
if target, ok := normalizeRemoteBranchReference(r.Target()); ok {
|
||||
return resolvedRemoteBranch{name: target}, nil
|
||||
}
|
||||
}
|
||||
s := r.String()
|
||||
if idx := strings.Index(s, "->"); idx != -1 {
|
||||
if target, ok := normalizeRemoteBranchReference(plumbing.ReferenceName(strings.TrimSpace(s[idx+2:]))); ok {
|
||||
return resolvedRemoteBranch{name: target}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok {
|
||||
return resolved, nil
|
||||
}
|
||||
for _, r := range refs {
|
||||
if normalized, ok := normalizeRemoteBranchReference(r.Name()); ok {
|
||||
return resolvedRemoteBranch{name: normalized, hash: r.Hash()}, nil
|
||||
}
|
||||
}
|
||||
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: remote default branch not found")
|
||||
}
|
||||
|
||||
func resolveRemoteDefaultBranchFromLocal(repo *git.Repository) (resolvedRemoteBranch, bool) {
|
||||
ref, err := repo.Reference(plumbing.ReferenceName("refs/remotes/origin/HEAD"), true)
|
||||
if err != nil || ref.Type() != plumbing.SymbolicReference {
|
||||
return resolvedRemoteBranch{}, false
|
||||
}
|
||||
target, ok := normalizeRemoteBranchReference(ref.Target())
|
||||
if !ok {
|
||||
return resolvedRemoteBranch{}, false
|
||||
}
|
||||
return resolvedRemoteBranch{name: target}, true
|
||||
}
|
||||
|
||||
func normalizeRemoteBranchReference(name plumbing.ReferenceName) (plumbing.ReferenceName, bool) {
|
||||
switch {
|
||||
case strings.HasPrefix(name.String(), "refs/heads/"):
|
||||
return name, true
|
||||
case strings.HasPrefix(name.String(), "refs/remotes/origin/"):
|
||||
return plumbing.NewBranchReferenceName(strings.TrimPrefix(name.String(), "refs/remotes/origin/")), true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func shouldFallbackToCurrentBranch(repo *git.Repository, err error) bool {
|
||||
if !errors.Is(err, transport.ErrAuthenticationRequired) && !errors.Is(err, transport.ErrEmptyRemoteRepository) {
|
||||
return false
|
||||
}
|
||||
_, headErr := repo.Head()
|
||||
return headErr == nil
|
||||
}
|
||||
|
||||
// checkoutRemoteDefaultBranch ensures the working tree is checked out to the remote's default branch
|
||||
// (the branch target of origin/HEAD). If the local branch does not exist it will be created to track
|
||||
// the remote branch.
|
||||
func checkoutRemoteDefaultBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error {
|
||||
resolved, err := resolveRemoteDefaultBranch(repo, authMethod)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
branchRefName := resolved.name
|
||||
// If HEAD already points to the desired branch, nothing to do.
|
||||
headRef, errHead := repo.Head()
|
||||
if errHead == nil && headRef.Name() == branchRefName {
|
||||
return nil
|
||||
}
|
||||
// If local branch exists, attempt a checkout
|
||||
if _, err := repo.Reference(branchRefName, true); err == nil {
|
||||
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err != nil {
|
||||
return fmt.Errorf("checkout branch %s: %w", branchRefName.String(), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Try to find the corresponding remote tracking ref (refs/remotes/origin/<name>)
|
||||
branchShort := strings.TrimPrefix(branchRefName.String(), "refs/heads/")
|
||||
remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + branchShort)
|
||||
hash := resolved.hash
|
||||
if remoteRef, err := repo.Reference(remoteRefName, true); err == nil {
|
||||
hash = remoteRef.Hash()
|
||||
} else if err != nil && !errors.Is(err, plumbing.ErrReferenceNotFound) {
|
||||
return fmt.Errorf("checkout remote default: remote ref %s: %w", remoteRefName.String(), err)
|
||||
}
|
||||
if hash == plumbing.ZeroHash {
|
||||
return fmt.Errorf("checkout remote default: remote ref %s not found", remoteRefName.String())
|
||||
}
|
||||
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: hash}); err != nil {
|
||||
return fmt.Errorf("checkout create branch %s: %w", branchRefName.String(), err)
|
||||
}
|
||||
cfg, err := repo.Config()
|
||||
if err != nil {
|
||||
return fmt.Errorf("git token store: repo config: %w", err)
|
||||
}
|
||||
if _, ok := cfg.Branches[branchShort]; !ok {
|
||||
cfg.Branches[branchShort] = &config.Branch{Name: branchShort}
|
||||
}
|
||||
cfg.Branches[branchShort].Remote = "origin"
|
||||
cfg.Branches[branchShort].Merge = branchRefName
|
||||
if err := repo.SetConfig(cfg); err != nil {
|
||||
return fmt.Errorf("git token store: set branch config: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string) error {
|
||||
repoDir := s.repoDirSnapshot()
|
||||
if repoDir == "" {
|
||||
@@ -619,7 +847,16 @@ func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string)
|
||||
return errRewrite
|
||||
}
|
||||
s.maybeRunGC(repo)
|
||||
if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil {
|
||||
pushOpts := &git.PushOptions{Auth: s.gitAuth(), Force: true}
|
||||
if s.branch != "" {
|
||||
pushOpts.RefSpecs = []config.RefSpec{config.RefSpec("refs/heads/" + s.branch + ":refs/heads/" + s.branch)}
|
||||
} else {
|
||||
// When branch is unset, pin push to the currently checked-out branch.
|
||||
if headRef, err := repo.Head(); err == nil {
|
||||
pushOpts.RefSpecs = []config.RefSpec{config.RefSpec(headRef.Name().String() + ":" + headRef.Name().String())}
|
||||
}
|
||||
}
|
||||
if err = repo.Push(pushOpts); err != nil {
|
||||
if errors.Is(err, git.NoErrAlreadyUpToDate) {
|
||||
return nil
|
||||
}
|
||||
|
||||
585
internal/store/gitstore_test.go
Normal file
585
internal/store/gitstore_test.go
Normal file
@@ -0,0 +1,585 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-git/go-git/v6"
|
||||
gitconfig "github.com/go-git/go-git/v6/config"
|
||||
"github.com/go-git/go-git/v6/plumbing"
|
||||
"github.com/go-git/go-git/v6/plumbing/object"
|
||||
)
|
||||
|
||||
type testBranchSpec struct {
|
||||
name string
|
||||
contents string
|
||||
}
|
||||
|
||||
func TestEnsureRepositoryUsesRemoteDefaultBranchWhenBranchNotConfigured(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
remoteDir := setupGitRemoteRepository(t, root, "trunk",
|
||||
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
|
||||
testBranchSpec{name: "release/2026", contents: "release branch\n"},
|
||||
)
|
||||
|
||||
store := NewGitTokenStore(remoteDir, "", "", "")
|
||||
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
|
||||
|
||||
if err := store.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository: %v", err)
|
||||
}
|
||||
|
||||
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "trunk", "remote default branch\n")
|
||||
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "trunk", "remote default branch updated\n", "advance trunk")
|
||||
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch updated\n", "advance release")
|
||||
|
||||
if err := store.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository second call: %v", err)
|
||||
}
|
||||
|
||||
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "trunk", "remote default branch updated\n")
|
||||
assertRemoteHeadBranch(t, remoteDir, "trunk")
|
||||
}
|
||||
|
||||
func TestEnsureRepositoryUsesConfiguredBranchWhenExplicitlySet(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
remoteDir := setupGitRemoteRepository(t, root, "trunk",
|
||||
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
|
||||
testBranchSpec{name: "release/2026", contents: "release branch\n"},
|
||||
)
|
||||
|
||||
store := NewGitTokenStore(remoteDir, "", "", "release/2026")
|
||||
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
|
||||
|
||||
if err := store.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository: %v", err)
|
||||
}
|
||||
|
||||
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch\n")
|
||||
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "trunk", "remote default branch updated\n", "advance trunk")
|
||||
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch updated\n", "advance release")
|
||||
|
||||
if err := store.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository second call: %v", err)
|
||||
}
|
||||
|
||||
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch updated\n")
|
||||
assertRemoteHeadBranch(t, remoteDir, "trunk")
|
||||
}
|
||||
|
||||
func TestEnsureRepositoryReturnsErrorForMissingConfiguredBranch(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
remoteDir := setupGitRemoteRepository(t, root, "trunk",
|
||||
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
|
||||
)
|
||||
|
||||
store := NewGitTokenStore(remoteDir, "", "", "missing-branch")
|
||||
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
|
||||
|
||||
err := store.EnsureRepository()
|
||||
if err == nil {
|
||||
t.Fatal("EnsureRepository succeeded, want error for nonexistent configured branch")
|
||||
}
|
||||
assertRemoteHeadBranch(t, remoteDir, "trunk")
|
||||
}
|
||||
|
||||
func TestEnsureRepositoryReturnsErrorForMissingConfiguredBranchOnExistingRepositoryPull(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
remoteDir := setupGitRemoteRepository(t, root, "trunk",
|
||||
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
|
||||
)
|
||||
|
||||
baseDir := filepath.Join(root, "workspace", "auths")
|
||||
store := NewGitTokenStore(remoteDir, "", "", "")
|
||||
store.SetBaseDir(baseDir)
|
||||
|
||||
if err := store.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository initial clone: %v", err)
|
||||
}
|
||||
|
||||
reopened := NewGitTokenStore(remoteDir, "", "", "missing-branch")
|
||||
reopened.SetBaseDir(baseDir)
|
||||
|
||||
err := reopened.EnsureRepository()
|
||||
if err == nil {
|
||||
t.Fatal("EnsureRepository succeeded on reopen, want error for nonexistent configured branch")
|
||||
}
|
||||
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "trunk")
|
||||
assertRemoteHeadBranch(t, remoteDir, "trunk")
|
||||
}
|
||||
|
||||
func TestEnsureRepositoryInitializesEmptyRemoteUsingConfiguredBranch(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
remoteDir := filepath.Join(root, "remote.git")
|
||||
if _, err := git.PlainInit(remoteDir, true); err != nil {
|
||||
t.Fatalf("init bare remote: %v", err)
|
||||
}
|
||||
|
||||
branch := "feature/gemini-fix"
|
||||
store := NewGitTokenStore(remoteDir, "", "", branch)
|
||||
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
|
||||
|
||||
if err := store.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository: %v", err)
|
||||
}
|
||||
|
||||
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), branch)
|
||||
assertRemoteBranchExistsWithCommit(t, remoteDir, branch)
|
||||
assertRemoteBranchDoesNotExist(t, remoteDir, "master")
|
||||
}
|
||||
|
||||
func TestEnsureRepositoryExistingRepoSwitchesToConfiguredBranch(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
remoteDir := setupGitRemoteRepository(t, root, "master",
|
||||
testBranchSpec{name: "master", contents: "remote master branch\n"},
|
||||
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
|
||||
)
|
||||
|
||||
baseDir := filepath.Join(root, "workspace", "auths")
|
||||
store := NewGitTokenStore(remoteDir, "", "", "")
|
||||
store.SetBaseDir(baseDir)
|
||||
|
||||
if err := store.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository initial clone: %v", err)
|
||||
}
|
||||
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
|
||||
|
||||
reopened := NewGitTokenStore(remoteDir, "", "", "develop")
|
||||
reopened.SetBaseDir(baseDir)
|
||||
|
||||
if err := reopened.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository reopen: %v", err)
|
||||
}
|
||||
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
|
||||
|
||||
workspaceDir := filepath.Join(root, "workspace")
|
||||
if err := os.WriteFile(filepath.Join(workspaceDir, "branch.txt"), []byte("local develop update\n"), 0o600); err != nil {
|
||||
t.Fatalf("write local branch marker: %v", err)
|
||||
}
|
||||
|
||||
reopened.mu.Lock()
|
||||
err := reopened.commitAndPushLocked("Update develop branch marker", "branch.txt")
|
||||
reopened.mu.Unlock()
|
||||
if err != nil {
|
||||
t.Fatalf("commitAndPushLocked: %v", err)
|
||||
}
|
||||
|
||||
assertRepositoryHeadBranch(t, workspaceDir, "develop")
|
||||
assertRemoteBranchContents(t, remoteDir, "develop", "local develop update\n")
|
||||
assertRemoteBranchContents(t, remoteDir, "master", "remote master branch\n")
|
||||
}
|
||||
|
||||
func TestEnsureRepositoryExistingRepoSwitchesToConfiguredBranchCreatedAfterClone(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
remoteDir := setupGitRemoteRepository(t, root, "master",
|
||||
testBranchSpec{name: "master", contents: "remote master branch\n"},
|
||||
)
|
||||
|
||||
baseDir := filepath.Join(root, "workspace", "auths")
|
||||
store := NewGitTokenStore(remoteDir, "", "", "")
|
||||
store.SetBaseDir(baseDir)
|
||||
|
||||
if err := store.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository initial clone: %v", err)
|
||||
}
|
||||
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
|
||||
|
||||
advanceRemoteBranchFromNewBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch\n", "create release")
|
||||
|
||||
reopened := NewGitTokenStore(remoteDir, "", "", "release/2026")
|
||||
reopened.SetBaseDir(baseDir)
|
||||
|
||||
if err := reopened.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository reopen: %v", err)
|
||||
}
|
||||
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch\n")
|
||||
}
|
||||
|
||||
func TestEnsureRepositoryResetsToRemoteDefaultWhenBranchUnset(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
remoteDir := setupGitRemoteRepository(t, root, "master",
|
||||
testBranchSpec{name: "master", contents: "remote master branch\n"},
|
||||
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
|
||||
)
|
||||
|
||||
baseDir := filepath.Join(root, "workspace", "auths")
|
||||
// First store pins to develop and prepares local workspace
|
||||
storePinned := NewGitTokenStore(remoteDir, "", "", "develop")
|
||||
storePinned.SetBaseDir(baseDir)
|
||||
if err := storePinned.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository pinned: %v", err)
|
||||
}
|
||||
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
|
||||
|
||||
// Second store has branch unset and should reset local workspace to remote default (master)
|
||||
storeDefault := NewGitTokenStore(remoteDir, "", "", "")
|
||||
storeDefault.SetBaseDir(baseDir)
|
||||
if err := storeDefault.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository default: %v", err)
|
||||
}
|
||||
// Local HEAD should now follow remote default (master)
|
||||
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "master")
|
||||
|
||||
// Make a local change and push using the store with branch unset; push should update remote master
|
||||
workspaceDir := filepath.Join(root, "workspace")
|
||||
if err := os.WriteFile(filepath.Join(workspaceDir, "branch.txt"), []byte("local master update\n"), 0o600); err != nil {
|
||||
t.Fatalf("write local master marker: %v", err)
|
||||
}
|
||||
storeDefault.mu.Lock()
|
||||
if err := storeDefault.commitAndPushLocked("Update master marker", "branch.txt"); err != nil {
|
||||
storeDefault.mu.Unlock()
|
||||
t.Fatalf("commitAndPushLocked: %v", err)
|
||||
}
|
||||
storeDefault.mu.Unlock()
|
||||
|
||||
assertRemoteBranchContents(t, remoteDir, "master", "local master update\n")
|
||||
}
|
||||
|
||||
func TestEnsureRepositoryFollowsRenamedRemoteDefaultBranchWhenAvailable(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
remoteDir := setupGitRemoteRepository(t, root, "master",
|
||||
testBranchSpec{name: "master", contents: "remote master branch\n"},
|
||||
testBranchSpec{name: "main", contents: "remote main branch\n"},
|
||||
)
|
||||
|
||||
baseDir := filepath.Join(root, "workspace", "auths")
|
||||
store := NewGitTokenStore(remoteDir, "", "", "")
|
||||
store.SetBaseDir(baseDir)
|
||||
|
||||
if err := store.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository initial clone: %v", err)
|
||||
}
|
||||
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
|
||||
|
||||
setRemoteHeadBranch(t, remoteDir, "main")
|
||||
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "main", "remote main branch updated\n", "advance main")
|
||||
|
||||
reopened := NewGitTokenStore(remoteDir, "", "", "")
|
||||
reopened.SetBaseDir(baseDir)
|
||||
|
||||
if err := reopened.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository after remote default rename: %v", err)
|
||||
}
|
||||
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "main", "remote main branch updated\n")
|
||||
assertRemoteHeadBranch(t, remoteDir, "main")
|
||||
}
|
||||
|
||||
func TestEnsureRepositoryKeepsCurrentBranchWhenRemoteDefaultCannotBeResolved(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
remoteDir := setupGitRemoteRepository(t, root, "master",
|
||||
testBranchSpec{name: "master", contents: "remote master branch\n"},
|
||||
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
|
||||
)
|
||||
|
||||
baseDir := filepath.Join(root, "workspace", "auths")
|
||||
pinned := NewGitTokenStore(remoteDir, "", "", "develop")
|
||||
pinned.SetBaseDir(baseDir)
|
||||
if err := pinned.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository pinned: %v", err)
|
||||
}
|
||||
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
|
||||
|
||||
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="git"`)
|
||||
http.Error(w, "auth required", http.StatusUnauthorized)
|
||||
}))
|
||||
defer authServer.Close()
|
||||
|
||||
repo, err := git.PlainOpen(filepath.Join(root, "workspace"))
|
||||
if err != nil {
|
||||
t.Fatalf("open workspace repo: %v", err)
|
||||
}
|
||||
cfg, err := repo.Config()
|
||||
if err != nil {
|
||||
t.Fatalf("read repo config: %v", err)
|
||||
}
|
||||
cfg.Remotes["origin"].URLs = []string{authServer.URL}
|
||||
if err := repo.SetConfig(cfg); err != nil {
|
||||
t.Fatalf("set repo config: %v", err)
|
||||
}
|
||||
|
||||
reopened := NewGitTokenStore(remoteDir, "", "", "")
|
||||
reopened.SetBaseDir(baseDir)
|
||||
|
||||
if err := reopened.EnsureRepository(); err != nil {
|
||||
t.Fatalf("EnsureRepository default branch fallback: %v", err)
|
||||
}
|
||||
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "develop")
|
||||
}
|
||||
|
||||
func setupGitRemoteRepository(t *testing.T, root, defaultBranch string, branches ...testBranchSpec) string {
|
||||
t.Helper()
|
||||
|
||||
remoteDir := filepath.Join(root, "remote.git")
|
||||
if _, err := git.PlainInit(remoteDir, true); err != nil {
|
||||
t.Fatalf("init bare remote: %v", err)
|
||||
}
|
||||
|
||||
seedDir := filepath.Join(root, "seed")
|
||||
seedRepo, err := git.PlainInit(seedDir, false)
|
||||
if err != nil {
|
||||
t.Fatalf("init seed repo: %v", err)
|
||||
}
|
||||
if err := seedRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(defaultBranch))); err != nil {
|
||||
t.Fatalf("set seed HEAD: %v", err)
|
||||
}
|
||||
|
||||
worktree, err := seedRepo.Worktree()
|
||||
if err != nil {
|
||||
t.Fatalf("open seed worktree: %v", err)
|
||||
}
|
||||
|
||||
defaultSpec, ok := findBranchSpec(branches, defaultBranch)
|
||||
if !ok {
|
||||
t.Fatalf("missing default branch spec for %q", defaultBranch)
|
||||
}
|
||||
commitBranchMarker(t, seedDir, worktree, defaultSpec, "seed default branch")
|
||||
|
||||
for _, branch := range branches {
|
||||
if branch.name == defaultBranch {
|
||||
continue
|
||||
}
|
||||
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(defaultBranch)}); err != nil {
|
||||
t.Fatalf("checkout default branch %s: %v", defaultBranch, err)
|
||||
}
|
||||
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch.name), Create: true}); err != nil {
|
||||
t.Fatalf("create branch %s: %v", branch.name, err)
|
||||
}
|
||||
commitBranchMarker(t, seedDir, worktree, branch, "seed branch "+branch.name)
|
||||
}
|
||||
|
||||
if _, err := seedRepo.CreateRemote(&gitconfig.RemoteConfig{Name: "origin", URLs: []string{remoteDir}}); err != nil {
|
||||
t.Fatalf("create origin remote: %v", err)
|
||||
}
|
||||
if err := seedRepo.Push(&git.PushOptions{
|
||||
RemoteName: "origin",
|
||||
RefSpecs: []gitconfig.RefSpec{gitconfig.RefSpec("refs/heads/*:refs/heads/*")},
|
||||
}); err != nil {
|
||||
t.Fatalf("push seed branches: %v", err)
|
||||
}
|
||||
|
||||
remoteRepo, err := git.PlainOpen(remoteDir)
|
||||
if err != nil {
|
||||
t.Fatalf("open remote repo: %v", err)
|
||||
}
|
||||
if err := remoteRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(defaultBranch))); err != nil {
|
||||
t.Fatalf("set remote HEAD: %v", err)
|
||||
}
|
||||
|
||||
return remoteDir
|
||||
}
|
||||
|
||||
func commitBranchMarker(t *testing.T, seedDir string, worktree *git.Worktree, branch testBranchSpec, message string) {
|
||||
t.Helper()
|
||||
|
||||
if err := os.WriteFile(filepath.Join(seedDir, "branch.txt"), []byte(branch.contents), 0o600); err != nil {
|
||||
t.Fatalf("write branch marker for %s: %v", branch.name, err)
|
||||
}
|
||||
if _, err := worktree.Add("branch.txt"); err != nil {
|
||||
t.Fatalf("add branch marker for %s: %v", branch.name, err)
|
||||
}
|
||||
if _, err := worktree.Commit(message, &git.CommitOptions{
|
||||
Author: &object.Signature{
|
||||
Name: "CLIProxyAPI",
|
||||
Email: "cliproxy@local",
|
||||
When: time.Unix(1711929600, 0),
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("commit branch marker for %s: %v", branch.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func advanceRemoteBranch(t *testing.T, seedDir, remoteDir, branch, contents, message string) {
|
||||
t.Helper()
|
||||
|
||||
seedRepo, err := git.PlainOpen(seedDir)
|
||||
if err != nil {
|
||||
t.Fatalf("open seed repo: %v", err)
|
||||
}
|
||||
worktree, err := seedRepo.Worktree()
|
||||
if err != nil {
|
||||
t.Fatalf("open seed worktree: %v", err)
|
||||
}
|
||||
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch)}); err != nil {
|
||||
t.Fatalf("checkout branch %s: %v", branch, err)
|
||||
}
|
||||
commitBranchMarker(t, seedDir, worktree, testBranchSpec{name: branch, contents: contents}, message)
|
||||
if err := seedRepo.Push(&git.PushOptions{
|
||||
RemoteName: "origin",
|
||||
RefSpecs: []gitconfig.RefSpec{
|
||||
gitconfig.RefSpec(plumbing.NewBranchReferenceName(branch).String() + ":" + plumbing.NewBranchReferenceName(branch).String()),
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("push branch %s update to %s: %v", branch, remoteDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
func advanceRemoteBranchFromNewBranch(t *testing.T, seedDir, remoteDir, branch, contents, message string) {
|
||||
t.Helper()
|
||||
|
||||
seedRepo, err := git.PlainOpen(seedDir)
|
||||
if err != nil {
|
||||
t.Fatalf("open seed repo: %v", err)
|
||||
}
|
||||
worktree, err := seedRepo.Worktree()
|
||||
if err != nil {
|
||||
t.Fatalf("open seed worktree: %v", err)
|
||||
}
|
||||
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName("master")}); err != nil {
|
||||
t.Fatalf("checkout master before creating %s: %v", branch, err)
|
||||
}
|
||||
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch), Create: true}); err != nil {
|
||||
t.Fatalf("create branch %s: %v", branch, err)
|
||||
}
|
||||
commitBranchMarker(t, seedDir, worktree, testBranchSpec{name: branch, contents: contents}, message)
|
||||
if err := seedRepo.Push(&git.PushOptions{
|
||||
RemoteName: "origin",
|
||||
RefSpecs: []gitconfig.RefSpec{
|
||||
gitconfig.RefSpec(plumbing.NewBranchReferenceName(branch).String() + ":" + plumbing.NewBranchReferenceName(branch).String()),
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("push new branch %s update to %s: %v", branch, remoteDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
func findBranchSpec(branches []testBranchSpec, name string) (testBranchSpec, bool) {
|
||||
for _, branch := range branches {
|
||||
if branch.name == name {
|
||||
return branch, true
|
||||
}
|
||||
}
|
||||
return testBranchSpec{}, false
|
||||
}
|
||||
|
||||
func assertRepositoryBranchAndContents(t *testing.T, repoDir, branch, wantContents string) {
|
||||
t.Helper()
|
||||
|
||||
repo, err := git.PlainOpen(repoDir)
|
||||
if err != nil {
|
||||
t.Fatalf("open local repo: %v", err)
|
||||
}
|
||||
head, err := repo.Head()
|
||||
if err != nil {
|
||||
t.Fatalf("local repo head: %v", err)
|
||||
}
|
||||
if got, want := head.Name(), plumbing.NewBranchReferenceName(branch); got != want {
|
||||
t.Fatalf("local head branch = %s, want %s", got, want)
|
||||
}
|
||||
contents, err := os.ReadFile(filepath.Join(repoDir, "branch.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("read branch marker: %v", err)
|
||||
}
|
||||
if got := string(contents); got != wantContents {
|
||||
t.Fatalf("branch marker contents = %q, want %q", got, wantContents)
|
||||
}
|
||||
}
|
||||
|
||||
func assertRepositoryHeadBranch(t *testing.T, repoDir, branch string) {
|
||||
t.Helper()
|
||||
|
||||
repo, err := git.PlainOpen(repoDir)
|
||||
if err != nil {
|
||||
t.Fatalf("open local repo: %v", err)
|
||||
}
|
||||
head, err := repo.Head()
|
||||
if err != nil {
|
||||
t.Fatalf("local repo head: %v", err)
|
||||
}
|
||||
if got, want := head.Name(), plumbing.NewBranchReferenceName(branch); got != want {
|
||||
t.Fatalf("local head branch = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertRemoteHeadBranch(t *testing.T, remoteDir, branch string) {
|
||||
t.Helper()
|
||||
|
||||
remoteRepo, err := git.PlainOpen(remoteDir)
|
||||
if err != nil {
|
||||
t.Fatalf("open remote repo: %v", err)
|
||||
}
|
||||
head, err := remoteRepo.Reference(plumbing.HEAD, false)
|
||||
if err != nil {
|
||||
t.Fatalf("read remote HEAD: %v", err)
|
||||
}
|
||||
if got, want := head.Target(), plumbing.NewBranchReferenceName(branch); got != want {
|
||||
t.Fatalf("remote HEAD target = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func setRemoteHeadBranch(t *testing.T, remoteDir, branch string) {
|
||||
t.Helper()
|
||||
|
||||
remoteRepo, err := git.PlainOpen(remoteDir)
|
||||
if err != nil {
|
||||
t.Fatalf("open remote repo: %v", err)
|
||||
}
|
||||
if err := remoteRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(branch))); err != nil {
|
||||
t.Fatalf("set remote HEAD to %s: %v", branch, err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertRemoteBranchExistsWithCommit(t *testing.T, remoteDir, branch string) {
|
||||
t.Helper()
|
||||
|
||||
remoteRepo, err := git.PlainOpen(remoteDir)
|
||||
if err != nil {
|
||||
t.Fatalf("open remote repo: %v", err)
|
||||
}
|
||||
ref, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false)
|
||||
if err != nil {
|
||||
t.Fatalf("read remote branch %s: %v", branch, err)
|
||||
}
|
||||
if got := ref.Hash(); got == plumbing.ZeroHash {
|
||||
t.Fatalf("remote branch %s hash = %s, want non-zero hash", branch, got)
|
||||
}
|
||||
}
|
||||
|
||||
func assertRemoteBranchDoesNotExist(t *testing.T, remoteDir, branch string) {
|
||||
t.Helper()
|
||||
|
||||
remoteRepo, err := git.PlainOpen(remoteDir)
|
||||
if err != nil {
|
||||
t.Fatalf("open remote repo: %v", err)
|
||||
}
|
||||
if _, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false); err == nil {
|
||||
t.Fatalf("remote branch %s exists, want missing", branch)
|
||||
} else if err != plumbing.ErrReferenceNotFound {
|
||||
t.Fatalf("read remote branch %s: %v", branch, err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertRemoteBranchContents(t *testing.T, remoteDir, branch, wantContents string) {
|
||||
t.Helper()
|
||||
|
||||
remoteRepo, err := git.PlainOpen(remoteDir)
|
||||
if err != nil {
|
||||
t.Fatalf("open remote repo: %v", err)
|
||||
}
|
||||
ref, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false)
|
||||
if err != nil {
|
||||
t.Fatalf("read remote branch %s: %v", branch, err)
|
||||
}
|
||||
commit, err := remoteRepo.CommitObject(ref.Hash())
|
||||
if err != nil {
|
||||
t.Fatalf("read remote branch %s commit: %v", branch, err)
|
||||
}
|
||||
tree, err := commit.Tree()
|
||||
if err != nil {
|
||||
t.Fatalf("read remote branch %s tree: %v", branch, err)
|
||||
}
|
||||
file, err := tree.File("branch.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("read remote branch %s file: %v", branch, err)
|
||||
}
|
||||
contents, err := file.Contents()
|
||||
if err != nil {
|
||||
t.Fatalf("read remote branch %s contents: %v", branch, err)
|
||||
}
|
||||
if contents != wantContents {
|
||||
t.Fatalf("remote branch %s contents = %q, want %q", branch, contents, wantContents)
|
||||
}
|
||||
}
|
||||
@@ -174,7 +174,8 @@ func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo
|
||||
// Ensure the request satisfies Claude constraints:
|
||||
// 1) Determine effective max_tokens (request overrides model default)
|
||||
// 2) If budget_tokens >= max_tokens, reduce budget_tokens to max_tokens-1
|
||||
// 3) If the adjusted budget falls below the model minimum, leave the request unchanged
|
||||
// 3) If the adjusted budget falls below the model minimum, try raising max_tokens
|
||||
// (clamped to MaxCompletionTokens); disable thinking if constraints are unsatisfiable
|
||||
// 4) If max_tokens came from model default, write it back into the request
|
||||
|
||||
effectiveMax, setDefaultMax := a.effectiveMaxTokens(body, modelInfo)
|
||||
@@ -193,8 +194,28 @@ func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo
|
||||
minBudget = modelInfo.Thinking.Min
|
||||
}
|
||||
if minBudget > 0 && adjustedBudget > 0 && adjustedBudget < minBudget {
|
||||
// If enforcing the max_tokens constraint would push the budget below the model minimum,
|
||||
// leave the request unchanged.
|
||||
// Enforcing budget_tokens < max_tokens pushed the budget below the model minimum.
|
||||
// Try raising max_tokens to fit the original budget.
|
||||
needed := budgetTokens + 1
|
||||
maxAllowed := 0
|
||||
if modelInfo != nil {
|
||||
maxAllowed = modelInfo.MaxCompletionTokens
|
||||
}
|
||||
if maxAllowed > 0 && needed > maxAllowed {
|
||||
// Cannot use original budget; cap max_tokens at model limit.
|
||||
needed = maxAllowed
|
||||
}
|
||||
cappedBudget := needed - 1
|
||||
if cappedBudget < minBudget {
|
||||
// Impossible to satisfy both budget >= minBudget and budget < max_tokens
|
||||
// within the model's completion limit. Disable thinking entirely.
|
||||
body, _ = sjson.DeleteBytes(body, "thinking")
|
||||
return body
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "max_tokens", needed)
|
||||
if cappedBudget != budgetTokens {
|
||||
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", cappedBudget)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
|
||||
99
internal/thinking/provider/claude/apply_test.go
Normal file
99
internal/thinking/provider/claude/apply_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNormalizeClaudeBudget_RaisesMaxTokens(t *testing.T) {
|
||||
a := &Applier{}
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000},
|
||||
}
|
||||
body := []byte(`{"max_tokens":1000,"thinking":{"type":"enabled","budget_tokens":5000}}`)
|
||||
|
||||
out := a.normalizeClaudeBudget(body, 5000, modelInfo)
|
||||
|
||||
maxTok := gjson.GetBytes(out, "max_tokens").Int()
|
||||
if maxTok != 5001 {
|
||||
t.Fatalf("max_tokens = %d, want 5001, body=%s", maxTok, string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeBudget_ClampsToModelMax(t *testing.T) {
|
||||
a := &Applier{}
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000},
|
||||
}
|
||||
body := []byte(`{"max_tokens":500,"thinking":{"type":"enabled","budget_tokens":200000}}`)
|
||||
|
||||
out := a.normalizeClaudeBudget(body, 200000, modelInfo)
|
||||
|
||||
maxTok := gjson.GetBytes(out, "max_tokens").Int()
|
||||
if maxTok != 64000 {
|
||||
t.Fatalf("max_tokens = %d, want 64000 (capped to model limit), body=%s", maxTok, string(out))
|
||||
}
|
||||
budget := gjson.GetBytes(out, "thinking.budget_tokens").Int()
|
||||
if budget != 63999 {
|
||||
t.Fatalf("budget_tokens = %d, want 63999 (max_tokens-1), body=%s", budget, string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeBudget_DisablesThinkingWhenUnsatisfiable(t *testing.T) {
|
||||
a := &Applier{}
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
MaxCompletionTokens: 1000,
|
||||
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000},
|
||||
}
|
||||
body := []byte(`{"max_tokens":500,"thinking":{"type":"enabled","budget_tokens":2000}}`)
|
||||
|
||||
out := a.normalizeClaudeBudget(body, 2000, modelInfo)
|
||||
|
||||
if gjson.GetBytes(out, "thinking").Exists() {
|
||||
t.Fatalf("thinking should be removed when constraints are unsatisfiable, body=%s", string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeBudget_NoClamping(t *testing.T) {
|
||||
a := &Applier{}
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000},
|
||||
}
|
||||
body := []byte(`{"max_tokens":32000,"thinking":{"type":"enabled","budget_tokens":16000}}`)
|
||||
|
||||
out := a.normalizeClaudeBudget(body, 16000, modelInfo)
|
||||
|
||||
maxTok := gjson.GetBytes(out, "max_tokens").Int()
|
||||
if maxTok != 32000 {
|
||||
t.Fatalf("max_tokens should remain 32000, got %d, body=%s", maxTok, string(out))
|
||||
}
|
||||
budget := gjson.GetBytes(out, "thinking.budget_tokens").Int()
|
||||
if budget != 16000 {
|
||||
t.Fatalf("budget_tokens should remain 16000, got %d, body=%s", budget, string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeBudget_AdjustsBudgetToMaxMinus1(t *testing.T) {
|
||||
a := &Applier{}
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
MaxCompletionTokens: 8192,
|
||||
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000},
|
||||
}
|
||||
body := []byte(`{"max_tokens":8192,"thinking":{"type":"enabled","budget_tokens":10000}}`)
|
||||
|
||||
out := a.normalizeClaudeBudget(body, 10000, modelInfo)
|
||||
|
||||
maxTok := gjson.GetBytes(out, "max_tokens").Int()
|
||||
if maxTok != 8192 {
|
||||
t.Fatalf("max_tokens = %d, want 8192 (unchanged), body=%s", maxTok, string(out))
|
||||
}
|
||||
budget := gjson.GetBytes(out, "thinking.budget_tokens").Int()
|
||||
if budget != 8191 {
|
||||
t.Fatalf("budget_tokens = %d, want 8191 (max_tokens-1), body=%s", budget, string(out))
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,9 @@ type ConvertCodexResponseToClaudeParams struct {
|
||||
HasToolCall bool
|
||||
BlockIndex int
|
||||
HasReceivedArgumentsDelta bool
|
||||
ThinkingBlockOpen bool
|
||||
ThinkingStopPending bool
|
||||
ThinkingSignature string
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
|
||||
@@ -44,7 +47,7 @@ type ConvertCodexResponseToClaudeParams struct {
|
||||
//
|
||||
// Returns:
|
||||
// - [][]byte: A slice of Claude Code-compatible JSON responses
|
||||
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &ConvertCodexResponseToClaudeParams{
|
||||
HasToolCall: false,
|
||||
@@ -52,7 +55,6 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
}
|
||||
}
|
||||
|
||||
// log.Debugf("rawJSON: %s", string(rawJSON))
|
||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||
return [][]byte{}
|
||||
}
|
||||
@@ -60,9 +62,18 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
|
||||
output := make([]byte, 0, 512)
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
params := (*param).(*ConvertCodexResponseToClaudeParams)
|
||||
if params.ThinkingBlockOpen && params.ThinkingStopPending {
|
||||
switch rootResult.Get("type").String() {
|
||||
case "response.content_part.added", "response.completed":
|
||||
output = append(output, finalizeCodexThinkingBlock(params)...)
|
||||
}
|
||||
}
|
||||
|
||||
typeResult := rootResult.Get("type")
|
||||
typeStr := typeResult.String()
|
||||
var template []byte
|
||||
|
||||
if typeStr == "response.created" {
|
||||
template = []byte(`{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`)
|
||||
template, _ = sjson.SetBytes(template, "message.model", rootResult.Get("response.model").String())
|
||||
@@ -70,43 +81,46 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2)
|
||||
} else if typeStr == "response.reasoning_summary_part.added" {
|
||||
if params.ThinkingBlockOpen && params.ThinkingStopPending {
|
||||
output = append(output, finalizeCodexThinkingBlock(params)...)
|
||||
}
|
||||
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||
params.ThinkingBlockOpen = true
|
||||
params.ThinkingStopPending = false
|
||||
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
||||
} else if typeStr == "response.reasoning_summary_text.delta" {
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "delta.thinking", rootResult.Get("delta").String())
|
||||
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
} else if typeStr == "response.reasoning_summary_part.done" {
|
||||
template = []byte(`{"type":"content_block_stop","index":0}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
||||
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
||||
|
||||
params.ThinkingStopPending = true
|
||||
if params.ThinkingSignature != "" {
|
||||
output = append(output, finalizeCodexThinkingBlock(params)...)
|
||||
}
|
||||
} else if typeStr == "response.content_part.added" {
|
||||
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
||||
} else if typeStr == "response.output_text.delta" {
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "delta.text", rootResult.Get("delta").String())
|
||||
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
} else if typeStr == "response.content_part.done" {
|
||||
template = []byte(`{"type":"content_block_stop","index":0}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
||||
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||
params.BlockIndex++
|
||||
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
||||
} else if typeStr == "response.completed" {
|
||||
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
|
||||
p := params.HasToolCall
|
||||
stopReason := rootResult.Get("response.stop_reason").String()
|
||||
if p {
|
||||
template, _ = sjson.SetBytes(template, "delta.stop_reason", "tool_use")
|
||||
@@ -128,13 +142,13 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
itemResult := rootResult.Get("item")
|
||||
itemType := itemResult.Get("type").String()
|
||||
if itemType == "function_call" {
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
|
||||
output = append(output, finalizeCodexThinkingBlock(params)...)
|
||||
params.HasToolCall = true
|
||||
params.HasReceivedArgumentsDelta = false
|
||||
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
|
||||
{
|
||||
// Restore original tool name if shortened
|
||||
name := itemResult.Get("name").String()
|
||||
rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
|
||||
if orig, ok := rev[name]; ok {
|
||||
@@ -146,37 +160,43 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
||||
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
} else if itemType == "reasoning" {
|
||||
params.ThinkingSignature = itemResult.Get("encrypted_content").String()
|
||||
if params.ThinkingStopPending {
|
||||
output = append(output, finalizeCodexThinkingBlock(params)...)
|
||||
}
|
||||
}
|
||||
} else if typeStr == "response.output_item.done" {
|
||||
itemResult := rootResult.Get("item")
|
||||
itemType := itemResult.Get("type").String()
|
||||
if itemType == "function_call" {
|
||||
template = []byte(`{"type":"content_block_stop","index":0}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
||||
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||
params.BlockIndex++
|
||||
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
||||
} else if itemType == "reasoning" {
|
||||
if signature := itemResult.Get("encrypted_content").String(); signature != "" {
|
||||
params.ThinkingSignature = signature
|
||||
}
|
||||
output = append(output, finalizeCodexThinkingBlock(params)...)
|
||||
params.ThinkingSignature = ""
|
||||
}
|
||||
} else if typeStr == "response.function_call_arguments.delta" {
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
|
||||
params.HasReceivedArgumentsDelta = true
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String())
|
||||
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
} else if typeStr == "response.function_call_arguments.done" {
|
||||
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
|
||||
// in a single "done" event without preceding "delta" events.
|
||||
// Emit the full arguments as a single input_json_delta so the
|
||||
// downstream Claude client receives the complete tool input.
|
||||
// When delta events were already received, skip to avoid duplicating arguments.
|
||||
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
|
||||
if !params.HasReceivedArgumentsDelta {
|
||||
if args := rootResult.Get("arguments").String(); args != "" {
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "delta.partial_json", args)
|
||||
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
@@ -191,15 +211,6 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
// This function processes the complete Codex response and transforms it into a single Claude Code-compatible
|
||||
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
|
||||
// the information into a single response that matches the Claude Code API format.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request, used for cancellation and timeout handling
|
||||
// - modelName: The name of the model being used for the response (unused in current implementation)
|
||||
// - rawJSON: The raw JSON response from the Codex API
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - []byte: A Claude Code-compatible JSON response containing all message content and metadata
|
||||
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) []byte {
|
||||
revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
|
||||
|
||||
@@ -230,6 +241,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
switch item.Get("type").String() {
|
||||
case "reasoning":
|
||||
thinkingBuilder := strings.Builder{}
|
||||
signature := item.Get("encrypted_content").String()
|
||||
if summary := item.Get("summary"); summary.Exists() {
|
||||
if summary.IsArray() {
|
||||
summary.ForEach(func(_, part gjson.Result) bool {
|
||||
@@ -260,9 +272,12 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
}
|
||||
}
|
||||
}
|
||||
if thinkingBuilder.Len() > 0 {
|
||||
if thinkingBuilder.Len() > 0 || signature != "" {
|
||||
block := []byte(`{"type":"thinking","thinking":""}`)
|
||||
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
|
||||
if signature != "" {
|
||||
block, _ = sjson.SetBytes(block, "signature", signature)
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "content.-1", block)
|
||||
}
|
||||
case "message":
|
||||
@@ -371,6 +386,30 @@ func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[strin
|
||||
return rev
|
||||
}
|
||||
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
|
||||
func ClaudeTokenCount(_ context.Context, count int64) []byte {
|
||||
return translatorcommon.ClaudeInputTokensJSON(count)
|
||||
}
|
||||
|
||||
func finalizeCodexThinkingBlock(params *ConvertCodexResponseToClaudeParams) []byte {
|
||||
if !params.ThinkingBlockOpen {
|
||||
return nil
|
||||
}
|
||||
|
||||
output := make([]byte, 0, 256)
|
||||
if params.ThinkingSignature != "" {
|
||||
signatureDelta := []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"signature_delta","signature":""}}`)
|
||||
signatureDelta, _ = sjson.SetBytes(signatureDelta, "index", params.BlockIndex)
|
||||
signatureDelta, _ = sjson.SetBytes(signatureDelta, "delta.signature", params.ThinkingSignature)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", signatureDelta, 2)
|
||||
}
|
||||
|
||||
contentBlockStop := []byte(`{"type":"content_block_stop","index":0}`)
|
||||
contentBlockStop, _ = sjson.SetBytes(contentBlockStop, "index", params.BlockIndex)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", contentBlockStop, 2)
|
||||
|
||||
params.BlockIndex++
|
||||
params.ThinkingBlockOpen = false
|
||||
params.ThinkingStopPending = false
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
282
internal/translator/codex/claude/codex_claude_response_test.go
Normal file
282
internal/translator/codex/claude/codex_claude_response_test.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertCodexResponseToClaude_StreamThinkingIncludesSignature(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
originalRequest := []byte(`{"messages":[]}`)
|
||||
var param any
|
||||
|
||||
chunks := [][]byte{
|
||||
[]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_123\",\"model\":\"gpt-5\"}}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
|
||||
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_123\"}}"),
|
||||
}
|
||||
|
||||
var outputs [][]byte
|
||||
for _, chunk := range chunks {
|
||||
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...)
|
||||
}
|
||||
|
||||
startFound := false
|
||||
signatureDeltaFound := false
|
||||
stopFound := false
|
||||
|
||||
for _, out := range outputs {
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
|
||||
switch data.Get("type").String() {
|
||||
case "content_block_start":
|
||||
if data.Get("content_block.type").String() == "thinking" {
|
||||
startFound = true
|
||||
if data.Get("content_block.signature").Exists() {
|
||||
t.Fatalf("thinking start block should NOT have signature field when signature is unknown: %s", line)
|
||||
}
|
||||
}
|
||||
case "content_block_delta":
|
||||
if data.Get("delta.type").String() == "signature_delta" {
|
||||
signatureDeltaFound = true
|
||||
if got := data.Get("delta.signature").String(); got != "enc_sig_123" {
|
||||
t.Fatalf("unexpected signature delta: %q", got)
|
||||
}
|
||||
}
|
||||
case "content_block_stop":
|
||||
stopFound = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !startFound {
|
||||
t.Fatal("expected thinking content_block_start event")
|
||||
}
|
||||
if !signatureDeltaFound {
|
||||
t.Fatal("expected signature_delta event for thinking block")
|
||||
}
|
||||
if !stopFound {
|
||||
t.Fatal("expected content_block_stop event for thinking block")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertCodexResponseToClaude_StreamThinkingWithoutReasoningItemStillIncludesSignatureField(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
originalRequest := []byte(`{"messages":[]}`)
|
||||
var param any
|
||||
|
||||
chunks := [][]byte{
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
|
||||
[]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}"),
|
||||
}
|
||||
|
||||
var outputs [][]byte
|
||||
for _, chunk := range chunks {
|
||||
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...)
|
||||
}
|
||||
|
||||
thinkingStartFound := false
|
||||
thinkingStopFound := false
|
||||
signatureDeltaFound := false
|
||||
|
||||
for _, out := range outputs {
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
|
||||
if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "thinking" {
|
||||
thinkingStartFound = true
|
||||
if data.Get("content_block.signature").Exists() {
|
||||
t.Fatalf("thinking start block should NOT have signature field without encrypted_content: %s", line)
|
||||
}
|
||||
}
|
||||
if data.Get("type").String() == "content_block_stop" && data.Get("index").Int() == 0 {
|
||||
thinkingStopFound = true
|
||||
}
|
||||
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
|
||||
signatureDeltaFound = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !thinkingStartFound {
|
||||
t.Fatal("expected thinking content_block_start event")
|
||||
}
|
||||
if !thinkingStopFound {
|
||||
t.Fatal("expected thinking content_block_stop event")
|
||||
}
|
||||
if signatureDeltaFound {
|
||||
t.Fatal("did not expect signature_delta without encrypted_content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertCodexResponseToClaude_StreamThinkingFinalizesPendingBlockBeforeNextSummaryPart(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
originalRequest := []byte(`{"messages":[]}`)
|
||||
var param any
|
||||
|
||||
chunks := [][]byte{
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"First part\"}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||
}
|
||||
|
||||
var outputs [][]byte
|
||||
for _, chunk := range chunks {
|
||||
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...)
|
||||
}
|
||||
|
||||
startCount := 0
|
||||
stopCount := 0
|
||||
for _, out := range outputs {
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
|
||||
if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "thinking" {
|
||||
startCount++
|
||||
}
|
||||
if data.Get("type").String() == "content_block_stop" {
|
||||
stopCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if startCount != 2 {
|
||||
t.Fatalf("expected 2 thinking block starts, got %d", startCount)
|
||||
}
|
||||
if stopCount != 1 {
|
||||
t.Fatalf("expected pending thinking block to be finalized before second start, got %d stops", stopCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertCodexResponseToClaude_StreamThinkingRetainsSignatureAcrossMultipartReasoning(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
originalRequest := []byte(`{"messages":[]}`)
|
||||
var param any
|
||||
|
||||
chunks := [][]byte{
|
||||
[]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_multipart\"}}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"First part\"}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Second part\"}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
|
||||
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\"}}"),
|
||||
}
|
||||
|
||||
var outputs [][]byte
|
||||
for _, chunk := range chunks {
|
||||
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...)
|
||||
}
|
||||
|
||||
signatureDeltaCount := 0
|
||||
for _, out := range outputs {
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
|
||||
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
|
||||
signatureDeltaCount++
|
||||
if got := data.Get("delta.signature").String(); got != "enc_sig_multipart" {
|
||||
t.Fatalf("unexpected signature delta: %q", got)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if signatureDeltaCount != 2 {
|
||||
t.Fatalf("expected signature_delta for both multipart thinking blocks, got %d", signatureDeltaCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertCodexResponseToClaude_StreamThinkingUsesEarlyCapturedSignatureWhenDoneOmitsIt(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
originalRequest := []byte(`{"messages":[]}`)
|
||||
var param any
|
||||
|
||||
chunks := [][]byte{
|
||||
[]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_early\"}}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
|
||||
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\"}}"),
|
||||
}
|
||||
|
||||
var outputs [][]byte
|
||||
for _, chunk := range chunks {
|
||||
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...)
|
||||
}
|
||||
|
||||
signatureDeltaCount := 0
|
||||
for _, out := range outputs {
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
|
||||
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
|
||||
signatureDeltaCount++
|
||||
if got := data.Get("delta.signature").String(); got != "enc_sig_early" {
|
||||
t.Fatalf("unexpected signature delta: %q", got)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if signatureDeltaCount != 1 {
|
||||
t.Fatalf("expected signature_delta from early-captured signature, got %d", signatureDeltaCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertCodexResponseToClaudeNonStream_ThinkingIncludesSignature(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
originalRequest := []byte(`{"messages":[]}`)
|
||||
response := []byte(`{
|
||||
"type":"response.completed",
|
||||
"response":{
|
||||
"id":"resp_123",
|
||||
"model":"gpt-5",
|
||||
"usage":{"input_tokens":10,"output_tokens":20},
|
||||
"output":[
|
||||
{
|
||||
"type":"reasoning",
|
||||
"encrypted_content":"enc_sig_nonstream",
|
||||
"summary":[{"type":"summary_text","text":"internal reasoning"}]
|
||||
},
|
||||
{
|
||||
"type":"message",
|
||||
"content":[{"type":"output_text","text":"final answer"}]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
|
||||
out := ConvertCodexResponseToClaudeNonStream(ctx, "", originalRequest, nil, response, nil)
|
||||
parsed := gjson.ParseBytes(out)
|
||||
|
||||
thinking := parsed.Get("content.0")
|
||||
if thinking.Get("type").String() != "thinking" {
|
||||
t.Fatalf("expected first content block to be thinking, got %s", thinking.Raw)
|
||||
}
|
||||
if got := thinking.Get("signature").String(); got != "enc_sig_nonstream" {
|
||||
t.Fatalf("expected signature to be preserved, got %q", got)
|
||||
}
|
||||
if got := thinking.Get("thinking").String(); got != "internal reasoning" {
|
||||
t.Fatalf("unexpected thinking text: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -20,12 +20,14 @@ type oaiToResponsesStateReasoning struct {
|
||||
OutputIndex int
|
||||
}
|
||||
type oaiToResponsesState struct {
|
||||
Seq int
|
||||
ResponseID string
|
||||
Created int64
|
||||
Started bool
|
||||
ReasoningID string
|
||||
ReasoningIndex int
|
||||
Seq int
|
||||
ResponseID string
|
||||
Created int64
|
||||
Started bool
|
||||
CompletionPending bool
|
||||
CompletedEmitted bool
|
||||
ReasoningID string
|
||||
ReasoningIndex int
|
||||
// aggregation buffers for response.output
|
||||
// Per-output message text buffers by index
|
||||
MsgTextBuf map[int]*strings.Builder
|
||||
@@ -60,6 +62,141 @@ func emitRespEvent(event string, payload []byte) []byte {
|
||||
return translatorcommon.SSEEventData(event, payload)
|
||||
}
|
||||
|
||||
func buildResponsesCompletedEvent(st *oaiToResponsesState, requestRawJSON []byte, nextSeq func() int) []byte {
|
||||
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
|
||||
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
|
||||
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
|
||||
completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created)
|
||||
// Inject original request fields into response as per docs/response.completed.json
|
||||
if requestRawJSON != nil {
|
||||
req := gjson.ParseBytes(requestRawJSON)
|
||||
if v := req.Get("instructions"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
|
||||
}
|
||||
if v := req.Get("max_output_tokens"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
|
||||
}
|
||||
if v := req.Get("max_tool_calls"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
|
||||
}
|
||||
if v := req.Get("model"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
|
||||
}
|
||||
if v := req.Get("parallel_tool_calls"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
|
||||
}
|
||||
if v := req.Get("previous_response_id"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
|
||||
}
|
||||
if v := req.Get("prompt_cache_key"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
|
||||
}
|
||||
if v := req.Get("reasoning"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
|
||||
}
|
||||
if v := req.Get("safety_identifier"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
|
||||
}
|
||||
if v := req.Get("service_tier"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
|
||||
}
|
||||
if v := req.Get("store"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
|
||||
}
|
||||
if v := req.Get("temperature"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
|
||||
}
|
||||
if v := req.Get("text"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
|
||||
}
|
||||
if v := req.Get("tool_choice"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
|
||||
}
|
||||
if v := req.Get("tools"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
|
||||
}
|
||||
if v := req.Get("top_logprobs"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
|
||||
}
|
||||
if v := req.Get("top_p"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
|
||||
}
|
||||
if v := req.Get("truncation"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
|
||||
}
|
||||
if v := req.Get("user"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
|
||||
}
|
||||
if v := req.Get("metadata"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
|
||||
}
|
||||
}
|
||||
|
||||
outputsWrapper := []byte(`{"arr":[]}`)
|
||||
type completedOutputItem struct {
|
||||
index int
|
||||
raw []byte
|
||||
}
|
||||
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
|
||||
if len(st.Reasonings) > 0 {
|
||||
for _, r := range st.Reasonings {
|
||||
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
|
||||
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
|
||||
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
|
||||
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
|
||||
}
|
||||
}
|
||||
if len(st.MsgItemAdded) > 0 {
|
||||
for i := range st.MsgItemAdded {
|
||||
txt := ""
|
||||
if b := st.MsgTextBuf[i]; b != nil {
|
||||
txt = b.String()
|
||||
}
|
||||
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
item, _ = sjson.SetBytes(item, "content.0.text", txt)
|
||||
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
|
||||
}
|
||||
}
|
||||
if len(st.FuncArgsBuf) > 0 {
|
||||
for key := range st.FuncArgsBuf {
|
||||
args := ""
|
||||
if b := st.FuncArgsBuf[key]; b != nil {
|
||||
args = b.String()
|
||||
}
|
||||
callID := st.FuncCallIDs[key]
|
||||
name := st.FuncNames[key]
|
||||
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||
item, _ = sjson.SetBytes(item, "arguments", args)
|
||||
item, _ = sjson.SetBytes(item, "call_id", callID)
|
||||
item, _ = sjson.SetBytes(item, "name", name)
|
||||
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
|
||||
}
|
||||
}
|
||||
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
|
||||
for _, item := range outputItems {
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
|
||||
}
|
||||
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
|
||||
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
|
||||
}
|
||||
if st.UsageSeen {
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens)
|
||||
if st.ReasoningTokens > 0 {
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
|
||||
}
|
||||
total := st.TotalTokens
|
||||
if total == 0 {
|
||||
total = st.PromptTokens + st.CompletionTokens
|
||||
}
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
|
||||
}
|
||||
return emitRespEvent("response.completed", completed)
|
||||
}
|
||||
|
||||
// ConvertOpenAIChatCompletionsResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
|
||||
// to OpenAI Responses SSE events (response.*).
|
||||
func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
@@ -90,6 +227,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
return [][]byte{}
|
||||
}
|
||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||
if st.CompletionPending && !st.CompletedEmitted {
|
||||
st.CompletedEmitted = true
|
||||
return [][]byte{buildResponsesCompletedEvent(st, requestRawJSON, func() int { st.Seq++; return st.Seq })}
|
||||
}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
@@ -165,6 +306,8 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
st.TotalTokens = 0
|
||||
st.ReasoningTokens = 0
|
||||
st.UsageSeen = false
|
||||
st.CompletionPending = false
|
||||
st.CompletedEmitted = false
|
||||
// response.created
|
||||
created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
|
||||
created, _ = sjson.SetBytes(created, "sequence_number", nextSeq())
|
||||
@@ -374,8 +517,9 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
}
|
||||
}
|
||||
|
||||
// finish_reason triggers finalization, including text done/content done/item done,
|
||||
// reasoning done/part.done, function args done/item done, and completed
|
||||
// finish_reason triggers item-level finalization. response.completed is
|
||||
// deferred until the terminal [DONE] marker so late usage-only chunks can
|
||||
// still populate response.usage.
|
||||
if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" {
|
||||
// Emit message done events for all indices that started a message
|
||||
if len(st.MsgItemAdded) > 0 {
|
||||
@@ -464,138 +608,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
st.FuncArgsDone[key] = true
|
||||
}
|
||||
}
|
||||
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
|
||||
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
|
||||
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
|
||||
completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created)
|
||||
// Inject original request fields into response as per docs/response.completed.json
|
||||
if requestRawJSON != nil {
|
||||
req := gjson.ParseBytes(requestRawJSON)
|
||||
if v := req.Get("instructions"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
|
||||
}
|
||||
if v := req.Get("max_output_tokens"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
|
||||
}
|
||||
if v := req.Get("max_tool_calls"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
|
||||
}
|
||||
if v := req.Get("model"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
|
||||
}
|
||||
if v := req.Get("parallel_tool_calls"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
|
||||
}
|
||||
if v := req.Get("previous_response_id"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
|
||||
}
|
||||
if v := req.Get("prompt_cache_key"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
|
||||
}
|
||||
if v := req.Get("reasoning"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
|
||||
}
|
||||
if v := req.Get("safety_identifier"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
|
||||
}
|
||||
if v := req.Get("service_tier"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
|
||||
}
|
||||
if v := req.Get("store"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
|
||||
}
|
||||
if v := req.Get("temperature"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
|
||||
}
|
||||
if v := req.Get("text"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
|
||||
}
|
||||
if v := req.Get("tool_choice"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
|
||||
}
|
||||
if v := req.Get("tools"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
|
||||
}
|
||||
if v := req.Get("top_logprobs"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
|
||||
}
|
||||
if v := req.Get("top_p"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
|
||||
}
|
||||
if v := req.Get("truncation"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
|
||||
}
|
||||
if v := req.Get("user"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
|
||||
}
|
||||
if v := req.Get("metadata"); v.Exists() {
|
||||
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
|
||||
}
|
||||
}
|
||||
// Build response.output using aggregated buffers
|
||||
outputsWrapper := []byte(`{"arr":[]}`)
|
||||
type completedOutputItem struct {
|
||||
index int
|
||||
raw []byte
|
||||
}
|
||||
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
|
||||
if len(st.Reasonings) > 0 {
|
||||
for _, r := range st.Reasonings {
|
||||
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
|
||||
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
|
||||
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
|
||||
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
|
||||
}
|
||||
}
|
||||
if len(st.MsgItemAdded) > 0 {
|
||||
for i := range st.MsgItemAdded {
|
||||
txt := ""
|
||||
if b := st.MsgTextBuf[i]; b != nil {
|
||||
txt = b.String()
|
||||
}
|
||||
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
item, _ = sjson.SetBytes(item, "content.0.text", txt)
|
||||
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
|
||||
}
|
||||
}
|
||||
if len(st.FuncArgsBuf) > 0 {
|
||||
for key := range st.FuncArgsBuf {
|
||||
args := ""
|
||||
if b := st.FuncArgsBuf[key]; b != nil {
|
||||
args = b.String()
|
||||
}
|
||||
callID := st.FuncCallIDs[key]
|
||||
name := st.FuncNames[key]
|
||||
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||
item, _ = sjson.SetBytes(item, "arguments", args)
|
||||
item, _ = sjson.SetBytes(item, "call_id", callID)
|
||||
item, _ = sjson.SetBytes(item, "name", name)
|
||||
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
|
||||
}
|
||||
}
|
||||
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
|
||||
for _, item := range outputItems {
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
|
||||
}
|
||||
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
|
||||
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
|
||||
}
|
||||
if st.UsageSeen {
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens)
|
||||
if st.ReasoningTokens > 0 {
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
|
||||
}
|
||||
total := st.TotalTokens
|
||||
if total == 0 {
|
||||
total = st.PromptTokens + st.CompletionTokens
|
||||
}
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
|
||||
}
|
||||
out = append(out, emitRespEvent("response.completed", completed))
|
||||
st.CompletionPending = true
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
@@ -24,6 +24,120 @@ func parseOpenAIResponsesSSEEvent(t *testing.T, chunk []byte) (string, gjson.Res
|
||||
return event, gjson.Parse(dataLine)
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_ResponseCompletedWaitsForDone(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in []string
|
||||
doneInputIndex int // Index in tt.in where the terminal [DONE] chunk arrives and response.completed must be emitted.
|
||||
hasUsage bool
|
||||
inputTokens int64
|
||||
outputTokens int64
|
||||
totalTokens int64
|
||||
}{
|
||||
{
|
||||
// A provider may send finish_reason first and only attach usage in a later chunk (e.g. Vertex AI),
|
||||
// so response.completed must wait for [DONE] to include that usage.
|
||||
name: "late usage after finish reason",
|
||||
in: []string{
|
||||
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_late_usage","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}]}`,
|
||||
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[],"usage":{"prompt_tokens":11,"completion_tokens":7,"total_tokens":18}}`,
|
||||
`data: [DONE]`,
|
||||
},
|
||||
doneInputIndex: 3,
|
||||
hasUsage: true,
|
||||
inputTokens: 11,
|
||||
outputTokens: 7,
|
||||
totalTokens: 18,
|
||||
},
|
||||
{
|
||||
// When usage arrives on the same chunk as finish_reason, we still expect a
|
||||
// single response.completed event and it should remain deferred until [DONE].
|
||||
name: "usage on finish reason chunk",
|
||||
in: []string{
|
||||
`data: {"id":"resp_usage_same_chunk","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_usage_same_chunk","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_usage_same_chunk","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":13,"completion_tokens":5,"total_tokens":18}}`,
|
||||
`data: [DONE]`,
|
||||
},
|
||||
doneInputIndex: 2,
|
||||
hasUsage: true,
|
||||
inputTokens: 13,
|
||||
outputTokens: 5,
|
||||
totalTokens: 18,
|
||||
},
|
||||
{
|
||||
// An OpenAI-compatible streams from a buggy server might never send usage, so response.completed should
|
||||
// still wait for [DONE] but omit the usage object entirely.
|
||||
name: "no usage chunk",
|
||||
in: []string{
|
||||
`data: {"id":"resp_no_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_no_usage","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_no_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}]}`,
|
||||
`data: [DONE]`,
|
||||
},
|
||||
doneInputIndex: 2,
|
||||
hasUsage: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
completedCount := 0
|
||||
completedInputIndex := -1
|
||||
var completedData gjson.Result
|
||||
|
||||
// Reuse converter state across input lines to simulate one streaming response.
|
||||
var param any
|
||||
|
||||
for i, line := range tt.in {
|
||||
// One upstream chunk can emit multiple downstream SSE events.
|
||||
for _, chunk := range ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m) {
|
||||
event, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
if event != "response.completed" {
|
||||
continue
|
||||
}
|
||||
|
||||
completedCount++
|
||||
completedInputIndex = i
|
||||
completedData = data
|
||||
if i < tt.doneInputIndex {
|
||||
t.Fatalf("unexpected early response.completed on input index %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if completedCount != 1 {
|
||||
t.Fatalf("expected exactly 1 response.completed event, got %d", completedCount)
|
||||
}
|
||||
if completedInputIndex != tt.doneInputIndex {
|
||||
t.Fatalf("expected response.completed on terminal [DONE] chunk at input index %d, got %d", tt.doneInputIndex, completedInputIndex)
|
||||
}
|
||||
|
||||
// Missing upstream usage should stay omitted in the final completed event.
|
||||
if !tt.hasUsage {
|
||||
if completedData.Get("response.usage").Exists() {
|
||||
t.Fatalf("expected response.completed to omit usage when none was provided, got %s", completedData.Get("response.usage").Raw)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// When usage is present, the final response.completed event must preserve the usage values.
|
||||
if got := completedData.Get("response.usage.input_tokens").Int(); got != tt.inputTokens {
|
||||
t.Fatalf("unexpected response.usage.input_tokens: got %d want %d", got, tt.inputTokens)
|
||||
}
|
||||
if got := completedData.Get("response.usage.output_tokens").Int(); got != tt.outputTokens {
|
||||
t.Fatalf("unexpected response.usage.output_tokens: got %d want %d", got, tt.outputTokens)
|
||||
}
|
||||
if got := completedData.Get("response.usage.total_tokens").Int(); got != tt.totalTokens {
|
||||
t.Fatalf("unexpected response.usage.total_tokens: got %d want %d", got, tt.totalTokens)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCallsRemainSeparate(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
@@ -31,6 +145,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCalls
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.{yml,yaml}\"}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
`data: [DONE]`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
@@ -131,6 +246,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultiChoiceToolCa
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice0","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
`data: [DONE]`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
@@ -213,6 +329,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MixedMessageAndTo
|
||||
in := []string{
|
||||
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":"hello","reasoning_content":null,"tool_calls":null},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"stop"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
`data: [DONE]`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
@@ -261,6 +378,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_FunctionCallDoneA
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
`data: [DONE]`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
@@ -6,6 +6,7 @@ package handlers
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -493,6 +494,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
opts.Metadata = reqMeta
|
||||
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
err = enrichAuthSelectionError(err, providers, normalizedModel)
|
||||
status := http.StatusInternalServerError
|
||||
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
|
||||
if code := se.StatusCode(); code > 0 {
|
||||
@@ -539,6 +541,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
opts.Metadata = reqMeta
|
||||
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
err = enrichAuthSelectionError(err, providers, normalizedModel)
|
||||
status := http.StatusInternalServerError
|
||||
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
|
||||
if code := se.StatusCode(); code > 0 {
|
||||
@@ -589,6 +592,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
opts.Metadata = reqMeta
|
||||
streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
err = enrichAuthSelectionError(err, providers, normalizedModel)
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
status := http.StatusInternalServerError
|
||||
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
|
||||
@@ -698,7 +702,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
chunks = retryResult.Chunks
|
||||
continue outer
|
||||
}
|
||||
streamErr = retryErr
|
||||
streamErr = enrichAuthSelectionError(retryErr, providers, normalizedModel)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -841,6 +845,54 @@ func replaceHeader(dst http.Header, src http.Header) {
|
||||
}
|
||||
}
|
||||
|
||||
func enrichAuthSelectionError(err error, providers []string, model string) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var authErr *coreauth.Error
|
||||
if !errors.As(err, &authErr) || authErr == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(authErr.Code)
|
||||
if code != "auth_not_found" && code != "auth_unavailable" {
|
||||
return err
|
||||
}
|
||||
|
||||
providerText := strings.Join(providers, ",")
|
||||
if providerText == "" {
|
||||
providerText = "unknown"
|
||||
}
|
||||
modelText := strings.TrimSpace(model)
|
||||
if modelText == "" {
|
||||
modelText = "unknown"
|
||||
}
|
||||
|
||||
baseMessage := strings.TrimSpace(authErr.Message)
|
||||
if baseMessage == "" {
|
||||
baseMessage = "no auth available"
|
||||
}
|
||||
detail := fmt.Sprintf("%s (providers=%s, model=%s)", baseMessage, providerText, modelText)
|
||||
|
||||
// Clarify the most common alias confusion between Anthropic route names and internal provider keys.
|
||||
if strings.Contains(","+providerText+",", ",claude,") {
|
||||
detail += "; check Claude auth/key session and cooldown state via /v0/management/auth-files"
|
||||
}
|
||||
|
||||
status := authErr.HTTPStatus
|
||||
if status <= 0 {
|
||||
status = http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
return &coreauth.Error{
|
||||
Code: authErr.Code,
|
||||
Message: detail,
|
||||
Retryable: authErr.Retryable,
|
||||
HTTPStatus: status,
|
||||
}
|
||||
}
|
||||
|
||||
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
|
||||
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
|
||||
status := http.StatusInternalServerError
|
||||
|
||||
@@ -5,10 +5,12 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
@@ -66,3 +68,46 @@ func TestWriteErrorResponse_AddonHeadersEnabled(t *testing.T) {
|
||||
t.Fatalf("X-Request-Id = %#v, want %#v", got, []string{"new-1", "new-2"})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnrichAuthSelectionError_DefaultsTo503WithContext(t *testing.T) {
|
||||
in := &coreauth.Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
out := enrichAuthSelectionError(in, []string{"claude"}, "claude-sonnet-4-6")
|
||||
|
||||
var got *coreauth.Error
|
||||
if !errors.As(out, &got) || got == nil {
|
||||
t.Fatalf("expected coreauth.Error, got %T", out)
|
||||
}
|
||||
if got.StatusCode() != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status = %d, want %d", got.StatusCode(), http.StatusServiceUnavailable)
|
||||
}
|
||||
if !strings.Contains(got.Message, "providers=claude") {
|
||||
t.Fatalf("message missing provider context: %q", got.Message)
|
||||
}
|
||||
if !strings.Contains(got.Message, "model=claude-sonnet-4-6") {
|
||||
t.Fatalf("message missing model context: %q", got.Message)
|
||||
}
|
||||
if !strings.Contains(got.Message, "/v0/management/auth-files") {
|
||||
t.Fatalf("message missing management hint: %q", got.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnrichAuthSelectionError_PreservesExplicitStatus(t *testing.T) {
|
||||
in := &coreauth.Error{Code: "auth_unavailable", Message: "no auth available", HTTPStatus: http.StatusTooManyRequests}
|
||||
out := enrichAuthSelectionError(in, []string{"gemini"}, "gemini-2.5-pro")
|
||||
|
||||
var got *coreauth.Error
|
||||
if !errors.As(out, &got) || got == nil {
|
||||
t.Fatalf("expected coreauth.Error, got %T", out)
|
||||
}
|
||||
if got.StatusCode() != http.StatusTooManyRequests {
|
||||
t.Fatalf("status = %d, want %d", got.StatusCode(), http.StatusTooManyRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnrichAuthSelectionError_IgnoresOtherErrors(t *testing.T) {
|
||||
in := errors.New("boom")
|
||||
out := enrichAuthSelectionError(in, []string{"claude"}, "claude-sonnet-4-6")
|
||||
if out != in {
|
||||
t.Fatalf("expected original error to be returned unchanged")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,13 @@ package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
@@ -463,6 +466,76 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_EnrichesBootstrapRetryAuthUnavailableError(t *testing.T) {
|
||||
executor := &failOnceStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||
Streaming: sdkconfig.StreamingConfig{
|
||||
BootstrapRetries: 1,
|
||||
},
|
||||
}, manager)
|
||||
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty payload, got %q", string(got))
|
||||
}
|
||||
|
||||
var gotErr *interfaces.ErrorMessage
|
||||
for msg := range errChan {
|
||||
if msg != nil {
|
||||
gotErr = msg
|
||||
}
|
||||
}
|
||||
if gotErr == nil {
|
||||
t.Fatalf("expected terminal error")
|
||||
}
|
||||
if gotErr.StatusCode != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status = %d, want %d", gotErr.StatusCode, http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
var authErr *coreauth.Error
|
||||
if !errors.As(gotErr.Error, &authErr) || authErr == nil {
|
||||
t.Fatalf("expected coreauth.Error, got %T", gotErr.Error)
|
||||
}
|
||||
if authErr.Code != "auth_unavailable" {
|
||||
t.Fatalf("code = %q, want %q", authErr.Code, "auth_unavailable")
|
||||
}
|
||||
if !strings.Contains(authErr.Message, "providers=codex") {
|
||||
t.Fatalf("message missing provider context: %q", authErr.Message)
|
||||
}
|
||||
if !strings.Contains(authErr.Message, "model=test-model") {
|
||||
t.Fatalf("message missing model context: %q", authErr.Message)
|
||||
}
|
||||
|
||||
if executor.Calls() != 1 {
|
||||
t.Fatalf("expected exactly one upstream call before retry path selection failure, got %d", executor.Calls())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) {
|
||||
executor := &authAwareStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
|
||||
@@ -234,6 +234,84 @@ func (m *Manager) RefreshSchedulerEntry(authID string) {
|
||||
m.scheduler.upsertAuth(snapshot)
|
||||
}
|
||||
|
||||
// ReconcileRegistryModelStates aligns per-model runtime state with the current
|
||||
// registry snapshot for one auth.
|
||||
//
|
||||
// Supported models are reset to a clean state because re-registration already
|
||||
// cleared the registry-side cooldown/suspension snapshot. ModelStates for
|
||||
// models that are no longer present in the registry are pruned entirely so
|
||||
// renamed/removed models cannot keep auth-level status stale.
|
||||
func (m *Manager) ReconcileRegistryModelStates(ctx context.Context, authID string) {
|
||||
if m == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
supportedModels := registry.GetGlobalRegistry().GetModelsForClient(authID)
|
||||
supported := make(map[string]struct{}, len(supportedModels))
|
||||
for _, model := range supportedModels {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
modelKey := canonicalModelKey(model.ID)
|
||||
if modelKey == "" {
|
||||
continue
|
||||
}
|
||||
supported[modelKey] = struct{}{}
|
||||
}
|
||||
|
||||
var snapshot *Auth
|
||||
now := time.Now()
|
||||
|
||||
m.mu.Lock()
|
||||
auth, ok := m.auths[authID]
|
||||
if ok && auth != nil && len(auth.ModelStates) > 0 {
|
||||
changed := false
|
||||
for modelKey, state := range auth.ModelStates {
|
||||
baseModel := canonicalModelKey(modelKey)
|
||||
if baseModel == "" {
|
||||
baseModel = strings.TrimSpace(modelKey)
|
||||
}
|
||||
if _, supportedModel := supported[baseModel]; !supportedModel {
|
||||
// Drop state for models that disappeared from the current registry
|
||||
// snapshot. Keeping them around leaks stale errors into auth-level
|
||||
// status, management output, and websocket fallback checks.
|
||||
delete(auth.ModelStates, modelKey)
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
if state == nil {
|
||||
continue
|
||||
}
|
||||
if modelStateIsClean(state) {
|
||||
continue
|
||||
}
|
||||
resetModelState(state, now)
|
||||
changed = true
|
||||
}
|
||||
if len(auth.ModelStates) == 0 {
|
||||
auth.ModelStates = nil
|
||||
}
|
||||
if changed {
|
||||
updateAggregatedAvailability(auth, now)
|
||||
if !hasModelError(auth, now) {
|
||||
auth.LastError = nil
|
||||
auth.StatusMessage = ""
|
||||
auth.Status = StatusActive
|
||||
}
|
||||
auth.UpdatedAt = now
|
||||
if errPersist := m.persist(ctx, auth); errPersist != nil {
|
||||
logEntryWithRequestID(ctx).WithField("auth_id", auth.ID).Warnf("failed to persist auth changes during model state reconciliation: %v", errPersist)
|
||||
}
|
||||
snapshot = auth.Clone()
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if m.scheduler != nil && snapshot != nil {
|
||||
m.scheduler.upsertAuth(snapshot)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) SetSelector(selector Selector) {
|
||||
if m == nil {
|
||||
return
|
||||
@@ -1838,6 +1916,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
} else {
|
||||
if result.Model != "" {
|
||||
if !isRequestScopedNotFoundResultError(result.Error) {
|
||||
disableCooling := quotaCooldownDisabledForAuth(auth)
|
||||
state := ensureModelState(auth, result.Model)
|
||||
state.Unavailable = true
|
||||
state.Status = StatusError
|
||||
@@ -1858,31 +1937,45 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
} else {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "unauthorized"
|
||||
shouldSuspendModel = true
|
||||
if disableCooling {
|
||||
state.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "unauthorized"
|
||||
shouldSuspendModel = true
|
||||
}
|
||||
case 402, 403:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "payment_required"
|
||||
shouldSuspendModel = true
|
||||
if disableCooling {
|
||||
state.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "payment_required"
|
||||
shouldSuspendModel = true
|
||||
}
|
||||
case 404:
|
||||
next := now.Add(12 * time.Hour)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "not_found"
|
||||
shouldSuspendModel = true
|
||||
if disableCooling {
|
||||
state.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
next := now.Add(12 * time.Hour)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "not_found"
|
||||
shouldSuspendModel = true
|
||||
}
|
||||
case 429:
|
||||
var next time.Time
|
||||
backoffLevel := state.Quota.BackoffLevel
|
||||
if result.RetryAfter != nil {
|
||||
next = now.Add(*result.RetryAfter)
|
||||
} else {
|
||||
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
|
||||
if cooldown > 0 {
|
||||
next = now.Add(cooldown)
|
||||
if !disableCooling {
|
||||
if result.RetryAfter != nil {
|
||||
next = now.Add(*result.RetryAfter)
|
||||
} else {
|
||||
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, disableCooling)
|
||||
if cooldown > 0 {
|
||||
next = now.Add(cooldown)
|
||||
}
|
||||
backoffLevel = nextLevel
|
||||
}
|
||||
backoffLevel = nextLevel
|
||||
}
|
||||
state.NextRetryAfter = next
|
||||
state.Quota = QuotaState{
|
||||
@@ -1891,11 +1984,13 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
NextRecoverAt: next,
|
||||
BackoffLevel: backoffLevel,
|
||||
}
|
||||
suspendReason = "quota"
|
||||
shouldSuspendModel = true
|
||||
setModelQuota = true
|
||||
if !disableCooling {
|
||||
suspendReason = "quota"
|
||||
shouldSuspendModel = true
|
||||
setModelQuota = true
|
||||
}
|
||||
case 408, 500, 502, 503, 504:
|
||||
if quotaCooldownDisabledForAuth(auth) {
|
||||
if disableCooling {
|
||||
state.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
next := now.Add(1 * time.Minute)
|
||||
@@ -1966,8 +2061,28 @@ func resetModelState(state *ModelState, now time.Time) {
|
||||
state.UpdatedAt = now
|
||||
}
|
||||
|
||||
func modelStateIsClean(state *ModelState) bool {
|
||||
if state == nil {
|
||||
return true
|
||||
}
|
||||
if state.Status != StatusActive {
|
||||
return false
|
||||
}
|
||||
if state.Unavailable || state.StatusMessage != "" || !state.NextRetryAfter.IsZero() || state.LastError != nil {
|
||||
return false
|
||||
}
|
||||
if state.Quota.Exceeded || state.Quota.Reason != "" || !state.Quota.NextRecoverAt.IsZero() || state.Quota.BackoffLevel != 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func updateAggregatedAvailability(auth *Auth, now time.Time) {
|
||||
if auth == nil || len(auth.ModelStates) == 0 {
|
||||
if auth == nil {
|
||||
return
|
||||
}
|
||||
if len(auth.ModelStates) == 0 {
|
||||
clearAggregatedAvailability(auth)
|
||||
return
|
||||
}
|
||||
allUnavailable := true
|
||||
@@ -1975,10 +2090,12 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) {
|
||||
quotaExceeded := false
|
||||
quotaRecover := time.Time{}
|
||||
maxBackoffLevel := 0
|
||||
hasState := false
|
||||
for _, state := range auth.ModelStates {
|
||||
if state == nil {
|
||||
continue
|
||||
}
|
||||
hasState = true
|
||||
stateUnavailable := false
|
||||
if state.Status == StatusDisabled {
|
||||
stateUnavailable = true
|
||||
@@ -2008,6 +2125,10 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasState {
|
||||
clearAggregatedAvailability(auth)
|
||||
return
|
||||
}
|
||||
auth.Unavailable = allUnavailable
|
||||
if allUnavailable {
|
||||
auth.NextRetryAfter = earliestRetry
|
||||
@@ -2027,6 +2148,15 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) {
|
||||
}
|
||||
}
|
||||
|
||||
func clearAggregatedAvailability(auth *Auth) {
|
||||
if auth == nil {
|
||||
return
|
||||
}
|
||||
auth.Unavailable = false
|
||||
auth.NextRetryAfter = time.Time{}
|
||||
auth.Quota = QuotaState{}
|
||||
}
|
||||
|
||||
func hasModelError(auth *Auth, now time.Time) bool {
|
||||
if auth == nil || len(auth.ModelStates) == 0 {
|
||||
return false
|
||||
@@ -2211,6 +2341,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
|
||||
if isRequestScopedNotFoundResultError(resultErr) {
|
||||
return
|
||||
}
|
||||
disableCooling := quotaCooldownDisabledForAuth(auth)
|
||||
auth.Unavailable = true
|
||||
auth.Status = StatusError
|
||||
auth.UpdatedAt = now
|
||||
@@ -2224,32 +2355,46 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
|
||||
switch statusCode {
|
||||
case 401:
|
||||
auth.StatusMessage = "unauthorized"
|
||||
auth.NextRetryAfter = now.Add(30 * time.Minute)
|
||||
if disableCooling {
|
||||
auth.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
auth.NextRetryAfter = now.Add(30 * time.Minute)
|
||||
}
|
||||
case 402, 403:
|
||||
auth.StatusMessage = "payment_required"
|
||||
auth.NextRetryAfter = now.Add(30 * time.Minute)
|
||||
if disableCooling {
|
||||
auth.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
auth.NextRetryAfter = now.Add(30 * time.Minute)
|
||||
}
|
||||
case 404:
|
||||
auth.StatusMessage = "not_found"
|
||||
auth.NextRetryAfter = now.Add(12 * time.Hour)
|
||||
if disableCooling {
|
||||
auth.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
auth.NextRetryAfter = now.Add(12 * time.Hour)
|
||||
}
|
||||
case 429:
|
||||
auth.StatusMessage = "quota exhausted"
|
||||
auth.Quota.Exceeded = true
|
||||
auth.Quota.Reason = "quota"
|
||||
var next time.Time
|
||||
if retryAfter != nil {
|
||||
next = now.Add(*retryAfter)
|
||||
} else {
|
||||
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, quotaCooldownDisabledForAuth(auth))
|
||||
if cooldown > 0 {
|
||||
next = now.Add(cooldown)
|
||||
if !disableCooling {
|
||||
if retryAfter != nil {
|
||||
next = now.Add(*retryAfter)
|
||||
} else {
|
||||
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, disableCooling)
|
||||
if cooldown > 0 {
|
||||
next = now.Add(cooldown)
|
||||
}
|
||||
auth.Quota.BackoffLevel = nextLevel
|
||||
}
|
||||
auth.Quota.BackoffLevel = nextLevel
|
||||
}
|
||||
auth.Quota.NextRecoverAt = next
|
||||
auth.NextRetryAfter = next
|
||||
case 408, 500, 502, 503, 504:
|
||||
auth.StatusMessage = "transient upstream error"
|
||||
if quotaCooldownDisabledForAuth(auth) {
|
||||
if disableCooling {
|
||||
auth.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
auth.NextRetryAfter = now.Add(1 * time.Minute)
|
||||
|
||||
@@ -180,6 +180,34 @@ func (e *authFallbackExecutor) StreamCalls() []string {
|
||||
return out
|
||||
}
|
||||
|
||||
type retryAfterStatusError struct {
|
||||
status int
|
||||
message string
|
||||
retryAfter time.Duration
|
||||
}
|
||||
|
||||
func (e *retryAfterStatusError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return e.message
|
||||
}
|
||||
|
||||
func (e *retryAfterStatusError) StatusCode() int {
|
||||
if e == nil {
|
||||
return 0
|
||||
}
|
||||
return e.status
|
||||
}
|
||||
|
||||
func (e *retryAfterStatusError) RetryAfter() *time.Duration {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
d := e.retryAfter
|
||||
return &d
|
||||
}
|
||||
|
||||
func newCredentialRetryLimitTestManager(t *testing.T, maxRetryCredentials int) (*Manager, *credentialRetryLimitExecutor) {
|
||||
t.Helper()
|
||||
|
||||
@@ -450,6 +478,174 @@ func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_MarkResult_RespectsAuthDisableCoolingOverride_On403(t *testing.T) {
|
||||
prev := quotaCooldownDisabled.Load()
|
||||
quotaCooldownDisabled.Store(false)
|
||||
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
|
||||
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
auth := &Auth{
|
||||
ID: "auth-403",
|
||||
Provider: "claude",
|
||||
Metadata: map[string]any{
|
||||
"disable_cooling": true,
|
||||
},
|
||||
}
|
||||
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
|
||||
t.Fatalf("register auth: %v", errRegister)
|
||||
}
|
||||
|
||||
model := "test-model-403"
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(auth.ID, "claude", []*registry.ModelInfo{{ID: model}})
|
||||
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
|
||||
|
||||
m.MarkResult(context.Background(), Result{
|
||||
AuthID: auth.ID,
|
||||
Provider: "claude",
|
||||
Model: model,
|
||||
Success: false,
|
||||
Error: &Error{HTTPStatus: http.StatusForbidden, Message: "forbidden"},
|
||||
})
|
||||
|
||||
updated, ok := m.GetByID(auth.ID)
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
state := updated.ModelStates[model]
|
||||
if state == nil {
|
||||
t.Fatalf("expected model state to be present")
|
||||
}
|
||||
if !state.NextRetryAfter.IsZero() {
|
||||
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
|
||||
}
|
||||
|
||||
if count := reg.GetModelCount(model); count <= 0 {
|
||||
t.Fatalf("expected model count > 0 when disable_cooling=true, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Execute_DisableCooling_DoesNotBlackoutAfter403(t *testing.T) {
|
||||
prev := quotaCooldownDisabled.Load()
|
||||
quotaCooldownDisabled.Store(false)
|
||||
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
|
||||
|
||||
m := NewManager(nil, nil, nil)
|
||||
executor := &authFallbackExecutor{
|
||||
id: "claude",
|
||||
executeErrors: map[string]error{
|
||||
"auth-403-exec": &Error{
|
||||
HTTPStatus: http.StatusForbidden,
|
||||
Message: "forbidden",
|
||||
},
|
||||
},
|
||||
}
|
||||
m.RegisterExecutor(executor)
|
||||
|
||||
auth := &Auth{
|
||||
ID: "auth-403-exec",
|
||||
Provider: "claude",
|
||||
Metadata: map[string]any{
|
||||
"disable_cooling": true,
|
||||
},
|
||||
}
|
||||
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
|
||||
t.Fatalf("register auth: %v", errRegister)
|
||||
}
|
||||
|
||||
model := "test-model-403-exec"
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(auth.ID, "claude", []*registry.ModelInfo{{ID: model}})
|
||||
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
|
||||
|
||||
req := cliproxyexecutor.Request{Model: model}
|
||||
_, errExecute1 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{})
|
||||
if errExecute1 == nil {
|
||||
t.Fatal("expected first execute error")
|
||||
}
|
||||
if statusCodeFromError(errExecute1) != http.StatusForbidden {
|
||||
t.Fatalf("first execute status = %d, want %d", statusCodeFromError(errExecute1), http.StatusForbidden)
|
||||
}
|
||||
|
||||
_, errExecute2 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{})
|
||||
if errExecute2 == nil {
|
||||
t.Fatal("expected second execute error")
|
||||
}
|
||||
if statusCodeFromError(errExecute2) != http.StatusForbidden {
|
||||
t.Fatalf("second execute status = %d, want %d", statusCodeFromError(errExecute2), http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Execute_DisableCooling_DoesNotBlackoutAfter429RetryAfter(t *testing.T) {
|
||||
prev := quotaCooldownDisabled.Load()
|
||||
quotaCooldownDisabled.Store(false)
|
||||
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
|
||||
|
||||
m := NewManager(nil, nil, nil)
|
||||
executor := &authFallbackExecutor{
|
||||
id: "claude",
|
||||
executeErrors: map[string]error{
|
||||
"auth-429-exec": &retryAfterStatusError{
|
||||
status: http.StatusTooManyRequests,
|
||||
message: "quota exhausted",
|
||||
retryAfter: 2 * time.Minute,
|
||||
},
|
||||
},
|
||||
}
|
||||
m.RegisterExecutor(executor)
|
||||
|
||||
auth := &Auth{
|
||||
ID: "auth-429-exec",
|
||||
Provider: "claude",
|
||||
Metadata: map[string]any{
|
||||
"disable_cooling": true,
|
||||
},
|
||||
}
|
||||
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
|
||||
t.Fatalf("register auth: %v", errRegister)
|
||||
}
|
||||
|
||||
model := "test-model-429-exec"
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(auth.ID, "claude", []*registry.ModelInfo{{ID: model}})
|
||||
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
|
||||
|
||||
req := cliproxyexecutor.Request{Model: model}
|
||||
_, errExecute1 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{})
|
||||
if errExecute1 == nil {
|
||||
t.Fatal("expected first execute error")
|
||||
}
|
||||
if statusCodeFromError(errExecute1) != http.StatusTooManyRequests {
|
||||
t.Fatalf("first execute status = %d, want %d", statusCodeFromError(errExecute1), http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
_, errExecute2 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{})
|
||||
if errExecute2 == nil {
|
||||
t.Fatal("expected second execute error")
|
||||
}
|
||||
if statusCodeFromError(errExecute2) != http.StatusTooManyRequests {
|
||||
t.Fatalf("second execute status = %d, want %d", statusCodeFromError(errExecute2), http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
calls := executor.ExecuteCalls()
|
||||
if len(calls) != 2 {
|
||||
t.Fatalf("execute calls = %d, want 2", len(calls))
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID(auth.ID)
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
state := updated.ModelStates[model]
|
||||
if state == nil {
|
||||
t.Fatalf("expected model state to be present")
|
||||
}
|
||||
if !state.NextRetryAfter.IsZero() {
|
||||
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_MarkResult_RequestScopedNotFoundDoesNotCooldownAuth(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
|
||||
@@ -324,6 +324,7 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
|
||||
// This operation may block on network calls, but the auth configuration
|
||||
// is already effective at this point.
|
||||
s.registerModelsForAuth(auth)
|
||||
s.coreManager.ReconcileRegistryModelStates(ctx, auth.ID)
|
||||
|
||||
// Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt
|
||||
// from the now-populated global model registry. Without this, newly added auths
|
||||
@@ -1085,6 +1086,7 @@ func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool {
|
||||
s.ensureExecutorsForAuth(current)
|
||||
}
|
||||
s.registerModelsForAuth(current)
|
||||
s.coreManager.ReconcileRegistryModelStates(context.Background(), current.ID)
|
||||
|
||||
latest, ok := s.latestAuthForModelRegistration(current.ID)
|
||||
if !ok || latest.Disabled {
|
||||
@@ -1098,6 +1100,7 @@ func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool {
|
||||
// no auth fields changed, but keeps the refresh path simple and correct.
|
||||
s.ensureExecutorsForAuth(latest)
|
||||
s.registerModelsForAuth(latest)
|
||||
s.coreManager.ReconcileRegistryModelStates(context.Background(), latest.ID)
|
||||
s.coreManager.RefreshSchedulerEntry(current.ID)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func Parse(raw string) (Setting, error) {
|
||||
}
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "socks5", "http", "https":
|
||||
case "socks5", "socks5h", "http", "https":
|
||||
setting.Mode = ModeProxy
|
||||
setting.URL = parsedURL
|
||||
return setting, nil
|
||||
@@ -95,7 +95,7 @@ func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) {
|
||||
case ModeDirect:
|
||||
return NewDirectTransport(), setting.Mode, nil
|
||||
case ModeProxy:
|
||||
if setting.URL.Scheme == "socks5" {
|
||||
if setting.URL.Scheme == "socks5" || setting.URL.Scheme == "socks5h" {
|
||||
var proxyAuth *proxy.Auth
|
||||
if setting.URL.User != nil {
|
||||
username := setting.URL.User.Username()
|
||||
|
||||
@@ -30,6 +30,7 @@ func TestParse(t *testing.T) {
|
||||
{name: "http", input: "http://proxy.example.com:8080", want: ModeProxy},
|
||||
{name: "https", input: "https://proxy.example.com:8443", want: ModeProxy},
|
||||
{name: "socks5", input: "socks5://proxy.example.com:1080", want: ModeProxy},
|
||||
{name: "socks5h", input: "socks5h://proxy.example.com:1080", want: ModeProxy},
|
||||
{name: "invalid", input: "bad-value", want: ModeInvalid, wantErr: true},
|
||||
}
|
||||
|
||||
@@ -137,3 +138,24 @@ func TestBuildHTTPTransportSOCKS5ProxyInheritsDefaultTransportSettings(t *testin
|
||||
t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildHTTPTransportSOCKS5HProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transport, mode, errBuild := BuildHTTPTransport("socks5h://proxy.example.com:1080")
|
||||
if errBuild != nil {
|
||||
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
|
||||
}
|
||||
if mode != ModeProxy {
|
||||
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
|
||||
}
|
||||
if transport == nil {
|
||||
t.Fatal("expected transport, got nil")
|
||||
}
|
||||
if transport.Proxy != nil {
|
||||
t.Fatal("expected SOCKS5H transport to bypass http proxy function")
|
||||
}
|
||||
if transport.DialContext == nil {
|
||||
t.Fatal("expected SOCKS5H transport to have custom DialContext")
|
||||
}
|
||||
}
|
||||
|
||||
106
test/claude_code_compatibility_sentinel_test.go
Normal file
106
test/claude_code_compatibility_sentinel_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type jsonObject = map[string]any
|
||||
|
||||
func loadClaudeCodeSentinelFixture(t *testing.T, name string) jsonObject {
|
||||
t.Helper()
|
||||
path := filepath.Join("testdata", "claude_code_sentinels", name)
|
||||
data := mustReadFile(t, path)
|
||||
var payload jsonObject
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
t.Fatalf("unmarshal %s: %v", name, err)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func mustReadFile(t *testing.T, path string) []byte {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read %s: %v", path, err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func requireStringField(t *testing.T, obj jsonObject, key string) string {
|
||||
t.Helper()
|
||||
value, ok := obj[key].(string)
|
||||
if !ok || value == "" {
|
||||
t.Fatalf("field %q missing or empty: %#v", key, obj[key])
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func TestClaudeCodeSentinel_ToolProgressShape(t *testing.T) {
|
||||
payload := loadClaudeCodeSentinelFixture(t, "tool_progress.json")
|
||||
if got := requireStringField(t, payload, "type"); got != "tool_progress" {
|
||||
t.Fatalf("type = %q, want tool_progress", got)
|
||||
}
|
||||
requireStringField(t, payload, "tool_use_id")
|
||||
requireStringField(t, payload, "tool_name")
|
||||
requireStringField(t, payload, "session_id")
|
||||
if _, ok := payload["elapsed_time_seconds"].(float64); !ok {
|
||||
t.Fatalf("elapsed_time_seconds missing or non-number: %#v", payload["elapsed_time_seconds"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCodeSentinel_SessionStateShape(t *testing.T) {
|
||||
payload := loadClaudeCodeSentinelFixture(t, "session_state_changed.json")
|
||||
if got := requireStringField(t, payload, "type"); got != "system" {
|
||||
t.Fatalf("type = %q, want system", got)
|
||||
}
|
||||
if got := requireStringField(t, payload, "subtype"); got != "session_state_changed" {
|
||||
t.Fatalf("subtype = %q, want session_state_changed", got)
|
||||
}
|
||||
state := requireStringField(t, payload, "state")
|
||||
switch state {
|
||||
case "idle", "running", "requires_action":
|
||||
default:
|
||||
t.Fatalf("unexpected session state %q", state)
|
||||
}
|
||||
requireStringField(t, payload, "session_id")
|
||||
}
|
||||
|
||||
func TestClaudeCodeSentinel_ToolUseSummaryShape(t *testing.T) {
|
||||
payload := loadClaudeCodeSentinelFixture(t, "tool_use_summary.json")
|
||||
if got := requireStringField(t, payload, "type"); got != "tool_use_summary" {
|
||||
t.Fatalf("type = %q, want tool_use_summary", got)
|
||||
}
|
||||
requireStringField(t, payload, "summary")
|
||||
rawIDs, ok := payload["preceding_tool_use_ids"].([]any)
|
||||
if !ok || len(rawIDs) == 0 {
|
||||
t.Fatalf("preceding_tool_use_ids missing or empty: %#v", payload["preceding_tool_use_ids"])
|
||||
}
|
||||
for i, raw := range rawIDs {
|
||||
if id, ok := raw.(string); !ok || id == "" {
|
||||
t.Fatalf("preceding_tool_use_ids[%d] invalid: %#v", i, raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCodeSentinel_ControlRequestCanUseToolShape(t *testing.T) {
|
||||
payload := loadClaudeCodeSentinelFixture(t, "control_request_can_use_tool.json")
|
||||
if got := requireStringField(t, payload, "type"); got != "control_request" {
|
||||
t.Fatalf("type = %q, want control_request", got)
|
||||
}
|
||||
requireStringField(t, payload, "request_id")
|
||||
request, ok := payload["request"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("request missing or invalid: %#v", payload["request"])
|
||||
}
|
||||
if got := requireStringField(t, request, "subtype"); got != "can_use_tool" {
|
||||
t.Fatalf("request.subtype = %q, want can_use_tool", got)
|
||||
}
|
||||
requireStringField(t, request, "tool_name")
|
||||
requireStringField(t, request, "tool_use_id")
|
||||
if input, ok := request["input"].(map[string]any); !ok || len(input) == 0 {
|
||||
t.Fatalf("request.input missing or empty: %#v", request["input"])
|
||||
}
|
||||
}
|
||||
11
test/testdata/claude_code_sentinels/control_request_can_use_tool.json
vendored
Normal file
11
test/testdata/claude_code_sentinels/control_request_can_use_tool.json
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"type": "control_request",
|
||||
"request_id": "req_123",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "Bash",
|
||||
"input": {"command": "npm test"},
|
||||
"tool_use_id": "toolu_123",
|
||||
"description": "Running npm test"
|
||||
}
|
||||
}
|
||||
7
test/testdata/claude_code_sentinels/session_state_changed.json
vendored
Normal file
7
test/testdata/claude_code_sentinels/session_state_changed.json
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"type": "system",
|
||||
"subtype": "session_state_changed",
|
||||
"state": "requires_action",
|
||||
"uuid": "22222222-2222-4222-8222-222222222222",
|
||||
"session_id": "sess_123"
|
||||
}
|
||||
10
test/testdata/claude_code_sentinels/tool_progress.json
vendored
Normal file
10
test/testdata/claude_code_sentinels/tool_progress.json
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"type": "tool_progress",
|
||||
"tool_use_id": "toolu_123",
|
||||
"tool_name": "Bash",
|
||||
"parent_tool_use_id": null,
|
||||
"elapsed_time_seconds": 2.5,
|
||||
"task_id": "task_123",
|
||||
"uuid": "11111111-1111-4111-8111-111111111111",
|
||||
"session_id": "sess_123"
|
||||
}
|
||||
7
test/testdata/claude_code_sentinels/tool_use_summary.json
vendored
Normal file
7
test/testdata/claude_code_sentinels/tool_use_summary.json
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"type": "tool_use_summary",
|
||||
"summary": "Searched in auth/",
|
||||
"preceding_tool_use_ids": ["toolu_1", "toolu_2"],
|
||||
"uuid": "33333333-3333-4333-8333-333333333333",
|
||||
"session_id": "sess_123"
|
||||
}
|
||||
Reference in New Issue
Block a user