Merge pull request #2513 from lonr-6/codex/fix-ws-custom-tool-repair-v2

fix: repair responses websocket custom tool call pairing
This commit is contained in:
Luis Pater
2026-04-03 23:45:38 +08:00
committed by GitHub
3 changed files with 300 additions and 9 deletions

View File

@@ -379,7 +379,7 @@ func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bo
for _, item := range nextInput.Array() {
switch strings.TrimSpace(item.Get("type").String()) {
case "function_call":
case "function_call", "custom_tool_call":
return true
case "message":
role := strings.TrimSpace(item.Get("role").String())
@@ -431,7 +431,7 @@ func dedupeFunctionCallsByCallID(rawArray string) (string, error) {
continue
}
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
if itemType == "function_call" {
if isResponsesToolCallType(itemType) {
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID != "" {
if _, ok := seenCallIDs[callID]; ok {

View File

@@ -520,6 +520,92 @@ func TestRepairResponsesWebsocketToolCallsDropsOrphanOutputWhenCallMissing(t *te
}
}
func TestRepairResponsesWebsocketToolCallsInsertsCachedCustomToolOutput(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
cacheWarm := []byte(`{"previous_response_id":"resp-1","input":[{"type":"custom_tool_call_output","call_id":"call-1","output":"ok"}]}`)
warmed := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, cacheWarm)
if gjson.GetBytes(warmed, "input.0.call_id").String() != "call-1" {
t.Fatalf("expected warmup output to remain")
}
raw := []byte(`{"input":[{"type":"custom_tool_call","call_id":"call-1","name":"apply_patch"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 3 {
t.Fatalf("repaired input len = %d, want 3", len(input))
}
if input[0].Get("type").String() != "custom_tool_call" || input[0].Get("call_id").String() != "call-1" {
t.Fatalf("unexpected first item: %s", input[0].Raw)
}
if input[1].Get("type").String() != "custom_tool_call_output" || input[1].Get("call_id").String() != "call-1" {
t.Fatalf("missing inserted output: %s", input[1].Raw)
}
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsDropsOrphanCustomToolCall(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
raw := []byte(`{"input":[{"type":"custom_tool_call","call_id":"call-1","name":"apply_patch"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 1 {
t.Fatalf("repaired input len = %d, want 1", len(input))
}
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsInsertsCachedCustomToolCallForOrphanOutput(t *testing.T) {
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
callCache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
callCache.record(sessionKey, "call-1", []byte(`{"type":"custom_tool_call","call_id":"call-1","name":"apply_patch"}`))
raw := []byte(`{"input":[{"type":"custom_tool_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 3 {
t.Fatalf("repaired input len = %d, want 3", len(input))
}
if input[0].Get("type").String() != "custom_tool_call" || input[0].Get("call_id").String() != "call-1" {
t.Fatalf("missing inserted call: %s", input[0].Raw)
}
if input[1].Get("type").String() != "custom_tool_call_output" || input[1].Get("call_id").String() != "call-1" {
t.Fatalf("unexpected output item: %s", input[1].Raw)
}
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsDropsOrphanCustomToolOutputWhenCallMissing(t *testing.T) {
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
callCache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
raw := []byte(`{"input":[{"type":"custom_tool_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 1 {
t.Fatalf("repaired input len = %d, want 1", len(input))
}
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
}
}
func TestRecordResponsesWebsocketToolCallsFromPayloadWithCache(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
@@ -536,6 +622,38 @@ func TestRecordResponsesWebsocketToolCallsFromPayloadWithCache(t *testing.T) {
}
}
func TestRecordResponsesWebsocketCustomToolCallsFromCompletedPayloadWithCache(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch","input":"*** Begin Patch"}]}}`)
recordResponsesWebsocketToolCallsFromPayloadWithCache(cache, sessionKey, payload)
cached, ok := cache.get(sessionKey, "call-1")
if !ok {
t.Fatalf("expected cached custom tool call")
}
if gjson.GetBytes(cached, "type").String() != "custom_tool_call" || gjson.GetBytes(cached, "call_id").String() != "call-1" {
t.Fatalf("unexpected cached custom tool call: %s", cached)
}
}
func TestRecordResponsesWebsocketCustomToolCallsFromOutputItemDoneWithCache(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
payload := []byte(`{"type":"response.output_item.done","item":{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch","input":"*** Begin Patch"}}`)
recordResponsesWebsocketToolCallsFromPayloadWithCache(cache, sessionKey, payload)
cached, ok := cache.get(sessionKey, "call-1")
if !ok {
t.Fatalf("expected cached custom tool call")
}
if gjson.GetBytes(cached, "type").String() != "custom_tool_call" || gjson.GetBytes(cached, "call_id").String() != "call-1" {
t.Fatalf("unexpected cached custom tool call: %s", cached)
}
}
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -1023,6 +1141,161 @@ func TestNormalizeResponsesWebsocketRequestDropsDuplicateFunctionCallsByCallID(t
}
}
func TestNormalizeResponsesWebsocketRequestTreatsCustomToolTranscriptReplacementAsReset(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch"},{"type":"custom_tool_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`)
lastResponseOutput := []byte(`[
{"type":"message","id":"assistant-1","role":"assistant"}
]`)
raw := []byte(`{"type":"response.create","input":[{"type":"custom_tool_call","id":"ctc-compact","call_id":"call-1","name":"apply_patch"},{"type":"custom_tool_call_output","id":"tool-out-compact","call_id":"call-1"},{"type":"message","id":"msg-2"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "previous_response_id").Exists() {
t.Fatalf("previous_response_id must not exist in transcript replacement mode")
}
items := gjson.GetBytes(normalized, "input").Array()
if len(items) != 3 {
t.Fatalf("replacement input len = %d, want 3: %s", len(items), normalized)
}
if items[0].Get("id").String() != "ctc-compact" ||
items[1].Get("id").String() != "tool-out-compact" ||
items[2].Get("id").String() != "msg-2" {
t.Fatalf("replacement transcript was not preserved as-is: %s", normalized)
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match replacement request")
}
}
func TestNormalizeResponsesWebsocketRequestDropsDuplicateCustomToolCallsByCallID(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch"},{"type":"custom_tool_call_output","id":"tool-out-1","call_id":"call-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch"}
]`)
raw := []byte(`{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`)
normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
items := gjson.GetBytes(normalized, "input").Array()
if len(items) != 3 {
t.Fatalf("merged input len = %d, want 3: %s", len(items), normalized)
}
if items[0].Get("id").String() != "ctc-1" ||
items[1].Get("id").String() != "tool-out-1" ||
items[2].Get("id").String() != "msg-2" {
t.Fatalf("unexpected merged input order: %s", normalized)
}
}
func TestResponsesWebsocketCompactionResetsTurnStateOnCustomToolTranscriptReplacement(t *testing.T) {
gin.SetMode(gin.TestMode)
executor := &websocketCompactionCaptureExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
router := gin.New()
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
router.POST("/v1/responses/compact", h.Compact)
server := httptest.NewServer(router)
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
if errClose := conn.Close(); errClose != nil {
t.Fatalf("close websocket: %v", errClose)
}
}()
requests := []string{
`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`,
`{"type":"response.create","input":[{"type":"custom_tool_call_output","call_id":"call-1","id":"tool-out-1"}]}`,
}
for i := range requests {
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil {
t.Fatalf("write websocket message %d: %v", i+1, errWrite)
}
_, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read websocket message %d: %v", i+1, errReadMessage)
}
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
t.Fatalf("message %d payload type = %s, want %s", i+1, got, wsEventTypeCompleted)
}
}
compactResp, errPost := server.Client().Post(
server.URL+"/v1/responses/compact",
"application/json",
strings.NewReader(`{"model":"test-model","input":[{"type":"message","id":"summary-1"}]}`),
)
if errPost != nil {
t.Fatalf("compact request failed: %v", errPost)
}
if errClose := compactResp.Body.Close(); errClose != nil {
t.Fatalf("close compact response body: %v", errClose)
}
if compactResp.StatusCode != http.StatusOK {
t.Fatalf("compact status = %d, want %d", compactResp.StatusCode, http.StatusOK)
}
postCompact := `{"type":"response.create","input":[{"type":"custom_tool_call","id":"ctc-compact","call_id":"call-1","name":"apply_patch"},{"type":"custom_tool_call_output","id":"tool-out-compact","call_id":"call-1"},{"type":"message","id":"msg-2"}]}`
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(postCompact)); errWrite != nil {
t.Fatalf("write post-compact websocket message: %v", errWrite)
}
_, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read post-compact websocket message: %v", errReadMessage)
}
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
t.Fatalf("post-compact payload type = %s, want %s", got, wsEventTypeCompleted)
}
executor.mu.Lock()
defer executor.mu.Unlock()
if executor.compactPayload == nil {
t.Fatalf("compact payload was not captured")
}
if len(executor.streamPayloads) != 3 {
t.Fatalf("stream payload count = %d, want 3", len(executor.streamPayloads))
}
merged := executor.streamPayloads[2]
items := gjson.GetBytes(merged, "input").Array()
if len(items) != 3 {
t.Fatalf("merged input len = %d, want 3: %s", len(items), merged)
}
if items[0].Get("id").String() != "ctc-compact" ||
items[1].Get("id").String() != "tool-out-compact" ||
items[2].Get("id").String() != "msg-2" {
t.Fatalf("unexpected post-compact input order: %s", merged)
}
if items[0].Get("call_id").String() != "call-1" {
t.Fatalf("post-compact custom tool call id = %s, want call-1", items[0].Get("call_id").String())
}
}
func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -266,15 +266,15 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
continue
}
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
switch itemType {
case "function_call_output":
switch {
case isResponsesToolCallOutputType(itemType):
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
continue
}
outputPresent[callID] = struct{}{}
outputCache.record(sessionKey, callID, item)
case "function_call":
case isResponsesToolCallType(itemType):
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
continue
@@ -293,7 +293,7 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
continue
}
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
if itemType == "function_call_output" {
if isResponsesToolCallOutputType(itemType) {
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
// Upstream rejects tool outputs without a call_id; drop it.
@@ -325,7 +325,7 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
// Drop orphaned function_call_output items; upstream rejects transcripts with missing calls.
continue
}
if itemType != "function_call" {
if !isResponsesToolCallType(itemType) {
filtered = append(filtered, item)
continue
}
@@ -376,7 +376,7 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO
return
}
for _, item := range output.Array() {
if strings.TrimSpace(item.Get("type").String()) != "function_call" {
if !isResponsesToolCallType(item.Get("type").String()) {
continue
}
callID := strings.TrimSpace(item.Get("call_id").String())
@@ -390,7 +390,7 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO
if !item.Exists() || !item.IsObject() {
return
}
if strings.TrimSpace(item.Get("type").String()) != "function_call" {
if !isResponsesToolCallType(item.Get("type").String()) {
return
}
callID := strings.TrimSpace(item.Get("call_id").String())
@@ -400,3 +400,21 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO
cache.record(sessionKey, callID, json.RawMessage(item.Raw))
}
}
func isResponsesToolCallType(itemType string) bool {
switch strings.TrimSpace(itemType) {
case "function_call", "custom_tool_call":
return true
default:
return false
}
}
func isResponsesToolCallOutputType(itemType string) bool {
switch strings.TrimSpace(itemType) {
case "function_call_output", "custom_tool_call_output":
return true
default:
return false
}
}