diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 5c68f40e..211b8b81 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -277,6 +277,15 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last } } + // Compaction can cause clients to replace local websocket history with a new + // compact transcript on the next `response.create`. When the input already + // contains historical model output items, treating it as an incremental append + // duplicates stale turn-state and can leave late orphaned function_call items. + if shouldReplaceWebsocketTranscript(rawJSON, nextInput) { + normalized := normalizeResponseTranscriptReplacement(rawJSON, lastRequest) + return normalized, bytes.Clone(normalized), nil + } + // Websocket v2 mode uses response.create with previous_response_id + incremental input. // Do not expand it into a full input transcript; upstream expects the incremental payload. if allowIncrementalInputWithPreviousResponseID { @@ -348,6 +357,54 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last return normalized, bytes.Clone(normalized), nil } +func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bool { + if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate { + return false + } + if strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()) != "" { + return false + } + if !nextInput.Exists() || !nextInput.IsArray() { + return false + } + + for _, item := range nextInput.Array() { + switch strings.TrimSpace(item.Get("type").String()) { + case "function_call": + return true + case "message": + role := strings.TrimSpace(item.Get("role").String()) + if role == "assistant" || role == "developer" { + return true + } + } + } + + return false +} + +func normalizeResponseTranscriptReplacement(rawJSON []byte, lastRequest []byte) []byte { + normalized, errDelete := sjson.DeleteBytes(rawJSON, "type") + if errDelete != nil { + normalized = bytes.Clone(rawJSON) + } + normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id") + if !gjson.GetBytes(normalized, "model").Exists() { + modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) + if modelName != "" { + normalized, _ = sjson.SetBytes(normalized, "model", modelName) + } + } + if !gjson.GetBytes(normalized, "instructions").Exists() { + instructions := gjson.GetBytes(lastRequest, "instructions") + if instructions.Exists() { + normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw)) + } + } + normalized, _ = sjson.SetBytes(normalized, "stream", true) + return bytes.Clone(normalized) +} + func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool { if len(attributes) > 0 { if raw := strings.TrimSpace(attributes["websockets"]); raw != "" { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index b3a32c5c..b1440a95 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -27,6 +27,12 @@ type websocketCaptureExecutor struct { payloads [][]byte } +type websocketCompactionCaptureExecutor struct { + mu sync.Mutex + streamPayloads [][]byte + compactPayload []byte +} + type orderedWebsocketSelector struct { mu sync.Mutex order []string @@ -126,6 +132,52 @@ func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, return nil, errors.New("not implemented") } +func (e *websocketCompactionCaptureExecutor) Identifier() string { return "test-provider" } + +func (e *websocketCompactionCaptureExecutor) Execute(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (coreexecutor.Response, error) { + e.mu.Lock() + e.compactPayload = bytes.Clone(req.Payload) + e.mu.Unlock() + if opts.Alt != "responses/compact" { + return coreexecutor.Response{}, fmt.Errorf("unexpected non-compact execute alt: %q", opts.Alt) + } + return coreexecutor.Response{Payload: []byte(`{"id":"cmp-1","object":"response.compaction"}`)}, nil +} + +func (e *websocketCompactionCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) { + e.mu.Lock() + callIndex := len(e.streamPayloads) + e.streamPayloads = append(e.streamPayloads, bytes.Clone(req.Payload)) + e.mu.Unlock() + + var payload []byte + switch callIndex { + case 0: + payload = []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"}]}}`) + case 1: + payload = []byte(`{"type":"response.completed","response":{"id":"resp-2","output":[{"type":"message","id":"assistant-1"}]}}`) + default: + payload = []byte(`{"type":"response.completed","response":{"id":"resp-3","output":[{"type":"message","id":"assistant-2"}]}}`) + } + + chunks := make(chan coreexecutor.StreamChunk, 1) + chunks <- coreexecutor.StreamChunk{Payload: payload} + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil +} + +func (e *websocketCompactionCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *websocketCompactionCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketCompactionCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) { + return nil, errors.New("not implemented") +} + func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) { raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`) @@ -662,3 +714,134 @@ func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) { t.Fatalf("selected auth IDs = %v, want [auth-sse auth-ws]", got) } } + +func TestNormalizeResponsesWebsocketRequestTreatsTranscriptReplacementAsReset(t *testing.T) { + lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`) + lastResponseOutput := []byte(`[ + {"type":"message","id":"assistant-1","role":"assistant"} + ]`) + raw := []byte(`{"type":"response.create","input":[{"type":"function_call","id":"fc-compact","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-2"}]}`) + + normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput) + if errMsg != nil { + t.Fatalf("unexpected error: %v", errMsg.Error) + } + if gjson.GetBytes(normalized, "previous_response_id").Exists() { + t.Fatalf("previous_response_id must not exist in transcript replacement mode") + } + items := gjson.GetBytes(normalized, "input").Array() + if len(items) != 2 { + t.Fatalf("replacement input len = %d, want 2: %s", len(items), normalized) + } + if items[0].Get("id").String() != "fc-compact" || items[1].Get("id").String() != "msg-2" { + t.Fatalf("replacement transcript was not preserved as-is: %s", normalized) + } + if !bytes.Equal(next, normalized) { + t.Fatalf("next request snapshot should match replacement request") + } +} + +func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &websocketCompactionCaptureExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive} + if _, err := manager.Register(context.Background(), auth); err != nil { + t.Fatalf("Register auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + router := gin.New() + router.GET("/v1/responses/ws", h.ResponsesWebsocket) + router.POST("/v1/responses/compact", h.Compact) + + server := httptest.NewServer(router) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { + if errClose := conn.Close(); errClose != nil { + t.Fatalf("close websocket: %v", errClose) + } + }() + + requests := []string{ + `{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`, + `{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`, + } + for i := range requests { + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil { + t.Fatalf("write websocket message %d: %v", i+1, errWrite) + } + _, payload, errReadMessage := conn.ReadMessage() + if errReadMessage != nil { + t.Fatalf("read websocket message %d: %v", i+1, errReadMessage) + } + if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted { + t.Fatalf("message %d payload type = %s, want %s", i+1, got, wsEventTypeCompleted) + } + } + + compactResp, errPost := server.Client().Post( + server.URL+"/v1/responses/compact", + "application/json", + strings.NewReader(`{"model":"test-model","input":[{"type":"message","id":"summary-1"}]}`), + ) + if errPost != nil { + t.Fatalf("compact request failed: %v", errPost) + } + if errClose := compactResp.Body.Close(); errClose != nil { + t.Fatalf("close compact response body: %v", errClose) + } + if compactResp.StatusCode != http.StatusOK { + t.Fatalf("compact status = %d, want %d", compactResp.StatusCode, http.StatusOK) + } + + // Simulate a post-compaction client turn that replaces local history with a compacted transcript. + // The websocket handler must treat this as a state reset, not append it to stale pre-compaction state. + postCompact := `{"type":"response.create","input":[{"type":"function_call","id":"fc-compact","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-2"}]}` + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(postCompact)); errWrite != nil { + t.Fatalf("write post-compact websocket message: %v", errWrite) + } + _, payload, errReadMessage := conn.ReadMessage() + if errReadMessage != nil { + t.Fatalf("read post-compact websocket message: %v", errReadMessage) + } + if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted { + t.Fatalf("post-compact payload type = %s, want %s", got, wsEventTypeCompleted) + } + + executor.mu.Lock() + defer executor.mu.Unlock() + + if executor.compactPayload == nil { + t.Fatalf("compact payload was not captured") + } + if len(executor.streamPayloads) != 3 { + t.Fatalf("stream payload count = %d, want 3", len(executor.streamPayloads)) + } + + merged := executor.streamPayloads[2] + items := gjson.GetBytes(merged, "input").Array() + if len(items) != 2 { + t.Fatalf("merged input len = %d, want 2: %s", len(items), merged) + } + if items[0].Get("id").String() != "fc-compact" || + items[1].Get("id").String() != "msg-2" { + t.Fatalf("unexpected post-compact input order: %s", merged) + } + if items[0].Get("call_id").String() != "call-1" { + t.Fatalf("post-compact function call id = %s, want call-1", items[0].Get("call_id").String()) + } +}