mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-07 22:33:30 +00:00
Merge branch 'router-for-me:main' into main
This commit is contained in:
@@ -212,6 +212,33 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
} else {
|
} else {
|
||||||
log.Warnf("Unknown file name extension '%s' in user message, skip", ext)
|
log.Warnf("Unknown file name extension '%s' in user message, skip", ext)
|
||||||
}
|
}
|
||||||
|
case "input_audio":
|
||||||
|
audioData := item.Get("input_audio.data").String()
|
||||||
|
audioFormat := item.Get("input_audio.format").String()
|
||||||
|
if audioData != "" {
|
||||||
|
audioMimeMap := map[string]string{
|
||||||
|
"mp3": "audio/mpeg",
|
||||||
|
"wav": "audio/wav",
|
||||||
|
"ogg": "audio/ogg",
|
||||||
|
"flac": "audio/flac",
|
||||||
|
"aac": "audio/aac",
|
||||||
|
"webm": "audio/webm",
|
||||||
|
"pcm16": "audio/pcm",
|
||||||
|
"g711_ulaw": "audio/basic",
|
||||||
|
"g711_alaw": "audio/basic",
|
||||||
|
}
|
||||||
|
mimeType := "audio/wav"
|
||||||
|
if audioFormat != "" {
|
||||||
|
if mapped, ok := audioMimeMap[audioFormat]; ok {
|
||||||
|
mimeType = mapped
|
||||||
|
} else {
|
||||||
|
mimeType = "audio/" + audioFormat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
|
||||||
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", audioData)
|
||||||
|
p++
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -203,46 +203,9 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
|||||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
} else if contentResult.Exists() && contentResult.IsArray() {
|
} else if contentResult.Exists() && contentResult.IsArray() {
|
||||||
contentResult.ForEach(func(_, part gjson.Result) bool {
|
contentResult.ForEach(func(_, part gjson.Result) bool {
|
||||||
partType := part.Get("type").String()
|
claudePart := convertOpenAIContentPartToClaudePart(part)
|
||||||
|
if claudePart != "" {
|
||||||
switch partType {
|
msg, _ = sjson.SetRaw(msg, "content.-1", claudePart)
|
||||||
case "text":
|
|
||||||
textPart := `{"type":"text","text":""}`
|
|
||||||
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
|
|
||||||
msg, _ = sjson.SetRaw(msg, "content.-1", textPart)
|
|
||||||
|
|
||||||
case "image_url":
|
|
||||||
// Convert OpenAI image format to Claude Code format
|
|
||||||
imageURL := part.Get("image_url.url").String()
|
|
||||||
if strings.HasPrefix(imageURL, "data:") {
|
|
||||||
// Extract base64 data and media type from data URL
|
|
||||||
parts := strings.Split(imageURL, ",")
|
|
||||||
if len(parts) == 2 {
|
|
||||||
mediaTypePart := strings.Split(parts[0], ";")[0]
|
|
||||||
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
|
|
||||||
data := parts[1]
|
|
||||||
|
|
||||||
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
|
|
||||||
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
|
|
||||||
imagePart, _ = sjson.Set(imagePart, "source.data", data)
|
|
||||||
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "file":
|
|
||||||
fileData := part.Get("file.file_data").String()
|
|
||||||
if strings.HasPrefix(fileData, "data:") {
|
|
||||||
semicolonIdx := strings.Index(fileData, ";")
|
|
||||||
commaIdx := strings.Index(fileData, ",")
|
|
||||||
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
|
||||||
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
|
||||||
data := fileData[commaIdx+1:]
|
|
||||||
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
|
||||||
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
|
||||||
docPart, _ = sjson.Set(docPart, "source.data", data)
|
|
||||||
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
@@ -291,11 +254,16 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
|||||||
case "tool":
|
case "tool":
|
||||||
// Handle tool result messages conversion
|
// Handle tool result messages conversion
|
||||||
toolCallID := message.Get("tool_call_id").String()
|
toolCallID := message.Get("tool_call_id").String()
|
||||||
content := message.Get("content").String()
|
toolContentResult := message.Get("content")
|
||||||
|
|
||||||
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
|
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
|
||||||
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
|
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
|
||||||
msg, _ = sjson.Set(msg, "content.0.content", content)
|
toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult)
|
||||||
|
if toolResultContentRaw {
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.0.content", toolResultContent)
|
||||||
|
} else {
|
||||||
|
msg, _ = sjson.Set(msg, "content.0.content", toolResultContent)
|
||||||
|
}
|
||||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||||
messageIndex++
|
messageIndex++
|
||||||
}
|
}
|
||||||
@@ -358,3 +326,110 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
|||||||
|
|
||||||
return []byte(out)
|
return []byte(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
|
||||||
|
switch part.Get("type").String() {
|
||||||
|
case "text":
|
||||||
|
textPart := `{"type":"text","text":""}`
|
||||||
|
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
|
||||||
|
return textPart
|
||||||
|
|
||||||
|
case "image_url":
|
||||||
|
return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String())
|
||||||
|
|
||||||
|
case "file":
|
||||||
|
fileData := part.Get("file.file_data").String()
|
||||||
|
if strings.HasPrefix(fileData, "data:") {
|
||||||
|
semicolonIdx := strings.Index(fileData, ";")
|
||||||
|
commaIdx := strings.Index(fileData, ",")
|
||||||
|
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||||
|
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||||
|
data := fileData[commaIdx+1:]
|
||||||
|
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||||
|
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||||
|
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||||
|
return docPart
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertOpenAIImageURLToClaudePart(imageURL string) string {
|
||||||
|
if imageURL == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(imageURL, "data:") {
|
||||||
|
parts := strings.SplitN(imageURL, ",", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
mediaTypePart := strings.SplitN(parts[0], ";", 2)[0]
|
||||||
|
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
|
||||||
|
if mediaType == "" {
|
||||||
|
mediaType = "application/octet-stream"
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
|
||||||
|
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
|
||||||
|
imagePart, _ = sjson.Set(imagePart, "source.data", parts[1])
|
||||||
|
return imagePart
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePart := `{"type":"image","source":{"type":"url","url":""}}`
|
||||||
|
imagePart, _ = sjson.Set(imagePart, "source.url", imageURL)
|
||||||
|
return imagePart
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
|
||||||
|
if !content.Exists() {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
return content.String(), false
|
||||||
|
}
|
||||||
|
|
||||||
|
if content.IsArray() {
|
||||||
|
claudeContent := "[]"
|
||||||
|
partCount := 0
|
||||||
|
|
||||||
|
content.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
if part.Type == gjson.String {
|
||||||
|
textPart := `{"type":"text","text":""}`
|
||||||
|
textPart, _ = sjson.Set(textPart, "text", part.String())
|
||||||
|
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", textPart)
|
||||||
|
partCount++
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
claudePart := convertOpenAIContentPartToClaudePart(part)
|
||||||
|
if claudePart != "" {
|
||||||
|
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
|
||||||
|
partCount++
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if partCount > 0 || len(content.Array()) == 0 {
|
||||||
|
return claudeContent, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return content.Raw, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if content.IsObject() {
|
||||||
|
claudePart := convertOpenAIContentPartToClaudePart(content)
|
||||||
|
if claudePart != "" {
|
||||||
|
claudeContent := "[]"
|
||||||
|
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
|
||||||
|
return claudeContent, true
|
||||||
|
}
|
||||||
|
return content.Raw, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return content.Raw, false
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,137 @@
|
|||||||
|
package chat_completions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvertOpenAIRequestToClaude_ToolResultTextAndBase64Image(t *testing.T) {
|
||||||
|
inputJSON := `{
|
||||||
|
"model": "gpt-4.1",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "do_work",
|
||||||
|
"arguments": "{\"a\":1}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_1",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "tool ok"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg=="
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
|
||||||
|
resultJSON := gjson.ParseBytes(result)
|
||||||
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
|
if len(messages) != 2 {
|
||||||
|
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResult := messages[1].Get("content.0")
|
||||||
|
if got := toolResult.Get("type").String(); got != "tool_result" {
|
||||||
|
t.Fatalf("Expected content[0].type %q, got %q", "tool_result", got)
|
||||||
|
}
|
||||||
|
if got := toolResult.Get("tool_use_id").String(); got != "call_1" {
|
||||||
|
t.Fatalf("Expected tool_use_id %q, got %q", "call_1", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolContent := toolResult.Get("content")
|
||||||
|
if !toolContent.IsArray() {
|
||||||
|
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("0.type").String(); got != "text" {
|
||||||
|
t.Fatalf("Expected first tool_result part type %q, got %q", "text", got)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("0.text").String(); got != "tool ok" {
|
||||||
|
t.Fatalf("Expected first tool_result part text %q, got %q", "tool ok", got)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("1.type").String(); got != "image" {
|
||||||
|
t.Fatalf("Expected second tool_result part type %q, got %q", "image", got)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("1.source.type").String(); got != "base64" {
|
||||||
|
t.Fatalf("Expected image source type %q, got %q", "base64", got)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("1.source.media_type").String(); got != "image/png" {
|
||||||
|
t.Fatalf("Expected image media type %q, got %q", "image/png", got)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("1.source.data").String(); got != "iVBORw0KGgoAAAANSUhEUg==" {
|
||||||
|
t.Fatalf("Unexpected base64 image data: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertOpenAIRequestToClaude_ToolResultURLImageOnly(t *testing.T) {
|
||||||
|
inputJSON := `{
|
||||||
|
"model": "gpt-4.1",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "do_work",
|
||||||
|
"arguments": "{\"a\":1}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_1",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://example.com/tool.png"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
|
||||||
|
resultJSON := gjson.ParseBytes(result)
|
||||||
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
|
if len(messages) != 2 {
|
||||||
|
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolContent := messages[1].Get("content.0.content")
|
||||||
|
if !toolContent.IsArray() {
|
||||||
|
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("0.type").String(); got != "image" {
|
||||||
|
t.Fatalf("Expected tool_result part type %q, got %q", "image", got)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("0.source.type").String(); got != "url" {
|
||||||
|
t.Fatalf("Expected image source type %q, got %q", "url", got)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("0.source.url").String(); got != "https://example.com/tool.png" {
|
||||||
|
t.Fatalf("Unexpected image URL: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -237,6 +237,33 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
partJSON, _ = sjson.Set(partJSON, "inline_data.data", data)
|
partJSON, _ = sjson.Set(partJSON, "inline_data.data", data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case "input_audio":
|
||||||
|
audioData := contentItem.Get("data").String()
|
||||||
|
audioFormat := contentItem.Get("format").String()
|
||||||
|
if audioData != "" {
|
||||||
|
audioMimeMap := map[string]string{
|
||||||
|
"mp3": "audio/mpeg",
|
||||||
|
"wav": "audio/wav",
|
||||||
|
"ogg": "audio/ogg",
|
||||||
|
"flac": "audio/flac",
|
||||||
|
"aac": "audio/aac",
|
||||||
|
"webm": "audio/webm",
|
||||||
|
"pcm16": "audio/pcm",
|
||||||
|
"g711_ulaw": "audio/basic",
|
||||||
|
"g711_alaw": "audio/basic",
|
||||||
|
}
|
||||||
|
mimeType := "audio/wav"
|
||||||
|
if audioFormat != "" {
|
||||||
|
if mapped, ok := audioMimeMap[audioFormat]; ok {
|
||||||
|
mimeType = mapped
|
||||||
|
} else {
|
||||||
|
mimeType = "audio/" + audioFormat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
partJSON = `{"inline_data":{"mime_type":"","data":""}}`
|
||||||
|
partJSON, _ = sjson.Set(partJSON, "inline_data.mime_type", mimeType)
|
||||||
|
partJSON, _ = sjson.Set(partJSON, "inline_data.data", audioData)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if partJSON != "" {
|
if partJSON != "" {
|
||||||
|
|||||||
@@ -183,7 +183,12 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
|||||||
// Collect tool_result to emit after the main message (ensures tool results follow tool_calls)
|
// Collect tool_result to emit after the main message (ensures tool results follow tool_calls)
|
||||||
toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}`
|
toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}`
|
||||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String())
|
toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String())
|
||||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", convertClaudeToolResultContentToString(part.Get("content")))
|
toolResultContent, toolResultContentRaw := convertClaudeToolResultContent(part.Get("content"))
|
||||||
|
if toolResultContentRaw {
|
||||||
|
toolResultJSON, _ = sjson.SetRaw(toolResultJSON, "content", toolResultContent)
|
||||||
|
} else {
|
||||||
|
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", toolResultContent)
|
||||||
|
}
|
||||||
toolResults = append(toolResults, toolResultJSON)
|
toolResults = append(toolResults, toolResultJSON)
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
@@ -374,21 +379,41 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertClaudeToolResultContentToString(content gjson.Result) string {
|
func convertClaudeToolResultContent(content gjson.Result) (string, bool) {
|
||||||
if !content.Exists() {
|
if !content.Exists() {
|
||||||
return ""
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
if content.Type == gjson.String {
|
if content.Type == gjson.String {
|
||||||
return content.String()
|
return content.String(), false
|
||||||
}
|
}
|
||||||
|
|
||||||
if content.IsArray() {
|
if content.IsArray() {
|
||||||
var parts []string
|
var parts []string
|
||||||
|
contentJSON := "[]"
|
||||||
|
hasImagePart := false
|
||||||
content.ForEach(func(_, item gjson.Result) bool {
|
content.ForEach(func(_, item gjson.Result) bool {
|
||||||
switch {
|
switch {
|
||||||
case item.Type == gjson.String:
|
case item.Type == gjson.String:
|
||||||
parts = append(parts, item.String())
|
text := item.String()
|
||||||
|
parts = append(parts, text)
|
||||||
|
textContent := `{"type":"text","text":""}`
|
||||||
|
textContent, _ = sjson.Set(textContent, "text", text)
|
||||||
|
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
|
||||||
|
case item.IsObject() && item.Get("type").String() == "text":
|
||||||
|
text := item.Get("text").String()
|
||||||
|
parts = append(parts, text)
|
||||||
|
textContent := `{"type":"text","text":""}`
|
||||||
|
textContent, _ = sjson.Set(textContent, "text", text)
|
||||||
|
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
|
||||||
|
case item.IsObject() && item.Get("type").String() == "image":
|
||||||
|
contentItem, ok := convertClaudeContentPart(item)
|
||||||
|
if ok {
|
||||||
|
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
|
||||||
|
hasImagePart = true
|
||||||
|
} else {
|
||||||
|
parts = append(parts, item.Raw)
|
||||||
|
}
|
||||||
case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String:
|
case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String:
|
||||||
parts = append(parts, item.Get("text").String())
|
parts = append(parts, item.Get("text").String())
|
||||||
default:
|
default:
|
||||||
@@ -397,19 +422,31 @@ func convertClaudeToolResultContentToString(content gjson.Result) string {
|
|||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if hasImagePart {
|
||||||
|
return contentJSON, true
|
||||||
|
}
|
||||||
|
|
||||||
joined := strings.Join(parts, "\n\n")
|
joined := strings.Join(parts, "\n\n")
|
||||||
if strings.TrimSpace(joined) != "" {
|
if strings.TrimSpace(joined) != "" {
|
||||||
return joined
|
return joined, false
|
||||||
}
|
}
|
||||||
return content.Raw
|
return content.Raw, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if content.IsObject() {
|
if content.IsObject() {
|
||||||
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
|
if content.Get("type").String() == "image" {
|
||||||
return text.String()
|
contentItem, ok := convertClaudeContentPart(content)
|
||||||
|
if ok {
|
||||||
|
contentJSON := "[]"
|
||||||
|
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
|
||||||
|
return contentJSON, true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return content.Raw
|
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
|
||||||
|
return text.String(), false
|
||||||
|
}
|
||||||
|
return content.Raw, false
|
||||||
}
|
}
|
||||||
|
|
||||||
return content.Raw
|
return content.Raw, false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -488,6 +488,114 @@ func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToOpenAI_ToolResultTextAndImageContent(t *testing.T) {
|
||||||
|
inputJSON := `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "call_1",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "tool ok"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/png",
|
||||||
|
"data": "iVBORw0KGgoAAAANSUhEUg=="
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||||
|
resultJSON := gjson.ParseBytes(result)
|
||||||
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
|
if len(messages) != 2 {
|
||||||
|
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolContent := messages[1].Get("content")
|
||||||
|
if !toolContent.IsArray() {
|
||||||
|
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("0.type").String(); got != "text" {
|
||||||
|
t.Fatalf("Expected first tool content type %q, got %q", "text", got)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("0.text").String(); got != "tool ok" {
|
||||||
|
t.Fatalf("Expected first tool content text %q, got %q", "tool ok", got)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("1.type").String(); got != "image_url" {
|
||||||
|
t.Fatalf("Expected second tool content type %q, got %q", "image_url", got)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("1.image_url.url").String(); got != "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg==" {
|
||||||
|
t.Fatalf("Unexpected image_url: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToOpenAI_ToolResultURLImageOnly(t *testing.T) {
|
||||||
|
inputJSON := `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "call_1",
|
||||||
|
"content": {
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "url",
|
||||||
|
"url": "https://example.com/tool.png"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||||
|
resultJSON := gjson.ParseBytes(result)
|
||||||
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
|
if len(messages) != 2 {
|
||||||
|
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolContent := messages[1].Get("content")
|
||||||
|
if !toolContent.IsArray() {
|
||||||
|
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("0.type").String(); got != "image_url" {
|
||||||
|
t.Fatalf("Expected tool content type %q, got %q", "image_url", got)
|
||||||
|
}
|
||||||
|
if got := toolContent.Get("0.image_url.url").String(); got != "https://example.com/tool.png" {
|
||||||
|
t.Fatalf("Unexpected image_url: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) {
|
func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) {
|
||||||
inputJSON := `{
|
inputJSON := `{
|
||||||
"model": "claude-3-opus",
|
"model": "claude-3-opus",
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@@ -75,6 +76,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
|||||||
|
|
||||||
w.lastAuthHashes = make(map[string]string)
|
w.lastAuthHashes = make(map[string]string)
|
||||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||||
|
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
|
||||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||||
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
||||||
} else if resolvedAuthDir != "" {
|
} else if resolvedAuthDir != "" {
|
||||||
@@ -92,6 +94,17 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
|||||||
if errParse := json.Unmarshal(data, &auth); errParse == nil {
|
if errParse := json.Unmarshal(data, &auth); errParse == nil {
|
||||||
w.lastAuthContents[normalizedPath] = &auth
|
w.lastAuthContents[normalizedPath] = &auth
|
||||||
}
|
}
|
||||||
|
ctx := &synthesizer.SynthesisContext{
|
||||||
|
Config: cfg,
|
||||||
|
AuthDir: resolvedAuthDir,
|
||||||
|
Now: time.Now(),
|
||||||
|
IDGenerator: synthesizer.NewStableIDGenerator(),
|
||||||
|
}
|
||||||
|
if generated := synthesizer.SynthesizeAuthFile(ctx, path, data); len(generated) > 0 {
|
||||||
|
if pathAuths := authSliceToMap(generated); len(pathAuths) > 0 {
|
||||||
|
w.fileAuthsByPath[normalizedPath] = pathAuths
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -143,13 +156,14 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.clientsMutex.Lock()
|
w.clientsMutex.Lock()
|
||||||
|
if w.config == nil {
|
||||||
cfg := w.config
|
|
||||||
if cfg == nil {
|
|
||||||
log.Error("config is nil, cannot add or update client")
|
log.Error("config is nil, cannot add or update client")
|
||||||
w.clientsMutex.Unlock()
|
w.clientsMutex.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if w.fileAuthsByPath == nil {
|
||||||
|
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
|
||||||
|
}
|
||||||
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
|
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
|
||||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
|
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
|
||||||
w.clientsMutex.Unlock()
|
w.clientsMutex.Unlock()
|
||||||
@@ -177,34 +191,86 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
|||||||
}
|
}
|
||||||
w.lastAuthContents[normalized] = &newAuth
|
w.lastAuthContents[normalized] = &newAuth
|
||||||
|
|
||||||
w.clientsMutex.Unlock() // Unlock before the callback
|
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
|
||||||
|
for id, a := range w.fileAuthsByPath[normalized] {
|
||||||
w.refreshAuthState(false)
|
oldByID[id] = a
|
||||||
|
|
||||||
if w.reloadCallback != nil {
|
|
||||||
log.Debugf("triggering server update callback after add/update")
|
|
||||||
w.triggerServerUpdate(cfg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build synthesized auth entries for this single file only.
|
||||||
|
sctx := &synthesizer.SynthesisContext{
|
||||||
|
Config: w.config,
|
||||||
|
AuthDir: w.authDir,
|
||||||
|
Now: time.Now(),
|
||||||
|
IDGenerator: synthesizer.NewStableIDGenerator(),
|
||||||
|
}
|
||||||
|
generated := synthesizer.SynthesizeAuthFile(sctx, path, data)
|
||||||
|
newByID := authSliceToMap(generated)
|
||||||
|
if len(newByID) > 0 {
|
||||||
|
w.fileAuthsByPath[normalized] = newByID
|
||||||
|
} else {
|
||||||
|
delete(w.fileAuthsByPath, normalized)
|
||||||
|
}
|
||||||
|
updates := w.computePerPathUpdatesLocked(oldByID, newByID)
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
|
||||||
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
|
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
|
||||||
|
w.dispatchAuthUpdates(updates)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Watcher) removeClient(path string) {
|
func (w *Watcher) removeClient(path string) {
|
||||||
normalized := w.normalizeAuthPath(path)
|
normalized := w.normalizeAuthPath(path)
|
||||||
w.clientsMutex.Lock()
|
w.clientsMutex.Lock()
|
||||||
|
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
|
||||||
cfg := w.config
|
for id, a := range w.fileAuthsByPath[normalized] {
|
||||||
|
oldByID[id] = a
|
||||||
|
}
|
||||||
delete(w.lastAuthHashes, normalized)
|
delete(w.lastAuthHashes, normalized)
|
||||||
delete(w.lastAuthContents, normalized)
|
delete(w.lastAuthContents, normalized)
|
||||||
|
delete(w.fileAuthsByPath, normalized)
|
||||||
|
|
||||||
w.clientsMutex.Unlock() // Release the lock before the callback
|
updates := w.computePerPathUpdatesLocked(oldByID, map[string]*coreauth.Auth{})
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
|
||||||
w.refreshAuthState(false)
|
|
||||||
|
|
||||||
if w.reloadCallback != nil {
|
|
||||||
log.Debugf("triggering server update callback after removal")
|
|
||||||
w.triggerServerUpdate(cfg)
|
|
||||||
}
|
|
||||||
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
|
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
|
||||||
|
w.dispatchAuthUpdates(updates)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) computePerPathUpdatesLocked(oldByID, newByID map[string]*coreauth.Auth) []AuthUpdate {
|
||||||
|
if w.currentAuths == nil {
|
||||||
|
w.currentAuths = make(map[string]*coreauth.Auth)
|
||||||
|
}
|
||||||
|
updates := make([]AuthUpdate, 0, len(oldByID)+len(newByID))
|
||||||
|
for id, newAuth := range newByID {
|
||||||
|
existing, ok := w.currentAuths[id]
|
||||||
|
if !ok {
|
||||||
|
w.currentAuths[id] = newAuth.Clone()
|
||||||
|
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: newAuth.Clone()})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !authEqual(existing, newAuth) {
|
||||||
|
w.currentAuths[id] = newAuth.Clone()
|
||||||
|
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: newAuth.Clone()})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for id := range oldByID {
|
||||||
|
if _, stillExists := newByID[id]; stillExists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(w.currentAuths, id)
|
||||||
|
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
|
||||||
|
}
|
||||||
|
return updates
|
||||||
|
}
|
||||||
|
|
||||||
|
func authSliceToMap(auths []*coreauth.Auth) map[string]*coreauth.Auth {
|
||||||
|
byID := make(map[string]*coreauth.Auth, len(auths))
|
||||||
|
for _, a := range auths {
|
||||||
|
if a == nil || strings.TrimSpace(a.ID) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
byID[a.ID] = a
|
||||||
|
}
|
||||||
|
return byID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Watcher) loadFileClients(cfg *config.Config) int {
|
func (w *Watcher) loadFileClients(cfg *config.Config) int {
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import (
|
|||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var snapshotCoreAuthsFunc = snapshotCoreAuths
|
||||||
|
|
||||||
func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) {
|
func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) {
|
||||||
w.clientsMutex.Lock()
|
w.clientsMutex.Lock()
|
||||||
defer w.clientsMutex.Unlock()
|
defer w.clientsMutex.Unlock()
|
||||||
@@ -76,7 +78,11 @@ func (w *Watcher) dispatchRuntimeAuthUpdate(update AuthUpdate) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *Watcher) refreshAuthState(force bool) {
|
func (w *Watcher) refreshAuthState(force bool) {
|
||||||
auths := w.SnapshotCoreAuths()
|
w.clientsMutex.RLock()
|
||||||
|
cfg := w.config
|
||||||
|
authDir := w.authDir
|
||||||
|
w.clientsMutex.RUnlock()
|
||||||
|
auths := snapshotCoreAuthsFunc(cfg, authDir)
|
||||||
w.clientsMutex.Lock()
|
w.clientsMutex.Lock()
|
||||||
if len(w.runtimeAuths) > 0 {
|
if len(w.runtimeAuths) > 0 {
|
||||||
for _, a := range w.runtimeAuths {
|
for _, a := range w.runtimeAuths {
|
||||||
|
|||||||
@@ -36,9 +36,6 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
now := ctx.Now
|
|
||||||
cfg := ctx.Config
|
|
||||||
|
|
||||||
for _, e := range entries {
|
for _, e := range entries {
|
||||||
if e.IsDir() {
|
if e.IsDir() {
|
||||||
continue
|
continue
|
||||||
@@ -52,99 +49,120 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
|
|||||||
if errRead != nil || len(data) == 0 {
|
if errRead != nil || len(data) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var metadata map[string]any
|
auths := synthesizeFileAuths(ctx, full, data)
|
||||||
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
|
if len(auths) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
t, _ := metadata["type"].(string)
|
out = append(out, auths...)
|
||||||
if t == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
provider := strings.ToLower(t)
|
|
||||||
if provider == "gemini" {
|
|
||||||
provider = "gemini-cli"
|
|
||||||
}
|
|
||||||
label := provider
|
|
||||||
if email, _ := metadata["email"].(string); email != "" {
|
|
||||||
label = email
|
|
||||||
}
|
|
||||||
// Use relative path under authDir as ID to stay consistent with the file-based token store
|
|
||||||
id := full
|
|
||||||
if rel, errRel := filepath.Rel(ctx.AuthDir, full); errRel == nil && rel != "" {
|
|
||||||
id = rel
|
|
||||||
}
|
|
||||||
// On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths.
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
id = strings.ToLower(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyURL := ""
|
|
||||||
if p, ok := metadata["proxy_url"].(string); ok {
|
|
||||||
proxyURL = p
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix := ""
|
|
||||||
if rawPrefix, ok := metadata["prefix"].(string); ok {
|
|
||||||
trimmed := strings.TrimSpace(rawPrefix)
|
|
||||||
trimmed = strings.Trim(trimmed, "/")
|
|
||||||
if trimmed != "" && !strings.Contains(trimmed, "/") {
|
|
||||||
prefix = trimmed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
disabled, _ := metadata["disabled"].(bool)
|
|
||||||
status := coreauth.StatusActive
|
|
||||||
if disabled {
|
|
||||||
status = coreauth.StatusDisabled
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read per-account excluded models from the OAuth JSON file
|
|
||||||
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
|
|
||||||
|
|
||||||
a := &coreauth.Auth{
|
|
||||||
ID: id,
|
|
||||||
Provider: provider,
|
|
||||||
Label: label,
|
|
||||||
Prefix: prefix,
|
|
||||||
Status: status,
|
|
||||||
Disabled: disabled,
|
|
||||||
Attributes: map[string]string{
|
|
||||||
"source": full,
|
|
||||||
"path": full,
|
|
||||||
},
|
|
||||||
ProxyURL: proxyURL,
|
|
||||||
Metadata: metadata,
|
|
||||||
CreatedAt: now,
|
|
||||||
UpdatedAt: now,
|
|
||||||
}
|
|
||||||
// Read priority from auth file
|
|
||||||
if rawPriority, ok := metadata["priority"]; ok {
|
|
||||||
switch v := rawPriority.(type) {
|
|
||||||
case float64:
|
|
||||||
a.Attributes["priority"] = strconv.Itoa(int(v))
|
|
||||||
case string:
|
|
||||||
priority := strings.TrimSpace(v)
|
|
||||||
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
|
|
||||||
a.Attributes["priority"] = priority
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
|
||||||
if provider == "gemini-cli" {
|
|
||||||
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
|
||||||
for _, v := range virtuals {
|
|
||||||
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
|
|
||||||
}
|
|
||||||
out = append(out, a)
|
|
||||||
out = append(out, virtuals...)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out = append(out, a)
|
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SynthesizeAuthFile generates Auth entries for one auth JSON file payload.
|
||||||
|
// It shares exactly the same mapping behavior as FileSynthesizer.Synthesize.
|
||||||
|
func SynthesizeAuthFile(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
|
||||||
|
return synthesizeFileAuths(ctx, fullPath, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
|
||||||
|
if ctx == nil || len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
now := ctx.Now
|
||||||
|
cfg := ctx.Config
|
||||||
|
var metadata map[string]any
|
||||||
|
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
t, _ := metadata["type"].(string)
|
||||||
|
if t == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
provider := strings.ToLower(t)
|
||||||
|
if provider == "gemini" {
|
||||||
|
provider = "gemini-cli"
|
||||||
|
}
|
||||||
|
label := provider
|
||||||
|
if email, _ := metadata["email"].(string); email != "" {
|
||||||
|
label = email
|
||||||
|
}
|
||||||
|
// Use relative path under authDir as ID to stay consistent with the file-based token store.
|
||||||
|
id := fullPath
|
||||||
|
if strings.TrimSpace(ctx.AuthDir) != "" {
|
||||||
|
if rel, errRel := filepath.Rel(ctx.AuthDir, fullPath); errRel == nil && rel != "" {
|
||||||
|
id = rel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
id = strings.ToLower(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL := ""
|
||||||
|
if p, ok := metadata["proxy_url"].(string); ok {
|
||||||
|
proxyURL = p
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := ""
|
||||||
|
if rawPrefix, ok := metadata["prefix"].(string); ok {
|
||||||
|
trimmed := strings.TrimSpace(rawPrefix)
|
||||||
|
trimmed = strings.Trim(trimmed, "/")
|
||||||
|
if trimmed != "" && !strings.Contains(trimmed, "/") {
|
||||||
|
prefix = trimmed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
disabled, _ := metadata["disabled"].(bool)
|
||||||
|
status := coreauth.StatusActive
|
||||||
|
if disabled {
|
||||||
|
status = coreauth.StatusDisabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read per-account excluded models from the OAuth JSON file.
|
||||||
|
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
|
||||||
|
|
||||||
|
a := &coreauth.Auth{
|
||||||
|
ID: id,
|
||||||
|
Provider: provider,
|
||||||
|
Label: label,
|
||||||
|
Prefix: prefix,
|
||||||
|
Status: status,
|
||||||
|
Disabled: disabled,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"source": fullPath,
|
||||||
|
"path": fullPath,
|
||||||
|
},
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Metadata: metadata,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
// Read priority from auth file.
|
||||||
|
if rawPriority, ok := metadata["priority"]; ok {
|
||||||
|
switch v := rawPriority.(type) {
|
||||||
|
case float64:
|
||||||
|
a.Attributes["priority"] = strconv.Itoa(int(v))
|
||||||
|
case string:
|
||||||
|
priority := strings.TrimSpace(v)
|
||||||
|
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
|
||||||
|
a.Attributes["priority"] = priority
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
||||||
|
if provider == "gemini-cli" {
|
||||||
|
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||||
|
for _, v := range virtuals {
|
||||||
|
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
|
||||||
|
}
|
||||||
|
out := make([]*coreauth.Auth, 0, 1+len(virtuals))
|
||||||
|
out = append(out, a)
|
||||||
|
out = append(out, virtuals...)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return []*coreauth.Auth{a}
|
||||||
|
}
|
||||||
|
|
||||||
// SynthesizeGeminiVirtualAuths creates virtual Auth entries for multi-project Gemini credentials.
|
// SynthesizeGeminiVirtualAuths creates virtual Auth entries for multi-project Gemini credentials.
|
||||||
// It disables the primary auth and creates one virtual auth per project.
|
// It disables the primary auth and creates one virtual auth per project.
|
||||||
func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]any, now time.Time) []*coreauth.Auth {
|
func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]any, now time.Time) []*coreauth.Auth {
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ type Watcher struct {
|
|||||||
watcher *fsnotify.Watcher
|
watcher *fsnotify.Watcher
|
||||||
lastAuthHashes map[string]string
|
lastAuthHashes map[string]string
|
||||||
lastAuthContents map[string]*coreauth.Auth
|
lastAuthContents map[string]*coreauth.Auth
|
||||||
|
fileAuthsByPath map[string]map[string]*coreauth.Auth
|
||||||
lastRemoveTimes map[string]time.Time
|
lastRemoveTimes map[string]time.Time
|
||||||
lastConfigHash string
|
lastConfigHash string
|
||||||
authQueue chan<- AuthUpdate
|
authQueue chan<- AuthUpdate
|
||||||
@@ -92,11 +93,12 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config))
|
|||||||
return nil, errNewWatcher
|
return nil, errNewWatcher
|
||||||
}
|
}
|
||||||
w := &Watcher{
|
w := &Watcher{
|
||||||
configPath: configPath,
|
configPath: configPath,
|
||||||
authDir: authDir,
|
authDir: authDir,
|
||||||
reloadCallback: reloadCallback,
|
reloadCallback: reloadCallback,
|
||||||
watcher: watcher,
|
watcher: watcher,
|
||||||
lastAuthHashes: make(map[string]string),
|
lastAuthHashes: make(map[string]string),
|
||||||
|
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
|
||||||
}
|
}
|
||||||
w.dispatchCond = sync.NewCond(&w.dispatchMu)
|
w.dispatchCond = sync.NewCond(&w.dispatchMu)
|
||||||
if store := sdkAuth.GetTokenStore(); store != nil {
|
if store := sdkAuth.GetTokenStore(); store != nil {
|
||||||
|
|||||||
@@ -406,8 +406,8 @@ func TestAddOrUpdateClientTriggersReloadAndHash(t *testing.T) {
|
|||||||
|
|
||||||
w.addOrUpdateClient(authFile)
|
w.addOrUpdateClient(authFile)
|
||||||
|
|
||||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
if got := atomic.LoadInt32(&reloads); got != 0 {
|
||||||
t.Fatalf("expected reload callback once, got %d", got)
|
t.Fatalf("expected no reload callback for auth update, got %d", got)
|
||||||
}
|
}
|
||||||
// Use normalizeAuthPath to match how addOrUpdateClient stores the key
|
// Use normalizeAuthPath to match how addOrUpdateClient stores the key
|
||||||
normalized := w.normalizeAuthPath(authFile)
|
normalized := w.normalizeAuthPath(authFile)
|
||||||
@@ -436,8 +436,110 @@ func TestRemoveClientRemovesHash(t *testing.T) {
|
|||||||
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
||||||
t.Fatal("expected hash to be removed after deletion")
|
t.Fatal("expected hash to be removed after deletion")
|
||||||
}
|
}
|
||||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
if got := atomic.LoadInt32(&reloads); got != 0 {
|
||||||
t.Fatalf("expected reload callback once, got %d", got)
|
t.Fatalf("expected no reload callback for auth removal, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFileEventsDoNotInvokeSnapshotCoreAuths(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
authFile := filepath.Join(tmpDir, "sample.json")
|
||||||
|
if err := os.WriteFile(authFile, []byte(`{"type":"codex","email":"u@example.com"}`), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to create auth file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
origSnapshot := snapshotCoreAuthsFunc
|
||||||
|
var snapshotCalls int32
|
||||||
|
snapshotCoreAuthsFunc = func(cfg *config.Config, authDir string) []*coreauth.Auth {
|
||||||
|
atomic.AddInt32(&snapshotCalls, 1)
|
||||||
|
return origSnapshot(cfg, authDir)
|
||||||
|
}
|
||||||
|
defer func() { snapshotCoreAuthsFunc = origSnapshot }()
|
||||||
|
|
||||||
|
w := &Watcher{
|
||||||
|
authDir: tmpDir,
|
||||||
|
lastAuthHashes: make(map[string]string),
|
||||||
|
lastAuthContents: make(map[string]*coreauth.Auth),
|
||||||
|
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
|
||||||
|
}
|
||||||
|
w.SetConfig(&config.Config{AuthDir: tmpDir})
|
||||||
|
|
||||||
|
w.addOrUpdateClient(authFile)
|
||||||
|
w.removeClient(authFile)
|
||||||
|
|
||||||
|
if got := atomic.LoadInt32(&snapshotCalls); got != 0 {
|
||||||
|
t.Fatalf("expected auth file events to avoid full snapshot, got %d calls", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthSliceToMap(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
valid1 := &coreauth.Auth{ID: "a"}
|
||||||
|
valid2 := &coreauth.Auth{ID: "b"}
|
||||||
|
dupOld := &coreauth.Auth{ID: "dup", Label: "old"}
|
||||||
|
dupNew := &coreauth.Auth{ID: "dup", Label: "new"}
|
||||||
|
empty := &coreauth.Auth{ID: " "}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in []*coreauth.Auth
|
||||||
|
want map[string]*coreauth.Auth
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil input",
|
||||||
|
in: nil,
|
||||||
|
want: map[string]*coreauth.Auth{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input",
|
||||||
|
in: []*coreauth.Auth{},
|
||||||
|
want: map[string]*coreauth.Auth{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filters invalid auths",
|
||||||
|
in: []*coreauth.Auth{nil, empty},
|
||||||
|
want: map[string]*coreauth.Auth{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "keeps valid auths",
|
||||||
|
in: []*coreauth.Auth{valid1, nil, valid2},
|
||||||
|
want: map[string]*coreauth.Auth{"a": valid1, "b": valid2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "last duplicate wins",
|
||||||
|
in: []*coreauth.Auth{dupOld, dupNew},
|
||||||
|
want: map[string]*coreauth.Auth{"dup": dupNew},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
got := authSliceToMap(tc.in)
|
||||||
|
if len(tc.want) == 0 {
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("expected empty map, got nil")
|
||||||
|
}
|
||||||
|
if len(got) != 0 {
|
||||||
|
t.Fatalf("expected empty map, got %#v", got)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(got) != len(tc.want) {
|
||||||
|
t.Fatalf("unexpected map length: got %d, want %d", len(got), len(tc.want))
|
||||||
|
}
|
||||||
|
for id, wantAuth := range tc.want {
|
||||||
|
gotAuth, ok := got[id]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("missing id %q in result map", id)
|
||||||
|
}
|
||||||
|
if !authEqual(gotAuth, wantAuth) {
|
||||||
|
t.Fatalf("unexpected auth for id %q: got %#v, want %#v", id, gotAuth, wantAuth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -695,8 +797,8 @@ func TestHandleEventRemovesAuthFile(t *testing.T) {
|
|||||||
|
|
||||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
|
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
|
||||||
|
|
||||||
if atomic.LoadInt32(&reloads) != 1 {
|
if atomic.LoadInt32(&reloads) != 0 {
|
||||||
t.Fatalf("expected reload callback once, got %d", reloads)
|
t.Fatalf("expected no reload callback for auth removal, got %d", reloads)
|
||||||
}
|
}
|
||||||
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
||||||
t.Fatal("expected hash entry to be removed")
|
t.Fatal("expected hash entry to be removed")
|
||||||
@@ -893,8 +995,8 @@ func TestHandleEventAuthWriteTriggersUpdate(t *testing.T) {
|
|||||||
w.SetConfig(&config.Config{AuthDir: authDir})
|
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||||
|
|
||||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Write})
|
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Write})
|
||||||
if atomic.LoadInt32(&reloads) != 1 {
|
if atomic.LoadInt32(&reloads) != 0 {
|
||||||
t.Fatalf("expected auth write to trigger reload callback, got %d", reloads)
|
t.Fatalf("expected auth write to avoid global reload callback, got %d", reloads)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -990,8 +1092,8 @@ func TestHandleEventAtomicReplaceChangedTriggersUpdate(t *testing.T) {
|
|||||||
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(oldSum[:])
|
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(oldSum[:])
|
||||||
|
|
||||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename})
|
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename})
|
||||||
if atomic.LoadInt32(&reloads) != 1 {
|
if atomic.LoadInt32(&reloads) != 0 {
|
||||||
t.Fatalf("expected changed atomic replace to trigger update, got %d", reloads)
|
t.Fatalf("expected changed atomic replace to avoid global reload, got %d", reloads)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1045,8 +1147,8 @@ func TestHandleEventRemoveKnownFileDeletes(t *testing.T) {
|
|||||||
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
|
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
|
||||||
|
|
||||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
|
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
|
||||||
if atomic.LoadInt32(&reloads) != 1 {
|
if atomic.LoadInt32(&reloads) != 0 {
|
||||||
t.Fatalf("expected known remove to trigger reload, got %d", reloads)
|
t.Fatalf("expected known remove to avoid global reload, got %d", reloads)
|
||||||
}
|
}
|
||||||
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
||||||
t.Fatal("expected known auth hash to be deleted")
|
t.Fatal("expected known auth hash to be deleted")
|
||||||
|
|||||||
@@ -14,7 +14,11 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -100,11 +104,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
// )
|
// )
|
||||||
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
||||||
|
|
||||||
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
|
allowIncrementalInputWithPreviousResponseID := false
|
||||||
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
||||||
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||||
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||||
|
if requestModelName == "" {
|
||||||
|
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||||
|
}
|
||||||
|
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestJSON []byte
|
var requestJSON []byte
|
||||||
@@ -139,6 +149,22 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
|
||||||
|
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
|
||||||
|
requestJSON = updated
|
||||||
|
}
|
||||||
|
if updated, errDelete := sjson.DeleteBytes(updatedLastRequest, "generate"); errDelete == nil {
|
||||||
|
updatedLastRequest = updated
|
||||||
|
}
|
||||||
|
lastRequest = updatedLastRequest
|
||||||
|
lastResponseOutput = []byte("[]")
|
||||||
|
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil {
|
||||||
|
wsTerminateErr = errWrite
|
||||||
|
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
lastRequest = updatedLastRequest
|
lastRequest = updatedLastRequest
|
||||||
|
|
||||||
modelName := gjson.GetBytes(requestJSON, "model").String()
|
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||||
@@ -339,6 +365,192 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
|
||||||
|
if h == nil || h.AuthManager == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedModelName := modelName
|
||||||
|
initialSuffix := thinking.ParseSuffix(modelName)
|
||||||
|
if initialSuffix.ModelName == "auto" {
|
||||||
|
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
|
||||||
|
if initialSuffix.HasSuffix {
|
||||||
|
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
|
||||||
|
} else {
|
||||||
|
resolvedModelName = resolvedBase
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
resolvedModelName = util.ResolveAutoModel(modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed := thinking.ParseSuffix(resolvedModelName)
|
||||||
|
baseModel := strings.TrimSpace(parsed.ModelName)
|
||||||
|
providers := util.GetProviderName(baseModel)
|
||||||
|
if len(providers) == 0 && baseModel != resolvedModelName {
|
||||||
|
providers = util.GetProviderName(resolvedModelName)
|
||||||
|
}
|
||||||
|
if len(providers) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
providerSet := make(map[string]struct{}, len(providers))
|
||||||
|
for i := 0; i < len(providers); i++ {
|
||||||
|
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
|
||||||
|
if providerKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
providerSet[providerKey] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(providerSet) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
modelKey := baseModel
|
||||||
|
if modelKey == "" {
|
||||||
|
modelKey = strings.TrimSpace(resolvedModelName)
|
||||||
|
}
|
||||||
|
registryRef := registry.GetGlobalRegistry()
|
||||||
|
now := time.Now()
|
||||||
|
auths := h.AuthManager.List()
|
||||||
|
for i := 0; i < len(auths); i++ {
|
||||||
|
auth := auths[i]
|
||||||
|
if auth == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
|
||||||
|
if _, ok := providerSet[providerKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
|
||||||
|
if auth == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if auth.Disabled || auth.Status == coreauth.StatusDisabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if modelName != "" && len(auth.ModelStates) > 0 {
|
||||||
|
state, ok := auth.ModelStates[modelName]
|
||||||
|
if (!ok || state == nil) && modelName != "" {
|
||||||
|
baseModel := strings.TrimSpace(thinking.ParseSuffix(modelName).ModelName)
|
||||||
|
if baseModel != "" && baseModel != modelName {
|
||||||
|
state, ok = auth.ModelStates[baseModel]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ok && state != nil {
|
||||||
|
if state.Status == coreauth.StatusDisabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if state.Unavailable && !state.NextRetryAfter.IsZero() && state.NextRetryAfter.After(now) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if auth.Unavailable && !auth.NextRetryAfter.IsZero() && auth.NextRetryAfter.After(now) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest []byte, allowIncrementalInputWithPreviousResponseID bool) bool {
|
||||||
|
if allowIncrementalInputWithPreviousResponseID || len(lastRequest) != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
generateResult := gjson.GetBytes(rawJSON, "generate")
|
||||||
|
return generateResult.Exists() && !generateResult.Bool()
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeResponsesWebsocketSyntheticPrewarm(
|
||||||
|
c *gin.Context,
|
||||||
|
conn *websocket.Conn,
|
||||||
|
requestJSON []byte,
|
||||||
|
wsBodyLog *strings.Builder,
|
||||||
|
sessionID string,
|
||||||
|
) error {
|
||||||
|
payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON)
|
||||||
|
if errPayloads != nil {
|
||||||
|
return errPayloads
|
||||||
|
}
|
||||||
|
for i := 0; i < len(payloads); i++ {
|
||||||
|
markAPIResponseTimestamp(c)
|
||||||
|
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
||||||
|
// log.Infof(
|
||||||
|
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||||
|
// sessionID,
|
||||||
|
// websocket.TextMessage,
|
||||||
|
// websocketPayloadEventType(payloads[i]),
|
||||||
|
// websocketPayloadPreview(payloads[i]),
|
||||||
|
// )
|
||||||
|
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
||||||
|
log.Warnf(
|
||||||
|
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||||
|
sessionID,
|
||||||
|
websocketPayloadEventType(payloads[i]),
|
||||||
|
errWrite,
|
||||||
|
)
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func syntheticResponsesWebsocketPrewarmPayloads(requestJSON []byte) ([][]byte, error) {
|
||||||
|
responseID := "resp_prewarm_" + uuid.NewString()
|
||||||
|
createdAt := time.Now().Unix()
|
||||||
|
modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String())
|
||||||
|
|
||||||
|
createdPayload := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
|
||||||
|
var errSet error
|
||||||
|
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.id", responseID)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.created_at", createdAt)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
if modelName != "" {
|
||||||
|
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.model", modelName)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
completedPayload := []byte(`{"type":"response.completed","sequence_number":1,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
|
||||||
|
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.id", responseID)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.created_at", createdAt)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
if modelName != "" {
|
||||||
|
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.model", modelName)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return [][]byte{createdPayload, completedPayload}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
|
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
|
||||||
existingRaw = strings.TrimSpace(existingRaw)
|
existingRaw = strings.TrimSpace(existingRaw)
|
||||||
appendRaw = strings.TrimSpace(appendRaw)
|
appendRaw = strings.TrimSpace(appendRaw)
|
||||||
@@ -550,47 +762,63 @@ func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.Error
|
|||||||
}
|
}
|
||||||
|
|
||||||
body := handlers.BuildErrorResponseBody(status, errText)
|
body := handlers.BuildErrorResponseBody(status, errText)
|
||||||
payload := map[string]any{
|
payload := []byte(`{}`)
|
||||||
"type": wsEventTypeError,
|
var errSet error
|
||||||
"status": status,
|
payload, errSet = sjson.SetBytes(payload, "type", wsEventTypeError)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
payload, errSet = sjson.SetBytes(payload, "status", status)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
}
|
}
|
||||||
|
|
||||||
if errMsg != nil && errMsg.Addon != nil {
|
if errMsg != nil && errMsg.Addon != nil {
|
||||||
headers := map[string]any{}
|
headers := []byte(`{}`)
|
||||||
|
hasHeaders := false
|
||||||
for key, values := range errMsg.Addon {
|
for key, values := range errMsg.Addon {
|
||||||
if len(values) == 0 {
|
if len(values) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
headers[key] = values[0]
|
headerPath := strings.ReplaceAll(strings.ReplaceAll(key, `\\`, `\\\\`), ".", `\\.`)
|
||||||
|
headers, errSet = sjson.SetBytes(headers, headerPath, values[0])
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
hasHeaders = true
|
||||||
}
|
}
|
||||||
if len(headers) > 0 {
|
if hasHeaders {
|
||||||
payload["headers"] = headers
|
payload, errSet = sjson.SetRawBytes(payload, "headers", headers)
|
||||||
}
|
if errSet != nil {
|
||||||
}
|
return nil, errSet
|
||||||
|
|
||||||
if len(body) > 0 && json.Valid(body) {
|
|
||||||
var decoded map[string]any
|
|
||||||
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
|
|
||||||
if inner, ok := decoded["error"]; ok {
|
|
||||||
payload["error"] = inner
|
|
||||||
} else {
|
|
||||||
payload["error"] = decoded
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := payload["error"]; !ok {
|
if len(body) > 0 && json.Valid(body) {
|
||||||
payload["error"] = map[string]any{
|
errorNode := gjson.GetBytes(body, "error")
|
||||||
"type": "server_error",
|
if errorNode.Exists() {
|
||||||
"message": errText,
|
payload, errSet = sjson.SetRawBytes(payload, "error", []byte(errorNode.Raw))
|
||||||
|
} else {
|
||||||
|
payload, errSet = sjson.SetRawBytes(payload, "error", body)
|
||||||
|
}
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(payload)
|
if !gjson.GetBytes(payload, "error").Exists() {
|
||||||
if err != nil {
|
payload, errSet = sjson.SetBytes(payload, "error.type", "server_error")
|
||||||
return nil, err
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
payload, errSet = sjson.SetBytes(payload, "error.message", errText)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return data, conn.WriteMessage(websocket.TextMessage, data)
|
|
||||||
|
return payload, conn.WriteMessage(websocket.TextMessage, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -11,9 +13,46 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type websocketCaptureExecutor struct {
|
||||||
|
streamCalls int
|
||||||
|
payloads [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
|
||||||
|
|
||||||
|
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
|
e.streamCalls++
|
||||||
|
e.payloads = append(e.payloads, bytes.Clone(req.Payload))
|
||||||
|
chunks := make(chan coreexecutor.StreamChunk, 1)
|
||||||
|
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)}
|
||||||
|
close(chunks)
|
||||||
|
return &coreexecutor.StreamResult{Chunks: chunks}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
||||||
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
|
||||||
@@ -326,3 +365,130 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
|||||||
t.Fatalf("server error: %v", errServer)
|
t.Fatalf("server error: %v", errServer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
auth := &coreauth.Auth{
|
||||||
|
ID: "auth-ws",
|
||||||
|
Provider: "test-provider",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Attributes: map[string]string{"websockets": "true"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||||
|
t.Fatalf("Register auth: %v", err)
|
||||||
|
}
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||||
|
h := NewOpenAIResponsesAPIHandler(base)
|
||||||
|
if !h.websocketUpstreamSupportsIncrementalInputForModel("test-model") {
|
||||||
|
t.Fatalf("expected websocket-capable upstream for test-model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
executor := &websocketCaptureExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
|
||||||
|
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||||
|
t.Fatalf("Register auth: %v", err)
|
||||||
|
}
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||||
|
h := NewOpenAIResponsesAPIHandler(base)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||||
|
|
||||||
|
server := httptest.NewServer(router)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial websocket: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
errClose := conn.Close()
|
||||||
|
if errClose != nil {
|
||||||
|
t.Fatalf("close websocket: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","generate":false}`))
|
||||||
|
if errWrite != nil {
|
||||||
|
t.Fatalf("write prewarm websocket message: %v", errWrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, createdPayload, errReadMessage := conn.ReadMessage()
|
||||||
|
if errReadMessage != nil {
|
||||||
|
t.Fatalf("read prewarm created message: %v", errReadMessage)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(createdPayload, "type").String() != "response.created" {
|
||||||
|
t.Fatalf("created payload type = %s, want response.created", gjson.GetBytes(createdPayload, "type").String())
|
||||||
|
}
|
||||||
|
prewarmResponseID := gjson.GetBytes(createdPayload, "response.id").String()
|
||||||
|
if prewarmResponseID == "" {
|
||||||
|
t.Fatalf("prewarm response id is empty")
|
||||||
|
}
|
||||||
|
if executor.streamCalls != 0 {
|
||||||
|
t.Fatalf("stream calls after prewarm = %d, want 0", executor.streamCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, completedPayload, errReadMessage := conn.ReadMessage()
|
||||||
|
if errReadMessage != nil {
|
||||||
|
t.Fatalf("read prewarm completed message: %v", errReadMessage)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(completedPayload, "type").String() != wsEventTypeCompleted {
|
||||||
|
t.Fatalf("completed payload type = %s, want %s", gjson.GetBytes(completedPayload, "type").String(), wsEventTypeCompleted)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(completedPayload, "response.id").String() != prewarmResponseID {
|
||||||
|
t.Fatalf("completed response id = %s, want %s", gjson.GetBytes(completedPayload, "response.id").String(), prewarmResponseID)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int() != 0 {
|
||||||
|
t.Fatalf("prewarm total tokens = %d, want 0", gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
secondRequest := fmt.Sprintf(`{"type":"response.create","previous_response_id":%q,"input":[{"type":"message","id":"msg-1"}]}`, prewarmResponseID)
|
||||||
|
errWrite = conn.WriteMessage(websocket.TextMessage, []byte(secondRequest))
|
||||||
|
if errWrite != nil {
|
||||||
|
t.Fatalf("write follow-up websocket message: %v", errWrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, upstreamPayload, errReadMessage := conn.ReadMessage()
|
||||||
|
if errReadMessage != nil {
|
||||||
|
t.Fatalf("read upstream completed message: %v", errReadMessage)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(upstreamPayload, "type").String() != wsEventTypeCompleted {
|
||||||
|
t.Fatalf("upstream payload type = %s, want %s", gjson.GetBytes(upstreamPayload, "type").String(), wsEventTypeCompleted)
|
||||||
|
}
|
||||||
|
if executor.streamCalls != 1 {
|
||||||
|
t.Fatalf("stream calls after follow-up = %d, want 1", executor.streamCalls)
|
||||||
|
}
|
||||||
|
if len(executor.payloads) != 1 {
|
||||||
|
t.Fatalf("captured upstream payloads = %d, want 1", len(executor.payloads))
|
||||||
|
}
|
||||||
|
forwarded := executor.payloads[0]
|
||||||
|
if gjson.GetBytes(forwarded, "previous_response_id").Exists() {
|
||||||
|
t.Fatalf("previous_response_id leaked upstream: %s", forwarded)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(forwarded, "generate").Exists() {
|
||||||
|
t.Fatalf("generate leaked upstream: %s", forwarded)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(forwarded, "model").String() != "test-model" {
|
||||||
|
t.Fatalf("forwarded model = %s, want test-model", gjson.GetBytes(forwarded, "model").String())
|
||||||
|
}
|
||||||
|
input := gjson.GetBytes(forwarded, "input").Array()
|
||||||
|
if len(input) != 1 || input[0].Get("id").String() != "msg-1" {
|
||||||
|
t.Fatalf("unexpected forwarded input: %s", forwarded)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user