mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-23 23:33:49 +00:00
fix(cursor): MCP tool call resume, H2 flow control, and token usage
- Rewrite tool call mechanism from interrupt-resume to inline-wait mode: processH2SessionFrames no longer exits on mcpArgs; instead blocks on toolResultCh while continuing to handle KV/heartbeat messages, then sends MCP result and continues processing text in the same goroutine. Fixes the issue where server stopped generating text after resume. - Add switchable output channel (outMu/currentOut) so first HTTP response closes after tool_calls+[DONE], and resumed text goes to a new channel returned by resumeWithToolResults. Reset streamParam on switch so Translator produces fresh message_start/content_block_start events. - Implement send-side H2 flow control: track server's initial window size and WINDOW_UPDATE increments; Write() blocks when window exhausted. Fixes RST_STREAM FLOW_CONTROL_ERROR on large requests (178KB+). - Decode new InteractionUpdate fields: TurnEndedUpdate (field 14) as stream termination signal, HeartbeatUpdate (field 13) silently ignored, TokenDeltaUpdate (field 8) for token usage tracking. - Include token usage in final stop chunk (prompt_tokens estimated from payload size, completion_tokens from accumulated TokenDeltaUpdate deltas) so Claude CLI status bar shows non-zero token counts. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -32,6 +32,9 @@ const (
|
|||||||
ServerMsgExecBgShellSpawn // Rejected: background shell
|
ServerMsgExecBgShellSpawn // Rejected: background shell
|
||||||
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
|
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
|
||||||
ServerMsgExecOther // Other exec types (respond with empty)
|
ServerMsgExecOther // Other exec types (respond with empty)
|
||||||
|
ServerMsgTurnEnded // Turn has ended (no more output)
|
||||||
|
ServerMsgHeartbeat // Server heartbeat
|
||||||
|
ServerMsgTokenDelta // Token usage delta
|
||||||
)
|
)
|
||||||
|
|
||||||
// DecodedServerMessage holds parsed data from an AgentServerMessage.
|
// DecodedServerMessage holds parsed data from an AgentServerMessage.
|
||||||
@@ -63,6 +66,9 @@ type DecodedServerMessage struct {
|
|||||||
|
|
||||||
// For other exec - the raw field number for building a response
|
// For other exec - the raw field number for building a response
|
||||||
ExecFieldNumber int
|
ExecFieldNumber int
|
||||||
|
|
||||||
|
// For TokenDeltaUpdate
|
||||||
|
TokenDelta int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeAgentServerMessage parses an AgentServerMessage and returns
|
// DecodeAgentServerMessage parses an AgentServerMessage and returns
|
||||||
@@ -160,6 +166,24 @@ func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) {
|
|||||||
case 3:
|
case 3:
|
||||||
// tool_call_completed - ignore but log
|
// tool_call_completed - ignore but log
|
||||||
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
|
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
|
||||||
|
case 8:
|
||||||
|
// token_delta - extract token count
|
||||||
|
msg.Type = ServerMsgTokenDelta
|
||||||
|
msg.TokenDelta = decodeVarintField(val, 1)
|
||||||
|
log.Debugf("decodeInteractionUpdate: TokenDeltaUpdate tokens=%d", msg.TokenDelta)
|
||||||
|
case 13:
|
||||||
|
// heartbeat from server
|
||||||
|
msg.Type = ServerMsgHeartbeat
|
||||||
|
case 14:
|
||||||
|
// turn_ended - critical: model finished generating
|
||||||
|
msg.Type = ServerMsgTurnEnded
|
||||||
|
log.Debugf("decodeInteractionUpdate: TurnEndedUpdate - stream should end")
|
||||||
|
case 16:
|
||||||
|
// step_started - ignore
|
||||||
|
log.Debugf("decodeInteractionUpdate: StepStartedUpdate (ignored)")
|
||||||
|
case 17:
|
||||||
|
// step_completed - ignore
|
||||||
|
log.Debugf("decodeInteractionUpdate: StepCompletedUpdate (ignored)")
|
||||||
default:
|
default:
|
||||||
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
|
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
|
||||||
}
|
}
|
||||||
@@ -500,6 +524,34 @@ func decodeBytesField(data []byte, targetField protowire.Number) []byte {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// decodeVarintField extracts an int64 from the first matching varint field in a submessage.
|
||||||
|
func decodeVarintField(data []byte, targetField protowire.Number) int64 {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if typ == protowire.VarintType {
|
||||||
|
val, n := protowire.ConsumeVarint(data)
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == targetField {
|
||||||
|
return int64(val)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
// BlobIdHex returns the hex string of a blob ID for use as a map key.
|
// BlobIdHex returns the hex string of a blob ID for use as a map key.
|
||||||
func BlobIdHex(blobId []byte) string {
|
func BlobIdHex(blobId []byte) string {
|
||||||
return hex.EncodeToString(blobId)
|
return hex.EncodeToString(blobId)
|
||||||
|
|||||||
@@ -13,6 +13,11 @@ import (
|
|||||||
"golang.org/x/net/http2/hpack"
|
"golang.org/x/net/http2/hpack"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultInitialWindowSize = 65535 // HTTP/2 default
|
||||||
|
maxFramePayload = 16384 // HTTP/2 default max frame size
|
||||||
|
)
|
||||||
|
|
||||||
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
|
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
|
||||||
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
|
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
|
||||||
type H2Stream struct {
|
type H2Stream struct {
|
||||||
@@ -21,11 +26,17 @@ type H2Stream struct {
|
|||||||
streamID uint32
|
streamID uint32
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
id string // unique identifier for debugging
|
id string // unique identifier for debugging
|
||||||
frameNum int64 // sequential frame counter for debugging
|
frameNum int64 // sequential frame counter for debugging
|
||||||
|
|
||||||
dataCh chan []byte
|
dataCh chan []byte
|
||||||
doneCh chan struct{}
|
doneCh chan struct{}
|
||||||
err error
|
err error
|
||||||
|
|
||||||
|
// Send-side flow control
|
||||||
|
sendWindow int32 // available bytes we can send on this stream
|
||||||
|
connWindow int32 // available bytes on the connection level
|
||||||
|
windowCond *sync.Cond // signaled when window is updated
|
||||||
|
windowMu sync.Mutex // protects sendWindow, connWindow
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the unique identifier for this stream (for logging).
|
// ID returns the unique identifier for this stream (for logging).
|
||||||
@@ -59,7 +70,7 @@ func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
|
|||||||
return nil, fmt.Errorf("h2: preface write failed: %w", err)
|
return nil, fmt.Errorf("h2: preface write failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send initial SETTINGS (with large initial window)
|
// Send initial SETTINGS (tell server how much WE can receive)
|
||||||
if err := framer.WriteSettings(
|
if err := framer.WriteSettings(
|
||||||
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
|
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
|
||||||
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
|
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
|
||||||
@@ -68,14 +79,17 @@ func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
|
|||||||
return nil, fmt.Errorf("h2: settings write failed: %w", err)
|
return nil, fmt.Errorf("h2: settings write failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connection-level window update (default is 65535, bump it up)
|
// Connection-level window update (for receiving)
|
||||||
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
|
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
|
||||||
tlsConn.Close()
|
tlsConn.Close()
|
||||||
return nil, fmt.Errorf("h2: window update failed: %w", err)
|
return nil, fmt.Errorf("h2: window update failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
|
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
|
||||||
for i := 0; i < 5; i++ {
|
// Track server's initial window size (how much WE can send)
|
||||||
|
serverInitialWindowSize := int32(defaultInitialWindowSize)
|
||||||
|
connWindowSize := int32(defaultInitialWindowSize) // connection-level send window
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
f, err := framer.ReadFrame()
|
f, err := framer.ReadFrame()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlsConn.Close()
|
tlsConn.Close()
|
||||||
@@ -84,12 +98,22 @@ func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
|
|||||||
switch sf := f.(type) {
|
switch sf := f.(type) {
|
||||||
case *http2.SettingsFrame:
|
case *http2.SettingsFrame:
|
||||||
if !sf.IsAck() {
|
if !sf.IsAck() {
|
||||||
|
sf.ForeachSetting(func(s http2.Setting) error {
|
||||||
|
if s.ID == http2.SettingInitialWindowSize {
|
||||||
|
serverInitialWindowSize = int32(s.Val)
|
||||||
|
log.Debugf("h2: server initial window size: %d", s.Val)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
framer.WriteSettingsAck()
|
framer.WriteSettingsAck()
|
||||||
} else {
|
} else {
|
||||||
goto handshakeDone
|
goto handshakeDone
|
||||||
}
|
}
|
||||||
case *http2.WindowUpdateFrame:
|
case *http2.WindowUpdateFrame:
|
||||||
// ignore
|
if sf.StreamID == 0 {
|
||||||
|
connWindowSize += int32(sf.Increment)
|
||||||
|
log.Debugf("h2: initial conn window update: +%d, total=%d", sf.Increment, connWindowSize)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
// unexpected but continue
|
// unexpected but continue
|
||||||
}
|
}
|
||||||
@@ -124,36 +148,53 @@ handshakeDone:
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := &H2Stream{
|
s := &H2Stream{
|
||||||
framer: framer,
|
framer: framer,
|
||||||
conn: tlsConn,
|
conn: tlsConn,
|
||||||
streamID: streamID,
|
streamID: streamID,
|
||||||
dataCh: make(chan []byte, 256),
|
dataCh: make(chan []byte, 256),
|
||||||
doneCh: make(chan struct{}),
|
doneCh: make(chan struct{}),
|
||||||
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
|
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
|
||||||
frameNum: 0,
|
frameNum: 0,
|
||||||
|
sendWindow: serverInitialWindowSize,
|
||||||
|
connWindow: connWindowSize,
|
||||||
}
|
}
|
||||||
|
s.windowCond = sync.NewCond(&s.windowMu)
|
||||||
go s.readLoop()
|
go s.readLoop()
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write sends a DATA frame on the stream.
|
// Write sends a DATA frame on the stream, respecting flow control.
|
||||||
func (s *H2Stream) Write(data []byte) error {
|
func (s *H2Stream) Write(data []byte) error {
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
const maxFrame = 16384
|
|
||||||
for len(data) > 0 {
|
for len(data) > 0 {
|
||||||
chunk := data
|
chunk := data
|
||||||
if len(chunk) > maxFrame {
|
if len(chunk) > maxFramePayload {
|
||||||
chunk = data[:maxFrame]
|
chunk = data[:maxFramePayload]
|
||||||
}
|
}
|
||||||
data = data[len(chunk):]
|
|
||||||
if err := s.framer.WriteData(s.streamID, false, chunk); err != nil {
|
// Wait for flow control window
|
||||||
|
s.windowMu.Lock()
|
||||||
|
for s.sendWindow <= 0 || s.connWindow <= 0 {
|
||||||
|
s.windowCond.Wait()
|
||||||
|
}
|
||||||
|
// Limit chunk to available window
|
||||||
|
allowed := int(s.sendWindow)
|
||||||
|
if int(s.connWindow) < allowed {
|
||||||
|
allowed = int(s.connWindow)
|
||||||
|
}
|
||||||
|
if len(chunk) > allowed {
|
||||||
|
chunk = chunk[:allowed]
|
||||||
|
}
|
||||||
|
s.sendWindow -= int32(len(chunk))
|
||||||
|
s.connWindow -= int32(len(chunk))
|
||||||
|
s.windowMu.Unlock()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
err := s.framer.WriteData(s.streamID, false, chunk)
|
||||||
|
s.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
data = data[len(chunk):]
|
||||||
// Try to flush the underlying connection if it supports it
|
|
||||||
if flusher, ok := s.conn.(interface{ Flush() error }); ok {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -167,12 +208,13 @@ func (s *H2Stream) Done() <-chan struct{} { return s.doneCh }
|
|||||||
// Close tears down the connection.
|
// Close tears down the connection.
|
||||||
func (s *H2Stream) Close() {
|
func (s *H2Stream) Close() {
|
||||||
s.conn.Close()
|
s.conn.Close()
|
||||||
|
// Unblock any writers waiting on flow control
|
||||||
|
s.windowCond.Broadcast()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *H2Stream) readLoop() {
|
func (s *H2Stream) readLoop() {
|
||||||
defer close(s.doneCh)
|
defer close(s.doneCh)
|
||||||
defer close(s.dataCh)
|
defer close(s.dataCh)
|
||||||
log.Debugf("h2stream[%s]: readLoop started for streamID=%d", s.id, s.streamID)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
f, err := s.framer.ReadFrame()
|
f, err := s.framer.ReadFrame()
|
||||||
@@ -180,71 +222,47 @@ func (s *H2Stream) readLoop() {
|
|||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
s.err = err
|
s.err = err
|
||||||
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
|
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
|
||||||
} else {
|
|
||||||
log.Debugf("h2stream[%s]: readLoop EOF", s.id)
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment frame counter for debugging
|
// Increment frame counter
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.frameNum++
|
s.frameNum++
|
||||||
frameNum := s.frameNum
|
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
switch frame := f.(type) {
|
switch frame := f.(type) {
|
||||||
case *http2.DataFrame:
|
case *http2.DataFrame:
|
||||||
log.Debugf("h2stream[%s]: frame#%d received DATA frame streamID=%d, len=%d, endStream=%v", s.id, frameNum, frame.StreamID, len(frame.Data()), frame.StreamEnded())
|
|
||||||
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
|
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
|
||||||
cp := make([]byte, len(frame.Data()))
|
cp := make([]byte, len(frame.Data()))
|
||||||
copy(cp, frame.Data())
|
copy(cp, frame.Data())
|
||||||
// Log first 20 bytes for debugging
|
|
||||||
previewLen := len(cp)
|
|
||||||
if previewLen > 20 {
|
|
||||||
previewLen = 20
|
|
||||||
}
|
|
||||||
log.Debugf("h2stream[%s]: frame#%d sending to dataCh: len=%d, dataCh len=%d/%d, first bytes: %x (%q)", s.id, frameNum, len(cp), len(s.dataCh), cap(s.dataCh), cp[:previewLen], string(cp[:previewLen]))
|
|
||||||
s.dataCh <- cp
|
s.dataCh <- cp
|
||||||
|
|
||||||
// Flow control: send WINDOW_UPDATE
|
// Flow control: send WINDOW_UPDATE for received data
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
|
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
|
||||||
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
|
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
if frame.StreamEnded() {
|
if frame.StreamEnded() {
|
||||||
log.Debugf("h2stream[%s]: frame#%d DATA frame has END_STREAM flag, stream ending", s.id, frameNum)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
case *http2.HeadersFrame:
|
case *http2.HeadersFrame:
|
||||||
// Decode HPACK headers for debugging
|
|
||||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {
|
|
||||||
log.Debugf("h2stream[%s]: frame#%d header: %s = %q", s.id, frameNum, hf.Name, hf.Value)
|
|
||||||
// Check for error status
|
|
||||||
if hf.Name == "grpc-status" || hf.Name == ":status" && hf.Value != "200" {
|
|
||||||
log.Warnf("h2stream[%s]: frame#%d received error status header: %s = %q", s.id, frameNum, hf.Name, hf.Value)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
decoder.Write(frame.HeaderBlockFragment())
|
|
||||||
log.Debugf("h2stream[%s]: frame#%d received HEADERS frame streamID=%d, endStream=%v", s.id, frameNum, frame.StreamID, frame.StreamEnded())
|
|
||||||
if frame.StreamEnded() {
|
if frame.StreamEnded() {
|
||||||
log.Debugf("h2stream[%s]: frame#%d HEADERS frame has END_STREAM flag, stream ending", s.id, frameNum)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
case *http2.RSTStreamFrame:
|
case *http2.RSTStreamFrame:
|
||||||
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
|
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
|
||||||
log.Debugf("h2stream[%s]: frame#%d received RST_STREAM code=%d", s.id, frameNum, frame.ErrCode)
|
log.Debugf("h2stream[%s]: received RST_STREAM code=%d", s.id, frame.ErrCode)
|
||||||
return
|
return
|
||||||
|
|
||||||
case *http2.GoAwayFrame:
|
case *http2.GoAwayFrame:
|
||||||
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
|
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
|
||||||
log.Debugf("h2stream[%s]: received GOAWAY code=%d", s.id, frame.ErrCode)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
case *http2.PingFrame:
|
case *http2.PingFrame:
|
||||||
log.Debugf("h2stream[%s]: received PING frame, isAck=%v", s.id, frame.IsAck())
|
|
||||||
if !frame.IsAck() {
|
if !frame.IsAck() {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.framer.WritePing(true, frame.Data)
|
s.framer.WritePing(true, frame.Data)
|
||||||
@@ -252,15 +270,33 @@ func (s *H2Stream) readLoop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
case *http2.SettingsFrame:
|
case *http2.SettingsFrame:
|
||||||
log.Debugf("h2stream[%s]: received SETTINGS frame, isAck=%v, numSettings=%d", s.id, frame.IsAck(), frame.NumSettings())
|
|
||||||
if !frame.IsAck() {
|
if !frame.IsAck() {
|
||||||
|
// Check for window size changes
|
||||||
|
frame.ForeachSetting(func(setting http2.Setting) error {
|
||||||
|
if setting.ID == http2.SettingInitialWindowSize {
|
||||||
|
s.windowMu.Lock()
|
||||||
|
delta := int32(setting.Val) - s.sendWindow
|
||||||
|
s.sendWindow += delta
|
||||||
|
s.windowMu.Unlock()
|
||||||
|
s.windowCond.Broadcast()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.framer.WriteSettingsAck()
|
s.framer.WriteSettingsAck()
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
case *http2.WindowUpdateFrame:
|
case *http2.WindowUpdateFrame:
|
||||||
log.Debugf("h2stream[%s]: received WINDOW_UPDATE frame", s.id)
|
// Update send-side flow control window
|
||||||
|
s.windowMu.Lock()
|
||||||
|
if frame.StreamID == 0 {
|
||||||
|
s.connWindow += int32(frame.Increment)
|
||||||
|
} else if frame.StreamID == s.streamID {
|
||||||
|
s.sendWindow += int32(frame.Increment)
|
||||||
|
}
|
||||||
|
s.windowMu.Unlock()
|
||||||
|
s.windowCond.Broadcast()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/binary"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -47,12 +46,15 @@ type CursorExecutor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type cursorSession struct {
|
type cursorSession struct {
|
||||||
stream *cursorproto.H2Stream
|
stream *cursorproto.H2Stream
|
||||||
blobStore map[string][]byte
|
blobStore map[string][]byte
|
||||||
mcpTools []cursorproto.McpToolDef
|
mcpTools []cursorproto.McpToolDef
|
||||||
pending []pendingMcpExec
|
pending []pendingMcpExec
|
||||||
cancel context.CancelFunc // cancels the session-scoped heartbeat (NOT tied to HTTP request)
|
cancel context.CancelFunc // cancels the session-scoped heartbeat (NOT tied to HTTP request)
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
|
toolResultCh chan []toolResultInfo // receives tool results from the next HTTP request
|
||||||
|
resumeOutCh chan cliproxyexecutor.StreamChunk // output channel for resumed response
|
||||||
|
switchOutput func(ch chan cliproxyexecutor.StreamChunk) // callback to switch output channel
|
||||||
}
|
}
|
||||||
|
|
||||||
type pendingMcpExec struct {
|
type pendingMcpExec struct {
|
||||||
@@ -235,6 +237,8 @@ func (e *CursorExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
fullText.WriteString(text)
|
fullText.WriteString(text)
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
|
nil, // tokenUsage - non-streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
id := "chatcmpl-" + uuid.New().String()[:28]
|
id := "chatcmpl-" + uuid.New().String()[:28]
|
||||||
@@ -341,9 +345,30 @@ func (e *CursorExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
chatId := "chatcmpl-" + uuid.New().String()[:28]
|
chatId := "chatcmpl-" + uuid.New().String()[:28]
|
||||||
created := time.Now().Unix()
|
created := time.Now().Unix()
|
||||||
|
|
||||||
// sendChunk builds an OpenAI SSE line and optionally translates to target format
|
|
||||||
var streamParam any
|
var streamParam any
|
||||||
sendChunk := func(delta string, finishReason string) {
|
|
||||||
|
// Tool result channel for inline mode. processH2SessionFrames blocks on it
|
||||||
|
// when mcpArgs is received, while continuing to handle KV/heartbeat.
|
||||||
|
toolResultCh := make(chan []toolResultInfo, 1)
|
||||||
|
|
||||||
|
// Switchable output: initially writes to `chunks`. After mcpArgs, the
|
||||||
|
// onMcpExec callback closes `chunks` (ending the first HTTP response),
|
||||||
|
// then processH2SessionFrames blocks on toolResultCh. When results arrive,
|
||||||
|
// it switches to `resumeOutCh` (created by resumeWithToolResults).
|
||||||
|
var outMu sync.Mutex
|
||||||
|
currentOut := chunks
|
||||||
|
|
||||||
|
emitToOut := func(chunk cliproxyexecutor.StreamChunk) {
|
||||||
|
outMu.Lock()
|
||||||
|
out := currentOut
|
||||||
|
outMu.Unlock()
|
||||||
|
if out != nil {
|
||||||
|
out <- chunk
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap sendChunk/sendDone to use emitToOut
|
||||||
|
sendChunkSwitchable := func(delta string, finishReason string) {
|
||||||
fr := "null"
|
fr := "null"
|
||||||
if finishReason != "" {
|
if finishReason != "" {
|
||||||
fr = finishReason
|
fr = finishReason
|
||||||
@@ -355,95 +380,146 @@ func (e *CursorExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if needsTranslate {
|
if needsTranslate {
|
||||||
translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam)
|
translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam)
|
||||||
for _, t := range translated {
|
for _, t := range translated {
|
||||||
chunks <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)}
|
emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
chunks <- cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)}
|
emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sendDone := func() {
|
sendDoneSwitchable := func() {
|
||||||
if needsTranslate {
|
if needsTranslate {
|
||||||
done := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, []byte("data: [DONE]\n"), &streamParam)
|
done := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, []byte("data: [DONE]\n"), &streamParam)
|
||||||
for _, d := range done {
|
for _, d := range done {
|
||||||
chunks <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(d)}
|
emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(d)})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
chunks <- cliproxyexecutor.StreamChunk{Payload: []byte("[DONE]")}
|
emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte("[DONE]")})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(chunks)
|
var resumeOutCh chan cliproxyexecutor.StreamChunk
|
||||||
|
_ = resumeOutCh
|
||||||
thinkingActive := false
|
thinkingActive := false
|
||||||
toolCallIndex := 0
|
toolCallIndex := 0
|
||||||
mcpExecReceived := false
|
usage := &cursorTokenUsage{}
|
||||||
|
usage.setInputEstimate(len(payload))
|
||||||
|
|
||||||
processH2SessionFrames(sessionCtx, stream, params.BlobStore, params.McpTools,
|
processH2SessionFrames(sessionCtx, stream, params.BlobStore, params.McpTools,
|
||||||
func(text string, isThinking bool) {
|
func(text string, isThinking bool) {
|
||||||
if isThinking {
|
if isThinking {
|
||||||
if !thinkingActive {
|
if !thinkingActive {
|
||||||
thinkingActive = true
|
thinkingActive = true
|
||||||
sendChunk(`{"role":"assistant","content":"<think>"}`, "")
|
sendChunkSwitchable(`{"role":"assistant","content":"<think>"}`, "")
|
||||||
}
|
}
|
||||||
sendChunk(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
|
sendChunkSwitchable(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
|
||||||
} else {
|
} else {
|
||||||
if thinkingActive {
|
if thinkingActive {
|
||||||
thinkingActive = false
|
thinkingActive = false
|
||||||
sendChunk(`{"content":"</think>"}`, "")
|
sendChunkSwitchable(`{"content":"</think>"}`, "")
|
||||||
}
|
}
|
||||||
sendChunk(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
|
sendChunkSwitchable(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
func(exec pendingMcpExec) {
|
func(exec pendingMcpExec) {
|
||||||
mcpExecReceived = true
|
|
||||||
if thinkingActive {
|
if thinkingActive {
|
||||||
thinkingActive = false
|
thinkingActive = false
|
||||||
sendChunk(`{"content":"</think>"}`, "")
|
sendChunkSwitchable(`{"content":"</think>"}`, "")
|
||||||
}
|
}
|
||||||
toolCallJSON := fmt.Sprintf(`{"tool_calls":[{"index":%d,"id":"%s","type":"function","function":{"name":"%s","arguments":%s}}]}`,
|
toolCallJSON := fmt.Sprintf(`{"tool_calls":[{"index":%d,"id":"%s","type":"function","function":{"name":"%s","arguments":%s}}]}`,
|
||||||
toolCallIndex, exec.ToolCallId, exec.ToolName, jsonString(exec.Args))
|
toolCallIndex, exec.ToolCallId, exec.ToolName, jsonString(exec.Args))
|
||||||
toolCallIndex++
|
toolCallIndex++
|
||||||
sendChunk(toolCallJSON, "")
|
sendChunkSwitchable(toolCallJSON, "")
|
||||||
sendChunk(`{}`, `"tool_calls"`)
|
sendChunkSwitchable(`{}`, `"tool_calls"`)
|
||||||
sendDone()
|
sendDoneSwitchable()
|
||||||
|
|
||||||
// Save session for resume — keep stream alive.
|
// Close current output to end the current HTTP SSE response
|
||||||
// The heartbeat goroutine continues running (session-scoped context),
|
outMu.Lock()
|
||||||
// keeping the H2 connection alive while the MCP tool executes.
|
if currentOut != nil {
|
||||||
log.Debugf("cursor: saving session %s for MCP tool resume (tool=%s, streamID=%s)", sessionKey, exec.ToolName, stream.ID())
|
close(currentOut)
|
||||||
|
currentOut = nil
|
||||||
|
}
|
||||||
|
outMu.Unlock()
|
||||||
|
|
||||||
|
// Create new resume output channel, reuse the same toolResultCh
|
||||||
|
resumeOut := make(chan cliproxyexecutor.StreamChunk, 64)
|
||||||
|
log.Debugf("cursor: saving session %s for MCP tool resume (tool=%s)", sessionKey, exec.ToolName)
|
||||||
e.mu.Lock()
|
e.mu.Lock()
|
||||||
e.sessions[sessionKey] = &cursorSession{
|
e.sessions[sessionKey] = &cursorSession{
|
||||||
stream: stream,
|
stream: stream,
|
||||||
blobStore: params.BlobStore,
|
blobStore: params.BlobStore,
|
||||||
mcpTools: params.McpTools,
|
mcpTools: params.McpTools,
|
||||||
pending: []pendingMcpExec{exec},
|
pending: []pendingMcpExec{exec},
|
||||||
cancel: sessionCancel,
|
cancel: sessionCancel,
|
||||||
createdAt: time.Now(),
|
createdAt: time.Now(),
|
||||||
|
toolResultCh: toolResultCh, // reuse same channel across rounds
|
||||||
|
resumeOutCh: resumeOut,
|
||||||
|
switchOutput: func(ch chan cliproxyexecutor.StreamChunk) {
|
||||||
|
outMu.Lock()
|
||||||
|
currentOut = ch
|
||||||
|
// Reset translator state so the new HTTP response gets
|
||||||
|
// a fresh message_start, content_block_start, etc.
|
||||||
|
streamParam = nil
|
||||||
|
// New response needs its own message ID
|
||||||
|
chatId = "chatcmpl-" + uuid.New().String()[:28]
|
||||||
|
created = time.Now().Unix()
|
||||||
|
outMu.Unlock()
|
||||||
|
},
|
||||||
}
|
}
|
||||||
e.mu.Unlock()
|
e.mu.Unlock()
|
||||||
|
resumeOutCh = resumeOut
|
||||||
|
|
||||||
|
// processH2SessionFrames will now block on toolResultCh (inline wait loop)
|
||||||
|
// while continuing to handle KV messages
|
||||||
},
|
},
|
||||||
|
toolResultCh,
|
||||||
|
usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
if !mcpExecReceived {
|
// processH2SessionFrames returned — stream is done
|
||||||
if thinkingActive {
|
if thinkingActive {
|
||||||
sendChunk(`{"content":"</think>"}`, "")
|
sendChunkSwitchable(`{"content":"</think>"}`, "")
|
||||||
}
|
|
||||||
sendChunk(`{}`, `"stop"`)
|
|
||||||
sendDone()
|
|
||||||
sessionCancel()
|
|
||||||
stream.Close()
|
|
||||||
}
|
}
|
||||||
// If mcpExecReceived, do NOT close stream or cancel — session resume will handle it
|
// Include token usage in the final stop chunk
|
||||||
|
inputTok, outputTok := usage.get()
|
||||||
|
stopDelta := fmt.Sprintf(`{},"usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}`,
|
||||||
|
inputTok, outputTok, inputTok+outputTok)
|
||||||
|
// Build the stop chunk with usage embedded in the choices array level
|
||||||
|
fr := `"stop"`
|
||||||
|
openaiJSON := fmt.Sprintf(`{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"%s","choices":[{"index":0,"delta":{},"finish_reason":%s}],"usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}}`,
|
||||||
|
chatId, created, parsed.Model, fr, inputTok, outputTok, inputTok+outputTok)
|
||||||
|
sseLine := []byte("data: " + openaiJSON + "\n")
|
||||||
|
if needsTranslate {
|
||||||
|
translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam)
|
||||||
|
for _, t := range translated {
|
||||||
|
emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)})
|
||||||
|
}
|
||||||
|
sendDoneSwitchable()
|
||||||
|
_ = stopDelta // unused
|
||||||
|
|
||||||
|
// Close whatever output channel is still active
|
||||||
|
outMu.Lock()
|
||||||
|
if currentOut != nil {
|
||||||
|
close(currentOut)
|
||||||
|
currentOut = nil
|
||||||
|
}
|
||||||
|
outMu.Unlock()
|
||||||
|
sessionCancel()
|
||||||
|
stream.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return &cliproxyexecutor.StreamResult{Chunks: chunks}, nil
|
return &cliproxyexecutor.StreamResult{Chunks: chunks}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// resumeWithToolResults sends MCP tool results back on the existing H2 stream,
|
// resumeWithToolResults injects tool results into the running processH2SessionFrames
|
||||||
// then continues reading the stream for the model's response.
|
// via the toolResultCh channel. The original goroutine from ExecuteStream is still alive,
|
||||||
// Mirrors resumeWithToolResults() in cursor-fetch.ts.
|
// blocking on toolResultCh. Once we send the results, it sends the MCP result to Cursor
|
||||||
|
// and continues processing the response text — all in the same goroutine that has been
|
||||||
|
// handling KV messages the whole time.
|
||||||
func (e *CursorExecutor) resumeWithToolResults(
|
func (e *CursorExecutor) resumeWithToolResults(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
session *cursorSession,
|
session *cursorSession,
|
||||||
@@ -453,208 +529,29 @@ func (e *CursorExecutor) resumeWithToolResults(
|
|||||||
originalPayload, payload []byte,
|
originalPayload, payload []byte,
|
||||||
needsTranslate bool,
|
needsTranslate bool,
|
||||||
) (*cliproxyexecutor.StreamResult, error) {
|
) (*cliproxyexecutor.StreamResult, error) {
|
||||||
stream := session.stream
|
log.Debugf("cursor: resumeWithToolResults: injecting %d tool results via channel", len(parsed.ToolResults))
|
||||||
log.Debugf("cursor: resumeWithToolResults: using stream ID=%s", stream.ID())
|
|
||||||
|
|
||||||
// Cancel old session-scoped heartbeat before starting a new one
|
if session.toolResultCh == nil {
|
||||||
session.cancel()
|
return nil, fmt.Errorf("cursor: session has no toolResultCh (stale session?)")
|
||||||
|
|
||||||
// CRITICAL: Process any pending messages from the channel before sending MCP result.
|
|
||||||
// After the initial processH2SessionFrames returned (upon receiving MCP args),
|
|
||||||
// the server may have sent more data (KV messages, text deltas) that are now buffered in dataCh.
|
|
||||||
// We must process KV messages (respond to them) but discard text deltas (stale responses).
|
|
||||||
drainedCount := 0
|
|
||||||
drainedBytes := 0
|
|
||||||
kvProcessedCount := 0
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case staleData, ok := <-stream.Data():
|
|
||||||
if !ok {
|
|
||||||
log.Debugf("cursor: resumeWithToolResults: dataCh closed during drain")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
drainedCount++
|
|
||||||
drainedBytes += len(staleData)
|
|
||||||
log.Debugf("cursor: resumeWithToolResults: processing stale data #%d: len=%d, first bytes: %x (%q)", drainedCount, len(staleData), staleData[:min(20, len(staleData))], string(staleData[:min(20, len(staleData))]))
|
|
||||||
|
|
||||||
// Try to decode and handle KV messages (they need responses)
|
|
||||||
if len(staleData) > 5 {
|
|
||||||
frameLen := binary.BigEndian.Uint32(staleData[1:5])
|
|
||||||
if int(frameLen)+5 <= len(staleData) {
|
|
||||||
payload := staleData[5 : 5+frameLen]
|
|
||||||
msg, err := cursorproto.DecodeAgentServerMessage(payload)
|
|
||||||
if err == nil && msg.Type == cursorproto.ServerMsgKvGetBlob {
|
|
||||||
// Respond to KV getBlob
|
|
||||||
blobKey := cursorproto.BlobIdHex(msg.BlobId)
|
|
||||||
data := session.blobStore[blobKey]
|
|
||||||
log.Debugf("cursor: resumeWithToolResults: responding to stale KV getBlob kvId=%d blobKey=%s found=%v", msg.KvId, blobKey, len(data) > 0)
|
|
||||||
resp := cursorproto.EncodeKvGetBlobResult(msg.KvId, data)
|
|
||||||
stream.Write(cursorproto.FrameConnectMessage(resp, 0))
|
|
||||||
kvProcessedCount++
|
|
||||||
continue
|
|
||||||
} else if err == nil && msg.Type == cursorproto.ServerMsgKvSetBlob {
|
|
||||||
// Respond to KV setBlob
|
|
||||||
blobKey := cursorproto.BlobIdHex(msg.BlobId)
|
|
||||||
session.blobStore[blobKey] = append([]byte(nil), msg.BlobData...)
|
|
||||||
log.Debugf("cursor: resumeWithToolResults: responding to stale KV setBlob kvId=%d blobKey=%s", msg.KvId, blobKey)
|
|
||||||
resp := cursorproto.EncodeKvSetBlobResult(msg.KvId)
|
|
||||||
stream.Write(cursorproto.FrameConnectMessage(resp, 0))
|
|
||||||
kvProcessedCount++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
log.Debugf("cursor: resumeWithToolResults: discarding non-KV stale data")
|
|
||||||
default:
|
|
||||||
// No more data in channel
|
|
||||||
goto drainDone
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
drainDone:
|
if session.resumeOutCh == nil {
|
||||||
if drainedCount > 0 {
|
return nil, fmt.Errorf("cursor: session has no resumeOutCh")
|
||||||
log.Debugf("cursor: resumeWithToolResults: processed %d stale frames (%d bytes total, %d KV responded)", drainedCount, drainedBytes, kvProcessedCount)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send MCP results back on the same H2 stream
|
log.Debugf("cursor: resumeWithToolResults: switching output to resumeOutCh and injecting results")
|
||||||
for _, exec := range session.pending {
|
|
||||||
var content string
|
// Switch the output channel BEFORE injecting results, so that when
|
||||||
var isError bool
|
// processH2SessionFrames unblocks and starts emitting text, it writes
|
||||||
found := false
|
// to the resumeOutCh which the new HTTP handler is reading from.
|
||||||
for _, tr := range parsed.ToolResults {
|
if session.switchOutput != nil {
|
||||||
if tr.ToolCallId == exec.ToolCallId {
|
session.switchOutput(session.resumeOutCh)
|
||||||
content = tr.Content
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
content = "Tool result not provided"
|
|
||||||
isError = true
|
|
||||||
}
|
|
||||||
log.Debugf("cursor: sending MCP result for tool=%s callId=%s execMsgId=%d execId=%s contentLen=%d isError=%v",
|
|
||||||
exec.ToolName, exec.ToolCallId, exec.ExecMsgId, exec.ExecId, len(content), isError)
|
|
||||||
resultBytes := cursorproto.EncodeExecMcpResult(exec.ExecMsgId, exec.ExecId, content, isError)
|
|
||||||
framedResult := cursorproto.FrameConnectMessage(resultBytes, 0)
|
|
||||||
// Log the framed result details for debugging
|
|
||||||
log.Debugf("cursor: MCP result frame size=%d bytes", len(framedResult))
|
|
||||||
log.Debugf("cursor: MCP result frame header: flags=%d, len=%d", framedResult[0], binary.BigEndian.Uint32(framedResult[1:5]))
|
|
||||||
log.Debugf("cursor: MCP result protobuf hex (first 50 bytes): %x", resultBytes[:min(50, len(resultBytes))])
|
|
||||||
if err := stream.Write(framedResult); err != nil {
|
|
||||||
stream.Close()
|
|
||||||
return nil, fmt.Errorf("cursor: failed to send MCP result: %w", err)
|
|
||||||
}
|
|
||||||
log.Debugf("cursor: MCP result sent successfully for tool=%s", exec.ToolName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start new session-scoped heartbeat (independent of HTTP request context)
|
// Inject tool results — this unblocks the waiting processH2SessionFrames
|
||||||
sessionCtx, sessionCancel := context.WithCancel(context.Background())
|
session.toolResultCh <- parsed.ToolResults
|
||||||
go cursorH2Heartbeat(sessionCtx, stream)
|
|
||||||
log.Debugf("cursor: started new heartbeat for resumed session, waiting for Cursor response...")
|
|
||||||
|
|
||||||
chunks := make(chan cliproxyexecutor.StreamChunk, 64)
|
// Return the resumeOutCh for the new HTTP handler to read from
|
||||||
chatId := "chatcmpl-" + uuid.New().String()[:28]
|
return &cliproxyexecutor.StreamResult{Chunks: session.resumeOutCh}, nil
|
||||||
created := time.Now().Unix()
|
|
||||||
sessionKey := deriveSessionKey(parsed.Model, parsed.Messages)
|
|
||||||
|
|
||||||
var streamParam any
|
|
||||||
sendChunk := func(delta string, finishReason string) {
|
|
||||||
fr := "null"
|
|
||||||
if finishReason != "" {
|
|
||||||
fr = finishReason
|
|
||||||
}
|
|
||||||
openaiJSON := fmt.Sprintf(`{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"%s","choices":[{"index":0,"delta":%s,"finish_reason":%s}]}`,
|
|
||||||
chatId, created, parsed.Model, delta, fr)
|
|
||||||
sseLine := []byte("data: " + openaiJSON + "\n")
|
|
||||||
|
|
||||||
if needsTranslate {
|
|
||||||
translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam)
|
|
||||||
for _, t := range translated {
|
|
||||||
chunks <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
chunks <- cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sendDone := func() {
|
|
||||||
if needsTranslate {
|
|
||||||
done := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, []byte("data: [DONE]\n"), &streamParam)
|
|
||||||
for _, d := range done {
|
|
||||||
chunks <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(d)}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
chunks <- cliproxyexecutor.StreamChunk{Payload: []byte("[DONE]")}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
log.Debugf("cursor: resume goroutine exiting, closing chunks channel")
|
|
||||||
close(chunks)
|
|
||||||
}()
|
|
||||||
log.Debugf("cursor: resume goroutine started, entering processH2SessionFrames")
|
|
||||||
|
|
||||||
thinkingActive := false
|
|
||||||
toolCallIndex := 0
|
|
||||||
mcpExecReceived := false
|
|
||||||
|
|
||||||
processH2SessionFrames(sessionCtx, stream, session.blobStore, session.mcpTools,
|
|
||||||
func(text string, isThinking bool) {
|
|
||||||
log.Debugf("cursor: resume received text (isThinking=%v, len=%d)", isThinking, len(text))
|
|
||||||
if isThinking {
|
|
||||||
if !thinkingActive {
|
|
||||||
thinkingActive = true
|
|
||||||
sendChunk(`{"role":"assistant","content":"<think>"}`, "")
|
|
||||||
}
|
|
||||||
sendChunk(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
|
|
||||||
} else {
|
|
||||||
if thinkingActive {
|
|
||||||
thinkingActive = false
|
|
||||||
sendChunk(`{"content":"</think>"}`, "")
|
|
||||||
}
|
|
||||||
sendChunk(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
func(exec pendingMcpExec) {
|
|
||||||
mcpExecReceived = true
|
|
||||||
if thinkingActive {
|
|
||||||
thinkingActive = false
|
|
||||||
sendChunk(`{"content":"</think>"}`, "")
|
|
||||||
}
|
|
||||||
toolCallJSON := fmt.Sprintf(`{"tool_calls":[{"index":%d,"id":"%s","type":"function","function":{"name":"%s","arguments":%s}}]}`,
|
|
||||||
toolCallIndex, exec.ToolCallId, exec.ToolName, jsonString(exec.Args))
|
|
||||||
toolCallIndex++
|
|
||||||
sendChunk(toolCallJSON, "")
|
|
||||||
sendChunk(`{}`, `"tool_calls"`)
|
|
||||||
sendDone()
|
|
||||||
|
|
||||||
// Save session again for another round of tool calls
|
|
||||||
log.Debugf("cursor: saving session %s for another MCP tool resume (tool=%s, streamID=%s)", sessionKey, exec.ToolName, stream.ID())
|
|
||||||
e.mu.Lock()
|
|
||||||
e.sessions[sessionKey] = &cursorSession{
|
|
||||||
stream: stream,
|
|
||||||
blobStore: session.blobStore,
|
|
||||||
mcpTools: session.mcpTools,
|
|
||||||
pending: []pendingMcpExec{exec},
|
|
||||||
cancel: sessionCancel,
|
|
||||||
createdAt: time.Now(),
|
|
||||||
}
|
|
||||||
e.mu.Unlock()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if !mcpExecReceived {
|
|
||||||
if thinkingActive {
|
|
||||||
sendChunk(`{"content":"</think>"}`, "")
|
|
||||||
}
|
|
||||||
sendChunk(`{}`, `"stop"`)
|
|
||||||
sendDone()
|
|
||||||
sessionCancel()
|
|
||||||
stream.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return &cliproxyexecutor.StreamResult{Chunks: chunks}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- H2Stream helpers ---
|
// --- H2Stream helpers ---
|
||||||
@@ -693,6 +590,35 @@ func cursorH2Heartbeat(ctx context.Context, stream *cursorproto.H2Stream) {
|
|||||||
|
|
||||||
// --- Response processing ---
|
// --- Response processing ---
|
||||||
|
|
||||||
|
// cursorTokenUsage tracks token counts from Cursor's TokenDeltaUpdate messages.
|
||||||
|
type cursorTokenUsage struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
outputTokens int64
|
||||||
|
inputTokensEst int64 // estimated from request payload size
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *cursorTokenUsage) addOutput(delta int64) {
|
||||||
|
u.mu.Lock()
|
||||||
|
defer u.mu.Unlock()
|
||||||
|
u.outputTokens += delta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *cursorTokenUsage) setInputEstimate(payloadBytes int) {
|
||||||
|
u.mu.Lock()
|
||||||
|
defer u.mu.Unlock()
|
||||||
|
// Rough estimate: ~4 bytes per token for mixed content
|
||||||
|
u.inputTokensEst = int64(payloadBytes / 4)
|
||||||
|
if u.inputTokensEst < 1 {
|
||||||
|
u.inputTokensEst = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *cursorTokenUsage) get() (input, output int64) {
|
||||||
|
u.mu.Lock()
|
||||||
|
defer u.mu.Unlock()
|
||||||
|
return u.inputTokensEst, u.outputTokens
|
||||||
|
}
|
||||||
|
|
||||||
func processH2SessionFrames(
|
func processH2SessionFrames(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
stream *cursorproto.H2Stream,
|
stream *cursorproto.H2Stream,
|
||||||
@@ -700,6 +626,8 @@ func processH2SessionFrames(
|
|||||||
mcpTools []cursorproto.McpToolDef,
|
mcpTools []cursorproto.McpToolDef,
|
||||||
onText func(text string, isThinking bool),
|
onText func(text string, isThinking bool),
|
||||||
onMcpExec func(exec pendingMcpExec),
|
onMcpExec func(exec pendingMcpExec),
|
||||||
|
toolResultCh <-chan []toolResultInfo, // nil for no tool result injection; non-nil to wait for results
|
||||||
|
tokenUsage *cursorTokenUsage, // tracks accumulated token usage (may be nil)
|
||||||
) {
|
) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
rejectReason := "Tool not available in this environment. Use the MCP tools provided instead."
|
rejectReason := "Tool not available in this environment. Use the MCP tools provided instead."
|
||||||
@@ -762,6 +690,20 @@ func processH2SessionFrames(
|
|||||||
case cursorproto.ServerMsgThinkingCompleted:
|
case cursorproto.ServerMsgThinkingCompleted:
|
||||||
// Handled by caller
|
// Handled by caller
|
||||||
|
|
||||||
|
case cursorproto.ServerMsgTurnEnded:
|
||||||
|
log.Debugf("cursor: TurnEnded received, stream will finish")
|
||||||
|
return
|
||||||
|
|
||||||
|
case cursorproto.ServerMsgHeartbeat:
|
||||||
|
// Server heartbeat, ignore silently
|
||||||
|
continue
|
||||||
|
|
||||||
|
case cursorproto.ServerMsgTokenDelta:
|
||||||
|
if tokenUsage != nil && msg.TokenDelta > 0 {
|
||||||
|
tokenUsage.addOutput(msg.TokenDelta)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
|
||||||
case cursorproto.ServerMsgKvGetBlob:
|
case cursorproto.ServerMsgKvGetBlob:
|
||||||
blobKey := cursorproto.BlobIdHex(msg.BlobId)
|
blobKey := cursorproto.BlobIdHex(msg.BlobId)
|
||||||
data := blobStore[blobKey]
|
data := blobStore[blobKey]
|
||||||
@@ -785,17 +727,85 @@ func processH2SessionFrames(
|
|||||||
if toolCallId == "" {
|
if toolCallId == "" {
|
||||||
toolCallId = uuid.New().String()
|
toolCallId = uuid.New().String()
|
||||||
}
|
}
|
||||||
// Debug: log the received execId from server
|
|
||||||
log.Debugf("cursor: received mcpArgs from server: execMsgId=%d execId=%q toolName=%s toolCallId=%s",
|
log.Debugf("cursor: received mcpArgs from server: execMsgId=%d execId=%q toolName=%s toolCallId=%s",
|
||||||
msg.ExecMsgId, msg.ExecId, msg.McpToolName, toolCallId)
|
msg.ExecMsgId, msg.ExecId, msg.McpToolName, toolCallId)
|
||||||
onMcpExec(pendingMcpExec{
|
pending := pendingMcpExec{
|
||||||
ExecMsgId: msg.ExecMsgId,
|
ExecMsgId: msg.ExecMsgId,
|
||||||
ExecId: msg.ExecId,
|
ExecId: msg.ExecId,
|
||||||
ToolCallId: toolCallId,
|
ToolCallId: toolCallId,
|
||||||
ToolName: msg.McpToolName,
|
ToolName: msg.McpToolName,
|
||||||
Args: decodedArgs,
|
Args: decodedArgs,
|
||||||
})
|
}
|
||||||
return
|
onMcpExec(pending)
|
||||||
|
|
||||||
|
if toolResultCh == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inline mode: wait for tool result while handling KV/heartbeat
|
||||||
|
log.Debugf("cursor: waiting for tool result on channel (inline mode)...")
|
||||||
|
var toolResults []toolResultInfo
|
||||||
|
waitLoop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case results, ok := <-toolResultCh:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
toolResults = results
|
||||||
|
break waitLoop
|
||||||
|
case waitData, ok := <-stream.Data():
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
buf.Write(waitData)
|
||||||
|
for {
|
||||||
|
cb := buf.Bytes()
|
||||||
|
if len(cb) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
wf, wp, wc, wok := cursorproto.ParseConnectFrame(cb)
|
||||||
|
if !wok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
buf.Next(wc)
|
||||||
|
if wf&cursorproto.ConnectEndStreamFlag != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
wmsg, werr := cursorproto.DecodeAgentServerMessage(wp)
|
||||||
|
if werr != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch wmsg.Type {
|
||||||
|
case cursorproto.ServerMsgKvGetBlob:
|
||||||
|
blobKey := cursorproto.BlobIdHex(wmsg.BlobId)
|
||||||
|
d := blobStore[blobKey]
|
||||||
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeKvGetBlobResult(wmsg.KvId, d), 0))
|
||||||
|
case cursorproto.ServerMsgKvSetBlob:
|
||||||
|
blobKey := cursorproto.BlobIdHex(wmsg.BlobId)
|
||||||
|
blobStore[blobKey] = append([]byte(nil), wmsg.BlobData...)
|
||||||
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeKvSetBlobResult(wmsg.KvId), 0))
|
||||||
|
case cursorproto.ServerMsgExecRequestCtx:
|
||||||
|
stream.Write(cursorproto.FrameConnectMessage(cursorproto.EncodeExecRequestContextResult(wmsg.ExecMsgId, wmsg.ExecId, mcpTools), 0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case <-stream.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send MCP result
|
||||||
|
for _, tr := range toolResults {
|
||||||
|
if tr.ToolCallId == pending.ToolCallId {
|
||||||
|
log.Debugf("cursor: sending inline MCP result for tool=%s", pending.ToolName)
|
||||||
|
resultBytes := cursorproto.EncodeExecMcpResult(pending.ExecMsgId, pending.ExecId, tr.Content, false)
|
||||||
|
stream.Write(cursorproto.FrameConnectMessage(resultBytes, 0))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
case cursorproto.ServerMsgExecReadArgs:
|
case cursorproto.ServerMsgExecReadArgs:
|
||||||
|
|||||||
Reference in New Issue
Block a user