mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-20 14:41:49 +00:00
fix(cursor): MCP tool call resume, H2 flow control, and token usage
- Rewrite tool call mechanism from interrupt-resume to inline-wait mode: processH2SessionFrames no longer exits on mcpArgs; instead blocks on toolResultCh while continuing to handle KV/heartbeat messages, then sends MCP result and continues processing text in the same goroutine. Fixes the issue where server stopped generating text after resume. - Add switchable output channel (outMu/currentOut) so first HTTP response closes after tool_calls+[DONE], and resumed text goes to a new channel returned by resumeWithToolResults. Reset streamParam on switch so Translator produces fresh message_start/content_block_start events. - Implement send-side H2 flow control: track server's initial window size and WINDOW_UPDATE increments; Write() blocks when window exhausted. Fixes RST_STREAM FLOW_CONTROL_ERROR on large requests (178KB+). - Decode new InteractionUpdate fields: TurnEndedUpdate (field 14) as stream termination signal, HeartbeatUpdate (field 13) silently ignored, TokenDeltaUpdate (field 8) for token usage tracking. - Include token usage in final stop chunk (prompt_tokens estimated from payload size, completion_tokens from accumulated TokenDeltaUpdate deltas) so Claude CLI status bar shows non-zero token counts. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -32,6 +32,9 @@ const (
|
||||
ServerMsgExecBgShellSpawn // Rejected: background shell
|
||||
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
|
||||
ServerMsgExecOther // Other exec types (respond with empty)
|
||||
ServerMsgTurnEnded // Turn has ended (no more output)
|
||||
ServerMsgHeartbeat // Server heartbeat
|
||||
ServerMsgTokenDelta // Token usage delta
|
||||
)
|
||||
|
||||
// DecodedServerMessage holds parsed data from an AgentServerMessage.
|
||||
@@ -63,6 +66,9 @@ type DecodedServerMessage struct {
|
||||
|
||||
// For other exec - the raw field number for building a response
|
||||
ExecFieldNumber int
|
||||
|
||||
// For TokenDeltaUpdate
|
||||
TokenDelta int64
|
||||
}
|
||||
|
||||
// DecodeAgentServerMessage parses an AgentServerMessage and returns
|
||||
@@ -160,6 +166,24 @@ func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) {
|
||||
case 3:
|
||||
// tool_call_completed - ignore but log
|
||||
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
|
||||
case 8:
|
||||
// token_delta - extract token count
|
||||
msg.Type = ServerMsgTokenDelta
|
||||
msg.TokenDelta = decodeVarintField(val, 1)
|
||||
log.Debugf("decodeInteractionUpdate: TokenDeltaUpdate tokens=%d", msg.TokenDelta)
|
||||
case 13:
|
||||
// heartbeat from server
|
||||
msg.Type = ServerMsgHeartbeat
|
||||
case 14:
|
||||
// turn_ended - critical: model finished generating
|
||||
msg.Type = ServerMsgTurnEnded
|
||||
log.Debugf("decodeInteractionUpdate: TurnEndedUpdate - stream should end")
|
||||
case 16:
|
||||
// step_started - ignore
|
||||
log.Debugf("decodeInteractionUpdate: StepStartedUpdate (ignored)")
|
||||
case 17:
|
||||
// step_completed - ignore
|
||||
log.Debugf("decodeInteractionUpdate: StepCompletedUpdate (ignored)")
|
||||
default:
|
||||
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
|
||||
}
|
||||
@@ -500,6 +524,34 @@ func decodeBytesField(data []byte, targetField protowire.Number) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeVarintField extracts an int64 from the first matching varint field in a submessage.
|
||||
func decodeVarintField(data []byte, targetField protowire.Number) int64 {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
data = data[n:]
|
||||
if typ == protowire.VarintType {
|
||||
val, n := protowire.ConsumeVarint(data)
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
data = data[n:]
|
||||
if num == targetField {
|
||||
return int64(val)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// BlobIdHex returns the hex string of a blob ID for use as a map key.
|
||||
func BlobIdHex(blobId []byte) string {
|
||||
return hex.EncodeToString(blobId)
|
||||
|
||||
@@ -13,6 +13,11 @@ import (
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultInitialWindowSize = 65535 // HTTP/2 default
|
||||
maxFramePayload = 16384 // HTTP/2 default max frame size
|
||||
)
|
||||
|
||||
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
|
||||
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
|
||||
type H2Stream struct {
|
||||
@@ -21,11 +26,17 @@ type H2Stream struct {
|
||||
streamID uint32
|
||||
mu sync.Mutex
|
||||
id string // unique identifier for debugging
|
||||
frameNum int64 // sequential frame counter for debugging
|
||||
frameNum int64 // sequential frame counter for debugging
|
||||
|
||||
dataCh chan []byte
|
||||
doneCh chan struct{}
|
||||
err error
|
||||
|
||||
// Send-side flow control
|
||||
sendWindow int32 // available bytes we can send on this stream
|
||||
connWindow int32 // available bytes on the connection level
|
||||
windowCond *sync.Cond // signaled when window is updated
|
||||
windowMu sync.Mutex // protects sendWindow, connWindow
|
||||
}
|
||||
|
||||
// ID returns the unique identifier for this stream (for logging).
|
||||
@@ -59,7 +70,7 @@ func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
|
||||
return nil, fmt.Errorf("h2: preface write failed: %w", err)
|
||||
}
|
||||
|
||||
// Send initial SETTINGS (with large initial window)
|
||||
// Send initial SETTINGS (tell server how much WE can receive)
|
||||
if err := framer.WriteSettings(
|
||||
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
|
||||
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
|
||||
@@ -68,14 +79,17 @@ func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
|
||||
return nil, fmt.Errorf("h2: settings write failed: %w", err)
|
||||
}
|
||||
|
||||
// Connection-level window update (default is 65535, bump it up)
|
||||
// Connection-level window update (for receiving)
|
||||
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: window update failed: %w", err)
|
||||
}
|
||||
|
||||
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
|
||||
for i := 0; i < 5; i++ {
|
||||
// Track server's initial window size (how much WE can send)
|
||||
serverInitialWindowSize := int32(defaultInitialWindowSize)
|
||||
connWindowSize := int32(defaultInitialWindowSize) // connection-level send window
|
||||
for i := 0; i < 10; i++ {
|
||||
f, err := framer.ReadFrame()
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
@@ -84,12 +98,22 @@ func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
|
||||
switch sf := f.(type) {
|
||||
case *http2.SettingsFrame:
|
||||
if !sf.IsAck() {
|
||||
sf.ForeachSetting(func(s http2.Setting) error {
|
||||
if s.ID == http2.SettingInitialWindowSize {
|
||||
serverInitialWindowSize = int32(s.Val)
|
||||
log.Debugf("h2: server initial window size: %d", s.Val)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
framer.WriteSettingsAck()
|
||||
} else {
|
||||
goto handshakeDone
|
||||
}
|
||||
case *http2.WindowUpdateFrame:
|
||||
// ignore
|
||||
if sf.StreamID == 0 {
|
||||
connWindowSize += int32(sf.Increment)
|
||||
log.Debugf("h2: initial conn window update: +%d, total=%d", sf.Increment, connWindowSize)
|
||||
}
|
||||
default:
|
||||
// unexpected but continue
|
||||
}
|
||||
@@ -124,36 +148,53 @@ handshakeDone:
|
||||
}
|
||||
|
||||
s := &H2Stream{
|
||||
framer: framer,
|
||||
conn: tlsConn,
|
||||
streamID: streamID,
|
||||
dataCh: make(chan []byte, 256),
|
||||
doneCh: make(chan struct{}),
|
||||
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
|
||||
frameNum: 0,
|
||||
framer: framer,
|
||||
conn: tlsConn,
|
||||
streamID: streamID,
|
||||
dataCh: make(chan []byte, 256),
|
||||
doneCh: make(chan struct{}),
|
||||
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
|
||||
frameNum: 0,
|
||||
sendWindow: serverInitialWindowSize,
|
||||
connWindow: connWindowSize,
|
||||
}
|
||||
s.windowCond = sync.NewCond(&s.windowMu)
|
||||
go s.readLoop()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Write sends a DATA frame on the stream.
|
||||
// Write sends a DATA frame on the stream, respecting flow control.
|
||||
func (s *H2Stream) Write(data []byte) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
const maxFrame = 16384
|
||||
for len(data) > 0 {
|
||||
chunk := data
|
||||
if len(chunk) > maxFrame {
|
||||
chunk = data[:maxFrame]
|
||||
if len(chunk) > maxFramePayload {
|
||||
chunk = data[:maxFramePayload]
|
||||
}
|
||||
data = data[len(chunk):]
|
||||
if err := s.framer.WriteData(s.streamID, false, chunk); err != nil {
|
||||
|
||||
// Wait for flow control window
|
||||
s.windowMu.Lock()
|
||||
for s.sendWindow <= 0 || s.connWindow <= 0 {
|
||||
s.windowCond.Wait()
|
||||
}
|
||||
// Limit chunk to available window
|
||||
allowed := int(s.sendWindow)
|
||||
if int(s.connWindow) < allowed {
|
||||
allowed = int(s.connWindow)
|
||||
}
|
||||
if len(chunk) > allowed {
|
||||
chunk = chunk[:allowed]
|
||||
}
|
||||
s.sendWindow -= int32(len(chunk))
|
||||
s.connWindow -= int32(len(chunk))
|
||||
s.windowMu.Unlock()
|
||||
|
||||
s.mu.Lock()
|
||||
err := s.framer.WriteData(s.streamID, false, chunk)
|
||||
s.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Try to flush the underlying connection if it supports it
|
||||
if flusher, ok := s.conn.(interface{ Flush() error }); ok {
|
||||
flusher.Flush()
|
||||
data = data[len(chunk):]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -167,12 +208,13 @@ func (s *H2Stream) Done() <-chan struct{} { return s.doneCh }
|
||||
// Close tears down the connection.
|
||||
func (s *H2Stream) Close() {
|
||||
s.conn.Close()
|
||||
// Unblock any writers waiting on flow control
|
||||
s.windowCond.Broadcast()
|
||||
}
|
||||
|
||||
func (s *H2Stream) readLoop() {
|
||||
defer close(s.doneCh)
|
||||
defer close(s.dataCh)
|
||||
log.Debugf("h2stream[%s]: readLoop started for streamID=%d", s.id, s.streamID)
|
||||
|
||||
for {
|
||||
f, err := s.framer.ReadFrame()
|
||||
@@ -180,71 +222,47 @@ func (s *H2Stream) readLoop() {
|
||||
if err != io.EOF {
|
||||
s.err = err
|
||||
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
|
||||
} else {
|
||||
log.Debugf("h2stream[%s]: readLoop EOF", s.id)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Increment frame counter for debugging
|
||||
// Increment frame counter
|
||||
s.mu.Lock()
|
||||
s.frameNum++
|
||||
frameNum := s.frameNum
|
||||
s.mu.Unlock()
|
||||
|
||||
switch frame := f.(type) {
|
||||
case *http2.DataFrame:
|
||||
log.Debugf("h2stream[%s]: frame#%d received DATA frame streamID=%d, len=%d, endStream=%v", s.id, frameNum, frame.StreamID, len(frame.Data()), frame.StreamEnded())
|
||||
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
|
||||
cp := make([]byte, len(frame.Data()))
|
||||
copy(cp, frame.Data())
|
||||
// Log first 20 bytes for debugging
|
||||
previewLen := len(cp)
|
||||
if previewLen > 20 {
|
||||
previewLen = 20
|
||||
}
|
||||
log.Debugf("h2stream[%s]: frame#%d sending to dataCh: len=%d, dataCh len=%d/%d, first bytes: %x (%q)", s.id, frameNum, len(cp), len(s.dataCh), cap(s.dataCh), cp[:previewLen], string(cp[:previewLen]))
|
||||
s.dataCh <- cp
|
||||
|
||||
// Flow control: send WINDOW_UPDATE
|
||||
// Flow control: send WINDOW_UPDATE for received data
|
||||
s.mu.Lock()
|
||||
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
|
||||
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
|
||||
s.mu.Unlock()
|
||||
}
|
||||
if frame.StreamEnded() {
|
||||
log.Debugf("h2stream[%s]: frame#%d DATA frame has END_STREAM flag, stream ending", s.id, frameNum)
|
||||
return
|
||||
}
|
||||
|
||||
case *http2.HeadersFrame:
|
||||
// Decode HPACK headers for debugging
|
||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {
|
||||
log.Debugf("h2stream[%s]: frame#%d header: %s = %q", s.id, frameNum, hf.Name, hf.Value)
|
||||
// Check for error status
|
||||
if hf.Name == "grpc-status" || hf.Name == ":status" && hf.Value != "200" {
|
||||
log.Warnf("h2stream[%s]: frame#%d received error status header: %s = %q", s.id, frameNum, hf.Name, hf.Value)
|
||||
}
|
||||
})
|
||||
decoder.Write(frame.HeaderBlockFragment())
|
||||
log.Debugf("h2stream[%s]: frame#%d received HEADERS frame streamID=%d, endStream=%v", s.id, frameNum, frame.StreamID, frame.StreamEnded())
|
||||
if frame.StreamEnded() {
|
||||
log.Debugf("h2stream[%s]: frame#%d HEADERS frame has END_STREAM flag, stream ending", s.id, frameNum)
|
||||
return
|
||||
}
|
||||
|
||||
case *http2.RSTStreamFrame:
|
||||
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
|
||||
log.Debugf("h2stream[%s]: frame#%d received RST_STREAM code=%d", s.id, frameNum, frame.ErrCode)
|
||||
log.Debugf("h2stream[%s]: received RST_STREAM code=%d", s.id, frame.ErrCode)
|
||||
return
|
||||
|
||||
case *http2.GoAwayFrame:
|
||||
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
|
||||
log.Debugf("h2stream[%s]: received GOAWAY code=%d", s.id, frame.ErrCode)
|
||||
return
|
||||
|
||||
case *http2.PingFrame:
|
||||
log.Debugf("h2stream[%s]: received PING frame, isAck=%v", s.id, frame.IsAck())
|
||||
if !frame.IsAck() {
|
||||
s.mu.Lock()
|
||||
s.framer.WritePing(true, frame.Data)
|
||||
@@ -252,15 +270,33 @@ func (s *H2Stream) readLoop() {
|
||||
}
|
||||
|
||||
case *http2.SettingsFrame:
|
||||
log.Debugf("h2stream[%s]: received SETTINGS frame, isAck=%v, numSettings=%d", s.id, frame.IsAck(), frame.NumSettings())
|
||||
if !frame.IsAck() {
|
||||
// Check for window size changes
|
||||
frame.ForeachSetting(func(setting http2.Setting) error {
|
||||
if setting.ID == http2.SettingInitialWindowSize {
|
||||
s.windowMu.Lock()
|
||||
delta := int32(setting.Val) - s.sendWindow
|
||||
s.sendWindow += delta
|
||||
s.windowMu.Unlock()
|
||||
s.windowCond.Broadcast()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
s.mu.Lock()
|
||||
s.framer.WriteSettingsAck()
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
case *http2.WindowUpdateFrame:
|
||||
log.Debugf("h2stream[%s]: received WINDOW_UPDATE frame", s.id)
|
||||
// Update send-side flow control window
|
||||
s.windowMu.Lock()
|
||||
if frame.StreamID == 0 {
|
||||
s.connWindow += int32(frame.Increment)
|
||||
} else if frame.StreamID == s.streamID {
|
||||
s.sendWindow += int32(frame.Increment)
|
||||
}
|
||||
s.windowMu.Unlock()
|
||||
s.windowCond.Broadcast()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user