mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-22 09:10:30 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c2ad4cda2 | ||
|
|
db63f9b5d6 | ||
|
|
25f6c4a250 | ||
|
|
b24ae74216 | ||
|
|
59ad8f40dc | ||
|
|
ff03dc6a2c | ||
|
|
dc7187ca5b | ||
|
|
b1dcff778c | ||
|
|
c1241a98e2 | ||
|
|
8d8f5970ee | ||
|
|
f90120f846 | ||
|
|
0b94d36c4a | ||
|
|
c8cee6a209 | ||
|
|
5b6342e6ac |
@@ -341,6 +341,21 @@ func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
|
||||
emailValue := gjson.GetBytes(data, "email").String()
|
||||
fileData["type"] = typeValue
|
||||
fileData["email"] = emailValue
|
||||
if pv := gjson.GetBytes(data, "priority"); pv.Exists() {
|
||||
switch pv.Type {
|
||||
case gjson.Number:
|
||||
fileData["priority"] = int(pv.Int())
|
||||
case gjson.String:
|
||||
if parsed, errAtoi := strconv.Atoi(strings.TrimSpace(pv.String())); errAtoi == nil {
|
||||
fileData["priority"] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
if nv := gjson.GetBytes(data, "note"); nv.Exists() && nv.Type == gjson.String {
|
||||
if trimmed := strings.TrimSpace(nv.String()); trimmed != "" {
|
||||
fileData["note"] = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
files = append(files, fileData)
|
||||
@@ -424,6 +439,37 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
||||
if claims := extractCodexIDTokenClaims(auth); claims != nil {
|
||||
entry["id_token"] = claims
|
||||
}
|
||||
// Expose priority from Attributes (set by synthesizer from JSON "priority" field).
|
||||
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
|
||||
if p := strings.TrimSpace(authAttribute(auth, "priority")); p != "" {
|
||||
if parsed, err := strconv.Atoi(p); err == nil {
|
||||
entry["priority"] = parsed
|
||||
}
|
||||
} else if auth.Metadata != nil {
|
||||
if rawPriority, ok := auth.Metadata["priority"]; ok {
|
||||
switch v := rawPriority.(type) {
|
||||
case float64:
|
||||
entry["priority"] = int(v)
|
||||
case int:
|
||||
entry["priority"] = v
|
||||
case string:
|
||||
if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
entry["priority"] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Expose note from Attributes (set by synthesizer from JSON "note" field).
|
||||
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
|
||||
if note := strings.TrimSpace(authAttribute(auth, "note")); note != "" {
|
||||
entry["note"] = note
|
||||
} else if auth.Metadata != nil {
|
||||
if rawNote, ok := auth.Metadata["note"].(string); ok {
|
||||
if trimmed := strings.TrimSpace(rawNote); trimmed != "" {
|
||||
entry["note"] = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
@@ -848,7 +894,7 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||
}
|
||||
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority, note) of an auth file.
|
||||
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
if h.authManager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||
@@ -860,6 +906,7 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
Prefix *string `json:"prefix"`
|
||||
ProxyURL *string `json:"proxy_url"`
|
||||
Priority *int `json:"priority"`
|
||||
Note *string `json:"note"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
@@ -902,14 +949,32 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
targetAuth.ProxyURL = *req.ProxyURL
|
||||
changed = true
|
||||
}
|
||||
if req.Priority != nil {
|
||||
if req.Priority != nil || req.Note != nil {
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if *req.Priority == 0 {
|
||||
delete(targetAuth.Metadata, "priority")
|
||||
} else {
|
||||
targetAuth.Metadata["priority"] = *req.Priority
|
||||
if targetAuth.Attributes == nil {
|
||||
targetAuth.Attributes = make(map[string]string)
|
||||
}
|
||||
|
||||
if req.Priority != nil {
|
||||
if *req.Priority == 0 {
|
||||
delete(targetAuth.Metadata, "priority")
|
||||
delete(targetAuth.Attributes, "priority")
|
||||
} else {
|
||||
targetAuth.Metadata["priority"] = *req.Priority
|
||||
targetAuth.Attributes["priority"] = strconv.Itoa(*req.Priority)
|
||||
}
|
||||
}
|
||||
if req.Note != nil {
|
||||
trimmedNote := strings.TrimSpace(*req.Note)
|
||||
if trimmedNote == "" {
|
||||
delete(targetAuth.Metadata, "note")
|
||||
delete(targetAuth.Attributes, "note")
|
||||
} else {
|
||||
targetAuth.Metadata["note"] = trimmedNote
|
||||
targetAuth.Attributes["note"] = trimmedNote
|
||||
}
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -68,6 +69,10 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
contentsJSON := "[]"
|
||||
hasContents := false
|
||||
|
||||
// tool_use_id → tool_name lookup, populated incrementally during the main loop.
|
||||
// Claude's tool_result references tool_use by ID; Gemini requires functionResponse.name.
|
||||
toolNameByID := make(map[string]string)
|
||||
|
||||
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messageResults := messagesResult.Array()
|
||||
@@ -170,6 +175,10 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
argsResult := contentResult.Get("input")
|
||||
functionID := contentResult.Get("id").String()
|
||||
|
||||
if functionID != "" && functionName != "" {
|
||||
toolNameByID[functionID] = functionName
|
||||
}
|
||||
|
||||
// Handle both object and string input formats
|
||||
var argsRaw string
|
||||
if argsResult.IsObject() {
|
||||
@@ -206,10 +215,19 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||
toolCallID := contentResult.Get("tool_use_id").String()
|
||||
if toolCallID != "" {
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-2], "-")
|
||||
funcName, ok := toolNameByID[toolCallID]
|
||||
if !ok {
|
||||
// Fallback: derive a semantic name from the ID by stripping
|
||||
// the last two dash-separated segments (e.g. "get_weather-call-123" → "get_weather").
|
||||
// Only use the raw ID as a last resort when the heuristic produces an empty string.
|
||||
parts := strings.Split(toolCallID, "-")
|
||||
if len(parts) > 2 {
|
||||
funcName = strings.Join(parts[:len(parts)-2], "-")
|
||||
}
|
||||
if funcName == "" {
|
||||
funcName = toolCallID
|
||||
}
|
||||
log.Warnf("antigravity claude request: tool_result references unknown tool_use_id=%s, derived function name=%s", toolCallID, funcName)
|
||||
}
|
||||
functionResponseResult := contentResult.Get("content")
|
||||
|
||||
|
||||
@@ -365,6 +365,17 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "get_weather-call-123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "Paris"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -382,13 +393,177 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
||||
outputStr := string(output)
|
||||
|
||||
// Check function response conversion
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Error("functionResponse should exist")
|
||||
}
|
||||
if funcResp.Get("id").String() != "get_weather-call-123" {
|
||||
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
|
||||
}
|
||||
if funcResp.Get("name").String() != "get_weather" {
|
||||
t.Errorf("Expected function name 'get_weather', got '%s'", funcResp.Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_TouluFormat(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||
"name": "Glob",
|
||||
"input": {"pattern": "**/*.py"}
|
||||
},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||
"content": "file1.py\nfile2.py"
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
|
||||
"content": "total 10"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp0 := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||
if !funcResp0.Exists() {
|
||||
t.Fatal("first functionResponse should exist")
|
||||
}
|
||||
if got := funcResp0.Get("name").String(); got != "Glob" {
|
||||
t.Errorf("Expected name 'Glob' for toolu_ format, got '%s'", got)
|
||||
}
|
||||
|
||||
funcResp1 := gjson.Get(outputStr, "request.contents.1.parts.1.functionResponse")
|
||||
if !funcResp1.Exists() {
|
||||
t.Fatal("second functionResponse should exist")
|
||||
}
|
||||
if got := funcResp1.Get("name").String(); got != "Bash" {
|
||||
t.Errorf("Expected name 'Bash' for toolu_ format, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_CustomFormat(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "Read-1773420180464065165-1327",
|
||||
"name": "Read",
|
||||
"input": {"file_path": "/tmp/test.py"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "Read-1773420180464065165-1327",
|
||||
"content": "file content here"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
if got := funcResp.Get("name").String(); got != "Read" {
|
||||
t.Errorf("Expected name 'Read', got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_Heuristic(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "get_weather-call-123",
|
||||
"content": "22C sunny"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
if got := funcResp.Get("name").String(); got != "get_weather" {
|
||||
t.Errorf("Expected heuristic-derived name 'get_weather', got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_RawID(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||
"content": "result data"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
got := funcResp.Get("name").String()
|
||||
if got == "" {
|
||||
t.Error("functionResponse.name must not be empty")
|
||||
}
|
||||
if got != "toolu_tool-48fca351f12844eabf49dad8b63886d2" {
|
||||
t.Errorf("Expected raw ID as last-resort name, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {
|
||||
|
||||
@@ -197,7 +197,12 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
}
|
||||
}
|
||||
|
||||
out, _ = sjson.SetRaw(out, "input.-1", msg)
|
||||
// Don't emit empty assistant messages when only tool_calls
|
||||
// are present — Responses API needs function_call items
|
||||
// directly, otherwise call_id matching fails (#2132).
|
||||
if role != "assistant" || len(gjson.Get(msg, "content").Array()) > 0 {
|
||||
out, _ = sjson.SetRaw(out, "input.-1", msg)
|
||||
}
|
||||
|
||||
// Handle tool calls for assistant messages as separate top-level objects
|
||||
if role == "assistant" {
|
||||
|
||||
@@ -0,0 +1,635 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Basic tool-call: system + user + assistant(tool_calls, no content) + tool result.
|
||||
// Expects developer msg + user msg + function_call + function_call_output.
|
||||
// No empty assistant message should appear between user and function_call.
|
||||
func TestToolCallSimple(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the weather in Paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"Paris\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "sunny, 22C"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a city",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
if len(items) != 4 {
|
||||
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
// system -> developer
|
||||
if items[0].Get("type").String() != "message" {
|
||||
t.Errorf("item 0: expected type 'message', got '%s'", items[0].Get("type").String())
|
||||
}
|
||||
if items[0].Get("role").String() != "developer" {
|
||||
t.Errorf("item 0: expected role 'developer', got '%s'", items[0].Get("role").String())
|
||||
}
|
||||
|
||||
// user
|
||||
if items[1].Get("type").String() != "message" {
|
||||
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
|
||||
}
|
||||
if items[1].Get("role").String() != "user" {
|
||||
t.Errorf("item 1: expected role 'user', got '%s'", items[1].Get("role").String())
|
||||
}
|
||||
|
||||
// function_call, not an empty assistant msg
|
||||
if items[2].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
|
||||
}
|
||||
if items[2].Get("call_id").String() != "call_1" {
|
||||
t.Errorf("item 2: expected call_id 'call_1', got '%s'", items[2].Get("call_id").String())
|
||||
}
|
||||
if items[2].Get("name").String() != "get_weather" {
|
||||
t.Errorf("item 2: expected name 'get_weather', got '%s'", items[2].Get("name").String())
|
||||
}
|
||||
if items[2].Get("arguments").String() != `{"city":"Paris"}` {
|
||||
t.Errorf("item 2: unexpected arguments: %s", items[2].Get("arguments").String())
|
||||
}
|
||||
|
||||
// function_call_output
|
||||
if items[3].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
|
||||
}
|
||||
if items[3].Get("call_id").String() != "call_1" {
|
||||
t.Errorf("item 3: expected call_id 'call_1', got '%s'", items[3].Get("call_id").String())
|
||||
}
|
||||
if items[3].Get("output").String() != "sunny, 22C" {
|
||||
t.Errorf("item 3: expected output 'sunny, 22C', got '%s'", items[3].Get("output").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Assistant has both text content and tool_calls — the message should
|
||||
// be emitted (non-empty content), followed by function_call items.
|
||||
func TestToolCallWithContent(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me check the weather for you.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc",
|
||||
"content": "rainy, 15C"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
// user + assistant(with content) + function_call + function_call_output
|
||||
if len(items) != 4 {
|
||||
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
if items[0].Get("role").String() != "user" {
|
||||
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
|
||||
}
|
||||
|
||||
// assistant with content — should be kept
|
||||
if items[1].Get("type").String() != "message" {
|
||||
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
|
||||
}
|
||||
if items[1].Get("role").String() != "assistant" {
|
||||
t.Errorf("item 1: expected role 'assistant', got '%s'", items[1].Get("role").String())
|
||||
}
|
||||
contentParts := items[1].Get("content").Array()
|
||||
if len(contentParts) == 0 {
|
||||
t.Errorf("item 1: assistant message should have content parts")
|
||||
}
|
||||
|
||||
if items[2].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
|
||||
}
|
||||
if items[2].Get("call_id").String() != "call_abc" {
|
||||
t.Errorf("item 2: expected call_id 'call_abc', got '%s'", items[2].Get("call_id").String())
|
||||
}
|
||||
|
||||
if items[3].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
|
||||
}
|
||||
if items[3].Get("call_id").String() != "call_abc" {
|
||||
t.Errorf("item 3: expected call_id 'call_abc', got '%s'", items[3].Get("call_id").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Parallel tool calls: assistant invokes 3 tools at once, all call_ids
|
||||
// and outputs must be translated and paired correctly.
|
||||
func TestMultipleToolCalls(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Compare weather in Paris, London and Tokyo"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_paris",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"Paris\"}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "call_london",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"London\"}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "call_tokyo",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"Tokyo\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_paris", "content": "sunny, 22C"},
|
||||
{"role": "tool", "tool_call_id": "call_london", "content": "cloudy, 14C"},
|
||||
{"role": "tool", "tool_call_id": "call_tokyo", "content": "humid, 28C"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
// user + 3 function_call + 3 function_call_output = 7
|
||||
if len(items) != 7 {
|
||||
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
if items[0].Get("role").String() != "user" {
|
||||
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
|
||||
}
|
||||
|
||||
expectedCallIDs := []string{"call_paris", "call_london", "call_tokyo"}
|
||||
for i, expectedID := range expectedCallIDs {
|
||||
idx := i + 1
|
||||
if items[idx].Get("type").String() != "function_call" {
|
||||
t.Errorf("item %d: expected type 'function_call', got '%s'", idx, items[idx].Get("type").String())
|
||||
}
|
||||
if items[idx].Get("call_id").String() != expectedID {
|
||||
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedID, items[idx].Get("call_id").String())
|
||||
}
|
||||
}
|
||||
|
||||
expectedOutputs := []string{"sunny, 22C", "cloudy, 14C", "humid, 28C"}
|
||||
for i, expectedOutput := range expectedOutputs {
|
||||
idx := i + 4
|
||||
if items[idx].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item %d: expected type 'function_call_output', got '%s'", idx, items[idx].Get("type").String())
|
||||
}
|
||||
if items[idx].Get("call_id").String() != expectedCallIDs[i] {
|
||||
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedCallIDs[i], items[idx].Get("call_id").String())
|
||||
}
|
||||
if items[idx].Get("output").String() != expectedOutput {
|
||||
t.Errorf("item %d: expected output '%s', got '%s'", idx, expectedOutput, items[idx].Get("output").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Regression test for #2132: tool-call-only assistant messages (content:null)
|
||||
// must not produce an empty message item in the translated output.
|
||||
func TestNoSpuriousEmptyAssistantMessage(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Call a tool"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_x",
|
||||
"type": "function",
|
||||
"function": {"name": "do_thing", "arguments": "{}"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_x", "content": "done"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "do_thing",
|
||||
"description": "Do a thing",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
for i, item := range items {
|
||||
typ := item.Get("type").String()
|
||||
role := item.Get("role").String()
|
||||
if typ == "message" && role == "assistant" {
|
||||
contentArr := item.Get("content").Array()
|
||||
if len(contentArr) == 0 {
|
||||
t.Errorf("item %d: empty assistant message breaks call_id matching. item: %s", i, item.Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// should be exactly: user + function_call + function_call_output
|
||||
if len(items) != 3 {
|
||||
t.Fatalf("expected 3 input items (user + function_call + function_call_output), got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
if items[0].Get("type").String() != "message" || items[0].Get("role").String() != "user" {
|
||||
t.Errorf("item 0: expected user message")
|
||||
}
|
||||
if items[1].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
|
||||
}
|
||||
if items[2].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Two rounds of tool calling in one conversation, with a text reply in between.
|
||||
func TestMultiTurnToolCalling(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Weather in Paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{"id": "call_r1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Paris\"}"}}]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_r1", "content": "sunny"},
|
||||
{"role": "assistant", "content": "It is sunny in Paris."},
|
||||
{"role": "user", "content": "And London?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{"id": "call_r2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"London\"}"}}]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_r2", "content": "rainy"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
// user, func_call(r1), func_output(r1), assistant text, user, func_call(r2), func_output(r2)
|
||||
if len(items) != 7 {
|
||||
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
|
||||
if len(item.Get("content").Array()) == 0 {
|
||||
t.Errorf("item %d: unexpected empty assistant message", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// round 1
|
||||
if items[1].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
|
||||
}
|
||||
if items[1].Get("call_id").String() != "call_r1" {
|
||||
t.Errorf("item 1: expected call_id 'call_r1', got '%s'", items[1].Get("call_id").String())
|
||||
}
|
||||
if items[2].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
|
||||
}
|
||||
|
||||
// text reply between rounds
|
||||
if items[3].Get("type").String() != "message" || items[3].Get("role").String() != "assistant" {
|
||||
t.Errorf("item 3: expected assistant message, got type=%s role=%s", items[3].Get("type").String(), items[3].Get("role").String())
|
||||
}
|
||||
|
||||
// round 2
|
||||
if items[5].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 5: expected function_call, got %s", items[5].Get("type").String())
|
||||
}
|
||||
if items[5].Get("call_id").String() != "call_r2" {
|
||||
t.Errorf("item 5: expected call_id 'call_r2', got '%s'", items[5].Get("call_id").String())
|
||||
}
|
||||
if items[6].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 6: expected function_call_output, got %s", items[6].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Tool names over 64 chars get shortened, call_id stays the same.
|
||||
func TestToolNameShortening(t *testing.T) {
|
||||
longName := "a_very_long_tool_name_that_exceeds_sixty_four_characters_limit_here_test"
|
||||
if len(longName) <= 64 {
|
||||
t.Fatalf("test setup error: name must be > 64 chars, got %d", len(longName))
|
||||
}
|
||||
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Do it"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_long",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "` + longName + `",
|
||||
"arguments": "{}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_long", "content": "ok"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "` + longName + `",
|
||||
"description": "A tool with a very long name",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
// find function_call
|
||||
var funcCallItem gjson.Result
|
||||
for _, item := range items {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
funcCallItem = item
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !funcCallItem.Exists() {
|
||||
t.Fatal("no function_call item found in output")
|
||||
}
|
||||
|
||||
// call_id unchanged
|
||||
if funcCallItem.Get("call_id").String() != "call_long" {
|
||||
t.Errorf("call_id changed: expected 'call_long', got '%s'", funcCallItem.Get("call_id").String())
|
||||
}
|
||||
|
||||
// name must be truncated
|
||||
translatedName := funcCallItem.Get("name").String()
|
||||
if translatedName == longName {
|
||||
t.Errorf("tool name was NOT shortened: still '%s'", translatedName)
|
||||
}
|
||||
if len(translatedName) > 64 {
|
||||
t.Errorf("shortened name still > 64 chars: len=%d name='%s'", len(translatedName), translatedName)
|
||||
}
|
||||
}
|
||||
|
||||
// content:"" (empty string, not null) should be treated the same as null.
|
||||
func TestEmptyStringContent(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Do something"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_empty",
|
||||
"type": "function",
|
||||
"function": {"name": "action", "arguments": "{}"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_empty", "content": "result"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "action",
|
||||
"description": "An action",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
for i, item := range items {
|
||||
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
|
||||
if len(item.Get("content").Array()) == 0 {
|
||||
t.Errorf("item %d: empty assistant message from content:\"\"", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// user + function_call + function_call_output
|
||||
if len(items) != 3 {
|
||||
t.Errorf("expected 3 input items, got %d", len(items))
|
||||
}
|
||||
}
|
||||
|
||||
// Every function_call_output must have a matching function_call by call_id.
|
||||
func TestCallIDsMatchBetweenCallAndOutput(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Multi-tool"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{"id": "id_a", "type": "function", "function": {"name": "tool_a", "arguments": "{}"}},
|
||||
{"id": "id_b", "type": "function", "function": {"name": "tool_b", "arguments": "{}"}}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "id_a", "content": "res_a"},
|
||||
{"role": "tool", "tool_call_id": "id_b", "content": "res_b"}
|
||||
],
|
||||
"tools": [
|
||||
{"type": "function", "function": {"name": "tool_a", "description": "A", "parameters": {"type": "object", "properties": {}}}},
|
||||
{"type": "function", "function": {"name": "tool_b", "description": "B", "parameters": {"type": "object", "properties": {}}}}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
// collect call_ids from function_call items
|
||||
callIDs := make(map[string]bool)
|
||||
for _, item := range items {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
callIDs[item.Get("call_id").String()] = true
|
||||
}
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
if item.Get("type").String() == "function_call_output" {
|
||||
outID := item.Get("call_id").String()
|
||||
if !callIDs[outID] {
|
||||
t.Errorf("item %d: function_call_output has call_id '%s' with no matching function_call", i, outID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2 calls, 2 outputs
|
||||
funcCallCount := 0
|
||||
funcOutputCount := 0
|
||||
for _, item := range items {
|
||||
switch item.Get("type").String() {
|
||||
case "function_call":
|
||||
funcCallCount++
|
||||
case "function_call_output":
|
||||
funcOutputCount++
|
||||
}
|
||||
}
|
||||
if funcCallCount != 2 {
|
||||
t.Errorf("expected 2 function_calls, got %d", funcCallCount)
|
||||
}
|
||||
if funcOutputCount != 2 {
|
||||
t.Errorf("expected 2 function_call_outputs, got %d", funcOutputCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Tools array should carry over to the Responses format output.
|
||||
func TestToolsDefinitionTranslated(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hi"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Search the web",
|
||||
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
tools := gjson.Get(result, "tools").Array()
|
||||
if len(tools) == 0 {
|
||||
t.Fatal("no tools found in output")
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, tool := range tools {
|
||||
if tool.Get("name").String() == "search" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("tool 'search' not found in output tools: %s", gjson.Get(result, "tools").Raw)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -340,7 +341,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
// Set the raw JSON output directly (preserves string encoding)
|
||||
if outputRaw != "" && outputRaw != "null" {
|
||||
output := gjson.Parse(outputRaw)
|
||||
if output.Type == gjson.JSON {
|
||||
if output.Type == gjson.JSON && json.Valid([]byte(output.Raw)) {
|
||||
functionResponse, _ = sjson.SetRaw(functionResponse, "functionResponse.response.result", output.Raw)
|
||||
} else {
|
||||
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.response.result", outputRaw)
|
||||
|
||||
@@ -149,6 +149,14 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []
|
||||
}
|
||||
}
|
||||
}
|
||||
// Read note from auth file.
|
||||
if rawNote, ok := metadata["note"]; ok {
|
||||
if note, isStr := rawNote.(string); isStr {
|
||||
if trimmed := strings.TrimSpace(note); trimmed != "" {
|
||||
a.Attributes["note"] = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
||||
// For codex auth files, extract plan_type from the JWT id_token.
|
||||
if provider == "codex" {
|
||||
@@ -221,6 +229,10 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an
|
||||
if priorityVal, hasPriority := primary.Attributes["priority"]; hasPriority && priorityVal != "" {
|
||||
attrs["priority"] = priorityVal
|
||||
}
|
||||
// Propagate note from primary auth to virtual auths
|
||||
if noteVal, hasNote := primary.Attributes["note"]; hasNote && noteVal != "" {
|
||||
attrs["note"] = noteVal
|
||||
}
|
||||
metadataCopy := map[string]any{
|
||||
"email": email,
|
||||
"project_id": projectID,
|
||||
|
||||
@@ -744,3 +744,200 @@ func TestBuildGeminiVirtualID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_NotePropagated(t *testing.T) {
|
||||
now := time.Now()
|
||||
primary := &coreauth.Auth{
|
||||
ID: "primary-id",
|
||||
Provider: "gemini-cli",
|
||||
Label: "test@example.com",
|
||||
Attributes: map[string]string{
|
||||
"source": "test-source",
|
||||
"path": "/path/to/auth",
|
||||
"priority": "5",
|
||||
"note": "my test note",
|
||||
},
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "proj-a, proj-b",
|
||||
"email": "test@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
|
||||
if len(virtuals) != 2 {
|
||||
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
|
||||
}
|
||||
|
||||
for i, v := range virtuals {
|
||||
if got := v.Attributes["note"]; got != "my test note" {
|
||||
t.Errorf("virtual %d: expected note %q, got %q", i, "my test note", got)
|
||||
}
|
||||
if got := v.Attributes["priority"]; got != "5" {
|
||||
t.Errorf("virtual %d: expected priority %q, got %q", i, "5", got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_NoteAbsentWhenEmpty(t *testing.T) {
|
||||
now := time.Now()
|
||||
primary := &coreauth.Auth{
|
||||
ID: "primary-id",
|
||||
Provider: "gemini-cli",
|
||||
Label: "test@example.com",
|
||||
Attributes: map[string]string{
|
||||
"source": "test-source",
|
||||
"path": "/path/to/auth",
|
||||
},
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "proj-a, proj-b",
|
||||
"email": "test@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
|
||||
if len(virtuals) != 2 {
|
||||
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
|
||||
}
|
||||
|
||||
for i, v := range virtuals {
|
||||
if _, hasNote := v.Attributes["note"]; hasNote {
|
||||
t.Errorf("virtual %d: expected no note attribute when primary has no note", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_NoteParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
note any
|
||||
want string
|
||||
hasValue bool
|
||||
}{
|
||||
{
|
||||
name: "valid string note",
|
||||
note: "hello world",
|
||||
want: "hello world",
|
||||
hasValue: true,
|
||||
},
|
||||
{
|
||||
name: "string note with whitespace",
|
||||
note: " trimmed note ",
|
||||
want: "trimmed note",
|
||||
hasValue: true,
|
||||
},
|
||||
{
|
||||
name: "empty string note",
|
||||
note: "",
|
||||
hasValue: false,
|
||||
},
|
||||
{
|
||||
name: "whitespace only note",
|
||||
note: " ",
|
||||
hasValue: false,
|
||||
},
|
||||
{
|
||||
name: "non-string note ignored",
|
||||
note: 12345,
|
||||
hasValue: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"note": tt.note,
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
errWriteFile := os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644)
|
||||
if errWriteFile != nil {
|
||||
t.Fatalf("failed to write auth file: %v", errWriteFile)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, errSynthesize := synth.Synthesize(ctx)
|
||||
if errSynthesize != nil {
|
||||
t.Fatalf("unexpected error: %v", errSynthesize)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
value, ok := auths[0].Attributes["note"]
|
||||
if tt.hasValue {
|
||||
if !ok {
|
||||
t.Fatal("expected note attribute to be set")
|
||||
}
|
||||
if value != tt.want {
|
||||
t.Fatalf("expected note %q, got %q", tt.want, value)
|
||||
}
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("expected note attribute to be absent, got %q", value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_MultiProjectGeminiWithNote(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authData := map[string]any{
|
||||
"type": "gemini",
|
||||
"email": "multi@example.com",
|
||||
"project_id": "project-a, project-b",
|
||||
"priority": 5,
|
||||
"note": "production keys",
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "gemini-multi.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Should have 3 auths: 1 primary (disabled) + 2 virtuals
|
||||
if len(auths) != 3 {
|
||||
t.Fatalf("expected 3 auths (1 primary + 2 virtuals), got %d", len(auths))
|
||||
}
|
||||
|
||||
primary := auths[0]
|
||||
if gotNote := primary.Attributes["note"]; gotNote != "production keys" {
|
||||
t.Errorf("expected primary note %q, got %q", "production keys", gotNote)
|
||||
}
|
||||
|
||||
// Verify virtuals inherit note
|
||||
for i := 1; i < len(auths); i++ {
|
||||
v := auths[i]
|
||||
if gotNote := v.Attributes["note"]; gotNote != "production keys" {
|
||||
t.Errorf("expected virtual %d note %q, got %q", i, "production keys", gotNote)
|
||||
}
|
||||
if gotPriority := v.Attributes["priority"]; gotPriority != "5" {
|
||||
t.Errorf("expected virtual %d priority %q, got %q", i, "5", gotPriority)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,7 +177,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID)
|
||||
} else {
|
||||
cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) {
|
||||
pinnedAuthID = strings.TrimSpace(authID)
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" || h == nil || h.AuthManager == nil {
|
||||
return
|
||||
}
|
||||
selectedAuth, ok := h.AuthManager.GetByID(authID)
|
||||
if !ok || selectedAuth == nil {
|
||||
return
|
||||
}
|
||||
if websocketUpstreamSupportsIncrementalInput(selectedAuth.Attributes, selectedAuth.Metadata) {
|
||||
pinnedAuthID = authID
|
||||
}
|
||||
})
|
||||
}
|
||||
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -26,6 +27,78 @@ type websocketCaptureExecutor struct {
|
||||
payloads [][]byte
|
||||
}
|
||||
|
||||
type orderedWebsocketSelector struct {
|
||||
mu sync.Mutex
|
||||
order []string
|
||||
cursor int
|
||||
}
|
||||
|
||||
func (s *orderedWebsocketSelector) Pick(_ context.Context, _ string, _ string, _ coreexecutor.Options, auths []*coreauth.Auth) (*coreauth.Auth, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if len(auths) == 0 {
|
||||
return nil, errors.New("no auth available")
|
||||
}
|
||||
for len(s.order) > 0 && s.cursor < len(s.order) {
|
||||
authID := strings.TrimSpace(s.order[s.cursor])
|
||||
s.cursor++
|
||||
for _, auth := range auths {
|
||||
if auth != nil && auth.ID == authID {
|
||||
return auth, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, auth := range auths {
|
||||
if auth != nil {
|
||||
return auth, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("no auth available")
|
||||
}
|
||||
|
||||
type websocketAuthCaptureExecutor struct {
|
||||
mu sync.Mutex
|
||||
authIDs []string
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" }
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, _ coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
e.mu.Lock()
|
||||
if auth != nil {
|
||||
e.authIDs = append(e.authIDs, auth.ID)
|
||||
}
|
||||
e.mu.Unlock()
|
||||
|
||||
chunks := make(chan coreexecutor.StreamChunk, 1)
|
||||
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)}
|
||||
close(chunks)
|
||||
return &coreexecutor.StreamResult{Chunks: chunks}, nil
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) AuthIDs() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return append([]string(nil), e.authIDs...)
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
|
||||
|
||||
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
@@ -519,3 +592,73 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
|
||||
t.Fatalf("unexpected forwarded input: %s", forwarded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
selector := &orderedWebsocketSelector{order: []string{"auth-sse", "auth-ws"}}
|
||||
executor := &websocketAuthCaptureExecutor{}
|
||||
manager := coreauth.NewManager(nil, selector, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
authSSE := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
|
||||
if _, err := manager.Register(context.Background(), authSSE); err != nil {
|
||||
t.Fatalf("Register SSE auth: %v", err)
|
||||
}
|
||||
authWS := &coreauth.Auth{
|
||||
ID: "auth-ws",
|
||||
Provider: executor.Identifier(),
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{"websockets": "true"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), authWS); err != nil {
|
||||
t.Fatalf("Register websocket auth: %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(authSSE.ID, authSSE.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
registry.GetGlobalRegistry().RegisterClient(authWS.ID, authWS.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(authSSE.ID)
|
||||
registry.GetGlobalRegistry().UnregisterClient(authWS.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
t.Fatalf("close websocket: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
requests := []string{
|
||||
`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`,
|
||||
`{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`,
|
||||
}
|
||||
for i := range requests {
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil {
|
||||
t.Fatalf("write websocket message %d: %v", i+1, errWrite)
|
||||
}
|
||||
_, payload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read websocket message %d: %v", i+1, errReadMessage)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
|
||||
t.Fatalf("message %d payload type = %s, want %s", i+1, got, wsEventTypeCompleted)
|
||||
}
|
||||
}
|
||||
|
||||
if got := executor.AuthIDs(); len(got) != 2 || got[0] != "auth-sse" || got[1] != "auth-ws" {
|
||||
t.Fatalf("selected auth IDs = %v, want [auth-sse auth-ws]", got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user