fix: repair websocket custom tool calls

This commit is contained in:
Kai Wang
2026-04-03 17:11:41 +08:00
parent ab9ebea592
commit 8f0e66b72e

View File

@@ -266,15 +266,15 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
continue
}
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
switch itemType {
case "function_call_output":
switch {
case isResponsesToolCallOutputType(itemType):
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
continue
}
outputPresent[callID] = struct{}{}
outputCache.record(sessionKey, callID, item)
case "function_call":
case isResponsesToolCallType(itemType):
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
continue
@@ -293,7 +293,7 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
continue
}
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
if itemType == "function_call_output" {
if isResponsesToolCallOutputType(itemType) {
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
// Upstream rejects tool outputs without a call_id; drop it.
@@ -325,7 +325,7 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
// Drop orphaned function_call_output items; upstream rejects transcripts with missing calls.
continue
}
if itemType != "function_call" {
if !isResponsesToolCallType(itemType) {
filtered = append(filtered, item)
continue
}
@@ -376,7 +376,7 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO
return
}
for _, item := range output.Array() {
if strings.TrimSpace(item.Get("type").String()) != "function_call" {
if !isResponsesToolCallType(item.Get("type").String()) {
continue
}
callID := strings.TrimSpace(item.Get("call_id").String())
@@ -390,7 +390,7 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO
if !item.Exists() || !item.IsObject() {
return
}
if strings.TrimSpace(item.Get("type").String()) != "function_call" {
if !isResponsesToolCallType(item.Get("type").String()) {
return
}
callID := strings.TrimSpace(item.Get("call_id").String())
@@ -400,3 +400,21 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO
cache.record(sessionKey, callID, json.RawMessage(item.Raw))
}
}
func isResponsesToolCallType(itemType string) bool {
switch strings.TrimSpace(itemType) {
case "function_call", "custom_tool_call":
return true
default:
return false
}
}
func isResponsesToolCallOutputType(itemType string) bool {
switch strings.TrimSpace(itemType) {
case "function_call_output", "custom_tool_call_output":
return true
default:
return false
}
}