Merge pull request #7 from Gouryella/perf/fix-something

perf: Add flow control, panic recovery and performance tuning
This commit is contained in:
Gouryella
2025-12-10 20:42:16 +08:00
committed by GitHub
12 changed files with 460 additions and 41 deletions

View File

@@ -145,7 +145,7 @@ func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) er
select {
case <-done:
// Closed successfully
case <-time.After(5 * time.Second):
case <-time.After(2 * time.Second):
fmt.Println(ui.Warning("Force closing (timeout)..."))
}

View File

@@ -158,13 +158,13 @@ func RenderRetrying(interval time.Duration) string {
// formatLatency formats latency with color
func formatLatency(d time.Duration) string {
ms := d.Milliseconds()
var style lipgloss.Style
if ms == 0 {
if d == 0 {
return mutedStyle.Render("measuring...")
}
ms := d.Milliseconds()
var style lipgloss.Style
switch {
case ms < 50:
style = lipgloss.NewStyle().Foreground(latencyFastColor)
@@ -176,6 +176,11 @@ func formatLatency(d time.Duration) string {
style = lipgloss.NewStyle().Foreground(latencyRedColor)
}
if ms == 0 {
us := d.Microseconds()
return style.Render(fmt.Sprintf("%dµs", us))
}
return style.Render(fmt.Sprintf("%dms", ms))
}

View File

@@ -12,6 +12,7 @@ import (
"drip/internal/shared/constants"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/recovery"
"drip/pkg/config"
"go.uber.org/zap"
)
@@ -47,6 +48,9 @@ type Connector struct {
// Worker pool for handling data frames
dataFrameQueue chan *protocol.Frame
workerCount int
recoverer *recovery.Recoverer
panicMetrics *recovery.PanicMetrics
}
// ConnectorConfig holds connector configuration
@@ -78,6 +82,9 @@ func NewConnector(cfg *ConnectorConfig, logger *zap.Logger) *Connector {
numCPU := pool.NumCPU()
workerCount := max(numCPU+numCPU/2, 4)
panicMetrics := recovery.NewPanicMetrics(logger, nil)
recoverer := recovery.NewRecoverer(logger, panicMetrics)
return &Connector{
serverAddr: cfg.ServerAddr,
tlsConfig: tlsConfig,
@@ -90,6 +97,8 @@ func NewConnector(cfg *ConnectorConfig, logger *zap.Logger) *Connector {
stopCh: make(chan struct{}),
dataFrameQueue: make(chan *protocol.Frame, workerCount*100),
workerCount: workerCount,
recoverer: recoverer,
panicMetrics: panicMetrics,
}
}
@@ -150,6 +159,7 @@ func (c *Connector) Connect() error {
}
go c.frameHandler.WarmupConnectionPool(3)
go c.monitorQueuePressure()
go c.handleFrames()
return nil
@@ -216,6 +226,7 @@ func (c *Connector) register() error {
func (c *Connector) dataFrameWorker(workerID int) {
defer c.handlerWg.Done()
defer c.recoverer.Recover(fmt.Sprintf("dataFrameWorker-%d", workerID))
for {
select {
@@ -224,12 +235,17 @@ func (c *Connector) dataFrameWorker(workerID int) {
return
}
if err := c.frameHandler.HandleDataFrame(frame); err != nil {
c.logger.Error("Failed to handle data frame",
zap.Int("worker_id", workerID),
zap.Error(err))
}
frame.Release()
func() {
sf := protocol.WithFrame(frame)
defer sf.Close()
defer c.recoverer.Recover("handleDataFrame")
if err := c.frameHandler.HandleDataFrame(sf.Frame); err != nil {
c.logger.Error("Failed to handle data frame",
zap.Int("worker_id", workerID),
zap.Error(err))
}
}()
case <-c.stopCh:
return
@@ -240,6 +256,7 @@ func (c *Connector) dataFrameWorker(workerID int) {
// handleFrames handles incoming frames from server
func (c *Connector) handleFrames() {
defer c.Close()
defer c.recoverer.Recover("handleFrames")
for {
select {
@@ -263,7 +280,9 @@ func (c *Connector) handleFrames() {
return
}
}
switch frame.Type {
sf := protocol.WithFrame(frame)
switch sf.Frame.Type {
case protocol.FrameTypeHeartbeatAck:
c.heartbeatMu.Lock()
if !c.heartbeatSentAt.IsZero() {
@@ -280,39 +299,39 @@ func (c *Connector) handleFrames() {
c.heartbeatMu.Unlock()
c.logger.Debug("Received heartbeat ack")
}
frame.Release()
sf.Close()
case protocol.FrameTypeData:
select {
case c.dataFrameQueue <- frame:
case c.dataFrameQueue <- sf.Frame:
case <-c.stopCh:
frame.Release()
sf.Close()
return
default:
c.logger.Warn("Data frame queue full, dropping frame")
frame.Release()
sf.Close()
}
case protocol.FrameTypeClose:
frame.Release()
sf.Close()
c.logger.Info("Server requested close")
return
case protocol.FrameTypeError:
var errMsg protocol.ErrorMessage
if err := json.Unmarshal(frame.Payload, &errMsg); err == nil {
if err := json.Unmarshal(sf.Frame.Payload, &errMsg); err == nil {
c.logger.Error("Received error from server",
zap.String("code", errMsg.Code),
zap.String("message", errMsg.Message),
)
}
frame.Release()
sf.Close()
return
default:
frame.Release()
sf.Close()
c.logger.Warn("Unexpected frame type",
zap.String("type", frame.Type.String()),
zap.String("type", sf.Frame.Type.String()),
)
}
}
@@ -359,7 +378,7 @@ func (c *Connector) Close() error {
select {
case <-done:
case <-time.After(3 * time.Second):
case <-time.After(2 * time.Second):
c.logger.Warn("Force closing: some handlers are still active")
}
@@ -421,3 +440,54 @@ func (c *Connector) IsClosed() bool {
defer c.closedMu.RUnlock()
return c.closed
}
func (c *Connector) monitorQueuePressure() {
defer c.recoverer.Recover("monitorQueuePressure")
const (
pauseThreshold = 0.80
resumeThreshold = 0.50
checkInterval = 100 * time.Millisecond
)
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
isPaused := false
for {
select {
case <-ticker.C:
queueLen := len(c.dataFrameQueue)
queueCap := cap(c.dataFrameQueue)
usage := float64(queueLen) / float64(queueCap)
if usage > pauseThreshold && !isPaused {
c.sendFlowControl("*", protocol.FlowControlPause)
isPaused = true
c.logger.Warn("Queue pressure high, sent pause signal",
zap.Int("queue_len", queueLen),
zap.Int("queue_cap", queueCap),
zap.Float64("usage", usage))
} else if usage < resumeThreshold && isPaused {
c.sendFlowControl("*", protocol.FlowControlResume)
isPaused = false
c.logger.Info("Queue pressure normal, sent resume signal",
zap.Int("queue_len", queueLen),
zap.Int("queue_cap", queueCap),
zap.Float64("usage", usage))
}
case <-c.stopCh:
return
}
}
}
func (c *Connector) sendFlowControl(streamID string, action protocol.FlowControlAction) {
frame := protocol.NewFlowControlFrame(streamID, action)
if err := c.SendFrame(frame); err != nil {
c.logger.Error("Failed to send flow control",
zap.String("action", string(action)),
zap.Error(err))
}
}

View File

@@ -604,8 +604,10 @@ func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
</body>
</html>`
data := []byte(html)
w.Header().Set("Content-Type", "text/html")
w.Write([]byte(html))
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {

View File

@@ -44,6 +44,10 @@ type Connection struct {
tunnelType protocol.TunnelType // Track tunnel type
ctx context.Context
cancel context.CancelFunc
// Flow control
paused bool
pauseCond *sync.Cond
}
// HTTPResponseHandler interface for response channel operations
@@ -60,7 +64,7 @@ type HTTPResponseHandler interface {
// NewConnection creates a new connection handler
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Connection {
ctx, cancel := context.WithCancel(context.Background())
return &Connection{
c := &Connection{
conn: conn,
authToken: authToken,
manager: manager,
@@ -75,6 +79,8 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
ctx: ctx,
cancel: cancel,
}
c.pauseCond = sync.NewCond(&c.mu)
return c
}
// Handle handles the connection lifecycle
@@ -118,14 +124,15 @@ func (c *Connection) Handle() error {
if err != nil {
return fmt.Errorf("failed to read registration frame: %w", err)
}
defer frame.Release() // Return pool buffer when done
sf := protocol.WithFrame(frame)
defer sf.Close()
if frame.Type != protocol.FrameTypeRegister {
return fmt.Errorf("expected register frame, got %s", frame.Type)
if sf.Frame.Type != protocol.FrameTypeRegister {
return fmt.Errorf("expected register frame, got %s", sf.Frame.Type)
}
var req protocol.RegisterRequest
if err := json.Unmarshal(frame.Payload, &req); err != nil {
if err := json.Unmarshal(sf.Frame.Payload, &req); err != nil {
return fmt.Errorf("failed to parse registration request: %w", err)
}
@@ -390,25 +397,31 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
}
// Handle frame based on type
switch frame.Type {
sf := protocol.WithFrame(frame)
switch sf.Frame.Type {
case protocol.FrameTypeHeartbeat:
c.handleHeartbeat()
frame.Release()
sf.Close()
case protocol.FrameTypeData:
// Data frame from client (response to forwarded request)
c.handleDataFrame(frame)
frame.Release() // Release after processing
c.handleDataFrame(sf.Frame)
sf.Close()
case protocol.FrameTypeFlowControl:
c.handleFlowControl(sf.Frame)
sf.Close()
case protocol.FrameTypeClose:
frame.Release()
sf.Close()
c.logger.Info("Client requested close")
return nil
default:
frame.Release()
sf.Close()
c.logger.Warn("Unexpected frame type",
zap.String("type", frame.Type.String()),
zap.String("type", sf.Frame.Type.String()),
)
}
}
@@ -574,6 +587,9 @@ func (c *Connection) SendFrame(frame *protocol.Frame) error {
if c.frameWriter == nil {
return protocol.WriteFrame(c.conn, frame)
}
if frame.Type == protocol.FrameTypeData {
return c.sendWithBackpressure(frame)
}
return c.frameWriter.WriteFrame(frame)
}
@@ -681,3 +697,40 @@ func (w *httpResponseWriter) Write(data []byte) (int, error) {
}
return w.writer.Write(data)
}
func (c *Connection) handleFlowControl(frame *protocol.Frame) {
msg, err := protocol.DecodeFlowControlMessage(frame.Payload)
if err != nil {
c.logger.Error("Failed to decode flow control", zap.Error(err))
return
}
c.mu.Lock()
defer c.mu.Unlock()
switch msg.Action {
case protocol.FlowControlPause:
c.paused = true
c.logger.Warn("Client requested pause",
zap.String("stream", msg.StreamID))
case protocol.FlowControlResume:
c.paused = false
c.pauseCond.Broadcast()
c.logger.Info("Client requested resume",
zap.String("stream", msg.StreamID))
default:
c.logger.Warn("Unknown flow control action",
zap.String("action", string(msg.Action)))
}
}
func (c *Connection) sendWithBackpressure(frame *protocol.Frame) error {
c.mu.Lock()
for c.paused {
c.pauseCond.Wait()
}
c.mu.Unlock()
return c.frameWriter.WriteFrame(frame)
}

View File

@@ -11,6 +11,7 @@ import (
"drip/internal/server/tunnel"
"drip/internal/shared/pool"
"drip/internal/shared/recovery"
"go.uber.org/zap"
)
@@ -32,6 +33,8 @@ type Listener struct {
connections map[string]*Connection
connMu sync.RWMutex
workerPool *pool.WorkerPool // Worker pool for connection handling
recoverer *recovery.Recoverer
panicMetrics *recovery.PanicMetrics
}
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Listener {
@@ -46,6 +49,9 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
zap.Int("queue_size", queueSize),
)
panicMetrics := recovery.NewPanicMetrics(logger, nil)
recoverer := recovery.NewRecoverer(logger, panicMetrics)
return &Listener{
address: address,
tlsConfig: tlsConfig,
@@ -60,6 +66,8 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
stopCh: make(chan struct{}),
connections: make(map[string]*Connection),
workerPool: workerPool,
recoverer: recoverer,
panicMetrics: panicMetrics,
}
}
@@ -86,6 +94,7 @@ func (l *Listener) Start() error {
// acceptLoop accepts incoming connections
func (l *Listener) acceptLoop() {
defer l.wg.Done()
defer l.recoverer.Recover("acceptLoop")
for {
select {
@@ -114,12 +123,20 @@ func (l *Listener) acceptLoop() {
}
l.wg.Add(1)
submitted := l.workerPool.Submit(func() {
l.handleConnection(conn)
})
submitted := l.workerPool.Submit(l.recoverer.WrapGoroutine(
fmt.Sprintf("handleConnection-%s", conn.RemoteAddr().String()),
func() {
l.handleConnection(conn)
},
))
if !submitted {
go l.handleConnection(conn)
l.recoverer.SafeGo(
fmt.Sprintf("handleConnection-fallback-%s", conn.RemoteAddr().String()),
func() {
l.handleConnection(conn)
},
)
}
}
}
@@ -128,6 +145,12 @@ func (l *Listener) acceptLoop() {
func (l *Listener) handleConnection(netConn net.Conn) {
defer l.wg.Done()
defer netConn.Close()
defer l.recoverer.RecoverWithCallback("handleConnection", func(p interface{}) {
connID := netConn.RemoteAddr().String()
l.connMu.Lock()
delete(l.connections, connID)
l.connMu.Unlock()
})
tlsConn, ok := netConn.(*tls.Conn)
if !ok {

View File

@@ -10,10 +10,10 @@ const (
DefaultWSPort = 8080
// HeartbeatInterval is how often clients send heartbeat messages
HeartbeatInterval = 5 * time.Second
HeartbeatInterval = 2 * time.Second
// HeartbeatTimeout is how long the server waits before considering a connection dead
HeartbeatTimeout = 15 * time.Second
HeartbeatTimeout = 6 * time.Second
// RequestTimeout is the maximum time to wait for a response from the client
RequestTimeout = 30 * time.Second

View File

@@ -0,0 +1,34 @@
package protocol
import (
json "github.com/goccy/go-json"
)
type FlowControlAction string
const (
FlowControlPause FlowControlAction = "pause"
FlowControlResume FlowControlAction = "resume"
)
type FlowControlMessage struct {
StreamID string `json:"stream_id"`
Action FlowControlAction `json:"action"`
}
func NewFlowControlFrame(streamID string, action FlowControlAction) *Frame {
msg := FlowControlMessage{
StreamID: streamID,
Action: action,
}
payload, _ := json.Marshal(&msg)
return NewFrame(FrameTypeFlowControl, payload)
}
func DecodeFlowControlMessage(payload []byte) (*FlowControlMessage, error) {
var msg FlowControlMessage
if err := json.Unmarshal(payload, &msg); err != nil {
return nil, err
}
return &msg, nil
}

View File

@@ -25,6 +25,7 @@ const (
FrameTypeData FrameType = 0x05
FrameTypeClose FrameType = 0x06
FrameTypeError FrameType = 0x07
FrameTypeFlowControl FrameType = 0x08
)
// String returns the string representation of frame type
@@ -44,6 +45,8 @@ func (t FrameType) String() string {
return "Close"
case FrameTypeError:
return "Error"
case FrameTypeFlowControl:
return "FlowControl"
default:
return fmt.Sprintf("Unknown(%d)", t)
}

View File

@@ -0,0 +1,40 @@
package protocol
import (
"sync"
)
// SafeFrame wraps Frame with automatic resource cleanup
type SafeFrame struct {
*Frame
once sync.Once
}
// NewSafeFrame creates a SafeFrame that implements io.Closer
func NewSafeFrame(frameType FrameType, payload []byte) *SafeFrame {
return &SafeFrame{
Frame: NewFrame(frameType, payload),
}
}
// Close implements io.Closer, ensures Release is called exactly once
func (sf *SafeFrame) Close() error {
sf.once.Do(func() {
if sf.Frame != nil {
sf.Frame.Release()
}
})
return nil
}
// WithFrame wraps an existing Frame with automatic cleanup
func WithFrame(frame *Frame) *SafeFrame {
return &SafeFrame{Frame: frame}
}
// MustClose is a helper that calls Close and panics on error (for defer cleanup)
func (sf *SafeFrame) MustClose() {
if err := sf.Close(); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,110 @@
package recovery
import (
"fmt"
"runtime/debug"
"sync"
"sync/atomic"
"time"
"go.uber.org/zap"
)
type PanicMetrics struct {
totalPanics uint64
recentPanics []PanicRecord
mu sync.Mutex
logger *zap.Logger
alerter Alerter
}
type PanicRecord struct {
Location string
Timestamp time.Time
Value interface{}
Stack string
}
type Alerter interface {
SendAlert(title string, message string)
}
func NewPanicMetrics(logger *zap.Logger, alerter Alerter) *PanicMetrics {
return &PanicMetrics{
recentPanics: make([]PanicRecord, 0, 100),
logger: logger,
alerter: alerter,
}
}
func (pm *PanicMetrics) RecordPanic(location string, panicValue interface{}) {
atomic.AddUint64(&pm.totalPanics, 1)
pm.mu.Lock()
record := PanicRecord{
Location: location,
Timestamp: time.Now(),
Value: panicValue,
Stack: string(debug.Stack()),
}
pm.recentPanics = append(pm.recentPanics, record)
if len(pm.recentPanics) > 100 {
pm.recentPanics = pm.recentPanics[1:]
}
shouldAlert := pm.shouldAlertUnlocked()
pm.mu.Unlock()
if shouldAlert {
pm.sendAlert()
}
}
func (pm *PanicMetrics) shouldAlertUnlocked() bool {
threshold := time.Now().Add(-5 * time.Minute)
count := 0
for i := len(pm.recentPanics) - 1; i >= 0; i-- {
if pm.recentPanics[i].Timestamp.After(threshold) {
count++
} else {
break
}
}
rate := float64(count) / 5.0
return rate >= 2.0
}
func (pm *PanicMetrics) sendAlert() {
total := atomic.LoadUint64(&pm.totalPanics)
pm.mu.Lock()
threshold := time.Now().Add(-5 * time.Minute)
count := 0
for i := len(pm.recentPanics) - 1; i >= 0; i-- {
if pm.recentPanics[i].Timestamp.After(threshold) {
count++
} else {
break
}
}
rate := float64(count) / 5.0
pm.mu.Unlock()
pm.logger.Error("ALERT: High panic rate detected",
zap.Uint64("total_panics", total),
zap.Float64("rate_per_minute", rate),
)
if pm.alerter != nil {
message := "High panic rate detected: %.2f panics/minute (total: %d)"
pm.alerter.SendAlert(
"Drip: High Panic Rate",
fmt.Sprintf(message, rate, total),
)
}
}

View File

@@ -0,0 +1,79 @@
package recovery
import (
"runtime/debug"
"go.uber.org/zap"
)
type Recoverer struct {
logger *zap.Logger
metrics MetricsCollector
}
type MetricsCollector interface {
RecordPanic(location string, panicValue interface{})
}
func NewRecoverer(logger *zap.Logger, metrics MetricsCollector) *Recoverer {
return &Recoverer{
logger: logger,
metrics: metrics,
}
}
func (r *Recoverer) WrapGoroutine(name string, fn func()) func() {
return func() {
defer func() {
if p := recover(); p != nil {
r.logger.Error("goroutine panic recovered",
zap.String("goroutine", name),
zap.Any("panic", p),
zap.ByteString("stack", debug.Stack()),
)
if r.metrics != nil {
r.metrics.RecordPanic(name, p)
}
}
}()
fn()
}
}
func (r *Recoverer) SafeGo(name string, fn func()) {
go r.WrapGoroutine(name, fn)()
}
func (r *Recoverer) Recover(location string) {
if p := recover(); p != nil {
r.logger.Error("panic recovered",
zap.String("location", location),
zap.Any("panic", p),
zap.ByteString("stack", debug.Stack()),
)
if r.metrics != nil {
r.metrics.RecordPanic(location, p)
}
}
}
func (r *Recoverer) RecoverWithCallback(location string, callback func(panicValue interface{})) {
if p := recover(); p != nil {
r.logger.Error("panic recovered with callback",
zap.String("location", location),
zap.Any("panic", p),
zap.ByteString("stack", debug.Stack()),
)
if r.metrics != nil {
r.metrics.RecordPanic(location, p)
}
if callback != nil {
callback(p)
}
}
}