From c1083cbfc61d4ebe8762c6f898a68b04e5fb5c99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=A7=9C=E6=81=92?= Date: Wed, 25 Mar 2026 17:03:14 +0800 Subject: [PATCH] fix(cursor): MCP tool call resume, H2 flow control, and token usage - Rewrite tool call mechanism from interrupt-resume to inline-wait mode: processH2SessionFrames no longer exits on mcpArgs; instead blocks on toolResultCh while continuing to handle KV/heartbeat messages, then sends MCP result and continues processing text in the same goroutine. Fixes the issue where server stopped generating text after resume. - Add switchable output channel (outMu/currentOut) so first HTTP response closes after tool_calls+[DONE], and resumed text goes to a new channel returned by resumeWithToolResults. Reset streamParam on switch so Translator produces fresh message_start/content_block_start events. - Implement send-side H2 flow control: track server's initial window size and WINDOW_UPDATE increments; Write() blocks when window exhausted. Fixes RST_STREAM FLOW_CONTROL_ERROR on large requests (178KB+). - Decode new InteractionUpdate fields: TurnEndedUpdate (field 14) as stream termination signal, HeartbeatUpdate (field 13) silently ignored, TokenDeltaUpdate (field 8) for token usage tracking. - Include token usage in final stop chunk (prompt_tokens estimated from payload size, completion_tokens from accumulated TokenDeltaUpdate deltas) so Claude CLI status bar shows non-zero token counts. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/auth/cursor/proto/decode.go | 52 ++ internal/auth/cursor/proto/h2stream.go | 144 ++++-- internal/runtime/executor/cursor_executor.go | 504 ++++++++++--------- 3 files changed, 399 insertions(+), 301 deletions(-) diff --git a/internal/auth/cursor/proto/decode.go b/internal/auth/cursor/proto/decode.go index cc10d483..898ca932 100644 --- a/internal/auth/cursor/proto/decode.go +++ b/internal/auth/cursor/proto/decode.go @@ -32,6 +32,9 @@ const ( ServerMsgExecBgShellSpawn // Rejected: background shell ServerMsgExecWriteShellStdin // Rejected: write shell stdin ServerMsgExecOther // Other exec types (respond with empty) + ServerMsgTurnEnded // Turn has ended (no more output) + ServerMsgHeartbeat // Server heartbeat + ServerMsgTokenDelta // Token usage delta ) // DecodedServerMessage holds parsed data from an AgentServerMessage. @@ -63,6 +66,9 @@ type DecodedServerMessage struct { // For other exec - the raw field number for building a response ExecFieldNumber int + + // For TokenDeltaUpdate + TokenDelta int64 } // DecodeAgentServerMessage parses an AgentServerMessage and returns @@ -160,6 +166,24 @@ func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) { case 3: // tool_call_completed - ignore but log log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)") + case 8: + // token_delta - extract token count + msg.Type = ServerMsgTokenDelta + msg.TokenDelta = decodeVarintField(val, 1) + log.Debugf("decodeInteractionUpdate: TokenDeltaUpdate tokens=%d", msg.TokenDelta) + case 13: + // heartbeat from server + msg.Type = ServerMsgHeartbeat + case 14: + // turn_ended - critical: model finished generating + msg.Type = ServerMsgTurnEnded + log.Debugf("decodeInteractionUpdate: TurnEndedUpdate - stream should end") + case 16: + // step_started - ignore + log.Debugf("decodeInteractionUpdate: StepStartedUpdate (ignored)") + case 17: + // step_completed - ignore + log.Debugf("decodeInteractionUpdate: StepCompletedUpdate (ignored)") default: log.Debugf("decodeInteractionUpdate: unknown field %d", num) } @@ -500,6 +524,34 @@ func decodeBytesField(data []byte, targetField protowire.Number) []byte { return nil } +// decodeVarintField extracts an int64 from the first matching varint field in a submessage. +func decodeVarintField(data []byte, targetField protowire.Number) int64 { + for len(data) > 0 { + num, typ, n := protowire.ConsumeTag(data) + if n < 0 { + return 0 + } + data = data[n:] + if typ == protowire.VarintType { + val, n := protowire.ConsumeVarint(data) + if n < 0 { + return 0 + } + data = data[n:] + if num == targetField { + return int64(val) + } + } else { + n := protowire.ConsumeFieldValue(num, typ, data) + if n < 0 { + return 0 + } + data = data[n:] + } + } + return 0 +} + // BlobIdHex returns the hex string of a blob ID for use as a map key. func BlobIdHex(blobId []byte) string { return hex.EncodeToString(blobId) diff --git a/internal/auth/cursor/proto/h2stream.go b/internal/auth/cursor/proto/h2stream.go index d08d099e..be3f7905 100644 --- a/internal/auth/cursor/proto/h2stream.go +++ b/internal/auth/cursor/proto/h2stream.go @@ -13,6 +13,11 @@ import ( "golang.org/x/net/http2/hpack" ) +const ( + defaultInitialWindowSize = 65535 // HTTP/2 default + maxFramePayload = 16384 // HTTP/2 default max frame size +) + // H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol. // Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer. type H2Stream struct { @@ -21,11 +26,17 @@ type H2Stream struct { streamID uint32 mu sync.Mutex id string // unique identifier for debugging - frameNum int64 // sequential frame counter for debugging + frameNum int64 // sequential frame counter for debugging dataCh chan []byte doneCh chan struct{} err error + + // Send-side flow control + sendWindow int32 // available bytes we can send on this stream + connWindow int32 // available bytes on the connection level + windowCond *sync.Cond // signaled when window is updated + windowMu sync.Mutex // protects sendWindow, connWindow } // ID returns the unique identifier for this stream (for logging). @@ -59,7 +70,7 @@ func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) { return nil, fmt.Errorf("h2: preface write failed: %w", err) } - // Send initial SETTINGS (with large initial window) + // Send initial SETTINGS (tell server how much WE can receive) if err := framer.WriteSettings( http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024}, http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100}, @@ -68,14 +79,17 @@ func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) { return nil, fmt.Errorf("h2: settings write failed: %w", err) } - // Connection-level window update (default is 65535, bump it up) + // Connection-level window update (for receiving) if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil { tlsConn.Close() return nil, fmt.Errorf("h2: window update failed: %w", err) } // Read and handle initial server frames (SETTINGS, WINDOW_UPDATE) - for i := 0; i < 5; i++ { + // Track server's initial window size (how much WE can send) + serverInitialWindowSize := int32(defaultInitialWindowSize) + connWindowSize := int32(defaultInitialWindowSize) // connection-level send window + for i := 0; i < 10; i++ { f, err := framer.ReadFrame() if err != nil { tlsConn.Close() @@ -84,12 +98,22 @@ func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) { switch sf := f.(type) { case *http2.SettingsFrame: if !sf.IsAck() { + sf.ForeachSetting(func(s http2.Setting) error { + if s.ID == http2.SettingInitialWindowSize { + serverInitialWindowSize = int32(s.Val) + log.Debugf("h2: server initial window size: %d", s.Val) + } + return nil + }) framer.WriteSettingsAck() } else { goto handshakeDone } case *http2.WindowUpdateFrame: - // ignore + if sf.StreamID == 0 { + connWindowSize += int32(sf.Increment) + log.Debugf("h2: initial conn window update: +%d, total=%d", sf.Increment, connWindowSize) + } default: // unexpected but continue } @@ -124,36 +148,53 @@ handshakeDone: } s := &H2Stream{ - framer: framer, - conn: tlsConn, - streamID: streamID, - dataCh: make(chan []byte, 256), - doneCh: make(chan struct{}), - id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")), - frameNum: 0, + framer: framer, + conn: tlsConn, + streamID: streamID, + dataCh: make(chan []byte, 256), + doneCh: make(chan struct{}), + id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")), + frameNum: 0, + sendWindow: serverInitialWindowSize, + connWindow: connWindowSize, } + s.windowCond = sync.NewCond(&s.windowMu) go s.readLoop() return s, nil } -// Write sends a DATA frame on the stream. +// Write sends a DATA frame on the stream, respecting flow control. func (s *H2Stream) Write(data []byte) error { - s.mu.Lock() - defer s.mu.Unlock() - const maxFrame = 16384 for len(data) > 0 { chunk := data - if len(chunk) > maxFrame { - chunk = data[:maxFrame] + if len(chunk) > maxFramePayload { + chunk = data[:maxFramePayload] } - data = data[len(chunk):] - if err := s.framer.WriteData(s.streamID, false, chunk); err != nil { + + // Wait for flow control window + s.windowMu.Lock() + for s.sendWindow <= 0 || s.connWindow <= 0 { + s.windowCond.Wait() + } + // Limit chunk to available window + allowed := int(s.sendWindow) + if int(s.connWindow) < allowed { + allowed = int(s.connWindow) + } + if len(chunk) > allowed { + chunk = chunk[:allowed] + } + s.sendWindow -= int32(len(chunk)) + s.connWindow -= int32(len(chunk)) + s.windowMu.Unlock() + + s.mu.Lock() + err := s.framer.WriteData(s.streamID, false, chunk) + s.mu.Unlock() + if err != nil { return err } - } - // Try to flush the underlying connection if it supports it - if flusher, ok := s.conn.(interface{ Flush() error }); ok { - flusher.Flush() + data = data[len(chunk):] } return nil } @@ -167,12 +208,13 @@ func (s *H2Stream) Done() <-chan struct{} { return s.doneCh } // Close tears down the connection. func (s *H2Stream) Close() { s.conn.Close() + // Unblock any writers waiting on flow control + s.windowCond.Broadcast() } func (s *H2Stream) readLoop() { defer close(s.doneCh) defer close(s.dataCh) - log.Debugf("h2stream[%s]: readLoop started for streamID=%d", s.id, s.streamID) for { f, err := s.framer.ReadFrame() @@ -180,71 +222,47 @@ func (s *H2Stream) readLoop() { if err != io.EOF { s.err = err log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err) - } else { - log.Debugf("h2stream[%s]: readLoop EOF", s.id) } return } - // Increment frame counter for debugging + // Increment frame counter s.mu.Lock() s.frameNum++ - frameNum := s.frameNum s.mu.Unlock() switch frame := f.(type) { case *http2.DataFrame: - log.Debugf("h2stream[%s]: frame#%d received DATA frame streamID=%d, len=%d, endStream=%v", s.id, frameNum, frame.StreamID, len(frame.Data()), frame.StreamEnded()) if frame.StreamID == s.streamID && len(frame.Data()) > 0 { cp := make([]byte, len(frame.Data())) copy(cp, frame.Data()) - // Log first 20 bytes for debugging - previewLen := len(cp) - if previewLen > 20 { - previewLen = 20 - } - log.Debugf("h2stream[%s]: frame#%d sending to dataCh: len=%d, dataCh len=%d/%d, first bytes: %x (%q)", s.id, frameNum, len(cp), len(s.dataCh), cap(s.dataCh), cp[:previewLen], string(cp[:previewLen])) s.dataCh <- cp - // Flow control: send WINDOW_UPDATE + // Flow control: send WINDOW_UPDATE for received data s.mu.Lock() s.framer.WriteWindowUpdate(0, uint32(len(cp))) s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp))) s.mu.Unlock() } if frame.StreamEnded() { - log.Debugf("h2stream[%s]: frame#%d DATA frame has END_STREAM flag, stream ending", s.id, frameNum) return } case *http2.HeadersFrame: - // Decode HPACK headers for debugging - decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) { - log.Debugf("h2stream[%s]: frame#%d header: %s = %q", s.id, frameNum, hf.Name, hf.Value) - // Check for error status - if hf.Name == "grpc-status" || hf.Name == ":status" && hf.Value != "200" { - log.Warnf("h2stream[%s]: frame#%d received error status header: %s = %q", s.id, frameNum, hf.Name, hf.Value) - } - }) - decoder.Write(frame.HeaderBlockFragment()) - log.Debugf("h2stream[%s]: frame#%d received HEADERS frame streamID=%d, endStream=%v", s.id, frameNum, frame.StreamID, frame.StreamEnded()) if frame.StreamEnded() { - log.Debugf("h2stream[%s]: frame#%d HEADERS frame has END_STREAM flag, stream ending", s.id, frameNum) return } case *http2.RSTStreamFrame: s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode) - log.Debugf("h2stream[%s]: frame#%d received RST_STREAM code=%d", s.id, frameNum, frame.ErrCode) + log.Debugf("h2stream[%s]: received RST_STREAM code=%d", s.id, frame.ErrCode) return case *http2.GoAwayFrame: s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode) - log.Debugf("h2stream[%s]: received GOAWAY code=%d", s.id, frame.ErrCode) return case *http2.PingFrame: - log.Debugf("h2stream[%s]: received PING frame, isAck=%v", s.id, frame.IsAck()) if !frame.IsAck() { s.mu.Lock() s.framer.WritePing(true, frame.Data) @@ -252,15 +270,33 @@ func (s *H2Stream) readLoop() { } case *http2.SettingsFrame: - log.Debugf("h2stream[%s]: received SETTINGS frame, isAck=%v, numSettings=%d", s.id, frame.IsAck(), frame.NumSettings()) if !frame.IsAck() { + // Check for window size changes + frame.ForeachSetting(func(setting http2.Setting) error { + if setting.ID == http2.SettingInitialWindowSize { + s.windowMu.Lock() + delta := int32(setting.Val) - s.sendWindow + s.sendWindow += delta + s.windowMu.Unlock() + s.windowCond.Broadcast() + } + return nil + }) s.mu.Lock() s.framer.WriteSettingsAck() s.mu.Unlock() } case *http2.WindowUpdateFrame: - log.Debugf("h2stream[%s]: received WINDOW_UPDATE frame", s.id) + // Update send-side flow control window + s.windowMu.Lock() + if frame.StreamID == 0 { + s.connWindow += int32(frame.Increment) + } else if frame.StreamID == s.streamID { + s.sendWindow += int32(frame.Increment) + } + s.windowMu.Unlock() + s.windowCond.Broadcast() } } } diff --git a/internal/runtime/executor/cursor_executor.go b/internal/runtime/executor/cursor_executor.go index 5519e92d..bba06bc7 100644 --- a/internal/runtime/executor/cursor_executor.go +++ b/internal/runtime/executor/cursor_executor.go @@ -6,7 +6,6 @@ import ( "crypto/sha256" "crypto/tls" "encoding/base64" - "encoding/binary" "encoding/hex" "encoding/json" "fmt" @@ -47,12 +46,15 @@ type CursorExecutor struct { } type cursorSession struct { - stream *cursorproto.H2Stream - blobStore map[string][]byte - mcpTools []cursorproto.McpToolDef - pending []pendingMcpExec - cancel context.CancelFunc // cancels the session-scoped heartbeat (NOT tied to HTTP request) - createdAt time.Time + stream *cursorproto.H2Stream + blobStore map[string][]byte + mcpTools []cursorproto.McpToolDef + pending []pendingMcpExec + cancel context.CancelFunc // cancels the session-scoped heartbeat (NOT tied to HTTP request) + createdAt time.Time + toolResultCh chan []toolResultInfo // receives tool results from the next HTTP request + resumeOutCh chan cliproxyexecutor.StreamChunk // output channel for resumed response + switchOutput func(ch chan cliproxyexecutor.StreamChunk) // callback to switch output channel } type pendingMcpExec struct { @@ -235,6 +237,8 @@ func (e *CursorExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r fullText.WriteString(text) }, nil, + nil, + nil, // tokenUsage - non-streaming ) id := "chatcmpl-" + uuid.New().String()[:28] @@ -341,9 +345,30 @@ func (e *CursorExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A chatId := "chatcmpl-" + uuid.New().String()[:28] created := time.Now().Unix() - // sendChunk builds an OpenAI SSE line and optionally translates to target format var streamParam any - sendChunk := func(delta string, finishReason string) { + + // Tool result channel for inline mode. processH2SessionFrames blocks on it + // when mcpArgs is received, while continuing to handle KV/heartbeat. + toolResultCh := make(chan []toolResultInfo, 1) + + // Switchable output: initially writes to `chunks`. After mcpArgs, the + // onMcpExec callback closes `chunks` (ending the first HTTP response), + // then processH2SessionFrames blocks on toolResultCh. When results arrive, + // it switches to `resumeOutCh` (created by resumeWithToolResults). + var outMu sync.Mutex + currentOut := chunks + + emitToOut := func(chunk cliproxyexecutor.StreamChunk) { + outMu.Lock() + out := currentOut + outMu.Unlock() + if out != nil { + out <- chunk + } + } + + // Wrap sendChunk/sendDone to use emitToOut + sendChunkSwitchable := func(delta string, finishReason string) { fr := "null" if finishReason != "" { fr = finishReason @@ -355,95 +380,146 @@ func (e *CursorExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A if needsTranslate { translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam) for _, t := range translated { - chunks <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)} + emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)}) } } else { - chunks <- cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)} + emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)}) } } - sendDone := func() { + sendDoneSwitchable := func() { if needsTranslate { done := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, []byte("data: [DONE]\n"), &streamParam) for _, d := range done { - chunks <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(d)} + emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(d)}) } } else { - chunks <- cliproxyexecutor.StreamChunk{Payload: []byte("[DONE]")} + emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte("[DONE]")}) } } go func() { - defer close(chunks) - + var resumeOutCh chan cliproxyexecutor.StreamChunk + _ = resumeOutCh thinkingActive := false toolCallIndex := 0 - mcpExecReceived := false + usage := &cursorTokenUsage{} + usage.setInputEstimate(len(payload)) processH2SessionFrames(sessionCtx, stream, params.BlobStore, params.McpTools, func(text string, isThinking bool) { if isThinking { if !thinkingActive { thinkingActive = true - sendChunk(`{"role":"assistant","content":""}`, "") + sendChunkSwitchable(`{"role":"assistant","content":""}`, "") } - sendChunk(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "") + sendChunkSwitchable(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "") } else { if thinkingActive { thinkingActive = false - sendChunk(`{"content":""}`, "") + sendChunkSwitchable(`{"content":""}`, "") } - sendChunk(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "") + sendChunkSwitchable(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "") } }, func(exec pendingMcpExec) { - mcpExecReceived = true if thinkingActive { thinkingActive = false - sendChunk(`{"content":""}`, "") + sendChunkSwitchable(`{"content":""}`, "") } toolCallJSON := fmt.Sprintf(`{"tool_calls":[{"index":%d,"id":"%s","type":"function","function":{"name":"%s","arguments":%s}}]}`, toolCallIndex, exec.ToolCallId, exec.ToolName, jsonString(exec.Args)) toolCallIndex++ - sendChunk(toolCallJSON, "") - sendChunk(`{}`, `"tool_calls"`) - sendDone() + sendChunkSwitchable(toolCallJSON, "") + sendChunkSwitchable(`{}`, `"tool_calls"`) + sendDoneSwitchable() - // Save session for resume — keep stream alive. - // The heartbeat goroutine continues running (session-scoped context), - // keeping the H2 connection alive while the MCP tool executes. - log.Debugf("cursor: saving session %s for MCP tool resume (tool=%s, streamID=%s)", sessionKey, exec.ToolName, stream.ID()) + // Close current output to end the current HTTP SSE response + outMu.Lock() + if currentOut != nil { + close(currentOut) + currentOut = nil + } + outMu.Unlock() + + // Create new resume output channel, reuse the same toolResultCh + resumeOut := make(chan cliproxyexecutor.StreamChunk, 64) + log.Debugf("cursor: saving session %s for MCP tool resume (tool=%s)", sessionKey, exec.ToolName) e.mu.Lock() e.sessions[sessionKey] = &cursorSession{ - stream: stream, - blobStore: params.BlobStore, - mcpTools: params.McpTools, - pending: []pendingMcpExec{exec}, - cancel: sessionCancel, - createdAt: time.Now(), + stream: stream, + blobStore: params.BlobStore, + mcpTools: params.McpTools, + pending: []pendingMcpExec{exec}, + cancel: sessionCancel, + createdAt: time.Now(), + toolResultCh: toolResultCh, // reuse same channel across rounds + resumeOutCh: resumeOut, + switchOutput: func(ch chan cliproxyexecutor.StreamChunk) { + outMu.Lock() + currentOut = ch + // Reset translator state so the new HTTP response gets + // a fresh message_start, content_block_start, etc. + streamParam = nil + // New response needs its own message ID + chatId = "chatcmpl-" + uuid.New().String()[:28] + created = time.Now().Unix() + outMu.Unlock() + }, } e.mu.Unlock() + resumeOutCh = resumeOut + + // processH2SessionFrames will now block on toolResultCh (inline wait loop) + // while continuing to handle KV messages }, + toolResultCh, + usage, ) - if !mcpExecReceived { - if thinkingActive { - sendChunk(`{"content":""}`, "") - } - sendChunk(`{}`, `"stop"`) - sendDone() - sessionCancel() - stream.Close() + // processH2SessionFrames returned — stream is done + if thinkingActive { + sendChunkSwitchable(`{"content":""}`, "") } - // If mcpExecReceived, do NOT close stream or cancel — session resume will handle it + // Include token usage in the final stop chunk + inputTok, outputTok := usage.get() + stopDelta := fmt.Sprintf(`{},"usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}`, + inputTok, outputTok, inputTok+outputTok) + // Build the stop chunk with usage embedded in the choices array level + fr := `"stop"` + openaiJSON := fmt.Sprintf(`{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"%s","choices":[{"index":0,"delta":{},"finish_reason":%s}],"usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}}`, + chatId, created, parsed.Model, fr, inputTok, outputTok, inputTok+outputTok) + sseLine := []byte("data: " + openaiJSON + "\n") + if needsTranslate { + translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam) + for _, t := range translated { + emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)}) + } + } else { + emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)}) + } + sendDoneSwitchable() + _ = stopDelta // unused + + // Close whatever output channel is still active + outMu.Lock() + if currentOut != nil { + close(currentOut) + currentOut = nil + } + outMu.Unlock() + sessionCancel() + stream.Close() }() return &cliproxyexecutor.StreamResult{Chunks: chunks}, nil } -// resumeWithToolResults sends MCP tool results back on the existing H2 stream, -// then continues reading the stream for the model's response. -// Mirrors resumeWithToolResults() in cursor-fetch.ts. +// resumeWithToolResults injects tool results into the running processH2SessionFrames +// via the toolResultCh channel. The original goroutine from ExecuteStream is still alive, +// blocking on toolResultCh. Once we send the results, it sends the MCP result to Cursor +// and continues processing the response text — all in the same goroutine that has been +// handling KV messages the whole time. func (e *CursorExecutor) resumeWithToolResults( ctx context.Context, session *cursorSession, @@ -453,208 +529,29 @@ func (e *CursorExecutor) resumeWithToolResults( originalPayload, payload []byte, needsTranslate bool, ) (*cliproxyexecutor.StreamResult, error) { - stream := session.stream - log.Debugf("cursor: resumeWithToolResults: using stream ID=%s", stream.ID()) + log.Debugf("cursor: resumeWithToolResults: injecting %d tool results via channel", len(parsed.ToolResults)) - // Cancel old session-scoped heartbeat before starting a new one - session.cancel() - - // CRITICAL: Process any pending messages from the channel before sending MCP result. - // After the initial processH2SessionFrames returned (upon receiving MCP args), - // the server may have sent more data (KV messages, text deltas) that are now buffered in dataCh. - // We must process KV messages (respond to them) but discard text deltas (stale responses). - drainedCount := 0 - drainedBytes := 0 - kvProcessedCount := 0 - for { - select { - case staleData, ok := <-stream.Data(): - if !ok { - log.Debugf("cursor: resumeWithToolResults: dataCh closed during drain") - break - } - drainedCount++ - drainedBytes += len(staleData) - log.Debugf("cursor: resumeWithToolResults: processing stale data #%d: len=%d, first bytes: %x (%q)", drainedCount, len(staleData), staleData[:min(20, len(staleData))], string(staleData[:min(20, len(staleData))])) - - // Try to decode and handle KV messages (they need responses) - if len(staleData) > 5 { - frameLen := binary.BigEndian.Uint32(staleData[1:5]) - if int(frameLen)+5 <= len(staleData) { - payload := staleData[5 : 5+frameLen] - msg, err := cursorproto.DecodeAgentServerMessage(payload) - if err == nil && msg.Type == cursorproto.ServerMsgKvGetBlob { - // Respond to KV getBlob - blobKey := cursorproto.BlobIdHex(msg.BlobId) - data := session.blobStore[blobKey] - log.Debugf("cursor: resumeWithToolResults: responding to stale KV getBlob kvId=%d blobKey=%s found=%v", msg.KvId, blobKey, len(data) > 0) - resp := cursorproto.EncodeKvGetBlobResult(msg.KvId, data) - stream.Write(cursorproto.FrameConnectMessage(resp, 0)) - kvProcessedCount++ - continue - } else if err == nil && msg.Type == cursorproto.ServerMsgKvSetBlob { - // Respond to KV setBlob - blobKey := cursorproto.BlobIdHex(msg.BlobId) - session.blobStore[blobKey] = append([]byte(nil), msg.BlobData...) - log.Debugf("cursor: resumeWithToolResults: responding to stale KV setBlob kvId=%d blobKey=%s", msg.KvId, blobKey) - resp := cursorproto.EncodeKvSetBlobResult(msg.KvId) - stream.Write(cursorproto.FrameConnectMessage(resp, 0)) - kvProcessedCount++ - continue - } - } - } - log.Debugf("cursor: resumeWithToolResults: discarding non-KV stale data") - default: - // No more data in channel - goto drainDone - } + if session.toolResultCh == nil { + return nil, fmt.Errorf("cursor: session has no toolResultCh (stale session?)") } -drainDone: - if drainedCount > 0 { - log.Debugf("cursor: resumeWithToolResults: processed %d stale frames (%d bytes total, %d KV responded)", drainedCount, drainedBytes, kvProcessedCount) + if session.resumeOutCh == nil { + return nil, fmt.Errorf("cursor: session has no resumeOutCh") } - // Send MCP results back on the same H2 stream - for _, exec := range session.pending { - var content string - var isError bool - found := false - for _, tr := range parsed.ToolResults { - if tr.ToolCallId == exec.ToolCallId { - content = tr.Content - found = true - break - } - } - if !found { - content = "Tool result not provided" - isError = true - } - log.Debugf("cursor: sending MCP result for tool=%s callId=%s execMsgId=%d execId=%s contentLen=%d isError=%v", - exec.ToolName, exec.ToolCallId, exec.ExecMsgId, exec.ExecId, len(content), isError) - resultBytes := cursorproto.EncodeExecMcpResult(exec.ExecMsgId, exec.ExecId, content, isError) - framedResult := cursorproto.FrameConnectMessage(resultBytes, 0) - // Log the framed result details for debugging - log.Debugf("cursor: MCP result frame size=%d bytes", len(framedResult)) - log.Debugf("cursor: MCP result frame header: flags=%d, len=%d", framedResult[0], binary.BigEndian.Uint32(framedResult[1:5])) - log.Debugf("cursor: MCP result protobuf hex (first 50 bytes): %x", resultBytes[:min(50, len(resultBytes))]) - if err := stream.Write(framedResult); err != nil { - stream.Close() - return nil, fmt.Errorf("cursor: failed to send MCP result: %w", err) - } - log.Debugf("cursor: MCP result sent successfully for tool=%s", exec.ToolName) + log.Debugf("cursor: resumeWithToolResults: switching output to resumeOutCh and injecting results") + + // Switch the output channel BEFORE injecting results, so that when + // processH2SessionFrames unblocks and starts emitting text, it writes + // to the resumeOutCh which the new HTTP handler is reading from. + if session.switchOutput != nil { + session.switchOutput(session.resumeOutCh) } - // Start new session-scoped heartbeat (independent of HTTP request context) - sessionCtx, sessionCancel := context.WithCancel(context.Background()) - go cursorH2Heartbeat(sessionCtx, stream) - log.Debugf("cursor: started new heartbeat for resumed session, waiting for Cursor response...") + // Inject tool results — this unblocks the waiting processH2SessionFrames + session.toolResultCh <- parsed.ToolResults - chunks := make(chan cliproxyexecutor.StreamChunk, 64) - chatId := "chatcmpl-" + uuid.New().String()[:28] - created := time.Now().Unix() - sessionKey := deriveSessionKey(parsed.Model, parsed.Messages) - - var streamParam any - sendChunk := func(delta string, finishReason string) { - fr := "null" - if finishReason != "" { - fr = finishReason - } - openaiJSON := fmt.Sprintf(`{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"%s","choices":[{"index":0,"delta":%s,"finish_reason":%s}]}`, - chatId, created, parsed.Model, delta, fr) - sseLine := []byte("data: " + openaiJSON + "\n") - - if needsTranslate { - translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam) - for _, t := range translated { - chunks <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)} - } - } else { - chunks <- cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)} - } - } - - sendDone := func() { - if needsTranslate { - done := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, []byte("data: [DONE]\n"), &streamParam) - for _, d := range done { - chunks <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(d)} - } - } else { - chunks <- cliproxyexecutor.StreamChunk{Payload: []byte("[DONE]")} - } - } - - go func() { - defer func() { - log.Debugf("cursor: resume goroutine exiting, closing chunks channel") - close(chunks) - }() - log.Debugf("cursor: resume goroutine started, entering processH2SessionFrames") - - thinkingActive := false - toolCallIndex := 0 - mcpExecReceived := false - - processH2SessionFrames(sessionCtx, stream, session.blobStore, session.mcpTools, - func(text string, isThinking bool) { - log.Debugf("cursor: resume received text (isThinking=%v, len=%d)", isThinking, len(text)) - if isThinking { - if !thinkingActive { - thinkingActive = true - sendChunk(`{"role":"assistant","content":""}`, "") - } - sendChunk(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "") - } else { - if thinkingActive { - thinkingActive = false - sendChunk(`{"content":""}`, "") - } - sendChunk(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "") - } - }, - func(exec pendingMcpExec) { - mcpExecReceived = true - if thinkingActive { - thinkingActive = false - sendChunk(`{"content":""}`, "") - } - toolCallJSON := fmt.Sprintf(`{"tool_calls":[{"index":%d,"id":"%s","type":"function","function":{"name":"%s","arguments":%s}}]}`, - toolCallIndex, exec.ToolCallId, exec.ToolName, jsonString(exec.Args)) - toolCallIndex++ - sendChunk(toolCallJSON, "") - sendChunk(`{}`, `"tool_calls"`) - sendDone() - - // Save session again for another round of tool calls - log.Debugf("cursor: saving session %s for another MCP tool resume (tool=%s, streamID=%s)", sessionKey, exec.ToolName, stream.ID()) - e.mu.Lock() - e.sessions[sessionKey] = &cursorSession{ - stream: stream, - blobStore: session.blobStore, - mcpTools: session.mcpTools, - pending: []pendingMcpExec{exec}, - cancel: sessionCancel, - createdAt: time.Now(), - } - e.mu.Unlock() - }, - ) - - if !mcpExecReceived { - if thinkingActive { - sendChunk(`{"content":""}`, "") - } - sendChunk(`{}`, `"stop"`) - sendDone() - sessionCancel() - stream.Close() - } - }() - - return &cliproxyexecutor.StreamResult{Chunks: chunks}, nil + // Return the resumeOutCh for the new HTTP handler to read from + return &cliproxyexecutor.StreamResult{Chunks: session.resumeOutCh}, nil } // --- H2Stream helpers --- @@ -693,6 +590,35 @@ func cursorH2Heartbeat(ctx context.Context, stream *cursorproto.H2Stream) { // --- Response processing --- +// cursorTokenUsage tracks token counts from Cursor's TokenDeltaUpdate messages. +type cursorTokenUsage struct { + mu sync.Mutex + outputTokens int64 + inputTokensEst int64 // estimated from request payload size +} + +func (u *cursorTokenUsage) addOutput(delta int64) { + u.mu.Lock() + defer u.mu.Unlock() + u.outputTokens += delta +} + +func (u *cursorTokenUsage) setInputEstimate(payloadBytes int) { + u.mu.Lock() + defer u.mu.Unlock() + // Rough estimate: ~4 bytes per token for mixed content + u.inputTokensEst = int64(payloadBytes / 4) + if u.inputTokensEst < 1 { + u.inputTokensEst = 1 + } +} + +func (u *cursorTokenUsage) get() (input, output int64) { + u.mu.Lock() + defer u.mu.Unlock() + return u.inputTokensEst, u.outputTokens +} + func processH2SessionFrames( ctx context.Context, stream *cursorproto.H2Stream, @@ -700,6 +626,8 @@ func processH2SessionFrames( mcpTools []cursorproto.McpToolDef, onText func(text string, isThinking bool), onMcpExec func(exec pendingMcpExec), + toolResultCh <-chan []toolResultInfo, // nil for no tool result injection; non-nil to wait for results + tokenUsage *cursorTokenUsage, // tracks accumulated token usage (may be nil) ) { var buf bytes.Buffer rejectReason := "Tool not available in this environment. Use the MCP tools provided instead." @@ -762,6 +690,20 @@ func processH2SessionFrames( case cursorproto.ServerMsgThinkingCompleted: // Handled by caller + case cursorproto.ServerMsgTurnEnded: + log.Debugf("cursor: TurnEnded received, stream will finish") + return + + case cursorproto.ServerMsgHeartbeat: + // Server heartbeat, ignore silently + continue + + case cursorproto.ServerMsgTokenDelta: + if tokenUsage != nil && msg.TokenDelta > 0 { + tokenUsage.addOutput(msg.TokenDelta) + } + continue + case cursorproto.ServerMsgKvGetBlob: blobKey := cursorproto.BlobIdHex(msg.BlobId) data := blobStore[blobKey] @@ -785,17 +727,85 @@ func processH2SessionFrames( if toolCallId == "" { toolCallId = uuid.New().String() } - // Debug: log the received execId from server log.Debugf("cursor: received mcpArgs from server: execMsgId=%d execId=%q toolName=%s toolCallId=%s", msg.ExecMsgId, msg.ExecId, msg.McpToolName, toolCallId) - onMcpExec(pendingMcpExec{ + pending := pendingMcpExec{ ExecMsgId: msg.ExecMsgId, ExecId: msg.ExecId, ToolCallId: toolCallId, ToolName: msg.McpToolName, Args: decodedArgs, - }) - return + } + onMcpExec(pending) + + if toolResultCh == nil { + return + } + + // Inline mode: wait for tool result while handling KV/heartbeat + log.Debugf("cursor: waiting for tool result on channel (inline mode)...") + var toolResults []toolResultInfo + waitLoop: + for { + select { + case <-ctx.Done(): + return + case results, ok := <-toolResultCh: + if !ok { + return + } + toolResults = results + break waitLoop + case waitData, ok := <-stream.Data(): + if !ok { + return + } + buf.Write(waitData) + for { + cb := buf.Bytes() + if len(cb) == 0 { + break + } + wf, wp, wc, wok := cursorproto.ParseConnectFrame(cb) + if !wok { + break + } + buf.Next(wc) + if wf&cursorproto.ConnectEndStreamFlag != 0 { + continue + } + wmsg, werr := cursorproto.DecodeAgentServerMessage(wp) + if werr != nil { + continue + } + switch wmsg.Type { + case cursorproto.ServerMsgKvGetBlob: + blobKey := cursorproto.BlobIdHex(wmsg.BlobId) + d := blobStore[blobKey] + stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeKvGetBlobResult(wmsg.KvId, d), 0)) + case cursorproto.ServerMsgKvSetBlob: + blobKey := cursorproto.BlobIdHex(wmsg.BlobId) + blobStore[blobKey] = append([]byte(nil), wmsg.BlobData...) + stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeKvSetBlobResult(wmsg.KvId), 0)) + case cursorproto.ServerMsgExecRequestCtx: + stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecRequestContextResult(wmsg.ExecMsgId, wmsg.ExecId, mcpTools), 0)) + } + } + case <-stream.Done(): + return + } + } + + // Send MCP result + for _, tr := range toolResults { + if tr.ToolCallId == pending.ToolCallId { + log.Debugf("cursor: sending inline MCP result for tool=%s", pending.ToolName) + resultBytes := cursorproto.EncodeExecMcpResult(pending.ExecMsgId, pending.ExecId, tr.Content, false) + stream.Write(cursorproto.FrameConnectMessage(resultBytes, 0)) + break + } + } + continue } case cursorproto.ServerMsgExecReadArgs: