diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 591552ae..b8076601 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -33,6 +33,8 @@ const ( wsDoneMarker = "[DONE]" wsTurnStateHeader = "x-codex-turn-state" wsRequestBodyKey = "REQUEST_BODY_OVERRIDE" + wsBodyLogMaxSize = 32 * 1024 + wsBodyLogTruncated = "\n...[truncated]\n" ) var responsesWebsocketUpgrader = websocket.Upgrader{ @@ -52,6 +54,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { return } passthroughSessionID := uuid.NewString() + downstreamSessionKey := websocketDownstreamSessionKey(c.Request) clientRemoteAddr := "" if c != nil && c.Request != nil { clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr) @@ -164,6 +167,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } continue } + + requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON) + updatedLastRequest = bytes.Clone(requestJSON) lastRequest = updatedLastRequest modelName := gjson.GetBytes(requestJSON, "model").String() @@ -324,6 +330,10 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last Error: fmt.Errorf("invalid request input: %w", errMerge), } } + dedupedInput, errDedupeFunctionCalls := dedupeFunctionCallsByCallID(mergedInput) + if errDedupeFunctionCalls == nil { + mergedInput = dedupedInput + } normalized, errDelete := sjson.DeleteBytes(rawJSON, "type") if errDelete != nil { @@ -355,7 +365,8 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last } func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bool { - if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate { + requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) + if requestType != wsRequestTypeCreate && requestType != wsRequestTypeAppend { return false } if strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()) != "" { @@ -402,6 +413,42 @@ func normalizeResponseTranscriptReplacement(rawJSON []byte, lastRequest []byte) return bytes.Clone(normalized) } +func dedupeFunctionCallsByCallID(rawArray string) (string, error) { + rawArray = strings.TrimSpace(rawArray) + if rawArray == "" { + return "[]", nil + } + var items []json.RawMessage + if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil { + return "", errUnmarshal + } + + seenCallIDs := make(map[string]struct{}, len(items)) + filtered := make([]json.RawMessage, 0, len(items)) + for _, item := range items { + if len(item) == 0 { + continue + } + itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String()) + if itemType == "function_call" { + callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) + if callID != "" { + if _, ok := seenCallIDs[callID]; ok { + continue + } + seenCallIDs[callID] = struct{}{} + } + } + filtered = append(filtered, item) + } + + out, errMarshal := json.Marshal(filtered) + if errMarshal != nil { + return "", errMarshal + } + return string(out), nil +} + func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool { if len(attributes) > 0 { if raw := strings.TrimSpace(attributes["websockets"]); raw != "" { @@ -667,6 +714,10 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( ) ([]byte, error) { completed := false completedOutput := []byte("[]") + downstreamSessionKey := "" + if c != nil && c.Request != nil { + downstreamSessionKey = websocketDownstreamSessionKey(c.Request) + } for { select { @@ -744,6 +795,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( payloads := websocketJSONPayloadsFromChunk(chunk) for i := range payloads { + recordResponsesWebsocketToolCallsFromPayload(downstreamSessionKey, payloads[i]) eventType := gjson.GetBytes(payloads[i], "type").String() if eventType == wsEventTypeCompleted { completed = true @@ -891,18 +943,53 @@ func appendWebsocketEvent(builder *strings.Builder, eventType string, payload [] if builder == nil { return } + if builder.Len() >= wsBodyLogMaxSize { + return + } trimmedPayload := bytes.TrimSpace(payload) if len(trimmedPayload) == 0 { return } + + separator := []byte{} if builder.Len() > 0 { - builder.WriteString("\n") + separator = []byte("\n") } - builder.WriteString("websocket.") - builder.WriteString(eventType) - builder.WriteString("\n") - builder.Write(trimmedPayload) - builder.WriteString("\n") + header := []byte("websocket." + eventType + "\n") + footer := []byte("\n") + entryLen := len(separator) + len(header) + len(trimmedPayload) + len(footer) + remaining := wsBodyLogMaxSize - builder.Len() + + if entryLen <= remaining { + builder.Write(separator) + builder.Write(header) + builder.Write(trimmedPayload) + builder.Write(footer) + return + } + + marker := []byte(wsBodyLogTruncated) + if len(marker) > remaining { + builder.Write(marker[:remaining]) + return + } + + allowed := remaining - len(marker) + parts := [][]byte{separator, header, trimmedPayload, footer} + for _, part := range parts { + if allowed <= 0 { + break + } + if len(part) <= allowed { + builder.Write(part) + allowed -= len(part) + continue + } + builder.Write(part[:allowed]) + allowed = 0 + break + } + builder.Write(marker) } func websocketPayloadEventType(payload []byte) string { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index 5619e6b1..9e2a1ed6 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -10,6 +10,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -442,6 +443,108 @@ func TestSetWebsocketRequestBody(t *testing.T) { } } +func TestRepairResponsesWebsocketToolCallsInsertsCachedOutput(t *testing.T) { + cache := newWebsocketToolOutputCache(time.Minute, 10) + sessionKey := "session-1" + + cacheWarm := []byte(`{"previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","output":"ok"}]}`) + warmed := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, cacheWarm) + if gjson.GetBytes(warmed, "input.0.call_id").String() != "call-1" { + t.Fatalf("expected warmup output to remain") + } + + raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`) + repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw) + + input := gjson.GetBytes(repaired, "input").Array() + if len(input) != 3 { + t.Fatalf("repaired input len = %d, want 3", len(input)) + } + if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" { + t.Fatalf("unexpected first item: %s", input[0].Raw) + } + if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" { + t.Fatalf("missing inserted output: %s", input[1].Raw) + } + if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" { + t.Fatalf("unexpected trailing item: %s", input[2].Raw) + } +} + +func TestRepairResponsesWebsocketToolCallsDropsOrphanFunctionCall(t *testing.T) { + cache := newWebsocketToolOutputCache(time.Minute, 10) + sessionKey := "session-1" + + raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`) + repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw) + + input := gjson.GetBytes(repaired, "input").Array() + if len(input) != 1 { + t.Fatalf("repaired input len = %d, want 1", len(input)) + } + if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" { + t.Fatalf("unexpected remaining item: %s", input[0].Raw) + } +} + +func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForOrphanOutput(t *testing.T) { + outputCache := newWebsocketToolOutputCache(time.Minute, 10) + callCache := newWebsocketToolOutputCache(time.Minute, 10) + sessionKey := "session-1" + + callCache.record(sessionKey, "call-1", []byte(`{"type":"function_call","call_id":"call-1","name":"tool"}`)) + + raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`) + repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw) + + input := gjson.GetBytes(repaired, "input").Array() + if len(input) != 3 { + t.Fatalf("repaired input len = %d, want 3", len(input)) + } + if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" { + t.Fatalf("missing inserted call: %s", input[0].Raw) + } + if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" { + t.Fatalf("unexpected output item: %s", input[1].Raw) + } + if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" { + t.Fatalf("unexpected trailing item: %s", input[2].Raw) + } +} + +func TestRepairResponsesWebsocketToolCallsDropsOrphanOutputWhenCallMissing(t *testing.T) { + outputCache := newWebsocketToolOutputCache(time.Minute, 10) + callCache := newWebsocketToolOutputCache(time.Minute, 10) + sessionKey := "session-1" + + raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`) + repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw) + + input := gjson.GetBytes(repaired, "input").Array() + if len(input) != 1 { + t.Fatalf("repaired input len = %d, want 1", len(input)) + } + if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" { + t.Fatalf("unexpected remaining item: %s", input[0].Raw) + } +} + +func TestRecordResponsesWebsocketToolCallsFromPayloadWithCache(t *testing.T) { + cache := newWebsocketToolOutputCache(time.Minute, 10) + sessionKey := "session-1" + + payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool","arguments":"{}"}]}}`) + recordResponsesWebsocketToolCallsFromPayloadWithCache(cache, sessionKey, payload) + + cached, ok := cache.get(sessionKey, "call-1") + if !ok { + t.Fatalf("expected cached tool call") + } + if gjson.GetBytes(cached, "type").String() != "function_call" || gjson.GetBytes(cached, "call_id").String() != "call-1" { + t.Fatalf("unexpected cached tool call: %s", cached) + } +} + func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { gin.SetMode(gin.TestMode) @@ -767,6 +870,29 @@ func TestNormalizeResponsesWebsocketRequestDoesNotTreatDeveloperMessageAsReplace } } +func TestNormalizeResponsesWebsocketRequestDropsDuplicateFunctionCallsByCallID(t *testing.T) { + lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"}]}`) + lastResponseOutput := []byte(`[ + {"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"} + ]`) + raw := []byte(`{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`) + + normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput) + if errMsg != nil { + t.Fatalf("unexpected error: %v", errMsg.Error) + } + + items := gjson.GetBytes(normalized, "input").Array() + if len(items) != 3 { + t.Fatalf("merged input len = %d, want 3: %s", len(items), normalized) + } + if items[0].Get("id").String() != "fc-1" || + items[1].Get("id").String() != "tool-out-1" || + items[2].Get("id").String() != "msg-2" { + t.Fatalf("unexpected merged input order: %s", normalized) + } +} + func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go b/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go new file mode 100644 index 00000000..8333bce6 --- /dev/null +++ b/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go @@ -0,0 +1,327 @@ +package openai + +import ( + "encoding/json" + "net/http" + "strings" + "sync" + "time" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + websocketToolOutputCacheMaxPerSession = 256 + websocketToolOutputCacheTTL = 30 * time.Minute +) + +var defaultWebsocketToolOutputCache = newWebsocketToolOutputCache(websocketToolOutputCacheTTL, websocketToolOutputCacheMaxPerSession) +var defaultWebsocketToolCallCache = newWebsocketToolOutputCache(websocketToolOutputCacheTTL, websocketToolOutputCacheMaxPerSession) + +type websocketToolOutputCache struct { + mu sync.Mutex + ttl time.Duration + maxPerSession int + sessions map[string]*websocketToolOutputSession +} + +type websocketToolOutputSession struct { + lastSeen time.Time + outputs map[string]json.RawMessage + order []string +} + +func newWebsocketToolOutputCache(ttl time.Duration, maxPerSession int) *websocketToolOutputCache { + if ttl <= 0 { + ttl = websocketToolOutputCacheTTL + } + if maxPerSession <= 0 { + maxPerSession = websocketToolOutputCacheMaxPerSession + } + return &websocketToolOutputCache{ + ttl: ttl, + maxPerSession: maxPerSession, + sessions: make(map[string]*websocketToolOutputSession), + } +} + +func (c *websocketToolOutputCache) record(sessionKey string, callID string, item json.RawMessage) { + sessionKey = strings.TrimSpace(sessionKey) + callID = strings.TrimSpace(callID) + if sessionKey == "" || callID == "" || c == nil { + return + } + + now := time.Now() + c.mu.Lock() + defer c.mu.Unlock() + + c.cleanupLocked(now) + + session, ok := c.sessions[sessionKey] + if !ok || session == nil { + session = &websocketToolOutputSession{ + lastSeen: now, + outputs: make(map[string]json.RawMessage), + } + c.sessions[sessionKey] = session + } + session.lastSeen = now + + if _, exists := session.outputs[callID]; !exists { + session.order = append(session.order, callID) + } + session.outputs[callID] = append(json.RawMessage(nil), item...) + + for len(session.order) > c.maxPerSession { + evict := session.order[0] + session.order = session.order[1:] + delete(session.outputs, evict) + } +} + +func (c *websocketToolOutputCache) get(sessionKey string, callID string) (json.RawMessage, bool) { + sessionKey = strings.TrimSpace(sessionKey) + callID = strings.TrimSpace(callID) + if sessionKey == "" || callID == "" || c == nil { + return nil, false + } + + now := time.Now() + c.mu.Lock() + defer c.mu.Unlock() + + c.cleanupLocked(now) + + session, ok := c.sessions[sessionKey] + if !ok || session == nil { + return nil, false + } + session.lastSeen = now + item, ok := session.outputs[callID] + if !ok || len(item) == 0 { + return nil, false + } + return append(json.RawMessage(nil), item...), true +} + +func (c *websocketToolOutputCache) cleanupLocked(now time.Time) { + if c == nil || c.ttl <= 0 { + return + } + + for key, session := range c.sessions { + if session == nil { + delete(c.sessions, key) + continue + } + if now.Sub(session.lastSeen) > c.ttl { + delete(c.sessions, key) + } + } +} + +func websocketDownstreamSessionKey(req *http.Request) string { + if req == nil { + return "" + } + if sessionID := strings.TrimSpace(req.Header.Get("Session_id")); sessionID != "" { + return sessionID + } + if requestID := strings.TrimSpace(req.Header.Get("X-Client-Request-Id")); requestID != "" { + return requestID + } + if raw := strings.TrimSpace(req.Header.Get("X-Codex-Turn-Metadata")); raw != "" { + if sessionID := strings.TrimSpace(gjson.Get(raw, "session_id").String()); sessionID != "" { + return sessionID + } + } + return "" +} + +func repairResponsesWebsocketToolCalls(sessionKey string, payload []byte) []byte { + return repairResponsesWebsocketToolCallsWithCaches(defaultWebsocketToolOutputCache, defaultWebsocketToolCallCache, sessionKey, payload) +} + +func repairResponsesWebsocketToolCallsWithCache(cache *websocketToolOutputCache, sessionKey string, payload []byte) []byte { + return repairResponsesWebsocketToolCallsWithCaches(cache, nil, sessionKey, payload) +} + +func repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache *websocketToolOutputCache, sessionKey string, payload []byte) []byte { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" || outputCache == nil || len(payload) == 0 { + return payload + } + + input := gjson.GetBytes(payload, "input") + if !input.Exists() || !input.IsArray() { + return payload + } + + allowOrphanOutputs := strings.TrimSpace(gjson.GetBytes(payload, "previous_response_id").String()) != "" + updatedRaw, errRepair := repairResponsesToolCallsArray(outputCache, callCache, sessionKey, input.Raw, allowOrphanOutputs) + if errRepair != nil || updatedRaw == "" || updatedRaw == input.Raw { + return payload + } + + updated, errSet := sjson.SetRawBytes(payload, "input", []byte(updatedRaw)) + if errSet != nil { + return payload + } + return updated +} + +func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCache, sessionKey string, rawArray string, allowOrphanOutputs bool) (string, error) { + rawArray = strings.TrimSpace(rawArray) + if rawArray == "" { + return "[]", nil + } + + var items []json.RawMessage + if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil { + return "", errUnmarshal + } + + // First pass: record tool outputs and remember which call_ids have outputs in this payload. + outputPresent := make(map[string]struct{}, len(items)) + callPresent := make(map[string]struct{}, len(items)) + for _, item := range items { + if len(item) == 0 { + continue + } + itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String()) + switch itemType { + case "function_call_output": + callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) + if callID == "" { + continue + } + outputPresent[callID] = struct{}{} + outputCache.record(sessionKey, callID, item) + case "function_call": + callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) + if callID == "" { + continue + } + callPresent[callID] = struct{}{} + if callCache != nil { + callCache.record(sessionKey, callID, item) + } + } + } + + filtered := make([]json.RawMessage, 0, len(items)) + insertedCalls := make(map[string]struct{}, len(items)) + for _, item := range items { + if len(item) == 0 { + continue + } + itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String()) + if itemType == "function_call_output" { + callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) + if callID == "" { + // Upstream rejects tool outputs without a call_id; drop it. + continue + } + + if allowOrphanOutputs { + filtered = append(filtered, item) + continue + } + + if _, ok := callPresent[callID]; ok { + filtered = append(filtered, item) + continue + } + + if callCache != nil { + if cached, ok := callCache.get(sessionKey, callID); ok { + if _, already := insertedCalls[callID]; !already { + filtered = append(filtered, cached) + insertedCalls[callID] = struct{}{} + callPresent[callID] = struct{}{} + } + filtered = append(filtered, item) + continue + } + } + + // Drop orphaned function_call_output items; upstream rejects transcripts with missing calls. + continue + } + if itemType != "function_call" { + filtered = append(filtered, item) + continue + } + + callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) + if callID == "" { + // Upstream rejects tool calls without a call_id; drop it. + continue + } + + if _, ok := outputPresent[callID]; ok { + filtered = append(filtered, item) + continue + } + + if cached, ok := outputCache.get(sessionKey, callID); ok { + filtered = append(filtered, item) + filtered = append(filtered, cached) + outputPresent[callID] = struct{}{} + continue + } + + // Drop orphaned function_call items; upstream rejects transcripts with missing outputs. + } + + out, errMarshal := json.Marshal(filtered) + if errMarshal != nil { + return "", errMarshal + } + return string(out), nil +} + +func recordResponsesWebsocketToolCallsFromPayload(sessionKey string, payload []byte) { + recordResponsesWebsocketToolCallsFromPayloadWithCache(defaultWebsocketToolCallCache, sessionKey, payload) +} + +func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolOutputCache, sessionKey string, payload []byte) { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" || cache == nil || len(payload) == 0 { + return + } + + eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String()) + switch eventType { + case "response.completed": + output := gjson.GetBytes(payload, "response.output") + if !output.Exists() || !output.IsArray() { + return + } + for _, item := range output.Array() { + if strings.TrimSpace(item.Get("type").String()) != "function_call" { + continue + } + callID := strings.TrimSpace(item.Get("call_id").String()) + if callID == "" { + continue + } + cache.record(sessionKey, callID, json.RawMessage(item.Raw)) + } + case "response.output_item.added", "response.output_item.done": + item := gjson.GetBytes(payload, "item") + if !item.Exists() || !item.IsObject() { + return + } + if strings.TrimSpace(item.Get("type").String()) != "function_call" { + return + } + callID := strings.TrimSpace(item.Get("call_id").String()) + if callID == "" { + return + } + cache.record(sessionKey, callID, json.RawMessage(item.Raw)) + } +}