mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
feat(tunnel): switch to yamux stream proxying and connection pooling
- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server - Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly - Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
This commit is contained in:
@@ -2,81 +2,23 @@ package protocol
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/pool"
|
||||
)
|
||||
|
||||
// AdaptivePoolManager dynamically adjusts buffer pool usage based on load
|
||||
// AdaptivePoolManager tracks active connections for load monitoring
|
||||
type AdaptivePoolManager struct {
|
||||
activeConnections atomic.Int64
|
||||
currentThreshold atomic.Int64
|
||||
highLoadConnectionThreshold int64
|
||||
midLoadConnectionThreshold int64
|
||||
midLoadThreshold int64
|
||||
highLoadThreshold int64
|
||||
activeConnections atomic.Int64
|
||||
}
|
||||
|
||||
var globalAdaptiveManager = NewAdaptivePoolManager()
|
||||
|
||||
func NewAdaptivePoolManager() *AdaptivePoolManager {
|
||||
m := &AdaptivePoolManager{
|
||||
highLoadConnectionThreshold: 300,
|
||||
midLoadConnectionThreshold: 150,
|
||||
midLoadThreshold: int64(pool.SizeLarge),
|
||||
highLoadThreshold: int64(pool.SizeMedium),
|
||||
}
|
||||
|
||||
m.currentThreshold.Store(m.midLoadThreshold)
|
||||
go m.monitor()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) monitor() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
connections := m.activeConnections.Load()
|
||||
|
||||
if connections >= m.highLoadConnectionThreshold {
|
||||
m.currentThreshold.Store(m.highLoadThreshold)
|
||||
} else if connections < m.midLoadConnectionThreshold {
|
||||
m.currentThreshold.Store(m.midLoadThreshold)
|
||||
}
|
||||
// Hysteresis zone (150-300): maintain current threshold
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) GetThreshold() int {
|
||||
return int(m.currentThreshold.Load())
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) RegisterConnection() {
|
||||
m.activeConnections.Add(1)
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) UnregisterConnection() {
|
||||
m.activeConnections.Add(-1)
|
||||
}
|
||||
|
||||
func (m *AdaptivePoolManager) GetActiveConnections() int64 {
|
||||
return m.activeConnections.Load()
|
||||
}
|
||||
|
||||
func GetAdaptiveThreshold() int {
|
||||
return globalAdaptiveManager.GetThreshold()
|
||||
}
|
||||
var globalAdaptiveManager = &AdaptivePoolManager{}
|
||||
|
||||
func RegisterConnection() {
|
||||
globalAdaptiveManager.RegisterConnection()
|
||||
globalAdaptiveManager.activeConnections.Add(1)
|
||||
}
|
||||
|
||||
func UnregisterConnection() {
|
||||
globalAdaptiveManager.UnregisterConnection()
|
||||
globalAdaptiveManager.activeConnections.Add(-1)
|
||||
}
|
||||
|
||||
func GetActiveConnections() int64 {
|
||||
return globalAdaptiveManager.GetActiveConnections()
|
||||
return globalAdaptiveManager.activeConnections.Load()
|
||||
}
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// DataHeader represents a binary-encoded data header for data plane
|
||||
// All data transmission uses pure binary encoding for performance
|
||||
type DataHeader struct {
|
||||
Type DataType
|
||||
IsLast bool
|
||||
StreamID string
|
||||
RequestID string
|
||||
}
|
||||
|
||||
// DataType represents the type of data frame
|
||||
type DataType uint8
|
||||
|
||||
const (
|
||||
DataTypeData DataType = 0x00 // 000
|
||||
DataTypeResponse DataType = 0x01 // 001
|
||||
DataTypeClose DataType = 0x02 // 010
|
||||
DataTypeHTTPRequest DataType = 0x03 // 011
|
||||
DataTypeHTTPResponse DataType = 0x04 // 100
|
||||
DataTypeHTTPHead DataType = 0x05 // 101 - streaming headers (shared)
|
||||
DataTypeHTTPBodyChunk DataType = 0x06 // 110 - streaming body chunks (shared)
|
||||
|
||||
// Reuse the same type codes for request streaming to stay within 3 bits.
|
||||
DataTypeHTTPRequestHead DataType = DataTypeHTTPHead
|
||||
DataTypeHTTPRequestBodyChunk DataType = DataTypeHTTPBodyChunk
|
||||
)
|
||||
|
||||
// String returns the string representation of DataType
|
||||
func (t DataType) String() string {
|
||||
switch t {
|
||||
case DataTypeData:
|
||||
return "data"
|
||||
case DataTypeResponse:
|
||||
return "response"
|
||||
case DataTypeClose:
|
||||
return "close"
|
||||
case DataTypeHTTPRequest:
|
||||
return "http_request"
|
||||
case DataTypeHTTPResponse:
|
||||
return "http_response"
|
||||
case DataTypeHTTPHead:
|
||||
return "http_head"
|
||||
case DataTypeHTTPBodyChunk:
|
||||
return "http_body_chunk"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// FromString converts a string to DataType
|
||||
func DataTypeFromString(s string) DataType {
|
||||
switch s {
|
||||
case "data":
|
||||
return DataTypeData
|
||||
case "response":
|
||||
return DataTypeResponse
|
||||
case "close":
|
||||
return DataTypeClose
|
||||
case "http_request":
|
||||
return DataTypeHTTPRequest
|
||||
case "http_response":
|
||||
return DataTypeHTTPResponse
|
||||
case "http_head":
|
||||
return DataTypeHTTPHead
|
||||
case "http_body_chunk":
|
||||
return DataTypeHTTPBodyChunk
|
||||
default:
|
||||
return DataTypeData
|
||||
}
|
||||
}
|
||||
|
||||
// Binary format:
|
||||
// +--------+--------+--------+--------+--------+
|
||||
// | Flags | StreamID Length | RequestID Len |
|
||||
// | 1 byte | 2 bytes | 2 bytes |
|
||||
// +--------+--------+--------+--------+--------+
|
||||
// | StreamID (variable) |
|
||||
// +--------+--------+--------+--------+--------+
|
||||
// | RequestID (variable) |
|
||||
// +--------+--------+--------+--------+--------+
|
||||
//
|
||||
// Flags (8 bits):
|
||||
// - Bit 0-2: Type (3 bits)
|
||||
// - Bit 3: IsLast (1 bit)
|
||||
// - Bit 4-7: Reserved (4 bits)
|
||||
|
||||
const (
|
||||
binaryHeaderMinSize = 5 // 1 byte flags + 2 bytes streamID len + 2 bytes requestID len
|
||||
)
|
||||
|
||||
// MarshalBinary encodes the header to binary format
|
||||
func (h *DataHeader) MarshalBinary() []byte {
|
||||
streamIDLen := len(h.StreamID)
|
||||
requestIDLen := len(h.RequestID)
|
||||
|
||||
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen
|
||||
buf := make([]byte, totalLen)
|
||||
|
||||
// Encode flags
|
||||
flags := uint8(h.Type) & 0x07 // Type uses bits 0-2
|
||||
if h.IsLast {
|
||||
flags |= 0x08 // IsLast uses bit 3
|
||||
}
|
||||
buf[0] = flags
|
||||
|
||||
// Encode lengths (big-endian)
|
||||
binary.BigEndian.PutUint16(buf[1:3], uint16(streamIDLen))
|
||||
binary.BigEndian.PutUint16(buf[3:5], uint16(requestIDLen))
|
||||
|
||||
// Encode StreamID
|
||||
offset := binaryHeaderMinSize
|
||||
copy(buf[offset:], h.StreamID)
|
||||
offset += streamIDLen
|
||||
|
||||
// Encode RequestID
|
||||
copy(buf[offset:], h.RequestID)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes the header from binary format
|
||||
func (h *DataHeader) UnmarshalBinary(data []byte) error {
|
||||
if len(data) < binaryHeaderMinSize {
|
||||
return errors.New("invalid binary header: too short")
|
||||
}
|
||||
|
||||
// Decode flags
|
||||
flags := data[0]
|
||||
h.Type = DataType(flags & 0x07) // Bits 0-2
|
||||
h.IsLast = (flags & 0x08) != 0 // Bit 3
|
||||
|
||||
// Decode lengths
|
||||
streamIDLen := int(binary.BigEndian.Uint16(data[1:3]))
|
||||
requestIDLen := int(binary.BigEndian.Uint16(data[3:5]))
|
||||
|
||||
// Validate total length
|
||||
expectedLen := binaryHeaderMinSize + streamIDLen + requestIDLen
|
||||
if len(data) < expectedLen {
|
||||
return errors.New("invalid binary header: length mismatch")
|
||||
}
|
||||
|
||||
// Decode StreamID
|
||||
offset := binaryHeaderMinSize
|
||||
h.StreamID = string(data[offset : offset+streamIDLen])
|
||||
offset += streamIDLen
|
||||
|
||||
// Decode RequestID
|
||||
h.RequestID = string(data[offset : offset+requestIDLen])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Size returns the size of the binary-encoded header
|
||||
func (h *DataHeader) Size() int {
|
||||
return binaryHeaderMinSize + len(h.StreamID) + len(h.RequestID)
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
json "github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
type FlowControlAction string
|
||||
|
||||
const (
|
||||
FlowControlPause FlowControlAction = "pause"
|
||||
FlowControlResume FlowControlAction = "resume"
|
||||
)
|
||||
|
||||
type FlowControlMessage struct {
|
||||
StreamID string `json:"stream_id"`
|
||||
Action FlowControlAction `json:"action"`
|
||||
}
|
||||
|
||||
func NewFlowControlFrame(streamID string, action FlowControlAction) *Frame {
|
||||
msg := FlowControlMessage{
|
||||
StreamID: streamID,
|
||||
Action: action,
|
||||
}
|
||||
payload, _ := json.Marshal(&msg)
|
||||
return NewFrame(FrameTypeFlowControl, payload)
|
||||
}
|
||||
|
||||
func DecodeFlowControlMessage(payload []byte) (*FlowControlMessage, error) {
|
||||
var msg FlowControlMessage
|
||||
if err := json.Unmarshal(payload, &msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &msg, nil
|
||||
}
|
||||
@@ -18,14 +18,14 @@ const (
|
||||
type FrameType byte
|
||||
|
||||
const (
|
||||
FrameTypeRegister FrameType = 0x01
|
||||
FrameTypeRegisterAck FrameType = 0x02
|
||||
FrameTypeHeartbeat FrameType = 0x03
|
||||
FrameTypeHeartbeatAck FrameType = 0x04
|
||||
FrameTypeData FrameType = 0x05
|
||||
FrameTypeClose FrameType = 0x06
|
||||
FrameTypeError FrameType = 0x07
|
||||
FrameTypeFlowControl FrameType = 0x08
|
||||
FrameTypeRegister FrameType = 0x01
|
||||
FrameTypeRegisterAck FrameType = 0x02
|
||||
FrameTypeHeartbeat FrameType = 0x03
|
||||
FrameTypeHeartbeatAck FrameType = 0x04
|
||||
FrameTypeClose FrameType = 0x05
|
||||
FrameTypeError FrameType = 0x06
|
||||
FrameTypeDataConnect FrameType = 0x07
|
||||
FrameTypeDataConnectAck FrameType = 0x08
|
||||
)
|
||||
|
||||
// String returns the string representation of frame type
|
||||
@@ -39,14 +39,14 @@ func (t FrameType) String() string {
|
||||
return "Heartbeat"
|
||||
case FrameTypeHeartbeatAck:
|
||||
return "HeartbeatAck"
|
||||
case FrameTypeData:
|
||||
return "Data"
|
||||
case FrameTypeClose:
|
||||
return "Close"
|
||||
case FrameTypeError:
|
||||
return "Error"
|
||||
case FrameTypeFlowControl:
|
||||
return "FlowControl"
|
||||
case FrameTypeDataConnect:
|
||||
return "DataConnect"
|
||||
case FrameTypeDataConnectAck:
|
||||
return "DataConnectAck"
|
||||
default:
|
||||
return fmt.Sprintf("Unknown(%d)", t)
|
||||
}
|
||||
@@ -56,6 +56,9 @@ type Frame struct {
|
||||
Type FrameType
|
||||
Payload []byte
|
||||
poolBuffer *[]byte
|
||||
// queuedBytes is set by FrameWriter when the frame is enqueued.
|
||||
// It allows the writer to decrement backlog counters exactly once.
|
||||
queuedBytes int64
|
||||
}
|
||||
|
||||
func WriteFrame(w io.Writer, frame *Frame) error {
|
||||
@@ -130,6 +133,8 @@ func (f *Frame) Release() {
|
||||
f.poolBuffer = nil
|
||||
f.Payload = nil
|
||||
}
|
||||
// Reset queued marker to avoid carrying over stale state if the frame is reused.
|
||||
f.queuedBytes = 0
|
||||
}
|
||||
|
||||
// NewFrame creates a new frame
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
// EncodeHTTPRequest encodes HTTPRequest using msgpack encoding (optimized)
|
||||
func EncodeHTTPRequest(req *HTTPRequest) ([]byte, error) {
|
||||
return msgpack.Marshal(req)
|
||||
}
|
||||
|
||||
// DecodeHTTPRequest decodes HTTPRequest with automatic version detection
|
||||
// Detects based on first byte: '{' = JSON, else = msgpack
|
||||
func DecodeHTTPRequest(data []byte) (*HTTPRequest, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var req HTTPRequest
|
||||
|
||||
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
|
||||
if data[0] == '{' {
|
||||
// v1: JSON
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// v2: msgpack
|
||||
if err := msgpack.Unmarshal(data, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
// EncodeHTTPRequestHead encodes HTTP request headers for streaming
|
||||
func EncodeHTTPRequestHead(head *HTTPRequestHead) ([]byte, error) {
|
||||
return msgpack.Marshal(head)
|
||||
}
|
||||
|
||||
// DecodeHTTPRequestHead decodes HTTP request headers for streaming
|
||||
func DecodeHTTPRequestHead(data []byte) (*HTTPRequestHead, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var head HTTPRequestHead
|
||||
if data[0] == '{' {
|
||||
if err := json.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := msgpack.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &head, nil
|
||||
}
|
||||
|
||||
// EncodeHTTPResponse encodes HTTPResponse using msgpack encoding (optimized)
|
||||
func EncodeHTTPResponse(resp *HTTPResponse) ([]byte, error) {
|
||||
return msgpack.Marshal(resp)
|
||||
}
|
||||
|
||||
// DecodeHTTPResponse decodes HTTPResponse with automatic version detection
|
||||
// Detects based on first byte: '{' = JSON, else = msgpack
|
||||
func DecodeHTTPResponse(data []byte) (*HTTPResponse, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var resp HTTPResponse
|
||||
|
||||
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
|
||||
if data[0] == '{' {
|
||||
// v1: JSON
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// v2: msgpack
|
||||
if err := msgpack.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// EncodeHTTPResponseHead encodes HTTP response headers for streaming
|
||||
func EncodeHTTPResponseHead(head *HTTPResponseHead) ([]byte, error) {
|
||||
return msgpack.Marshal(head)
|
||||
}
|
||||
|
||||
// DecodeHTTPResponseHead decodes HTTP response headers for streaming
|
||||
func DecodeHTTPResponseHead(data []byte) (*HTTPResponseHead, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var head HTTPResponseHead
|
||||
if data[0] == '{' {
|
||||
if err := json.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := msgpack.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &head, nil
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
package protocol
|
||||
|
||||
// MessageType defines the type of tunnel message
|
||||
type MessageType string
|
||||
|
||||
const (
|
||||
// TypeRegister is sent when a client connects and gets a subdomain assigned
|
||||
TypeRegister MessageType = "register"
|
||||
// TypeRequest is sent from server to client when an HTTP request arrives
|
||||
TypeRequest MessageType = "request"
|
||||
// TypeResponse is sent from client to server with the HTTP response
|
||||
TypeResponse MessageType = "response"
|
||||
// TypeHeartbeat is sent periodically to keep the connection alive
|
||||
TypeHeartbeat MessageType = "heartbeat"
|
||||
// TypeError is sent when an error occurs
|
||||
TypeError MessageType = "error"
|
||||
)
|
||||
|
||||
// Message represents a tunnel protocol message
|
||||
type Message struct {
|
||||
Type MessageType `json:"type"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Subdomain string `json:"subdomain,omitempty"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPRequest represents an HTTP request to be forwarded
|
||||
type HTTPRequest struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
Body []byte `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPRequestHead represents HTTP request headers for streaming (no body)
|
||||
type HTTPRequestHead struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
|
||||
}
|
||||
|
||||
// HTTPResponse represents an HTTP response from the local service
|
||||
type HTTPResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Status string `json:"status"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
Body []byte `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPResponseHead represents HTTP response headers for streaming (no body)
|
||||
type HTTPResponseHead struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Status string `json:"status"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
|
||||
}
|
||||
|
||||
// RegisterData contains information sent when a tunnel is registered
|
||||
type RegisterData struct {
|
||||
Subdomain string `json:"subdomain"`
|
||||
URL string `json:"url"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ErrorData contains error information
|
||||
type ErrorData struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
@@ -2,12 +2,23 @@ package protocol
|
||||
|
||||
import json "github.com/goccy/go-json"
|
||||
|
||||
// PoolCapabilities advertises client connection pool capabilities
|
||||
type PoolCapabilities struct {
|
||||
MaxDataConns int `json:"max_data_conns"` // Maximum data connections client supports
|
||||
Version int `json:"version"` // Protocol version for pool features
|
||||
}
|
||||
|
||||
// RegisterRequest is sent by client to register a tunnel
|
||||
type RegisterRequest struct {
|
||||
Token string `json:"token"` // Authentication token
|
||||
CustomSubdomain string `json:"custom_subdomain"` // Optional custom subdomain
|
||||
TunnelType TunnelType `json:"tunnel_type"` // http, tcp, udp
|
||||
LocalPort int `json:"local_port"` // Local port to forward to
|
||||
|
||||
// Connection pool fields (optional, for multi-connection support)
|
||||
ConnectionType string `json:"connection_type,omitempty"` // "primary" or empty for legacy
|
||||
TunnelID string `json:"tunnel_id,omitempty"` // For data connections to join
|
||||
PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"` // Client pool capabilities
|
||||
}
|
||||
|
||||
// RegisterResponse is sent by server after successful registration
|
||||
@@ -16,6 +27,25 @@ type RegisterResponse struct {
|
||||
Port int `json:"port,omitempty"` // Assigned TCP port (for TCP tunnels)
|
||||
URL string `json:"url"` // Full tunnel URL
|
||||
Message string `json:"message"` // Success message
|
||||
|
||||
// Connection pool fields (optional, for multi-connection support)
|
||||
TunnelID string `json:"tunnel_id,omitempty"` // Unique tunnel identifier
|
||||
SupportsDataConn bool `json:"supports_data_conn,omitempty"` // Server supports multi-connection
|
||||
RecommendedConns int `json:"recommended_conns,omitempty"` // Suggested data connection count
|
||||
}
|
||||
|
||||
// DataConnectRequest is sent by data connections to join a tunnel
|
||||
type DataConnectRequest struct {
|
||||
TunnelID string `json:"tunnel_id"` // Tunnel to join
|
||||
Token string `json:"token"` // Same auth token as primary
|
||||
ConnectionID string `json:"connection_id"` // Unique connection identifier
|
||||
}
|
||||
|
||||
// DataConnectResponse acknowledges data connection
|
||||
type DataConnectResponse struct {
|
||||
Accepted bool `json:"accepted"` // Whether connection was accepted
|
||||
ConnectionID string `json:"connection_id"` // Echoed connection ID
|
||||
Message string `json:"message,omitempty"` // Optional message
|
||||
}
|
||||
|
||||
// ErrorMessage represents an error
|
||||
@@ -24,9 +54,6 @@ type ErrorMessage struct {
|
||||
Message string `json:"message"` // Error message
|
||||
}
|
||||
|
||||
// Note: DataHeader is now defined in binary_header.go as a pure binary structure
|
||||
// TCPData has been removed - use DataHeader + raw bytes directly
|
||||
|
||||
// Marshal helpers for control plane messages (JSON encoding)
|
||||
func MarshalJSON(v interface{}) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"drip/internal/shared/pool"
|
||||
)
|
||||
|
||||
// encodeDataPayload encodes a data header and payload into a frame payload.
|
||||
func encodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
|
||||
streamIDLen := len(header.StreamID)
|
||||
requestIDLen := len(header.RequestID)
|
||||
|
||||
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen + len(data)
|
||||
payload := make([]byte, totalLen)
|
||||
|
||||
flags := uint8(header.Type) & 0x07
|
||||
if header.IsLast {
|
||||
flags |= 0x08
|
||||
}
|
||||
payload[0] = flags
|
||||
|
||||
binary.BigEndian.PutUint16(payload[1:3], uint16(streamIDLen))
|
||||
binary.BigEndian.PutUint16(payload[3:5], uint16(requestIDLen))
|
||||
|
||||
offset := binaryHeaderMinSize
|
||||
copy(payload[offset:], header.StreamID)
|
||||
offset += streamIDLen
|
||||
copy(payload[offset:], header.RequestID)
|
||||
offset += requestIDLen
|
||||
copy(payload[offset:], data)
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// EncodeDataPayloadPooled encodes with adaptive allocation based on load.
|
||||
// Returns payload slice and pool buffer pointer (may be nil).
|
||||
func EncodeDataPayloadPooled(header DataHeader, data []byte) (payload []byte, poolBuffer *[]byte, err error) {
|
||||
streamIDLen := len(header.StreamID)
|
||||
requestIDLen := len(header.RequestID)
|
||||
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen + len(data)
|
||||
|
||||
dynamicThreshold := GetAdaptiveThreshold()
|
||||
|
||||
if totalLen < dynamicThreshold {
|
||||
regularPayload, err := encodeDataPayload(header, data)
|
||||
return regularPayload, nil, err
|
||||
}
|
||||
|
||||
if totalLen > pool.SizeLarge {
|
||||
regularPayload, err := encodeDataPayload(header, data)
|
||||
return regularPayload, nil, err
|
||||
}
|
||||
|
||||
poolBuffer = pool.GetBuffer(totalLen)
|
||||
payload = (*poolBuffer)[:totalLen]
|
||||
|
||||
flags := uint8(header.Type) & 0x07
|
||||
if header.IsLast {
|
||||
flags |= 0x08
|
||||
}
|
||||
payload[0] = flags
|
||||
|
||||
binary.BigEndian.PutUint16(payload[1:3], uint16(streamIDLen))
|
||||
binary.BigEndian.PutUint16(payload[3:5], uint16(requestIDLen))
|
||||
|
||||
offset := binaryHeaderMinSize
|
||||
copy(payload[offset:], header.StreamID)
|
||||
offset += streamIDLen
|
||||
copy(payload[offset:], header.RequestID)
|
||||
offset += requestIDLen
|
||||
copy(payload[offset:], data)
|
||||
|
||||
return payload, poolBuffer, nil
|
||||
}
|
||||
|
||||
// DecodeDataPayload decodes a frame payload into header and data.
|
||||
func DecodeDataPayload(payload []byte) (DataHeader, []byte, error) {
|
||||
if len(payload) < binaryHeaderMinSize {
|
||||
return DataHeader{}, nil, errors.New("invalid payload: too short")
|
||||
}
|
||||
|
||||
var header DataHeader
|
||||
if err := header.UnmarshalBinary(payload); err != nil {
|
||||
return DataHeader{}, nil, err
|
||||
}
|
||||
|
||||
headerSize := header.Size()
|
||||
if len(payload) < headerSize {
|
||||
return DataHeader{}, nil, errors.New("invalid payload: data missing")
|
||||
}
|
||||
|
||||
data := payload[headerSize:]
|
||||
return header, data, nil
|
||||
}
|
||||
@@ -4,20 +4,11 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SafeFrame wraps Frame with automatic resource cleanup
|
||||
type SafeFrame struct {
|
||||
*Frame
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// NewSafeFrame creates a SafeFrame that implements io.Closer
|
||||
func NewSafeFrame(frameType FrameType, payload []byte) *SafeFrame {
|
||||
return &SafeFrame{
|
||||
Frame: NewFrame(frameType, payload),
|
||||
}
|
||||
}
|
||||
|
||||
// Close implements io.Closer, ensures Release is called exactly once
|
||||
func (sf *SafeFrame) Close() error {
|
||||
sf.once.Do(func() {
|
||||
if sf.Frame != nil {
|
||||
@@ -27,14 +18,6 @@ func (sf *SafeFrame) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithFrame wraps an existing Frame with automatic cleanup
|
||||
func WithFrame(frame *Frame) *SafeFrame {
|
||||
return &SafeFrame{Frame: frame}
|
||||
}
|
||||
|
||||
// MustClose is a helper that calls Close and panics on error (for defer cleanup)
|
||||
func (sf *SafeFrame) MustClose() {
|
||||
if err := sf.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,16 +4,18 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FrameWriter struct {
|
||||
conn io.Writer
|
||||
queue chan *Frame
|
||||
batch []*Frame
|
||||
mu sync.Mutex
|
||||
done chan struct{}
|
||||
closed bool
|
||||
conn io.Writer
|
||||
queue chan *Frame
|
||||
controlQueue chan *Frame
|
||||
batch []*Frame
|
||||
mu sync.Mutex
|
||||
done chan struct{}
|
||||
closed bool
|
||||
|
||||
maxBatch int
|
||||
maxBatchWait time.Duration
|
||||
@@ -24,13 +26,20 @@ type FrameWriter struct {
|
||||
heartbeatControl chan struct{}
|
||||
|
||||
// Error handling
|
||||
writeErr error
|
||||
errOnce sync.Once
|
||||
onWriteError func(error) // Callback for write errors
|
||||
writeErr error
|
||||
errOnce sync.Once
|
||||
onWriteError func(error) // Callback for write errors
|
||||
|
||||
// Adaptive flushing
|
||||
adaptiveFlush bool // Enable adaptive flush based on queue depth
|
||||
lowConcurrencyThreshold int // Queue depth threshold for immediate flush
|
||||
adaptiveFlush bool // Enable adaptive flush based on queue depth
|
||||
lowConcurrencyThreshold int // Queue depth threshold for immediate flush
|
||||
|
||||
// Hooks
|
||||
preWriteHook func(*Frame) // Called right before a frame is written to conn
|
||||
|
||||
// Backlog tracking
|
||||
queuedFrames atomic.Int64
|
||||
queuedBytes atomic.Int64
|
||||
}
|
||||
|
||||
func NewFrameWriter(conn io.Writer) *FrameWriter {
|
||||
@@ -41,8 +50,14 @@ func NewFrameWriter(conn io.Writer) *FrameWriter {
|
||||
|
||||
func NewFrameWriterWithConfig(conn io.Writer, maxBatch int, maxBatchWait time.Duration, queueSize int) *FrameWriter {
|
||||
w := &FrameWriter{
|
||||
conn: conn,
|
||||
queue: make(chan *Frame, queueSize),
|
||||
conn: conn,
|
||||
queue: make(chan *Frame, queueSize),
|
||||
controlQueue: make(chan *Frame, func() int {
|
||||
if queueSize < 256 {
|
||||
return queueSize
|
||||
}
|
||||
return 256
|
||||
}()), // control path needs small, fast buffer
|
||||
batch: make([]*Frame, 0, maxBatch),
|
||||
maxBatch: maxBatch,
|
||||
maxBatchWait: maxBatchWait,
|
||||
@@ -74,6 +89,22 @@ func (w *FrameWriter) writeLoop() {
|
||||
}()
|
||||
|
||||
for {
|
||||
// Always drain control queue first to prioritize control/heartbeat frames.
|
||||
select {
|
||||
case frame, ok := <-w.controlQueue:
|
||||
if !ok {
|
||||
w.mu.Lock()
|
||||
w.flushBatchLocked()
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.mu.Lock()
|
||||
w.flushFrameLocked(frame)
|
||||
w.mu.Unlock()
|
||||
continue
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case frame, ok := <-w.queue:
|
||||
if !ok {
|
||||
@@ -105,8 +136,7 @@ func (w *FrameWriter) writeLoop() {
|
||||
w.mu.Lock()
|
||||
if w.heartbeatCallback != nil {
|
||||
if frame := w.heartbeatCallback(); frame != nil {
|
||||
w.batch = append(w.batch, frame)
|
||||
w.flushBatchLocked()
|
||||
w.flushFrameLocked(frame)
|
||||
}
|
||||
}
|
||||
w.mu.Unlock()
|
||||
@@ -139,22 +169,47 @@ func (w *FrameWriter) flushBatchLocked() {
|
||||
}
|
||||
|
||||
for _, frame := range w.batch {
|
||||
if err := WriteFrame(w.conn, frame); err != nil {
|
||||
w.errOnce.Do(func() {
|
||||
w.writeErr = err
|
||||
if w.onWriteError != nil {
|
||||
go w.onWriteError(err)
|
||||
}
|
||||
w.closed = true
|
||||
})
|
||||
}
|
||||
frame.Release()
|
||||
w.flushFrameLocked(frame)
|
||||
}
|
||||
|
||||
w.batch = w.batch[:0]
|
||||
}
|
||||
|
||||
// flushFrameLocked writes a single frame immediately. Caller must hold w.mu.
|
||||
func (w *FrameWriter) flushFrameLocked(frame *Frame) {
|
||||
if frame == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if w.preWriteHook != nil {
|
||||
w.preWriteHook(frame)
|
||||
}
|
||||
|
||||
if err := WriteFrame(w.conn, frame); err != nil {
|
||||
w.errOnce.Do(func() {
|
||||
w.writeErr = err
|
||||
if w.onWriteError != nil {
|
||||
go w.onWriteError(err)
|
||||
}
|
||||
w.closed = true
|
||||
})
|
||||
}
|
||||
|
||||
w.unmarkQueued(frame)
|
||||
frame.Release()
|
||||
}
|
||||
|
||||
func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
return w.WriteFrameWithCancel(frame, nil)
|
||||
}
|
||||
|
||||
// WriteFrameWithCancel writes a frame with an optional cancellation channel
|
||||
// If cancel is closed, the write will be aborted immediately
|
||||
func (w *FrameWriter) WriteFrameWithCancel(frame *Frame, cancel <-chan struct{}) error {
|
||||
if frame == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
@@ -165,10 +220,19 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
size := int64(len(frame.Payload) + FrameHeaderSize)
|
||||
w.queuedFrames.Add(1)
|
||||
w.queuedBytes.Add(size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, size)
|
||||
|
||||
// Try non-blocking first for best performance
|
||||
select {
|
||||
case w.queue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
@@ -176,6 +240,54 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
default:
|
||||
}
|
||||
|
||||
// Queue full - block with cancellation support
|
||||
if cancel != nil {
|
||||
select {
|
||||
case w.queue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
case <-cancel:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
return errors.New("write cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
// No cancel channel - block with timeout
|
||||
select {
|
||||
case w.queue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
case <-time.After(30 * time.Second):
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
return errors.New("write queue full timeout")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,8 +301,14 @@ func (w *FrameWriter) Close() error {
|
||||
w.mu.Unlock()
|
||||
|
||||
close(w.queue)
|
||||
close(w.controlQueue)
|
||||
|
||||
for frame := range w.queue {
|
||||
w.unmarkQueued(frame)
|
||||
frame.Release()
|
||||
}
|
||||
for frame := range w.controlQueue {
|
||||
w.unmarkQueued(frame)
|
||||
frame.Release()
|
||||
}
|
||||
|
||||
@@ -264,3 +382,97 @@ func (w *FrameWriter) DisableAdaptiveFlush() {
|
||||
w.adaptiveFlush = false
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
// WriteControl enqueues a control/prioritized frame to be written ahead of data frames.
|
||||
func (w *FrameWriter) WriteControl(frame *Frame) error {
|
||||
if frame == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
if w.writeErr != nil {
|
||||
return w.writeErr
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
size := int64(len(frame.Payload) + FrameHeaderSize)
|
||||
w.queuedFrames.Add(1)
|
||||
w.queuedBytes.Add(size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, size)
|
||||
|
||||
// Try non-blocking first
|
||||
select {
|
||||
case w.controlQueue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
default:
|
||||
}
|
||||
|
||||
// Queue full - wait with timeout
|
||||
select {
|
||||
case w.controlQueue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Control frames should have priority, shorter timeout
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
atomic.StoreInt64(&frame.queuedBytes, 0)
|
||||
return errors.New("control queue full timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// SetPreWriteHook registers a callback invoked just before a frame is written to the underlying writer.
|
||||
func (w *FrameWriter) SetPreWriteHook(hook func(*Frame)) {
|
||||
w.mu.Lock()
|
||||
w.preWriteHook = hook
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
// QueuedFrames returns the number of frames currently queued (data + control).
|
||||
func (w *FrameWriter) QueuedFrames() int64 {
|
||||
return w.queuedFrames.Load()
|
||||
}
|
||||
|
||||
// QueuedBytes returns the approximate number of bytes currently queued.
|
||||
func (w *FrameWriter) QueuedBytes() int64 {
|
||||
return w.queuedBytes.Load()
|
||||
}
|
||||
|
||||
// unmarkQueued decrements backlog counters for a frame once it is written or discarded.
|
||||
func (w *FrameWriter) unmarkQueued(frame *Frame) {
|
||||
if frame == nil {
|
||||
return
|
||||
}
|
||||
size := atomic.SwapInt64(&frame.queuedBytes, 0)
|
||||
if size <= 0 {
|
||||
return
|
||||
}
|
||||
w.queuedFrames.Add(-1)
|
||||
w.queuedBytes.Add(-size)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user