diff --git a/internal/auth/cursor/proto/decode.go b/internal/auth/cursor/proto/decode.go
index cc10d483..898ca932 100644
--- a/internal/auth/cursor/proto/decode.go
+++ b/internal/auth/cursor/proto/decode.go
@@ -32,6 +32,9 @@ const (
ServerMsgExecBgShellSpawn // Rejected: background shell
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
ServerMsgExecOther // Other exec types (respond with empty)
+ ServerMsgTurnEnded // Turn has ended (no more output)
+ ServerMsgHeartbeat // Server heartbeat
+ ServerMsgTokenDelta // Token usage delta
)
// DecodedServerMessage holds parsed data from an AgentServerMessage.
@@ -63,6 +66,9 @@ type DecodedServerMessage struct {
// For other exec - the raw field number for building a response
ExecFieldNumber int
+
+ // For TokenDeltaUpdate
+ TokenDelta int64
}
// DecodeAgentServerMessage parses an AgentServerMessage and returns
@@ -160,6 +166,24 @@ func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) {
case 3:
// tool_call_completed - ignore but log
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
+ case 8:
+ // token_delta - extract token count
+ msg.Type = ServerMsgTokenDelta
+ msg.TokenDelta = decodeVarintField(val, 1)
+ log.Debugf("decodeInteractionUpdate: TokenDeltaUpdate tokens=%d", msg.TokenDelta)
+ case 13:
+ // heartbeat from server
+ msg.Type = ServerMsgHeartbeat
+ case 14:
+ // turn_ended - critical: model finished generating
+ msg.Type = ServerMsgTurnEnded
+ log.Debugf("decodeInteractionUpdate: TurnEndedUpdate - stream should end")
+ case 16:
+ // step_started - ignore
+ log.Debugf("decodeInteractionUpdate: StepStartedUpdate (ignored)")
+ case 17:
+ // step_completed - ignore
+ log.Debugf("decodeInteractionUpdate: StepCompletedUpdate (ignored)")
default:
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
}
@@ -500,6 +524,34 @@ func decodeBytesField(data []byte, targetField protowire.Number) []byte {
return nil
}
+// decodeVarintField extracts an int64 from the first matching varint field in a submessage.
+func decodeVarintField(data []byte, targetField protowire.Number) int64 {
+ for len(data) > 0 {
+ num, typ, n := protowire.ConsumeTag(data)
+ if n < 0 {
+ return 0
+ }
+ data = data[n:]
+ if typ == protowire.VarintType {
+ val, n := protowire.ConsumeVarint(data)
+ if n < 0 {
+ return 0
+ }
+ data = data[n:]
+ if num == targetField {
+ return int64(val)
+ }
+ } else {
+ n := protowire.ConsumeFieldValue(num, typ, data)
+ if n < 0 {
+ return 0
+ }
+ data = data[n:]
+ }
+ }
+ return 0
+}
+
// BlobIdHex returns the hex string of a blob ID for use as a map key.
func BlobIdHex(blobId []byte) string {
return hex.EncodeToString(blobId)
diff --git a/internal/auth/cursor/proto/h2stream.go b/internal/auth/cursor/proto/h2stream.go
index d08d099e..be3f7905 100644
--- a/internal/auth/cursor/proto/h2stream.go
+++ b/internal/auth/cursor/proto/h2stream.go
@@ -13,6 +13,11 @@ import (
"golang.org/x/net/http2/hpack"
)
+const (
+ defaultInitialWindowSize = 65535 // HTTP/2 default
+ maxFramePayload = 16384 // HTTP/2 default max frame size
+)
+
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
type H2Stream struct {
@@ -21,11 +26,17 @@ type H2Stream struct {
streamID uint32
mu sync.Mutex
id string // unique identifier for debugging
- frameNum int64 // sequential frame counter for debugging
+ frameNum int64 // sequential frame counter for debugging
dataCh chan []byte
doneCh chan struct{}
err error
+
+ // Send-side flow control
+ sendWindow int32 // available bytes we can send on this stream
+ connWindow int32 // available bytes on the connection level
+ windowCond *sync.Cond // signaled when window is updated
+ windowMu sync.Mutex // protects sendWindow, connWindow
}
// ID returns the unique identifier for this stream (for logging).
@@ -59,7 +70,7 @@ func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
return nil, fmt.Errorf("h2: preface write failed: %w", err)
}
- // Send initial SETTINGS (with large initial window)
+ // Send initial SETTINGS (tell server how much WE can receive)
if err := framer.WriteSettings(
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
@@ -68,14 +79,17 @@ func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
return nil, fmt.Errorf("h2: settings write failed: %w", err)
}
- // Connection-level window update (default is 65535, bump it up)
+ // Connection-level window update (for receiving)
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: window update failed: %w", err)
}
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
- for i := 0; i < 5; i++ {
+ // Track server's initial window size (how much WE can send)
+ serverInitialWindowSize := int32(defaultInitialWindowSize)
+ connWindowSize := int32(defaultInitialWindowSize) // connection-level send window
+ for i := 0; i < 10; i++ {
f, err := framer.ReadFrame()
if err != nil {
tlsConn.Close()
@@ -84,12 +98,22 @@ func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
switch sf := f.(type) {
case *http2.SettingsFrame:
if !sf.IsAck() {
+ sf.ForeachSetting(func(s http2.Setting) error {
+ if s.ID == http2.SettingInitialWindowSize {
+ serverInitialWindowSize = int32(s.Val)
+ log.Debugf("h2: server initial window size: %d", s.Val)
+ }
+ return nil
+ })
framer.WriteSettingsAck()
} else {
goto handshakeDone
}
case *http2.WindowUpdateFrame:
- // ignore
+ if sf.StreamID == 0 {
+ connWindowSize += int32(sf.Increment)
+ log.Debugf("h2: initial conn window update: +%d, total=%d", sf.Increment, connWindowSize)
+ }
default:
// unexpected but continue
}
@@ -124,36 +148,53 @@ handshakeDone:
}
s := &H2Stream{
- framer: framer,
- conn: tlsConn,
- streamID: streamID,
- dataCh: make(chan []byte, 256),
- doneCh: make(chan struct{}),
- id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
- frameNum: 0,
+ framer: framer,
+ conn: tlsConn,
+ streamID: streamID,
+ dataCh: make(chan []byte, 256),
+ doneCh: make(chan struct{}),
+ id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
+ frameNum: 0,
+ sendWindow: serverInitialWindowSize,
+ connWindow: connWindowSize,
}
+ s.windowCond = sync.NewCond(&s.windowMu)
go s.readLoop()
return s, nil
}
-// Write sends a DATA frame on the stream.
+// Write sends a DATA frame on the stream, respecting flow control.
func (s *H2Stream) Write(data []byte) error {
- s.mu.Lock()
- defer s.mu.Unlock()
- const maxFrame = 16384
for len(data) > 0 {
chunk := data
- if len(chunk) > maxFrame {
- chunk = data[:maxFrame]
+ if len(chunk) > maxFramePayload {
+ chunk = data[:maxFramePayload]
}
- data = data[len(chunk):]
- if err := s.framer.WriteData(s.streamID, false, chunk); err != nil {
+
+ // Wait for flow control window
+ s.windowMu.Lock()
+ for s.sendWindow <= 0 || s.connWindow <= 0 {
+ s.windowCond.Wait()
+ }
+ // Limit chunk to available window
+ allowed := int(s.sendWindow)
+ if int(s.connWindow) < allowed {
+ allowed = int(s.connWindow)
+ }
+ if len(chunk) > allowed {
+ chunk = chunk[:allowed]
+ }
+ s.sendWindow -= int32(len(chunk))
+ s.connWindow -= int32(len(chunk))
+ s.windowMu.Unlock()
+
+ s.mu.Lock()
+ err := s.framer.WriteData(s.streamID, false, chunk)
+ s.mu.Unlock()
+ if err != nil {
return err
}
- }
- // Try to flush the underlying connection if it supports it
- if flusher, ok := s.conn.(interface{ Flush() error }); ok {
- flusher.Flush()
+ data = data[len(chunk):]
}
return nil
}
@@ -167,12 +208,13 @@ func (s *H2Stream) Done() <-chan struct{} { return s.doneCh }
// Close tears down the connection.
func (s *H2Stream) Close() {
s.conn.Close()
+ // Unblock any writers waiting on flow control
+ s.windowCond.Broadcast()
}
func (s *H2Stream) readLoop() {
defer close(s.doneCh)
defer close(s.dataCh)
- log.Debugf("h2stream[%s]: readLoop started for streamID=%d", s.id, s.streamID)
for {
f, err := s.framer.ReadFrame()
@@ -180,71 +222,47 @@ func (s *H2Stream) readLoop() {
if err != io.EOF {
s.err = err
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
- } else {
- log.Debugf("h2stream[%s]: readLoop EOF", s.id)
}
return
}
- // Increment frame counter for debugging
+ // Increment frame counter
s.mu.Lock()
s.frameNum++
- frameNum := s.frameNum
s.mu.Unlock()
switch frame := f.(type) {
case *http2.DataFrame:
- log.Debugf("h2stream[%s]: frame#%d received DATA frame streamID=%d, len=%d, endStream=%v", s.id, frameNum, frame.StreamID, len(frame.Data()), frame.StreamEnded())
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
cp := make([]byte, len(frame.Data()))
copy(cp, frame.Data())
- // Log first 20 bytes for debugging
- previewLen := len(cp)
- if previewLen > 20 {
- previewLen = 20
- }
- log.Debugf("h2stream[%s]: frame#%d sending to dataCh: len=%d, dataCh len=%d/%d, first bytes: %x (%q)", s.id, frameNum, len(cp), len(s.dataCh), cap(s.dataCh), cp[:previewLen], string(cp[:previewLen]))
s.dataCh <- cp
- // Flow control: send WINDOW_UPDATE
+ // Flow control: send WINDOW_UPDATE for received data
s.mu.Lock()
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
s.mu.Unlock()
}
if frame.StreamEnded() {
- log.Debugf("h2stream[%s]: frame#%d DATA frame has END_STREAM flag, stream ending", s.id, frameNum)
return
}
case *http2.HeadersFrame:
- // Decode HPACK headers for debugging
- decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {
- log.Debugf("h2stream[%s]: frame#%d header: %s = %q", s.id, frameNum, hf.Name, hf.Value)
- // Check for error status
- if hf.Name == "grpc-status" || hf.Name == ":status" && hf.Value != "200" {
- log.Warnf("h2stream[%s]: frame#%d received error status header: %s = %q", s.id, frameNum, hf.Name, hf.Value)
- }
- })
- decoder.Write(frame.HeaderBlockFragment())
- log.Debugf("h2stream[%s]: frame#%d received HEADERS frame streamID=%d, endStream=%v", s.id, frameNum, frame.StreamID, frame.StreamEnded())
if frame.StreamEnded() {
- log.Debugf("h2stream[%s]: frame#%d HEADERS frame has END_STREAM flag, stream ending", s.id, frameNum)
return
}
case *http2.RSTStreamFrame:
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
- log.Debugf("h2stream[%s]: frame#%d received RST_STREAM code=%d", s.id, frameNum, frame.ErrCode)
+ log.Debugf("h2stream[%s]: received RST_STREAM code=%d", s.id, frame.ErrCode)
return
case *http2.GoAwayFrame:
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
- log.Debugf("h2stream[%s]: received GOAWAY code=%d", s.id, frame.ErrCode)
return
case *http2.PingFrame:
- log.Debugf("h2stream[%s]: received PING frame, isAck=%v", s.id, frame.IsAck())
if !frame.IsAck() {
s.mu.Lock()
s.framer.WritePing(true, frame.Data)
@@ -252,15 +270,33 @@ func (s *H2Stream) readLoop() {
}
case *http2.SettingsFrame:
- log.Debugf("h2stream[%s]: received SETTINGS frame, isAck=%v, numSettings=%d", s.id, frame.IsAck(), frame.NumSettings())
if !frame.IsAck() {
+ // Check for window size changes
+ frame.ForeachSetting(func(setting http2.Setting) error {
+ if setting.ID == http2.SettingInitialWindowSize {
+ s.windowMu.Lock()
+ delta := int32(setting.Val) - s.sendWindow
+ s.sendWindow += delta
+ s.windowMu.Unlock()
+ s.windowCond.Broadcast()
+ }
+ return nil
+ })
s.mu.Lock()
s.framer.WriteSettingsAck()
s.mu.Unlock()
}
case *http2.WindowUpdateFrame:
- log.Debugf("h2stream[%s]: received WINDOW_UPDATE frame", s.id)
+ // Update send-side flow control window
+ s.windowMu.Lock()
+ if frame.StreamID == 0 {
+ s.connWindow += int32(frame.Increment)
+ } else if frame.StreamID == s.streamID {
+ s.sendWindow += int32(frame.Increment)
+ }
+ s.windowMu.Unlock()
+ s.windowCond.Broadcast()
}
}
}
diff --git a/internal/runtime/executor/cursor_executor.go b/internal/runtime/executor/cursor_executor.go
index 5519e92d..bba06bc7 100644
--- a/internal/runtime/executor/cursor_executor.go
+++ b/internal/runtime/executor/cursor_executor.go
@@ -6,7 +6,6 @@ import (
"crypto/sha256"
"crypto/tls"
"encoding/base64"
- "encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
@@ -47,12 +46,15 @@ type CursorExecutor struct {
}
type cursorSession struct {
- stream *cursorproto.H2Stream
- blobStore map[string][]byte
- mcpTools []cursorproto.McpToolDef
- pending []pendingMcpExec
- cancel context.CancelFunc // cancels the session-scoped heartbeat (NOT tied to HTTP request)
- createdAt time.Time
+ stream *cursorproto.H2Stream
+ blobStore map[string][]byte
+ mcpTools []cursorproto.McpToolDef
+ pending []pendingMcpExec
+ cancel context.CancelFunc // cancels the session-scoped heartbeat (NOT tied to HTTP request)
+ createdAt time.Time
+ toolResultCh chan []toolResultInfo // receives tool results from the next HTTP request
+ resumeOutCh chan cliproxyexecutor.StreamChunk // output channel for resumed response
+ switchOutput func(ch chan cliproxyexecutor.StreamChunk) // callback to switch output channel
}
type pendingMcpExec struct {
@@ -235,6 +237,8 @@ func (e *CursorExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
fullText.WriteString(text)
},
nil,
+ nil,
+ nil, // tokenUsage - non-streaming
)
id := "chatcmpl-" + uuid.New().String()[:28]
@@ -341,9 +345,30 @@ func (e *CursorExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
chatId := "chatcmpl-" + uuid.New().String()[:28]
created := time.Now().Unix()
- // sendChunk builds an OpenAI SSE line and optionally translates to target format
var streamParam any
- sendChunk := func(delta string, finishReason string) {
+
+ // Tool result channel for inline mode. processH2SessionFrames blocks on it
+ // when mcpArgs is received, while continuing to handle KV/heartbeat.
+ toolResultCh := make(chan []toolResultInfo, 1)
+
+ // Switchable output: initially writes to `chunks`. After mcpArgs, the
+ // onMcpExec callback closes `chunks` (ending the first HTTP response),
+ // then processH2SessionFrames blocks on toolResultCh. When results arrive,
+ // it switches to `resumeOutCh` (created by resumeWithToolResults).
+ var outMu sync.Mutex
+ currentOut := chunks
+
+ emitToOut := func(chunk cliproxyexecutor.StreamChunk) {
+ outMu.Lock()
+ out := currentOut
+ outMu.Unlock()
+ if out != nil {
+ out <- chunk
+ }
+ }
+
+ // Wrap sendChunk/sendDone to use emitToOut
+ sendChunkSwitchable := func(delta string, finishReason string) {
fr := "null"
if finishReason != "" {
fr = finishReason
@@ -355,95 +380,146 @@ func (e *CursorExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if needsTranslate {
translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam)
for _, t := range translated {
- chunks <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)}
+ emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)})
}
} else {
- chunks <- cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)}
+ emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)})
}
}
- sendDone := func() {
+ sendDoneSwitchable := func() {
if needsTranslate {
done := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, []byte("data: [DONE]\n"), &streamParam)
for _, d := range done {
- chunks <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(d)}
+ emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(d)})
}
} else {
- chunks <- cliproxyexecutor.StreamChunk{Payload: []byte("[DONE]")}
+ emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte("[DONE]")})
}
}
go func() {
- defer close(chunks)
-
+ var resumeOutCh chan cliproxyexecutor.StreamChunk
+ _ = resumeOutCh
thinkingActive := false
toolCallIndex := 0
- mcpExecReceived := false
+ usage := &cursorTokenUsage{}
+ usage.setInputEstimate(len(payload))
processH2SessionFrames(sessionCtx, stream, params.BlobStore, params.McpTools,
func(text string, isThinking bool) {
if isThinking {
if !thinkingActive {
thinkingActive = true
- sendChunk(`{"role":"assistant","content":""}`, "")
+ sendChunkSwitchable(`{"role":"assistant","content":""}`, "")
}
- sendChunk(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
+ sendChunkSwitchable(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
} else {
if thinkingActive {
thinkingActive = false
- sendChunk(`{"content":""}`, "")
+ sendChunkSwitchable(`{"content":""}`, "")
}
- sendChunk(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
+ sendChunkSwitchable(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
}
},
func(exec pendingMcpExec) {
- mcpExecReceived = true
if thinkingActive {
thinkingActive = false
- sendChunk(`{"content":""}`, "")
+ sendChunkSwitchable(`{"content":""}`, "")
}
toolCallJSON := fmt.Sprintf(`{"tool_calls":[{"index":%d,"id":"%s","type":"function","function":{"name":"%s","arguments":%s}}]}`,
toolCallIndex, exec.ToolCallId, exec.ToolName, jsonString(exec.Args))
toolCallIndex++
- sendChunk(toolCallJSON, "")
- sendChunk(`{}`, `"tool_calls"`)
- sendDone()
+ sendChunkSwitchable(toolCallJSON, "")
+ sendChunkSwitchable(`{}`, `"tool_calls"`)
+ sendDoneSwitchable()
- // Save session for resume — keep stream alive.
- // The heartbeat goroutine continues running (session-scoped context),
- // keeping the H2 connection alive while the MCP tool executes.
- log.Debugf("cursor: saving session %s for MCP tool resume (tool=%s, streamID=%s)", sessionKey, exec.ToolName, stream.ID())
+ // Close current output to end the current HTTP SSE response
+ outMu.Lock()
+ if currentOut != nil {
+ close(currentOut)
+ currentOut = nil
+ }
+ outMu.Unlock()
+
+ // Create new resume output channel, reuse the same toolResultCh
+ resumeOut := make(chan cliproxyexecutor.StreamChunk, 64)
+ log.Debugf("cursor: saving session %s for MCP tool resume (tool=%s)", sessionKey, exec.ToolName)
e.mu.Lock()
e.sessions[sessionKey] = &cursorSession{
- stream: stream,
- blobStore: params.BlobStore,
- mcpTools: params.McpTools,
- pending: []pendingMcpExec{exec},
- cancel: sessionCancel,
- createdAt: time.Now(),
+ stream: stream,
+ blobStore: params.BlobStore,
+ mcpTools: params.McpTools,
+ pending: []pendingMcpExec{exec},
+ cancel: sessionCancel,
+ createdAt: time.Now(),
+ toolResultCh: toolResultCh, // reuse same channel across rounds
+ resumeOutCh: resumeOut,
+ switchOutput: func(ch chan cliproxyexecutor.StreamChunk) {
+ outMu.Lock()
+ currentOut = ch
+ // Reset translator state so the new HTTP response gets
+ // a fresh message_start, content_block_start, etc.
+ streamParam = nil
+ // New response needs its own message ID
+ chatId = "chatcmpl-" + uuid.New().String()[:28]
+ created = time.Now().Unix()
+ outMu.Unlock()
+ },
}
e.mu.Unlock()
+ resumeOutCh = resumeOut
+
+ // processH2SessionFrames will now block on toolResultCh (inline wait loop)
+ // while continuing to handle KV messages
},
+ toolResultCh,
+ usage,
)
- if !mcpExecReceived {
- if thinkingActive {
- sendChunk(`{"content":""}`, "")
- }
- sendChunk(`{}`, `"stop"`)
- sendDone()
- sessionCancel()
- stream.Close()
+ // processH2SessionFrames returned — stream is done
+ if thinkingActive {
+ sendChunkSwitchable(`{"content":""}`, "")
}
- // If mcpExecReceived, do NOT close stream or cancel — session resume will handle it
+ // Include token usage in the final stop chunk
+ inputTok, outputTok := usage.get()
+ stopDelta := fmt.Sprintf(`{},"usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}`,
+ inputTok, outputTok, inputTok+outputTok)
+ // Build the stop chunk with usage embedded in the choices array level
+ fr := `"stop"`
+ openaiJSON := fmt.Sprintf(`{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"%s","choices":[{"index":0,"delta":{},"finish_reason":%s}],"usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}}`,
+ chatId, created, parsed.Model, fr, inputTok, outputTok, inputTok+outputTok)
+ sseLine := []byte("data: " + openaiJSON + "\n")
+ if needsTranslate {
+ translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam)
+ for _, t := range translated {
+ emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)})
+ }
+ } else {
+ emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)})
+ }
+ sendDoneSwitchable()
+ _ = stopDelta // unused
+
+ // Close whatever output channel is still active
+ outMu.Lock()
+ if currentOut != nil {
+ close(currentOut)
+ currentOut = nil
+ }
+ outMu.Unlock()
+ sessionCancel()
+ stream.Close()
}()
return &cliproxyexecutor.StreamResult{Chunks: chunks}, nil
}
-// resumeWithToolResults sends MCP tool results back on the existing H2 stream,
-// then continues reading the stream for the model's response.
-// Mirrors resumeWithToolResults() in cursor-fetch.ts.
+// resumeWithToolResults injects tool results into the running processH2SessionFrames
+// via the toolResultCh channel. The original goroutine from ExecuteStream is still alive,
+// blocking on toolResultCh. Once we send the results, it sends the MCP result to Cursor
+// and continues processing the response text — all in the same goroutine that has been
+// handling KV messages the whole time.
func (e *CursorExecutor) resumeWithToolResults(
ctx context.Context,
session *cursorSession,
@@ -453,208 +529,29 @@ func (e *CursorExecutor) resumeWithToolResults(
originalPayload, payload []byte,
needsTranslate bool,
) (*cliproxyexecutor.StreamResult, error) {
- stream := session.stream
- log.Debugf("cursor: resumeWithToolResults: using stream ID=%s", stream.ID())
+ log.Debugf("cursor: resumeWithToolResults: injecting %d tool results via channel", len(parsed.ToolResults))
- // Cancel old session-scoped heartbeat before starting a new one
- session.cancel()
-
- // CRITICAL: Process any pending messages from the channel before sending MCP result.
- // After the initial processH2SessionFrames returned (upon receiving MCP args),
- // the server may have sent more data (KV messages, text deltas) that are now buffered in dataCh.
- // We must process KV messages (respond to them) but discard text deltas (stale responses).
- drainedCount := 0
- drainedBytes := 0
- kvProcessedCount := 0
- for {
- select {
- case staleData, ok := <-stream.Data():
- if !ok {
- log.Debugf("cursor: resumeWithToolResults: dataCh closed during drain")
- break
- }
- drainedCount++
- drainedBytes += len(staleData)
- log.Debugf("cursor: resumeWithToolResults: processing stale data #%d: len=%d, first bytes: %x (%q)", drainedCount, len(staleData), staleData[:min(20, len(staleData))], string(staleData[:min(20, len(staleData))]))
-
- // Try to decode and handle KV messages (they need responses)
- if len(staleData) > 5 {
- frameLen := binary.BigEndian.Uint32(staleData[1:5])
- if int(frameLen)+5 <= len(staleData) {
- payload := staleData[5 : 5+frameLen]
- msg, err := cursorproto.DecodeAgentServerMessage(payload)
- if err == nil && msg.Type == cursorproto.ServerMsgKvGetBlob {
- // Respond to KV getBlob
- blobKey := cursorproto.BlobIdHex(msg.BlobId)
- data := session.blobStore[blobKey]
- log.Debugf("cursor: resumeWithToolResults: responding to stale KV getBlob kvId=%d blobKey=%s found=%v", msg.KvId, blobKey, len(data) > 0)
- resp := cursorproto.EncodeKvGetBlobResult(msg.KvId, data)
- stream.Write(cursorproto.FrameConnectMessage(resp, 0))
- kvProcessedCount++
- continue
- } else if err == nil && msg.Type == cursorproto.ServerMsgKvSetBlob {
- // Respond to KV setBlob
- blobKey := cursorproto.BlobIdHex(msg.BlobId)
- session.blobStore[blobKey] = append([]byte(nil), msg.BlobData...)
- log.Debugf("cursor: resumeWithToolResults: responding to stale KV setBlob kvId=%d blobKey=%s", msg.KvId, blobKey)
- resp := cursorproto.EncodeKvSetBlobResult(msg.KvId)
- stream.Write(cursorproto.FrameConnectMessage(resp, 0))
- kvProcessedCount++
- continue
- }
- }
- }
- log.Debugf("cursor: resumeWithToolResults: discarding non-KV stale data")
- default:
- // No more data in channel
- goto drainDone
- }
+ if session.toolResultCh == nil {
+ return nil, fmt.Errorf("cursor: session has no toolResultCh (stale session?)")
}
-drainDone:
- if drainedCount > 0 {
- log.Debugf("cursor: resumeWithToolResults: processed %d stale frames (%d bytes total, %d KV responded)", drainedCount, drainedBytes, kvProcessedCount)
+ if session.resumeOutCh == nil {
+ return nil, fmt.Errorf("cursor: session has no resumeOutCh")
}
- // Send MCP results back on the same H2 stream
- for _, exec := range session.pending {
- var content string
- var isError bool
- found := false
- for _, tr := range parsed.ToolResults {
- if tr.ToolCallId == exec.ToolCallId {
- content = tr.Content
- found = true
- break
- }
- }
- if !found {
- content = "Tool result not provided"
- isError = true
- }
- log.Debugf("cursor: sending MCP result for tool=%s callId=%s execMsgId=%d execId=%s contentLen=%d isError=%v",
- exec.ToolName, exec.ToolCallId, exec.ExecMsgId, exec.ExecId, len(content), isError)
- resultBytes := cursorproto.EncodeExecMcpResult(exec.ExecMsgId, exec.ExecId, content, isError)
- framedResult := cursorproto.FrameConnectMessage(resultBytes, 0)
- // Log the framed result details for debugging
- log.Debugf("cursor: MCP result frame size=%d bytes", len(framedResult))
- log.Debugf("cursor: MCP result frame header: flags=%d, len=%d", framedResult[0], binary.BigEndian.Uint32(framedResult[1:5]))
- log.Debugf("cursor: MCP result protobuf hex (first 50 bytes): %x", resultBytes[:min(50, len(resultBytes))])
- if err := stream.Write(framedResult); err != nil {
- stream.Close()
- return nil, fmt.Errorf("cursor: failed to send MCP result: %w", err)
- }
- log.Debugf("cursor: MCP result sent successfully for tool=%s", exec.ToolName)
+ log.Debugf("cursor: resumeWithToolResults: switching output to resumeOutCh and injecting results")
+
+ // Switch the output channel BEFORE injecting results, so that when
+ // processH2SessionFrames unblocks and starts emitting text, it writes
+ // to the resumeOutCh which the new HTTP handler is reading from.
+ if session.switchOutput != nil {
+ session.switchOutput(session.resumeOutCh)
}
- // Start new session-scoped heartbeat (independent of HTTP request context)
- sessionCtx, sessionCancel := context.WithCancel(context.Background())
- go cursorH2Heartbeat(sessionCtx, stream)
- log.Debugf("cursor: started new heartbeat for resumed session, waiting for Cursor response...")
+ // Inject tool results — this unblocks the waiting processH2SessionFrames
+ session.toolResultCh <- parsed.ToolResults
- chunks := make(chan cliproxyexecutor.StreamChunk, 64)
- chatId := "chatcmpl-" + uuid.New().String()[:28]
- created := time.Now().Unix()
- sessionKey := deriveSessionKey(parsed.Model, parsed.Messages)
-
- var streamParam any
- sendChunk := func(delta string, finishReason string) {
- fr := "null"
- if finishReason != "" {
- fr = finishReason
- }
- openaiJSON := fmt.Sprintf(`{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"%s","choices":[{"index":0,"delta":%s,"finish_reason":%s}]}`,
- chatId, created, parsed.Model, delta, fr)
- sseLine := []byte("data: " + openaiJSON + "\n")
-
- if needsTranslate {
- translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam)
- for _, t := range translated {
- chunks <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)}
- }
- } else {
- chunks <- cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)}
- }
- }
-
- sendDone := func() {
- if needsTranslate {
- done := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, []byte("data: [DONE]\n"), &streamParam)
- for _, d := range done {
- chunks <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(d)}
- }
- } else {
- chunks <- cliproxyexecutor.StreamChunk{Payload: []byte("[DONE]")}
- }
- }
-
- go func() {
- defer func() {
- log.Debugf("cursor: resume goroutine exiting, closing chunks channel")
- close(chunks)
- }()
- log.Debugf("cursor: resume goroutine started, entering processH2SessionFrames")
-
- thinkingActive := false
- toolCallIndex := 0
- mcpExecReceived := false
-
- processH2SessionFrames(sessionCtx, stream, session.blobStore, session.mcpTools,
- func(text string, isThinking bool) {
- log.Debugf("cursor: resume received text (isThinking=%v, len=%d)", isThinking, len(text))
- if isThinking {
- if !thinkingActive {
- thinkingActive = true
- sendChunk(`{"role":"assistant","content":""}`, "")
- }
- sendChunk(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
- } else {
- if thinkingActive {
- thinkingActive = false
- sendChunk(`{"content":""}`, "")
- }
- sendChunk(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
- }
- },
- func(exec pendingMcpExec) {
- mcpExecReceived = true
- if thinkingActive {
- thinkingActive = false
- sendChunk(`{"content":""}`, "")
- }
- toolCallJSON := fmt.Sprintf(`{"tool_calls":[{"index":%d,"id":"%s","type":"function","function":{"name":"%s","arguments":%s}}]}`,
- toolCallIndex, exec.ToolCallId, exec.ToolName, jsonString(exec.Args))
- toolCallIndex++
- sendChunk(toolCallJSON, "")
- sendChunk(`{}`, `"tool_calls"`)
- sendDone()
-
- // Save session again for another round of tool calls
- log.Debugf("cursor: saving session %s for another MCP tool resume (tool=%s, streamID=%s)", sessionKey, exec.ToolName, stream.ID())
- e.mu.Lock()
- e.sessions[sessionKey] = &cursorSession{
- stream: stream,
- blobStore: session.blobStore,
- mcpTools: session.mcpTools,
- pending: []pendingMcpExec{exec},
- cancel: sessionCancel,
- createdAt: time.Now(),
- }
- e.mu.Unlock()
- },
- )
-
- if !mcpExecReceived {
- if thinkingActive {
- sendChunk(`{"content":""}`, "")
- }
- sendChunk(`{}`, `"stop"`)
- sendDone()
- sessionCancel()
- stream.Close()
- }
- }()
-
- return &cliproxyexecutor.StreamResult{Chunks: chunks}, nil
+ // Return the resumeOutCh for the new HTTP handler to read from
+ return &cliproxyexecutor.StreamResult{Chunks: session.resumeOutCh}, nil
}
// --- H2Stream helpers ---
@@ -693,6 +590,35 @@ func cursorH2Heartbeat(ctx context.Context, stream *cursorproto.H2Stream) {
// --- Response processing ---
+// cursorTokenUsage tracks token counts from Cursor's TokenDeltaUpdate messages.
+type cursorTokenUsage struct {
+ mu sync.Mutex
+ outputTokens int64
+ inputTokensEst int64 // estimated from request payload size
+}
+
+func (u *cursorTokenUsage) addOutput(delta int64) {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ u.outputTokens += delta
+}
+
+func (u *cursorTokenUsage) setInputEstimate(payloadBytes int) {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ // Rough estimate: ~4 bytes per token for mixed content
+ u.inputTokensEst = int64(payloadBytes / 4)
+ if u.inputTokensEst < 1 {
+ u.inputTokensEst = 1
+ }
+}
+
+func (u *cursorTokenUsage) get() (input, output int64) {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ return u.inputTokensEst, u.outputTokens
+}
+
func processH2SessionFrames(
ctx context.Context,
stream *cursorproto.H2Stream,
@@ -700,6 +626,8 @@ func processH2SessionFrames(
mcpTools []cursorproto.McpToolDef,
onText func(text string, isThinking bool),
onMcpExec func(exec pendingMcpExec),
+ toolResultCh <-chan []toolResultInfo, // nil for no tool result injection; non-nil to wait for results
+ tokenUsage *cursorTokenUsage, // tracks accumulated token usage (may be nil)
) {
var buf bytes.Buffer
rejectReason := "Tool not available in this environment. Use the MCP tools provided instead."
@@ -762,6 +690,20 @@ func processH2SessionFrames(
case cursorproto.ServerMsgThinkingCompleted:
// Handled by caller
+ case cursorproto.ServerMsgTurnEnded:
+ log.Debugf("cursor: TurnEnded received, stream will finish")
+ return
+
+ case cursorproto.ServerMsgHeartbeat:
+ // Server heartbeat, ignore silently
+ continue
+
+ case cursorproto.ServerMsgTokenDelta:
+ if tokenUsage != nil && msg.TokenDelta > 0 {
+ tokenUsage.addOutput(msg.TokenDelta)
+ }
+ continue
+
case cursorproto.ServerMsgKvGetBlob:
blobKey := cursorproto.BlobIdHex(msg.BlobId)
data := blobStore[blobKey]
@@ -785,17 +727,85 @@ func processH2SessionFrames(
if toolCallId == "" {
toolCallId = uuid.New().String()
}
- // Debug: log the received execId from server
log.Debugf("cursor: received mcpArgs from server: execMsgId=%d execId=%q toolName=%s toolCallId=%s",
msg.ExecMsgId, msg.ExecId, msg.McpToolName, toolCallId)
- onMcpExec(pendingMcpExec{
+ pending := pendingMcpExec{
ExecMsgId: msg.ExecMsgId,
ExecId: msg.ExecId,
ToolCallId: toolCallId,
ToolName: msg.McpToolName,
Args: decodedArgs,
- })
- return
+ }
+ onMcpExec(pending)
+
+ if toolResultCh == nil {
+ return
+ }
+
+ // Inline mode: wait for tool result while handling KV/heartbeat
+ log.Debugf("cursor: waiting for tool result on channel (inline mode)...")
+ var toolResults []toolResultInfo
+ waitLoop:
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case results, ok := <-toolResultCh:
+ if !ok {
+ return
+ }
+ toolResults = results
+ break waitLoop
+ case waitData, ok := <-stream.Data():
+ if !ok {
+ return
+ }
+ buf.Write(waitData)
+ for {
+ cb := buf.Bytes()
+ if len(cb) == 0 {
+ break
+ }
+ wf, wp, wc, wok := cursorproto.ParseConnectFrame(cb)
+ if !wok {
+ break
+ }
+ buf.Next(wc)
+ if wf&cursorproto.ConnectEndStreamFlag != 0 {
+ continue
+ }
+ wmsg, werr := cursorproto.DecodeAgentServerMessage(wp)
+ if werr != nil {
+ continue
+ }
+ switch wmsg.Type {
+ case cursorproto.ServerMsgKvGetBlob:
+ blobKey := cursorproto.BlobIdHex(wmsg.BlobId)
+ d := blobStore[blobKey]
+ stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeKvGetBlobResult(wmsg.KvId, d), 0))
+ case cursorproto.ServerMsgKvSetBlob:
+ blobKey := cursorproto.BlobIdHex(wmsg.BlobId)
+ blobStore[blobKey] = append([]byte(nil), wmsg.BlobData...)
+ stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeKvSetBlobResult(wmsg.KvId), 0))
+ case cursorproto.ServerMsgExecRequestCtx:
+ stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecRequestContextResult(wmsg.ExecMsgId, wmsg.ExecId, mcpTools), 0))
+ }
+ }
+ case <-stream.Done():
+ return
+ }
+ }
+
+ // Send MCP result
+ for _, tr := range toolResults {
+ if tr.ToolCallId == pending.ToolCallId {
+ log.Debugf("cursor: sending inline MCP result for tool=%s", pending.ToolName)
+ resultBytes := cursorproto.EncodeExecMcpResult(pending.ExecMsgId, pending.ExecId, tr.Content, false)
+ stream.Write(cursorproto.FrameConnectMessage(resultBytes, 0))
+ break
+ }
+ }
+ continue
}
case cursorproto.ServerMsgExecReadArgs: