feat(tunnel): switch to yamux stream proxying and connection pooling

- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server
- Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly
- Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
This commit is contained in:
Gouryella
2025-12-13 18:03:44 +08:00
parent 3c93789266
commit 0c19c3300c
55 changed files with 3380 additions and 4849 deletions

View File

@@ -2,81 +2,23 @@ package protocol
import (
"sync/atomic"
"time"
"drip/internal/shared/pool"
)
// AdaptivePoolManager dynamically adjusts buffer pool usage based on load
// AdaptivePoolManager tracks active connections for load monitoring
type AdaptivePoolManager struct {
activeConnections atomic.Int64
currentThreshold atomic.Int64
highLoadConnectionThreshold int64
midLoadConnectionThreshold int64
midLoadThreshold int64
highLoadThreshold int64
activeConnections atomic.Int64
}
var globalAdaptiveManager = NewAdaptivePoolManager()
func NewAdaptivePoolManager() *AdaptivePoolManager {
m := &AdaptivePoolManager{
highLoadConnectionThreshold: 300,
midLoadConnectionThreshold: 150,
midLoadThreshold: int64(pool.SizeLarge),
highLoadThreshold: int64(pool.SizeMedium),
}
m.currentThreshold.Store(m.midLoadThreshold)
go m.monitor()
return m
}
func (m *AdaptivePoolManager) monitor() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for range ticker.C {
connections := m.activeConnections.Load()
if connections >= m.highLoadConnectionThreshold {
m.currentThreshold.Store(m.highLoadThreshold)
} else if connections < m.midLoadConnectionThreshold {
m.currentThreshold.Store(m.midLoadThreshold)
}
// Hysteresis zone (150-300): maintain current threshold
}
}
func (m *AdaptivePoolManager) GetThreshold() int {
return int(m.currentThreshold.Load())
}
func (m *AdaptivePoolManager) RegisterConnection() {
m.activeConnections.Add(1)
}
func (m *AdaptivePoolManager) UnregisterConnection() {
m.activeConnections.Add(-1)
}
func (m *AdaptivePoolManager) GetActiveConnections() int64 {
return m.activeConnections.Load()
}
func GetAdaptiveThreshold() int {
return globalAdaptiveManager.GetThreshold()
}
var globalAdaptiveManager = &AdaptivePoolManager{}
func RegisterConnection() {
globalAdaptiveManager.RegisterConnection()
globalAdaptiveManager.activeConnections.Add(1)
}
func UnregisterConnection() {
globalAdaptiveManager.UnregisterConnection()
globalAdaptiveManager.activeConnections.Add(-1)
}
func GetActiveConnections() int64 {
return globalAdaptiveManager.GetActiveConnections()
return globalAdaptiveManager.activeConnections.Load()
}

View File

@@ -1,162 +0,0 @@
package protocol
import (
"encoding/binary"
"errors"
)
// DataHeader represents a binary-encoded data header for data plane
// All data transmission uses pure binary encoding for performance
type DataHeader struct {
Type DataType
IsLast bool
StreamID string
RequestID string
}
// DataType represents the type of data frame
type DataType uint8
const (
DataTypeData DataType = 0x00 // 000
DataTypeResponse DataType = 0x01 // 001
DataTypeClose DataType = 0x02 // 010
DataTypeHTTPRequest DataType = 0x03 // 011
DataTypeHTTPResponse DataType = 0x04 // 100
DataTypeHTTPHead DataType = 0x05 // 101 - streaming headers (shared)
DataTypeHTTPBodyChunk DataType = 0x06 // 110 - streaming body chunks (shared)
// Reuse the same type codes for request streaming to stay within 3 bits.
DataTypeHTTPRequestHead DataType = DataTypeHTTPHead
DataTypeHTTPRequestBodyChunk DataType = DataTypeHTTPBodyChunk
)
// String returns the string representation of DataType
func (t DataType) String() string {
switch t {
case DataTypeData:
return "data"
case DataTypeResponse:
return "response"
case DataTypeClose:
return "close"
case DataTypeHTTPRequest:
return "http_request"
case DataTypeHTTPResponse:
return "http_response"
case DataTypeHTTPHead:
return "http_head"
case DataTypeHTTPBodyChunk:
return "http_body_chunk"
default:
return "unknown"
}
}
// FromString converts a string to DataType
func DataTypeFromString(s string) DataType {
switch s {
case "data":
return DataTypeData
case "response":
return DataTypeResponse
case "close":
return DataTypeClose
case "http_request":
return DataTypeHTTPRequest
case "http_response":
return DataTypeHTTPResponse
case "http_head":
return DataTypeHTTPHead
case "http_body_chunk":
return DataTypeHTTPBodyChunk
default:
return DataTypeData
}
}
// Binary format:
// +--------+--------+--------+--------+--------+
// | Flags | StreamID Length | RequestID Len |
// | 1 byte | 2 bytes | 2 bytes |
// +--------+--------+--------+--------+--------+
// | StreamID (variable) |
// +--------+--------+--------+--------+--------+
// | RequestID (variable) |
// +--------+--------+--------+--------+--------+
//
// Flags (8 bits):
// - Bit 0-2: Type (3 bits)
// - Bit 3: IsLast (1 bit)
// - Bit 4-7: Reserved (4 bits)
const (
binaryHeaderMinSize = 5 // 1 byte flags + 2 bytes streamID len + 2 bytes requestID len
)
// MarshalBinary encodes the header to binary format
func (h *DataHeader) MarshalBinary() []byte {
streamIDLen := len(h.StreamID)
requestIDLen := len(h.RequestID)
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen
buf := make([]byte, totalLen)
// Encode flags
flags := uint8(h.Type) & 0x07 // Type uses bits 0-2
if h.IsLast {
flags |= 0x08 // IsLast uses bit 3
}
buf[0] = flags
// Encode lengths (big-endian)
binary.BigEndian.PutUint16(buf[1:3], uint16(streamIDLen))
binary.BigEndian.PutUint16(buf[3:5], uint16(requestIDLen))
// Encode StreamID
offset := binaryHeaderMinSize
copy(buf[offset:], h.StreamID)
offset += streamIDLen
// Encode RequestID
copy(buf[offset:], h.RequestID)
return buf
}
// UnmarshalBinary decodes the header from binary format
func (h *DataHeader) UnmarshalBinary(data []byte) error {
if len(data) < binaryHeaderMinSize {
return errors.New("invalid binary header: too short")
}
// Decode flags
flags := data[0]
h.Type = DataType(flags & 0x07) // Bits 0-2
h.IsLast = (flags & 0x08) != 0 // Bit 3
// Decode lengths
streamIDLen := int(binary.BigEndian.Uint16(data[1:3]))
requestIDLen := int(binary.BigEndian.Uint16(data[3:5]))
// Validate total length
expectedLen := binaryHeaderMinSize + streamIDLen + requestIDLen
if len(data) < expectedLen {
return errors.New("invalid binary header: length mismatch")
}
// Decode StreamID
offset := binaryHeaderMinSize
h.StreamID = string(data[offset : offset+streamIDLen])
offset += streamIDLen
// Decode RequestID
h.RequestID = string(data[offset : offset+requestIDLen])
return nil
}
// Size returns the size of the binary-encoded header
func (h *DataHeader) Size() int {
return binaryHeaderMinSize + len(h.StreamID) + len(h.RequestID)
}

View File

@@ -1,34 +0,0 @@
package protocol
import (
json "github.com/goccy/go-json"
)
type FlowControlAction string
const (
FlowControlPause FlowControlAction = "pause"
FlowControlResume FlowControlAction = "resume"
)
type FlowControlMessage struct {
StreamID string `json:"stream_id"`
Action FlowControlAction `json:"action"`
}
func NewFlowControlFrame(streamID string, action FlowControlAction) *Frame {
msg := FlowControlMessage{
StreamID: streamID,
Action: action,
}
payload, _ := json.Marshal(&msg)
return NewFrame(FrameTypeFlowControl, payload)
}
func DecodeFlowControlMessage(payload []byte) (*FlowControlMessage, error) {
var msg FlowControlMessage
if err := json.Unmarshal(payload, &msg); err != nil {
return nil, err
}
return &msg, nil
}

View File

@@ -18,14 +18,14 @@ const (
type FrameType byte
const (
FrameTypeRegister FrameType = 0x01
FrameTypeRegisterAck FrameType = 0x02
FrameTypeHeartbeat FrameType = 0x03
FrameTypeHeartbeatAck FrameType = 0x04
FrameTypeData FrameType = 0x05
FrameTypeClose FrameType = 0x06
FrameTypeError FrameType = 0x07
FrameTypeFlowControl FrameType = 0x08
FrameTypeRegister FrameType = 0x01
FrameTypeRegisterAck FrameType = 0x02
FrameTypeHeartbeat FrameType = 0x03
FrameTypeHeartbeatAck FrameType = 0x04
FrameTypeClose FrameType = 0x05
FrameTypeError FrameType = 0x06
FrameTypeDataConnect FrameType = 0x07
FrameTypeDataConnectAck FrameType = 0x08
)
// String returns the string representation of frame type
@@ -39,14 +39,14 @@ func (t FrameType) String() string {
return "Heartbeat"
case FrameTypeHeartbeatAck:
return "HeartbeatAck"
case FrameTypeData:
return "Data"
case FrameTypeClose:
return "Close"
case FrameTypeError:
return "Error"
case FrameTypeFlowControl:
return "FlowControl"
case FrameTypeDataConnect:
return "DataConnect"
case FrameTypeDataConnectAck:
return "DataConnectAck"
default:
return fmt.Sprintf("Unknown(%d)", t)
}
@@ -56,6 +56,9 @@ type Frame struct {
Type FrameType
Payload []byte
poolBuffer *[]byte
// queuedBytes is set by FrameWriter when the frame is enqueued.
// It allows the writer to decrement backlog counters exactly once.
queuedBytes int64
}
func WriteFrame(w io.Writer, frame *Frame) error {
@@ -130,6 +133,8 @@ func (f *Frame) Release() {
f.poolBuffer = nil
f.Payload = nil
}
// Reset queued marker to avoid carrying over stale state if the frame is reused.
f.queuedBytes = 0
}
// NewFrame creates a new frame

View File

@@ -1,119 +0,0 @@
package protocol
import (
"errors"
json "github.com/goccy/go-json"
"github.com/vmihailenco/msgpack/v5"
)
// EncodeHTTPRequest encodes HTTPRequest using msgpack encoding (optimized)
func EncodeHTTPRequest(req *HTTPRequest) ([]byte, error) {
return msgpack.Marshal(req)
}
// DecodeHTTPRequest decodes HTTPRequest with automatic version detection
// Detects based on first byte: '{' = JSON, else = msgpack
func DecodeHTTPRequest(data []byte) (*HTTPRequest, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var req HTTPRequest
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
if data[0] == '{' {
// v1: JSON
if err := json.Unmarshal(data, &req); err != nil {
return nil, err
}
} else {
// v2: msgpack
if err := msgpack.Unmarshal(data, &req); err != nil {
return nil, err
}
}
return &req, nil
}
// EncodeHTTPRequestHead encodes HTTP request headers for streaming
func EncodeHTTPRequestHead(head *HTTPRequestHead) ([]byte, error) {
return msgpack.Marshal(head)
}
// DecodeHTTPRequestHead decodes HTTP request headers for streaming
func DecodeHTTPRequestHead(data []byte) (*HTTPRequestHead, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var head HTTPRequestHead
if data[0] == '{' {
if err := json.Unmarshal(data, &head); err != nil {
return nil, err
}
} else {
if err := msgpack.Unmarshal(data, &head); err != nil {
return nil, err
}
}
return &head, nil
}
// EncodeHTTPResponse encodes HTTPResponse using msgpack encoding (optimized)
func EncodeHTTPResponse(resp *HTTPResponse) ([]byte, error) {
return msgpack.Marshal(resp)
}
// DecodeHTTPResponse decodes HTTPResponse with automatic version detection
// Detects based on first byte: '{' = JSON, else = msgpack
func DecodeHTTPResponse(data []byte) (*HTTPResponse, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var resp HTTPResponse
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
if data[0] == '{' {
// v1: JSON
if err := json.Unmarshal(data, &resp); err != nil {
return nil, err
}
} else {
// v2: msgpack
if err := msgpack.Unmarshal(data, &resp); err != nil {
return nil, err
}
}
return &resp, nil
}
// EncodeHTTPResponseHead encodes HTTP response headers for streaming
func EncodeHTTPResponseHead(head *HTTPResponseHead) ([]byte, error) {
return msgpack.Marshal(head)
}
// DecodeHTTPResponseHead decodes HTTP response headers for streaming
func DecodeHTTPResponseHead(data []byte) (*HTTPResponseHead, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var head HTTPResponseHead
if data[0] == '{' {
if err := json.Unmarshal(data, &head); err != nil {
return nil, err
}
} else {
if err := msgpack.Unmarshal(data, &head); err != nil {
return nil, err
}
}
return &head, nil
}

View File

@@ -1,71 +0,0 @@
package protocol
// MessageType defines the type of tunnel message
type MessageType string
const (
// TypeRegister is sent when a client connects and gets a subdomain assigned
TypeRegister MessageType = "register"
// TypeRequest is sent from server to client when an HTTP request arrives
TypeRequest MessageType = "request"
// TypeResponse is sent from client to server with the HTTP response
TypeResponse MessageType = "response"
// TypeHeartbeat is sent periodically to keep the connection alive
TypeHeartbeat MessageType = "heartbeat"
// TypeError is sent when an error occurs
TypeError MessageType = "error"
)
// Message represents a tunnel protocol message
type Message struct {
Type MessageType `json:"type"`
ID string `json:"id,omitempty"`
Subdomain string `json:"subdomain,omitempty"`
Data map[string]interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// HTTPRequest represents an HTTP request to be forwarded
type HTTPRequest struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string][]string `json:"headers"`
Body []byte `json:"body,omitempty"`
}
// HTTPRequestHead represents HTTP request headers for streaming (no body)
type HTTPRequestHead struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string][]string `json:"headers"`
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
}
// HTTPResponse represents an HTTP response from the local service
type HTTPResponse struct {
StatusCode int `json:"status_code"`
Status string `json:"status"`
Headers map[string][]string `json:"headers"`
Body []byte `json:"body,omitempty"`
}
// HTTPResponseHead represents HTTP response headers for streaming (no body)
type HTTPResponseHead struct {
StatusCode int `json:"status_code"`
Status string `json:"status"`
Headers map[string][]string `json:"headers"`
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
}
// RegisterData contains information sent when a tunnel is registered
type RegisterData struct {
Subdomain string `json:"subdomain"`
URL string `json:"url"`
Message string `json:"message"`
}
// ErrorData contains error information
type ErrorData struct {
Code string `json:"code"`
Message string `json:"message"`
}

View File

@@ -2,12 +2,23 @@ package protocol
import json "github.com/goccy/go-json"
// PoolCapabilities advertises client connection pool capabilities
type PoolCapabilities struct {
MaxDataConns int `json:"max_data_conns"` // Maximum data connections client supports
Version int `json:"version"` // Protocol version for pool features
}
// RegisterRequest is sent by client to register a tunnel
type RegisterRequest struct {
Token string `json:"token"` // Authentication token
CustomSubdomain string `json:"custom_subdomain"` // Optional custom subdomain
TunnelType TunnelType `json:"tunnel_type"` // http, tcp, udp
LocalPort int `json:"local_port"` // Local port to forward to
// Connection pool fields (optional, for multi-connection support)
ConnectionType string `json:"connection_type,omitempty"` // "primary" or empty for legacy
TunnelID string `json:"tunnel_id,omitempty"` // For data connections to join
PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"` // Client pool capabilities
}
// RegisterResponse is sent by server after successful registration
@@ -16,6 +27,25 @@ type RegisterResponse struct {
Port int `json:"port,omitempty"` // Assigned TCP port (for TCP tunnels)
URL string `json:"url"` // Full tunnel URL
Message string `json:"message"` // Success message
// Connection pool fields (optional, for multi-connection support)
TunnelID string `json:"tunnel_id,omitempty"` // Unique tunnel identifier
SupportsDataConn bool `json:"supports_data_conn,omitempty"` // Server supports multi-connection
RecommendedConns int `json:"recommended_conns,omitempty"` // Suggested data connection count
}
// DataConnectRequest is sent by data connections to join a tunnel
type DataConnectRequest struct {
TunnelID string `json:"tunnel_id"` // Tunnel to join
Token string `json:"token"` // Same auth token as primary
ConnectionID string `json:"connection_id"` // Unique connection identifier
}
// DataConnectResponse acknowledges data connection
type DataConnectResponse struct {
Accepted bool `json:"accepted"` // Whether connection was accepted
ConnectionID string `json:"connection_id"` // Echoed connection ID
Message string `json:"message,omitempty"` // Optional message
}
// ErrorMessage represents an error
@@ -24,9 +54,6 @@ type ErrorMessage struct {
Message string `json:"message"` // Error message
}
// Note: DataHeader is now defined in binary_header.go as a pure binary structure
// TCPData has been removed - use DataHeader + raw bytes directly
// Marshal helpers for control plane messages (JSON encoding)
func MarshalJSON(v interface{}) ([]byte, error) {
return json.Marshal(v)

View File

@@ -1,96 +0,0 @@
package protocol
import (
"encoding/binary"
"errors"
"drip/internal/shared/pool"
)
// encodeDataPayload encodes a data header and payload into a frame payload.
func encodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
streamIDLen := len(header.StreamID)
requestIDLen := len(header.RequestID)
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen + len(data)
payload := make([]byte, totalLen)
flags := uint8(header.Type) & 0x07
if header.IsLast {
flags |= 0x08
}
payload[0] = flags
binary.BigEndian.PutUint16(payload[1:3], uint16(streamIDLen))
binary.BigEndian.PutUint16(payload[3:5], uint16(requestIDLen))
offset := binaryHeaderMinSize
copy(payload[offset:], header.StreamID)
offset += streamIDLen
copy(payload[offset:], header.RequestID)
offset += requestIDLen
copy(payload[offset:], data)
return payload, nil
}
// EncodeDataPayloadPooled encodes with adaptive allocation based on load.
// Returns payload slice and pool buffer pointer (may be nil).
func EncodeDataPayloadPooled(header DataHeader, data []byte) (payload []byte, poolBuffer *[]byte, err error) {
streamIDLen := len(header.StreamID)
requestIDLen := len(header.RequestID)
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen + len(data)
dynamicThreshold := GetAdaptiveThreshold()
if totalLen < dynamicThreshold {
regularPayload, err := encodeDataPayload(header, data)
return regularPayload, nil, err
}
if totalLen > pool.SizeLarge {
regularPayload, err := encodeDataPayload(header, data)
return regularPayload, nil, err
}
poolBuffer = pool.GetBuffer(totalLen)
payload = (*poolBuffer)[:totalLen]
flags := uint8(header.Type) & 0x07
if header.IsLast {
flags |= 0x08
}
payload[0] = flags
binary.BigEndian.PutUint16(payload[1:3], uint16(streamIDLen))
binary.BigEndian.PutUint16(payload[3:5], uint16(requestIDLen))
offset := binaryHeaderMinSize
copy(payload[offset:], header.StreamID)
offset += streamIDLen
copy(payload[offset:], header.RequestID)
offset += requestIDLen
copy(payload[offset:], data)
return payload, poolBuffer, nil
}
// DecodeDataPayload decodes a frame payload into header and data.
func DecodeDataPayload(payload []byte) (DataHeader, []byte, error) {
if len(payload) < binaryHeaderMinSize {
return DataHeader{}, nil, errors.New("invalid payload: too short")
}
var header DataHeader
if err := header.UnmarshalBinary(payload); err != nil {
return DataHeader{}, nil, err
}
headerSize := header.Size()
if len(payload) < headerSize {
return DataHeader{}, nil, errors.New("invalid payload: data missing")
}
data := payload[headerSize:]
return header, data, nil
}

View File

@@ -4,20 +4,11 @@ import (
"sync"
)
// SafeFrame wraps Frame with automatic resource cleanup
type SafeFrame struct {
*Frame
once sync.Once
}
// NewSafeFrame creates a SafeFrame that implements io.Closer
func NewSafeFrame(frameType FrameType, payload []byte) *SafeFrame {
return &SafeFrame{
Frame: NewFrame(frameType, payload),
}
}
// Close implements io.Closer, ensures Release is called exactly once
func (sf *SafeFrame) Close() error {
sf.once.Do(func() {
if sf.Frame != nil {
@@ -27,14 +18,6 @@ func (sf *SafeFrame) Close() error {
return nil
}
// WithFrame wraps an existing Frame with automatic cleanup
func WithFrame(frame *Frame) *SafeFrame {
return &SafeFrame{Frame: frame}
}
// MustClose is a helper that calls Close and panics on error (for defer cleanup)
func (sf *SafeFrame) MustClose() {
if err := sf.Close(); err != nil {
panic(err)
}
}

View File

@@ -4,16 +4,18 @@ import (
"errors"
"io"
"sync"
"sync/atomic"
"time"
)
type FrameWriter struct {
conn io.Writer
queue chan *Frame
batch []*Frame
mu sync.Mutex
done chan struct{}
closed bool
conn io.Writer
queue chan *Frame
controlQueue chan *Frame
batch []*Frame
mu sync.Mutex
done chan struct{}
closed bool
maxBatch int
maxBatchWait time.Duration
@@ -24,13 +26,20 @@ type FrameWriter struct {
heartbeatControl chan struct{}
// Error handling
writeErr error
errOnce sync.Once
onWriteError func(error) // Callback for write errors
writeErr error
errOnce sync.Once
onWriteError func(error) // Callback for write errors
// Adaptive flushing
adaptiveFlush bool // Enable adaptive flush based on queue depth
lowConcurrencyThreshold int // Queue depth threshold for immediate flush
adaptiveFlush bool // Enable adaptive flush based on queue depth
lowConcurrencyThreshold int // Queue depth threshold for immediate flush
// Hooks
preWriteHook func(*Frame) // Called right before a frame is written to conn
// Backlog tracking
queuedFrames atomic.Int64
queuedBytes atomic.Int64
}
func NewFrameWriter(conn io.Writer) *FrameWriter {
@@ -41,8 +50,14 @@ func NewFrameWriter(conn io.Writer) *FrameWriter {
func NewFrameWriterWithConfig(conn io.Writer, maxBatch int, maxBatchWait time.Duration, queueSize int) *FrameWriter {
w := &FrameWriter{
conn: conn,
queue: make(chan *Frame, queueSize),
conn: conn,
queue: make(chan *Frame, queueSize),
controlQueue: make(chan *Frame, func() int {
if queueSize < 256 {
return queueSize
}
return 256
}()), // control path needs small, fast buffer
batch: make([]*Frame, 0, maxBatch),
maxBatch: maxBatch,
maxBatchWait: maxBatchWait,
@@ -74,6 +89,22 @@ func (w *FrameWriter) writeLoop() {
}()
for {
// Always drain control queue first to prioritize control/heartbeat frames.
select {
case frame, ok := <-w.controlQueue:
if !ok {
w.mu.Lock()
w.flushBatchLocked()
w.mu.Unlock()
return
}
w.mu.Lock()
w.flushFrameLocked(frame)
w.mu.Unlock()
continue
default:
}
select {
case frame, ok := <-w.queue:
if !ok {
@@ -105,8 +136,7 @@ func (w *FrameWriter) writeLoop() {
w.mu.Lock()
if w.heartbeatCallback != nil {
if frame := w.heartbeatCallback(); frame != nil {
w.batch = append(w.batch, frame)
w.flushBatchLocked()
w.flushFrameLocked(frame)
}
}
w.mu.Unlock()
@@ -139,22 +169,47 @@ func (w *FrameWriter) flushBatchLocked() {
}
for _, frame := range w.batch {
if err := WriteFrame(w.conn, frame); err != nil {
w.errOnce.Do(func() {
w.writeErr = err
if w.onWriteError != nil {
go w.onWriteError(err)
}
w.closed = true
})
}
frame.Release()
w.flushFrameLocked(frame)
}
w.batch = w.batch[:0]
}
// flushFrameLocked writes a single frame immediately. Caller must hold w.mu.
func (w *FrameWriter) flushFrameLocked(frame *Frame) {
if frame == nil {
return
}
if w.preWriteHook != nil {
w.preWriteHook(frame)
}
if err := WriteFrame(w.conn, frame); err != nil {
w.errOnce.Do(func() {
w.writeErr = err
if w.onWriteError != nil {
go w.onWriteError(err)
}
w.closed = true
})
}
w.unmarkQueued(frame)
frame.Release()
}
func (w *FrameWriter) WriteFrame(frame *Frame) error {
return w.WriteFrameWithCancel(frame, nil)
}
// WriteFrameWithCancel writes a frame with an optional cancellation channel
// If cancel is closed, the write will be aborted immediately
func (w *FrameWriter) WriteFrameWithCancel(frame *Frame, cancel <-chan struct{}) error {
if frame == nil {
return nil
}
w.mu.Lock()
if w.closed {
w.mu.Unlock()
@@ -165,10 +220,19 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
}
w.mu.Unlock()
size := int64(len(frame.Payload) + FrameHeaderSize)
w.queuedFrames.Add(1)
w.queuedBytes.Add(size)
atomic.StoreInt64(&frame.queuedBytes, size)
// Try non-blocking first for best performance
select {
case w.queue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
@@ -176,6 +240,54 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
return err
}
return errors.New("writer closed")
default:
}
// Queue full - block with cancellation support
if cancel != nil {
select {
case w.queue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
if err != nil {
return err
}
return errors.New("writer closed")
case <-cancel:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
return errors.New("write cancelled")
}
}
// No cancel channel - block with timeout
select {
case w.queue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
if err != nil {
return err
}
return errors.New("writer closed")
case <-time.After(30 * time.Second):
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
return errors.New("write queue full timeout")
}
}
@@ -189,8 +301,14 @@ func (w *FrameWriter) Close() error {
w.mu.Unlock()
close(w.queue)
close(w.controlQueue)
for frame := range w.queue {
w.unmarkQueued(frame)
frame.Release()
}
for frame := range w.controlQueue {
w.unmarkQueued(frame)
frame.Release()
}
@@ -264,3 +382,97 @@ func (w *FrameWriter) DisableAdaptiveFlush() {
w.adaptiveFlush = false
w.mu.Unlock()
}
// WriteControl enqueues a control/prioritized frame to be written ahead of data frames.
func (w *FrameWriter) WriteControl(frame *Frame) error {
if frame == nil {
return nil
}
w.mu.Lock()
if w.closed {
w.mu.Unlock()
if w.writeErr != nil {
return w.writeErr
}
return errors.New("writer closed")
}
w.mu.Unlock()
size := int64(len(frame.Payload) + FrameHeaderSize)
w.queuedFrames.Add(1)
w.queuedBytes.Add(size)
atomic.StoreInt64(&frame.queuedBytes, size)
// Try non-blocking first
select {
case w.controlQueue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
if err != nil {
return err
}
return errors.New("writer closed")
default:
}
// Queue full - wait with timeout
select {
case w.controlQueue <- frame:
return nil
case <-w.done:
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
w.mu.Lock()
err := w.writeErr
w.mu.Unlock()
if err != nil {
return err
}
return errors.New("writer closed")
case <-time.After(50 * time.Millisecond):
// Control frames should have priority, shorter timeout
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
atomic.StoreInt64(&frame.queuedBytes, 0)
return errors.New("control queue full timeout")
}
}
// SetPreWriteHook registers a callback invoked just before a frame is written to the underlying writer.
func (w *FrameWriter) SetPreWriteHook(hook func(*Frame)) {
w.mu.Lock()
w.preWriteHook = hook
w.mu.Unlock()
}
// QueuedFrames returns the number of frames currently queued (data + control).
func (w *FrameWriter) QueuedFrames() int64 {
return w.queuedFrames.Load()
}
// QueuedBytes returns the approximate number of bytes currently queued.
func (w *FrameWriter) QueuedBytes() int64 {
return w.queuedBytes.Load()
}
// unmarkQueued decrements backlog counters for a frame once it is written or discarded.
func (w *FrameWriter) unmarkQueued(frame *Frame) {
if frame == nil {
return
}
size := atomic.SwapInt64(&frame.queuedBytes, 0)
if size <= 0 {
return
}
w.queuedFrames.Add(-1)
w.queuedBytes.Add(-size)
}