mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-04 11:41:20 +00:00
Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c2ad4cda2 | ||
|
|
db63f9b5d6 | ||
|
|
25f6c4a250 | ||
|
|
b24ae74216 | ||
|
|
59ad8f40dc | ||
|
|
ff03dc6a2c | ||
|
|
dc7187ca5b | ||
|
|
b1dcff778c | ||
|
|
cef2aeeb08 | ||
|
|
bcd1e8cc34 | ||
|
|
198b3f4a40 | ||
|
|
9fee7f488e | ||
|
|
1b46d39b8b | ||
|
|
c1241a98e2 | ||
|
|
8d8f5970ee | ||
|
|
f90120f846 | ||
|
|
0b94d36c4a | ||
|
|
c8cee6a209 | ||
|
|
b5701f416b | ||
|
|
4b1a404fcb | ||
|
|
b93cce5412 | ||
|
|
c6cb24039d | ||
|
|
5382408489 | ||
|
|
67669196ed | ||
|
|
58fd9bf964 | ||
|
|
7b3dfc67bc | ||
|
|
cdd24052d3 | ||
|
|
733fd8edab | ||
|
|
af27f2b8bc | ||
|
|
2e1925d762 | ||
|
|
77254bd074 | ||
|
|
5b6342e6ac | ||
|
|
3960c93d51 | ||
|
|
339a81b650 | ||
|
|
560c020477 | ||
|
|
aec65e3be3 | ||
|
|
f44f0702f8 | ||
|
|
b76b79068f |
6
.github/workflows/docker-image.yml
vendored
6
.github/workflows/docker-image.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||
- name: Build and push (amd64)
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||
- name: Build and push (arm64)
|
||||
@@ -98,7 +98,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||
- name: Create and push multi-arch manifests
|
||||
|
||||
5
.github/workflows/release.yaml
vendored
5
.github/workflows/release.yaml
vendored
@@ -27,15 +27,14 @@ jobs:
|
||||
cache: true
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
VERSION=$(git describe --tags --always --dirty)
|
||||
echo "VERSION=${VERSION}" >> $GITHUB_ENV
|
||||
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||
- uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
distribution: goreleaser
|
||||
version: latest
|
||||
args: release --clean
|
||||
args: release --clean --skip=validate
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
VERSION: ${{ env.VERSION }}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
version: 2
|
||||
|
||||
builds:
|
||||
- id: "cli-proxy-api-plus"
|
||||
env:
|
||||
|
||||
@@ -244,11 +244,11 @@ nonstream-keepalive-interval: 0
|
||||
# - name: "kimi-k2.5"
|
||||
# alias: "claude-opus-4.66"
|
||||
|
||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
||||
# Vertex API keys (Vertex-compatible endpoints, base-url is optional)
|
||||
# vertex-api-key:
|
||||
# - api-key: "vk-123..." # x-goog-api-key header
|
||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
||||
# base-url: "https://example.com/api" # optional, e.g. https://zenmux.ai/api; falls back to Google Vertex when omitted
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||
# headers:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
@@ -56,3 +57,57 @@ func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
|
||||
t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
geminiAuth := &coreauth.Auth{
|
||||
ID: "gemini:apikey:123",
|
||||
Provider: "gemini",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "shared-key",
|
||||
},
|
||||
}
|
||||
compatAuth := &coreauth.Auth{
|
||||
ID: "openai-compatibility:bohe:456",
|
||||
Provider: "bohe",
|
||||
Label: "bohe",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "shared-key",
|
||||
"compat_name": "bohe",
|
||||
"provider_key": "bohe",
|
||||
},
|
||||
}
|
||||
|
||||
if _, errRegister := manager.Register(context.Background(), geminiAuth); errRegister != nil {
|
||||
t.Fatalf("register gemini auth: %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), compatAuth); errRegister != nil {
|
||||
t.Fatalf("register compat auth: %v", errRegister)
|
||||
}
|
||||
|
||||
geminiIndex := geminiAuth.EnsureIndex()
|
||||
compatIndex := compatAuth.EnsureIndex()
|
||||
if geminiIndex == compatIndex {
|
||||
t.Fatalf("shared api key produced duplicate auth_index %q", geminiIndex)
|
||||
}
|
||||
|
||||
h := &Handler{authManager: manager}
|
||||
|
||||
gotGemini := h.authByIndex(geminiIndex)
|
||||
if gotGemini == nil {
|
||||
t.Fatal("expected gemini auth by index")
|
||||
}
|
||||
if gotGemini.ID != geminiAuth.ID {
|
||||
t.Fatalf("authByIndex(gemini) returned %q, want %q", gotGemini.ID, geminiAuth.ID)
|
||||
}
|
||||
|
||||
gotCompat := h.authByIndex(compatIndex)
|
||||
if gotCompat == nil {
|
||||
t.Fatal("expected compat auth by index")
|
||||
}
|
||||
if gotCompat.ID != compatAuth.ID {
|
||||
t.Fatalf("authByIndex(compat) returned %q, want %q", gotCompat.ID, compatAuth.ID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,6 +341,21 @@ func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
|
||||
emailValue := gjson.GetBytes(data, "email").String()
|
||||
fileData["type"] = typeValue
|
||||
fileData["email"] = emailValue
|
||||
if pv := gjson.GetBytes(data, "priority"); pv.Exists() {
|
||||
switch pv.Type {
|
||||
case gjson.Number:
|
||||
fileData["priority"] = int(pv.Int())
|
||||
case gjson.String:
|
||||
if parsed, errAtoi := strconv.Atoi(strings.TrimSpace(pv.String())); errAtoi == nil {
|
||||
fileData["priority"] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
if nv := gjson.GetBytes(data, "note"); nv.Exists() && nv.Type == gjson.String {
|
||||
if trimmed := strings.TrimSpace(nv.String()); trimmed != "" {
|
||||
fileData["note"] = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
files = append(files, fileData)
|
||||
@@ -424,6 +439,37 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
||||
if claims := extractCodexIDTokenClaims(auth); claims != nil {
|
||||
entry["id_token"] = claims
|
||||
}
|
||||
// Expose priority from Attributes (set by synthesizer from JSON "priority" field).
|
||||
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
|
||||
if p := strings.TrimSpace(authAttribute(auth, "priority")); p != "" {
|
||||
if parsed, err := strconv.Atoi(p); err == nil {
|
||||
entry["priority"] = parsed
|
||||
}
|
||||
} else if auth.Metadata != nil {
|
||||
if rawPriority, ok := auth.Metadata["priority"]; ok {
|
||||
switch v := rawPriority.(type) {
|
||||
case float64:
|
||||
entry["priority"] = int(v)
|
||||
case int:
|
||||
entry["priority"] = v
|
||||
case string:
|
||||
if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
entry["priority"] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Expose note from Attributes (set by synthesizer from JSON "note" field).
|
||||
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
|
||||
if note := strings.TrimSpace(authAttribute(auth, "note")); note != "" {
|
||||
entry["note"] = note
|
||||
} else if auth.Metadata != nil {
|
||||
if rawNote, ok := auth.Metadata["note"].(string); ok {
|
||||
if trimmed := strings.TrimSpace(rawNote); trimmed != "" {
|
||||
entry["note"] = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
@@ -848,7 +894,7 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||
}
|
||||
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority, note) of an auth file.
|
||||
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
if h.authManager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||
@@ -860,6 +906,7 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
Prefix *string `json:"prefix"`
|
||||
ProxyURL *string `json:"proxy_url"`
|
||||
Priority *int `json:"priority"`
|
||||
Note *string `json:"note"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
@@ -902,14 +949,32 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
targetAuth.ProxyURL = *req.ProxyURL
|
||||
changed = true
|
||||
}
|
||||
if req.Priority != nil {
|
||||
if req.Priority != nil || req.Note != nil {
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if *req.Priority == 0 {
|
||||
delete(targetAuth.Metadata, "priority")
|
||||
} else {
|
||||
targetAuth.Metadata["priority"] = *req.Priority
|
||||
if targetAuth.Attributes == nil {
|
||||
targetAuth.Attributes = make(map[string]string)
|
||||
}
|
||||
|
||||
if req.Priority != nil {
|
||||
if *req.Priority == 0 {
|
||||
delete(targetAuth.Metadata, "priority")
|
||||
delete(targetAuth.Attributes, "priority")
|
||||
} else {
|
||||
targetAuth.Metadata["priority"] = *req.Priority
|
||||
targetAuth.Attributes["priority"] = strconv.Itoa(*req.Priority)
|
||||
}
|
||||
}
|
||||
if req.Note != nil {
|
||||
trimmedNote := strings.TrimSpace(*req.Note)
|
||||
if trimmedNote == "" {
|
||||
delete(targetAuth.Metadata, "note")
|
||||
delete(targetAuth.Attributes, "note")
|
||||
} else {
|
||||
targetAuth.Metadata["note"] = trimmedNote
|
||||
targetAuth.Attributes["note"] = trimmedNote
|
||||
}
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
@@ -2438,17 +2503,20 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
if label == "" {
|
||||
label = username
|
||||
}
|
||||
metadata, errMeta := copilotTokenMetadata(tokenStorage)
|
||||
if errMeta != nil {
|
||||
log.Errorf("Failed to build token metadata: %v", errMeta)
|
||||
SetOAuthSessionError(state, "Failed to build token metadata")
|
||||
return
|
||||
}
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "github-copilot",
|
||||
Label: label,
|
||||
FileName: fileName,
|
||||
Storage: tokenStorage,
|
||||
Metadata: map[string]any{
|
||||
"email": userInfo.Email,
|
||||
"username": username,
|
||||
"name": userInfo.Name,
|
||||
},
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
@@ -2473,6 +2541,21 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func copilotTokenMetadata(storage *copilot.CopilotTokenStorage) (map[string]any, error) {
|
||||
if storage == nil {
|
||||
return nil, fmt.Errorf("token storage is nil")
|
||||
}
|
||||
payload, errMarshal := json.Marshal(storage)
|
||||
if errMarshal != nil {
|
||||
return nil, fmt.Errorf("marshal token storage: %w", errMarshal)
|
||||
}
|
||||
metadata := make(map[string]any)
|
||||
if errUnmarshal := json.Unmarshal(payload, &metadata); errUnmarshal != nil {
|
||||
return nil, fmt.Errorf("unmarshal token storage: %w", errUnmarshal)
|
||||
}
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@@ -509,8 +509,12 @@ func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
|
||||
}
|
||||
for i := range arr {
|
||||
normalizeVertexCompatKey(&arr[i])
|
||||
if arr[i].APIKey == "" {
|
||||
c.JSON(400, gin.H{"error": fmt.Sprintf("vertex-api-key[%d].api-key is required", i)})
|
||||
return
|
||||
}
|
||||
}
|
||||
h.cfg.VertexCompatAPIKey = arr
|
||||
h.cfg.VertexCompatAPIKey = append([]config.VertexCompatKey(nil), arr...)
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
@@ -685,7 +685,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||
cfg.SanitizeGeminiKeys()
|
||||
|
||||
// Sanitize Vertex-compatible API keys: drop entries without base-url
|
||||
// Sanitize Vertex-compatible API keys.
|
||||
cfg.SanitizeVertexCompatKeys()
|
||||
|
||||
// Sanitize Codex keys: drop entries without base-url
|
||||
|
||||
@@ -20,9 +20,9 @@ type VertexCompatKey struct {
|
||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Vertex-compatible API endpoint.
|
||||
// BaseURL optionally overrides the Vertex-compatible API endpoint.
|
||||
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
||||
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
|
||||
// When empty, requests fall back to the default Vertex API base URL.
|
||||
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
||||
|
||||
// ProxyURL optionally overrides the global proxy for this API key.
|
||||
@@ -71,10 +71,6 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
if entry.BaseURL == "" {
|
||||
// BaseURL is required for Vertex API key entries
|
||||
continue
|
||||
}
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||
|
||||
@@ -205,6 +205,10 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Request usage data in the final streaming chunk so that token statistics
|
||||
// are captured even when the upstream is an OpenAI-compatible provider.
|
||||
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -68,6 +69,10 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
contentsJSON := "[]"
|
||||
hasContents := false
|
||||
|
||||
// tool_use_id → tool_name lookup, populated incrementally during the main loop.
|
||||
// Claude's tool_result references tool_use by ID; Gemini requires functionResponse.name.
|
||||
toolNameByID := make(map[string]string)
|
||||
|
||||
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messageResults := messagesResult.Array()
|
||||
@@ -170,6 +175,10 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
argsResult := contentResult.Get("input")
|
||||
functionID := contentResult.Get("id").String()
|
||||
|
||||
if functionID != "" && functionName != "" {
|
||||
toolNameByID[functionID] = functionName
|
||||
}
|
||||
|
||||
// Handle both object and string input formats
|
||||
var argsRaw string
|
||||
if argsResult.IsObject() {
|
||||
@@ -206,10 +215,19 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||
toolCallID := contentResult.Get("tool_use_id").String()
|
||||
if toolCallID != "" {
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-2], "-")
|
||||
funcName, ok := toolNameByID[toolCallID]
|
||||
if !ok {
|
||||
// Fallback: derive a semantic name from the ID by stripping
|
||||
// the last two dash-separated segments (e.g. "get_weather-call-123" → "get_weather").
|
||||
// Only use the raw ID as a last resort when the heuristic produces an empty string.
|
||||
parts := strings.Split(toolCallID, "-")
|
||||
if len(parts) > 2 {
|
||||
funcName = strings.Join(parts[:len(parts)-2], "-")
|
||||
}
|
||||
if funcName == "" {
|
||||
funcName = toolCallID
|
||||
}
|
||||
log.Warnf("antigravity claude request: tool_result references unknown tool_use_id=%s, derived function name=%s", toolCallID, funcName)
|
||||
}
|
||||
functionResponseResult := contentResult.Get("content")
|
||||
|
||||
|
||||
@@ -365,6 +365,17 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "get_weather-call-123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "Paris"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -382,13 +393,177 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
||||
outputStr := string(output)
|
||||
|
||||
// Check function response conversion
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Error("functionResponse should exist")
|
||||
}
|
||||
if funcResp.Get("id").String() != "get_weather-call-123" {
|
||||
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
|
||||
}
|
||||
if funcResp.Get("name").String() != "get_weather" {
|
||||
t.Errorf("Expected function name 'get_weather', got '%s'", funcResp.Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_TouluFormat(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||
"name": "Glob",
|
||||
"input": {"pattern": "**/*.py"}
|
||||
},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||
"content": "file1.py\nfile2.py"
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
|
||||
"content": "total 10"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp0 := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||
if !funcResp0.Exists() {
|
||||
t.Fatal("first functionResponse should exist")
|
||||
}
|
||||
if got := funcResp0.Get("name").String(); got != "Glob" {
|
||||
t.Errorf("Expected name 'Glob' for toolu_ format, got '%s'", got)
|
||||
}
|
||||
|
||||
funcResp1 := gjson.Get(outputStr, "request.contents.1.parts.1.functionResponse")
|
||||
if !funcResp1.Exists() {
|
||||
t.Fatal("second functionResponse should exist")
|
||||
}
|
||||
if got := funcResp1.Get("name").String(); got != "Bash" {
|
||||
t.Errorf("Expected name 'Bash' for toolu_ format, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_CustomFormat(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "Read-1773420180464065165-1327",
|
||||
"name": "Read",
|
||||
"input": {"file_path": "/tmp/test.py"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "Read-1773420180464065165-1327",
|
||||
"content": "file content here"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
if got := funcResp.Get("name").String(); got != "Read" {
|
||||
t.Errorf("Expected name 'Read', got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_Heuristic(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "get_weather-call-123",
|
||||
"content": "22C sunny"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
if got := funcResp.Get("name").String(); got != "get_weather" {
|
||||
t.Errorf("Expected heuristic-derived name 'get_weather', got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_RawID(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||
"content": "result data"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
got := funcResp.Get("name").String()
|
||||
if got == "" {
|
||||
t.Error("functionResponse.name must not be empty")
|
||||
}
|
||||
if got != "toolu_tool-48fca351f12844eabf49dad8b63886d2" {
|
||||
t.Errorf("Expected raw ID as last-resort name, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {
|
||||
|
||||
@@ -197,7 +197,12 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
}
|
||||
}
|
||||
|
||||
out, _ = sjson.SetRaw(out, "input.-1", msg)
|
||||
// Don't emit empty assistant messages when only tool_calls
|
||||
// are present — Responses API needs function_call items
|
||||
// directly, otherwise call_id matching fails (#2132).
|
||||
if role != "assistant" || len(gjson.Get(msg, "content").Array()) > 0 {
|
||||
out, _ = sjson.SetRaw(out, "input.-1", msg)
|
||||
}
|
||||
|
||||
// Handle tool calls for assistant messages as separate top-level objects
|
||||
if role == "assistant" {
|
||||
|
||||
@@ -0,0 +1,635 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Basic tool-call: system + user + assistant(tool_calls, no content) + tool result.
|
||||
// Expects developer msg + user msg + function_call + function_call_output.
|
||||
// No empty assistant message should appear between user and function_call.
|
||||
func TestToolCallSimple(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the weather in Paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"Paris\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "sunny, 22C"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a city",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
if len(items) != 4 {
|
||||
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
// system -> developer
|
||||
if items[0].Get("type").String() != "message" {
|
||||
t.Errorf("item 0: expected type 'message', got '%s'", items[0].Get("type").String())
|
||||
}
|
||||
if items[0].Get("role").String() != "developer" {
|
||||
t.Errorf("item 0: expected role 'developer', got '%s'", items[0].Get("role").String())
|
||||
}
|
||||
|
||||
// user
|
||||
if items[1].Get("type").String() != "message" {
|
||||
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
|
||||
}
|
||||
if items[1].Get("role").String() != "user" {
|
||||
t.Errorf("item 1: expected role 'user', got '%s'", items[1].Get("role").String())
|
||||
}
|
||||
|
||||
// function_call, not an empty assistant msg
|
||||
if items[2].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
|
||||
}
|
||||
if items[2].Get("call_id").String() != "call_1" {
|
||||
t.Errorf("item 2: expected call_id 'call_1', got '%s'", items[2].Get("call_id").String())
|
||||
}
|
||||
if items[2].Get("name").String() != "get_weather" {
|
||||
t.Errorf("item 2: expected name 'get_weather', got '%s'", items[2].Get("name").String())
|
||||
}
|
||||
if items[2].Get("arguments").String() != `{"city":"Paris"}` {
|
||||
t.Errorf("item 2: unexpected arguments: %s", items[2].Get("arguments").String())
|
||||
}
|
||||
|
||||
// function_call_output
|
||||
if items[3].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
|
||||
}
|
||||
if items[3].Get("call_id").String() != "call_1" {
|
||||
t.Errorf("item 3: expected call_id 'call_1', got '%s'", items[3].Get("call_id").String())
|
||||
}
|
||||
if items[3].Get("output").String() != "sunny, 22C" {
|
||||
t.Errorf("item 3: expected output 'sunny, 22C', got '%s'", items[3].Get("output").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Assistant has both text content and tool_calls — the message should
|
||||
// be emitted (non-empty content), followed by function_call items.
|
||||
func TestToolCallWithContent(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me check the weather for you.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc",
|
||||
"content": "rainy, 15C"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
// user + assistant(with content) + function_call + function_call_output
|
||||
if len(items) != 4 {
|
||||
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
if items[0].Get("role").String() != "user" {
|
||||
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
|
||||
}
|
||||
|
||||
// assistant with content — should be kept
|
||||
if items[1].Get("type").String() != "message" {
|
||||
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
|
||||
}
|
||||
if items[1].Get("role").String() != "assistant" {
|
||||
t.Errorf("item 1: expected role 'assistant', got '%s'", items[1].Get("role").String())
|
||||
}
|
||||
contentParts := items[1].Get("content").Array()
|
||||
if len(contentParts) == 0 {
|
||||
t.Errorf("item 1: assistant message should have content parts")
|
||||
}
|
||||
|
||||
if items[2].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
|
||||
}
|
||||
if items[2].Get("call_id").String() != "call_abc" {
|
||||
t.Errorf("item 2: expected call_id 'call_abc', got '%s'", items[2].Get("call_id").String())
|
||||
}
|
||||
|
||||
if items[3].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
|
||||
}
|
||||
if items[3].Get("call_id").String() != "call_abc" {
|
||||
t.Errorf("item 3: expected call_id 'call_abc', got '%s'", items[3].Get("call_id").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Parallel tool calls: assistant invokes 3 tools at once, all call_ids
|
||||
// and outputs must be translated and paired correctly.
|
||||
func TestMultipleToolCalls(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Compare weather in Paris, London and Tokyo"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_paris",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"Paris\"}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "call_london",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"London\"}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "call_tokyo",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"Tokyo\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_paris", "content": "sunny, 22C"},
|
||||
{"role": "tool", "tool_call_id": "call_london", "content": "cloudy, 14C"},
|
||||
{"role": "tool", "tool_call_id": "call_tokyo", "content": "humid, 28C"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
// user + 3 function_call + 3 function_call_output = 7
|
||||
if len(items) != 7 {
|
||||
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
if items[0].Get("role").String() != "user" {
|
||||
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
|
||||
}
|
||||
|
||||
expectedCallIDs := []string{"call_paris", "call_london", "call_tokyo"}
|
||||
for i, expectedID := range expectedCallIDs {
|
||||
idx := i + 1
|
||||
if items[idx].Get("type").String() != "function_call" {
|
||||
t.Errorf("item %d: expected type 'function_call', got '%s'", idx, items[idx].Get("type").String())
|
||||
}
|
||||
if items[idx].Get("call_id").String() != expectedID {
|
||||
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedID, items[idx].Get("call_id").String())
|
||||
}
|
||||
}
|
||||
|
||||
expectedOutputs := []string{"sunny, 22C", "cloudy, 14C", "humid, 28C"}
|
||||
for i, expectedOutput := range expectedOutputs {
|
||||
idx := i + 4
|
||||
if items[idx].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item %d: expected type 'function_call_output', got '%s'", idx, items[idx].Get("type").String())
|
||||
}
|
||||
if items[idx].Get("call_id").String() != expectedCallIDs[i] {
|
||||
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedCallIDs[i], items[idx].Get("call_id").String())
|
||||
}
|
||||
if items[idx].Get("output").String() != expectedOutput {
|
||||
t.Errorf("item %d: expected output '%s', got '%s'", idx, expectedOutput, items[idx].Get("output").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Regression test for #2132: tool-call-only assistant messages (content:null)
|
||||
// must not produce an empty message item in the translated output.
|
||||
func TestNoSpuriousEmptyAssistantMessage(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Call a tool"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_x",
|
||||
"type": "function",
|
||||
"function": {"name": "do_thing", "arguments": "{}"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_x", "content": "done"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "do_thing",
|
||||
"description": "Do a thing",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
for i, item := range items {
|
||||
typ := item.Get("type").String()
|
||||
role := item.Get("role").String()
|
||||
if typ == "message" && role == "assistant" {
|
||||
contentArr := item.Get("content").Array()
|
||||
if len(contentArr) == 0 {
|
||||
t.Errorf("item %d: empty assistant message breaks call_id matching. item: %s", i, item.Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// should be exactly: user + function_call + function_call_output
|
||||
if len(items) != 3 {
|
||||
t.Fatalf("expected 3 input items (user + function_call + function_call_output), got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
if items[0].Get("type").String() != "message" || items[0].Get("role").String() != "user" {
|
||||
t.Errorf("item 0: expected user message")
|
||||
}
|
||||
if items[1].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
|
||||
}
|
||||
if items[2].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Two rounds of tool calling in one conversation, with a text reply in between.
|
||||
func TestMultiTurnToolCalling(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Weather in Paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{"id": "call_r1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Paris\"}"}}]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_r1", "content": "sunny"},
|
||||
{"role": "assistant", "content": "It is sunny in Paris."},
|
||||
{"role": "user", "content": "And London?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{"id": "call_r2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"London\"}"}}]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_r2", "content": "rainy"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
// user, func_call(r1), func_output(r1), assistant text, user, func_call(r2), func_output(r2)
|
||||
if len(items) != 7 {
|
||||
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
|
||||
if len(item.Get("content").Array()) == 0 {
|
||||
t.Errorf("item %d: unexpected empty assistant message", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// round 1
|
||||
if items[1].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
|
||||
}
|
||||
if items[1].Get("call_id").String() != "call_r1" {
|
||||
t.Errorf("item 1: expected call_id 'call_r1', got '%s'", items[1].Get("call_id").String())
|
||||
}
|
||||
if items[2].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
|
||||
}
|
||||
|
||||
// text reply between rounds
|
||||
if items[3].Get("type").String() != "message" || items[3].Get("role").String() != "assistant" {
|
||||
t.Errorf("item 3: expected assistant message, got type=%s role=%s", items[3].Get("type").String(), items[3].Get("role").String())
|
||||
}
|
||||
|
||||
// round 2
|
||||
if items[5].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 5: expected function_call, got %s", items[5].Get("type").String())
|
||||
}
|
||||
if items[5].Get("call_id").String() != "call_r2" {
|
||||
t.Errorf("item 5: expected call_id 'call_r2', got '%s'", items[5].Get("call_id").String())
|
||||
}
|
||||
if items[6].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 6: expected function_call_output, got %s", items[6].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Tool names over 64 chars get shortened, call_id stays the same.
|
||||
func TestToolNameShortening(t *testing.T) {
|
||||
longName := "a_very_long_tool_name_that_exceeds_sixty_four_characters_limit_here_test"
|
||||
if len(longName) <= 64 {
|
||||
t.Fatalf("test setup error: name must be > 64 chars, got %d", len(longName))
|
||||
}
|
||||
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Do it"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_long",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "` + longName + `",
|
||||
"arguments": "{}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_long", "content": "ok"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "` + longName + `",
|
||||
"description": "A tool with a very long name",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
// find function_call
|
||||
var funcCallItem gjson.Result
|
||||
for _, item := range items {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
funcCallItem = item
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !funcCallItem.Exists() {
|
||||
t.Fatal("no function_call item found in output")
|
||||
}
|
||||
|
||||
// call_id unchanged
|
||||
if funcCallItem.Get("call_id").String() != "call_long" {
|
||||
t.Errorf("call_id changed: expected 'call_long', got '%s'", funcCallItem.Get("call_id").String())
|
||||
}
|
||||
|
||||
// name must be truncated
|
||||
translatedName := funcCallItem.Get("name").String()
|
||||
if translatedName == longName {
|
||||
t.Errorf("tool name was NOT shortened: still '%s'", translatedName)
|
||||
}
|
||||
if len(translatedName) > 64 {
|
||||
t.Errorf("shortened name still > 64 chars: len=%d name='%s'", len(translatedName), translatedName)
|
||||
}
|
||||
}
|
||||
|
||||
// content:"" (empty string, not null) should be treated the same as null.
|
||||
func TestEmptyStringContent(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Do something"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_empty",
|
||||
"type": "function",
|
||||
"function": {"name": "action", "arguments": "{}"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_empty", "content": "result"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "action",
|
||||
"description": "An action",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
for i, item := range items {
|
||||
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
|
||||
if len(item.Get("content").Array()) == 0 {
|
||||
t.Errorf("item %d: empty assistant message from content:\"\"", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// user + function_call + function_call_output
|
||||
if len(items) != 3 {
|
||||
t.Errorf("expected 3 input items, got %d", len(items))
|
||||
}
|
||||
}
|
||||
|
||||
// Every function_call_output must have a matching function_call by call_id.
|
||||
func TestCallIDsMatchBetweenCallAndOutput(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Multi-tool"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{"id": "id_a", "type": "function", "function": {"name": "tool_a", "arguments": "{}"}},
|
||||
{"id": "id_b", "type": "function", "function": {"name": "tool_b", "arguments": "{}"}}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "id_a", "content": "res_a"},
|
||||
{"role": "tool", "tool_call_id": "id_b", "content": "res_b"}
|
||||
],
|
||||
"tools": [
|
||||
{"type": "function", "function": {"name": "tool_a", "description": "A", "parameters": {"type": "object", "properties": {}}}},
|
||||
{"type": "function", "function": {"name": "tool_b", "description": "B", "parameters": {"type": "object", "properties": {}}}}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
// collect call_ids from function_call items
|
||||
callIDs := make(map[string]bool)
|
||||
for _, item := range items {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
callIDs[item.Get("call_id").String()] = true
|
||||
}
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
if item.Get("type").String() == "function_call_output" {
|
||||
outID := item.Get("call_id").String()
|
||||
if !callIDs[outID] {
|
||||
t.Errorf("item %d: function_call_output has call_id '%s' with no matching function_call", i, outID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2 calls, 2 outputs
|
||||
funcCallCount := 0
|
||||
funcOutputCount := 0
|
||||
for _, item := range items {
|
||||
switch item.Get("type").String() {
|
||||
case "function_call":
|
||||
funcCallCount++
|
||||
case "function_call_output":
|
||||
funcOutputCount++
|
||||
}
|
||||
}
|
||||
if funcCallCount != 2 {
|
||||
t.Errorf("expected 2 function_calls, got %d", funcCallCount)
|
||||
}
|
||||
if funcOutputCount != 2 {
|
||||
t.Errorf("expected 2 function_call_outputs, got %d", funcOutputCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Tools array should carry over to the Responses format output.
|
||||
func TestToolsDefinitionTranslated(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hi"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Search the web",
|
||||
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
tools := gjson.Get(result, "tools").Array()
|
||||
if len(tools) == 0 {
|
||||
t.Fatal("no tools found in output")
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, tool := range tools {
|
||||
if tool.Get("name").String() == "search" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("tool 'search' not found in output tools: %s", gjson.Get(result, "tools").Raw)
|
||||
}
|
||||
}
|
||||
@@ -6,10 +6,10 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -36,7 +36,6 @@ const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
out := `{"model":"","request":{"contents":[]}}`
|
||||
@@ -149,7 +148,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
|
||||
inputSchemaResult := toolResult.Get("input_schema")
|
||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||
inputSchema := inputSchemaResult.Raw
|
||||
inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw)
|
||||
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
|
||||
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
|
||||
tool, _ = sjson.Delete(tool, "strict")
|
||||
@@ -157,6 +156,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
tool, _ = sjson.Delete(tool, "type")
|
||||
tool, _ = sjson.Delete(tool, "cache_control")
|
||||
tool, _ = sjson.Delete(tool, "defer_loading")
|
||||
tool, _ = sjson.Delete(tool, "eager_input_streaming")
|
||||
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
|
||||
if !hasTools {
|
||||
out, _ = sjson.SetRaw(out, "request.tools", `[{"functionDeclarations":[]}]`)
|
||||
|
||||
@@ -111,6 +111,23 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by
|
||||
return true
|
||||
})
|
||||
|
||||
// Filter out contents with empty parts to avoid Gemini API error:
|
||||
// "required oneof field 'data' must have one initialized field"
|
||||
filteredContents := "[]"
|
||||
hasFiltered := false
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(_, content gjson.Result) bool {
|
||||
parts := content.Get("parts")
|
||||
if !parts.IsArray() || len(parts.Array()) == 0 {
|
||||
hasFiltered = true
|
||||
return true
|
||||
}
|
||||
filteredContents, _ = sjson.SetRaw(filteredContents, "-1", content.Raw)
|
||||
return true
|
||||
})
|
||||
if hasFiltered {
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents", []byte(filteredContents))
|
||||
}
|
||||
|
||||
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
|
||||
}
|
||||
|
||||
|
||||
@@ -114,6 +114,21 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
part, _ = sjson.Set(part, "functionResponse.name", funcName)
|
||||
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
|
||||
case "image":
|
||||
source := contentResult.Get("source")
|
||||
if source.Get("type").String() != "base64" {
|
||||
return true
|
||||
}
|
||||
mimeType := source.Get("media_type").String()
|
||||
data := source.Get("data").String()
|
||||
if mimeType == "" || data == "" {
|
||||
return true
|
||||
}
|
||||
part := `{"inline_data":{"mime_type":"","data":""}}`
|
||||
part, _ = sjson.Set(part, "inline_data.mime_type", mimeType)
|
||||
part, _ = sjson.Set(part, "inline_data.data", data)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -40,3 +40,41 @@ func TestConvertClaudeRequestToGemini_ToolChoice_SpecificTool(t *testing.T) {
|
||||
t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.GetBytes(output, "toolConfig.functionCallingConfig.allowedFunctionNames").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToGemini_ImageContent(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gemini-3-flash-preview",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "describe this image"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "aGVsbG8="
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToGemini("gemini-3-flash-preview", inputJSON, false)
|
||||
|
||||
parts := gjson.GetBytes(output, "contents.0.parts").Array()
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("Expected 2 parts, got %d", len(parts))
|
||||
}
|
||||
if got := parts[0].Get("text").String(); got != "describe this image" {
|
||||
t.Fatalf("Expected first part text 'describe this image', got '%s'", got)
|
||||
}
|
||||
if got := parts[1].Get("inline_data.mime_type").String(); got != "image/png" {
|
||||
t.Fatalf("Expected image mime type 'image/png', got '%s'", got)
|
||||
}
|
||||
if got := parts[1].Get("inline_data.data").String(); got != "aGVsbG8=" {
|
||||
t.Fatalf("Expected image data 'aGVsbG8=', got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -340,7 +341,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
// Set the raw JSON output directly (preserves string encoding)
|
||||
if outputRaw != "" && outputRaw != "null" {
|
||||
output := gjson.Parse(outputRaw)
|
||||
if output.Type == gjson.JSON {
|
||||
if output.Type == gjson.JSON && json.Valid([]byte(output.Raw)) {
|
||||
functionResponse, _ = sjson.SetRaw(functionResponse, "functionResponse.response.result", output.Raw)
|
||||
} else {
|
||||
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.response.result", outputRaw)
|
||||
|
||||
@@ -149,6 +149,14 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []
|
||||
}
|
||||
}
|
||||
}
|
||||
// Read note from auth file.
|
||||
if rawNote, ok := metadata["note"]; ok {
|
||||
if note, isStr := rawNote.(string); isStr {
|
||||
if trimmed := strings.TrimSpace(note); trimmed != "" {
|
||||
a.Attributes["note"] = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
||||
// For codex auth files, extract plan_type from the JWT id_token.
|
||||
if provider == "codex" {
|
||||
@@ -221,6 +229,10 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an
|
||||
if priorityVal, hasPriority := primary.Attributes["priority"]; hasPriority && priorityVal != "" {
|
||||
attrs["priority"] = priorityVal
|
||||
}
|
||||
// Propagate note from primary auth to virtual auths
|
||||
if noteVal, hasNote := primary.Attributes["note"]; hasNote && noteVal != "" {
|
||||
attrs["note"] = noteVal
|
||||
}
|
||||
metadataCopy := map[string]any{
|
||||
"email": email,
|
||||
"project_id": projectID,
|
||||
|
||||
@@ -744,3 +744,200 @@ func TestBuildGeminiVirtualID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_NotePropagated(t *testing.T) {
|
||||
now := time.Now()
|
||||
primary := &coreauth.Auth{
|
||||
ID: "primary-id",
|
||||
Provider: "gemini-cli",
|
||||
Label: "test@example.com",
|
||||
Attributes: map[string]string{
|
||||
"source": "test-source",
|
||||
"path": "/path/to/auth",
|
||||
"priority": "5",
|
||||
"note": "my test note",
|
||||
},
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "proj-a, proj-b",
|
||||
"email": "test@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
|
||||
if len(virtuals) != 2 {
|
||||
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
|
||||
}
|
||||
|
||||
for i, v := range virtuals {
|
||||
if got := v.Attributes["note"]; got != "my test note" {
|
||||
t.Errorf("virtual %d: expected note %q, got %q", i, "my test note", got)
|
||||
}
|
||||
if got := v.Attributes["priority"]; got != "5" {
|
||||
t.Errorf("virtual %d: expected priority %q, got %q", i, "5", got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_NoteAbsentWhenEmpty(t *testing.T) {
|
||||
now := time.Now()
|
||||
primary := &coreauth.Auth{
|
||||
ID: "primary-id",
|
||||
Provider: "gemini-cli",
|
||||
Label: "test@example.com",
|
||||
Attributes: map[string]string{
|
||||
"source": "test-source",
|
||||
"path": "/path/to/auth",
|
||||
},
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "proj-a, proj-b",
|
||||
"email": "test@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
|
||||
if len(virtuals) != 2 {
|
||||
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
|
||||
}
|
||||
|
||||
for i, v := range virtuals {
|
||||
if _, hasNote := v.Attributes["note"]; hasNote {
|
||||
t.Errorf("virtual %d: expected no note attribute when primary has no note", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_NoteParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
note any
|
||||
want string
|
||||
hasValue bool
|
||||
}{
|
||||
{
|
||||
name: "valid string note",
|
||||
note: "hello world",
|
||||
want: "hello world",
|
||||
hasValue: true,
|
||||
},
|
||||
{
|
||||
name: "string note with whitespace",
|
||||
note: " trimmed note ",
|
||||
want: "trimmed note",
|
||||
hasValue: true,
|
||||
},
|
||||
{
|
||||
name: "empty string note",
|
||||
note: "",
|
||||
hasValue: false,
|
||||
},
|
||||
{
|
||||
name: "whitespace only note",
|
||||
note: " ",
|
||||
hasValue: false,
|
||||
},
|
||||
{
|
||||
name: "non-string note ignored",
|
||||
note: 12345,
|
||||
hasValue: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"note": tt.note,
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
errWriteFile := os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644)
|
||||
if errWriteFile != nil {
|
||||
t.Fatalf("failed to write auth file: %v", errWriteFile)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, errSynthesize := synth.Synthesize(ctx)
|
||||
if errSynthesize != nil {
|
||||
t.Fatalf("unexpected error: %v", errSynthesize)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
value, ok := auths[0].Attributes["note"]
|
||||
if tt.hasValue {
|
||||
if !ok {
|
||||
t.Fatal("expected note attribute to be set")
|
||||
}
|
||||
if value != tt.want {
|
||||
t.Fatalf("expected note %q, got %q", tt.want, value)
|
||||
}
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("expected note attribute to be absent, got %q", value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_MultiProjectGeminiWithNote(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authData := map[string]any{
|
||||
"type": "gemini",
|
||||
"email": "multi@example.com",
|
||||
"project_id": "project-a, project-b",
|
||||
"priority": 5,
|
||||
"note": "production keys",
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "gemini-multi.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Should have 3 auths: 1 primary (disabled) + 2 virtuals
|
||||
if len(auths) != 3 {
|
||||
t.Fatalf("expected 3 auths (1 primary + 2 virtuals), got %d", len(auths))
|
||||
}
|
||||
|
||||
primary := auths[0]
|
||||
if gotNote := primary.Attributes["note"]; gotNote != "production keys" {
|
||||
t.Errorf("expected primary note %q, got %q", "production keys", gotNote)
|
||||
}
|
||||
|
||||
// Verify virtuals inherit note
|
||||
for i := 1; i < len(auths); i++ {
|
||||
v := auths[i]
|
||||
if gotNote := v.Attributes["note"]; gotNote != "production keys" {
|
||||
t.Errorf("expected virtual %d note %q, got %q", i, "production keys", gotNote)
|
||||
}
|
||||
if gotPriority := v.Attributes["priority"]; gotPriority != "5" {
|
||||
t.Errorf("expected virtual %d priority %q, got %q", i, "5", gotPriority)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,7 +177,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID)
|
||||
} else {
|
||||
cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) {
|
||||
pinnedAuthID = strings.TrimSpace(authID)
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" || h == nil || h.AuthManager == nil {
|
||||
return
|
||||
}
|
||||
selectedAuth, ok := h.AuthManager.GetByID(authID)
|
||||
if !ok || selectedAuth == nil {
|
||||
return
|
||||
}
|
||||
if websocketUpstreamSupportsIncrementalInput(selectedAuth.Attributes, selectedAuth.Metadata) {
|
||||
pinnedAuthID = authID
|
||||
}
|
||||
})
|
||||
}
|
||||
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -26,6 +27,78 @@ type websocketCaptureExecutor struct {
|
||||
payloads [][]byte
|
||||
}
|
||||
|
||||
type orderedWebsocketSelector struct {
|
||||
mu sync.Mutex
|
||||
order []string
|
||||
cursor int
|
||||
}
|
||||
|
||||
func (s *orderedWebsocketSelector) Pick(_ context.Context, _ string, _ string, _ coreexecutor.Options, auths []*coreauth.Auth) (*coreauth.Auth, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if len(auths) == 0 {
|
||||
return nil, errors.New("no auth available")
|
||||
}
|
||||
for len(s.order) > 0 && s.cursor < len(s.order) {
|
||||
authID := strings.TrimSpace(s.order[s.cursor])
|
||||
s.cursor++
|
||||
for _, auth := range auths {
|
||||
if auth != nil && auth.ID == authID {
|
||||
return auth, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, auth := range auths {
|
||||
if auth != nil {
|
||||
return auth, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("no auth available")
|
||||
}
|
||||
|
||||
type websocketAuthCaptureExecutor struct {
|
||||
mu sync.Mutex
|
||||
authIDs []string
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" }
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, _ coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
e.mu.Lock()
|
||||
if auth != nil {
|
||||
e.authIDs = append(e.authIDs, auth.ID)
|
||||
}
|
||||
e.mu.Unlock()
|
||||
|
||||
chunks := make(chan coreexecutor.StreamChunk, 1)
|
||||
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)}
|
||||
close(chunks)
|
||||
return &coreexecutor.StreamResult{Chunks: chunks}, nil
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) AuthIDs() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return append([]string(nil), e.authIDs...)
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
|
||||
|
||||
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
@@ -519,3 +592,73 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
|
||||
t.Fatalf("unexpected forwarded input: %s", forwarded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
selector := &orderedWebsocketSelector{order: []string{"auth-sse", "auth-ws"}}
|
||||
executor := &websocketAuthCaptureExecutor{}
|
||||
manager := coreauth.NewManager(nil, selector, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
authSSE := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
|
||||
if _, err := manager.Register(context.Background(), authSSE); err != nil {
|
||||
t.Fatalf("Register SSE auth: %v", err)
|
||||
}
|
||||
authWS := &coreauth.Auth{
|
||||
ID: "auth-ws",
|
||||
Provider: executor.Identifier(),
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{"websockets": "true"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), authWS); err != nil {
|
||||
t.Fatalf("Register websocket auth: %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(authSSE.ID, authSSE.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
registry.GetGlobalRegistry().RegisterClient(authWS.ID, authWS.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(authSSE.ID)
|
||||
registry.GetGlobalRegistry().UnregisterClient(authWS.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
t.Fatalf("close websocket: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
requests := []string{
|
||||
`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`,
|
||||
`{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`,
|
||||
}
|
||||
for i := range requests {
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil {
|
||||
t.Fatalf("write websocket message %d: %v", i+1, errWrite)
|
||||
}
|
||||
_, payload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read websocket message %d: %v", i+1, errReadMessage)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
|
||||
t.Fatalf("message %d payload type = %s, want %s", i+1, got, wsEventTypeCompleted)
|
||||
}
|
||||
}
|
||||
|
||||
if got := executor.AuthIDs(); len(got) != 2 || got[0] != "auth-sse" || got[1] != "auth-ws" {
|
||||
t.Fatalf("selected auth IDs = %v, want [auth-sse auth-ws]", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,7 +162,60 @@ func stableAuthIndex(seed string) string {
|
||||
return hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
// EnsureIndex returns a stable index derived from the auth file name or API key.
|
||||
func (a *Auth) indexSeed() string {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if fileName := strings.TrimSpace(a.FileName); fileName != "" {
|
||||
return "file:" + fileName
|
||||
}
|
||||
|
||||
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
compatName := ""
|
||||
baseURL := ""
|
||||
apiKey := ""
|
||||
source := ""
|
||||
if a.Attributes != nil {
|
||||
if value := strings.TrimSpace(a.Attributes["provider_key"]); value != "" {
|
||||
providerKey = strings.ToLower(value)
|
||||
}
|
||||
compatName = strings.ToLower(strings.TrimSpace(a.Attributes["compat_name"]))
|
||||
baseURL = strings.TrimSpace(a.Attributes["base_url"])
|
||||
apiKey = strings.TrimSpace(a.Attributes["api_key"])
|
||||
source = strings.TrimSpace(a.Attributes["source"])
|
||||
}
|
||||
|
||||
proxyURL := strings.TrimSpace(a.ProxyURL)
|
||||
hasCredentialIdentity := compatName != "" || baseURL != "" || proxyURL != "" || apiKey != "" || source != ""
|
||||
if providerKey != "" && hasCredentialIdentity {
|
||||
parts := []string{"provider=" + providerKey}
|
||||
if compatName != "" {
|
||||
parts = append(parts, "compat="+compatName)
|
||||
}
|
||||
if baseURL != "" {
|
||||
parts = append(parts, "base="+baseURL)
|
||||
}
|
||||
if proxyURL != "" {
|
||||
parts = append(parts, "proxy="+proxyURL)
|
||||
}
|
||||
if apiKey != "" {
|
||||
parts = append(parts, "api_key="+apiKey)
|
||||
}
|
||||
if source != "" {
|
||||
parts = append(parts, "source="+source)
|
||||
}
|
||||
return "config:" + strings.Join(parts, "\x00")
|
||||
}
|
||||
|
||||
if id := strings.TrimSpace(a.ID); id != "" {
|
||||
return "id:" + id
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// EnsureIndex returns a stable index derived from the auth file name or credential identity.
|
||||
func (a *Auth) EnsureIndex() string {
|
||||
if a == nil {
|
||||
return ""
|
||||
@@ -171,20 +224,9 @@ func (a *Auth) EnsureIndex() string {
|
||||
return a.Index
|
||||
}
|
||||
|
||||
seed := strings.TrimSpace(a.FileName)
|
||||
if seed != "" {
|
||||
seed = "file:" + seed
|
||||
} else if a.Attributes != nil {
|
||||
if apiKey := strings.TrimSpace(a.Attributes["api_key"]); apiKey != "" {
|
||||
seed = "api_key:" + apiKey
|
||||
}
|
||||
}
|
||||
seed := a.indexSeed()
|
||||
if seed == "" {
|
||||
if id := strings.TrimSpace(a.ID); id != "" {
|
||||
seed = "id:" + id
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
idx := stableAuthIndex(seed)
|
||||
|
||||
@@ -33,3 +33,66 @@ func TestToolPrefixDisabled(t *testing.T) {
|
||||
t.Error("should return false when set to false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureIndexUsesCredentialIdentity(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
geminiAuth := &Auth{
|
||||
Provider: "gemini",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "shared-key",
|
||||
"source": "config:gemini[abc123]",
|
||||
},
|
||||
}
|
||||
compatAuth := &Auth{
|
||||
Provider: "bohe",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "shared-key",
|
||||
"compat_name": "bohe",
|
||||
"provider_key": "bohe",
|
||||
"source": "config:bohe[def456]",
|
||||
},
|
||||
}
|
||||
geminiAltBase := &Auth{
|
||||
Provider: "gemini",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "shared-key",
|
||||
"base_url": "https://alt.example.com",
|
||||
"source": "config:gemini[ghi789]",
|
||||
},
|
||||
}
|
||||
geminiDuplicate := &Auth{
|
||||
Provider: "gemini",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "shared-key",
|
||||
"source": "config:gemini[abc123-1]",
|
||||
},
|
||||
}
|
||||
|
||||
geminiIndex := geminiAuth.EnsureIndex()
|
||||
compatIndex := compatAuth.EnsureIndex()
|
||||
altBaseIndex := geminiAltBase.EnsureIndex()
|
||||
duplicateIndex := geminiDuplicate.EnsureIndex()
|
||||
|
||||
if geminiIndex == "" {
|
||||
t.Fatal("gemini index should not be empty")
|
||||
}
|
||||
if compatIndex == "" {
|
||||
t.Fatal("compat index should not be empty")
|
||||
}
|
||||
if altBaseIndex == "" {
|
||||
t.Fatal("alt base index should not be empty")
|
||||
}
|
||||
if duplicateIndex == "" {
|
||||
t.Fatal("duplicate index should not be empty")
|
||||
}
|
||||
if geminiIndex == compatIndex {
|
||||
t.Fatalf("shared api key produced duplicate auth_index %q", geminiIndex)
|
||||
}
|
||||
if geminiIndex == altBaseIndex {
|
||||
t.Fatalf("same provider/key with different base_url produced duplicate auth_index %q", geminiIndex)
|
||||
}
|
||||
if geminiIndex == duplicateIndex {
|
||||
t.Fatalf("duplicate config entries should be separated by source-derived seed, got %q", geminiIndex)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -915,7 +915,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
models = registry.GetCodexProModels()
|
||||
case "plus":
|
||||
models = registry.GetCodexPlusModels()
|
||||
case "team":
|
||||
case "team", "business", "go":
|
||||
models = registry.GetCodexTeamModels()
|
||||
case "free":
|
||||
models = registry.GetCodexFreeModels()
|
||||
|
||||
Reference in New Issue
Block a user