diff --git a/internal/client/tcp/connector.go b/internal/client/tcp/connector.go index a70ea2c..618378a 100644 --- a/internal/client/tcp/connector.go +++ b/internal/client/tcp/connector.go @@ -12,6 +12,7 @@ import ( "drip/internal/shared/constants" "drip/internal/shared/pool" "drip/internal/shared/protocol" + "drip/internal/shared/recovery" "drip/pkg/config" "go.uber.org/zap" ) @@ -47,6 +48,9 @@ type Connector struct { // Worker pool for handling data frames dataFrameQueue chan *protocol.Frame workerCount int + + recoverer *recovery.Recoverer + panicMetrics *recovery.PanicMetrics } // ConnectorConfig holds connector configuration @@ -78,6 +82,9 @@ func NewConnector(cfg *ConnectorConfig, logger *zap.Logger) *Connector { numCPU := pool.NumCPU() workerCount := max(numCPU+numCPU/2, 4) + panicMetrics := recovery.NewPanicMetrics(logger, nil) + recoverer := recovery.NewRecoverer(logger, panicMetrics) + return &Connector{ serverAddr: cfg.ServerAddr, tlsConfig: tlsConfig, @@ -90,6 +97,8 @@ func NewConnector(cfg *ConnectorConfig, logger *zap.Logger) *Connector { stopCh: make(chan struct{}), dataFrameQueue: make(chan *protocol.Frame, workerCount*100), workerCount: workerCount, + recoverer: recoverer, + panicMetrics: panicMetrics, } } @@ -216,6 +225,7 @@ func (c *Connector) register() error { func (c *Connector) dataFrameWorker(workerID int) { defer c.handlerWg.Done() + defer c.recoverer.Recover(fmt.Sprintf("dataFrameWorker-%d", workerID)) for { select { @@ -224,12 +234,20 @@ func (c *Connector) dataFrameWorker(workerID int) { return } - if err := c.frameHandler.HandleDataFrame(frame); err != nil { - c.logger.Error("Failed to handle data frame", - zap.Int("worker_id", workerID), - zap.Error(err)) - } - frame.Release() + func() { + defer c.recoverer.RecoverWithCallback("handleDataFrame", func(p interface{}) { + if frame != nil { + frame.Release() + } + }) + + if err := c.frameHandler.HandleDataFrame(frame); err != nil { + c.logger.Error("Failed to handle data frame", + zap.Int("worker_id", workerID), + zap.Error(err)) + } + frame.Release() + }() case <-c.stopCh: return @@ -240,6 +258,7 @@ func (c *Connector) dataFrameWorker(workerID int) { // handleFrames handles incoming frames from server func (c *Connector) handleFrames() { defer c.Close() + defer c.recoverer.Recover("handleFrames") for { select { diff --git a/internal/server/tcp/listener.go b/internal/server/tcp/listener.go index 28194c4..337b075 100644 --- a/internal/server/tcp/listener.go +++ b/internal/server/tcp/listener.go @@ -11,6 +11,7 @@ import ( "drip/internal/server/tunnel" "drip/internal/shared/pool" + "drip/internal/shared/recovery" "go.uber.org/zap" ) @@ -32,6 +33,8 @@ type Listener struct { connections map[string]*Connection connMu sync.RWMutex workerPool *pool.WorkerPool // Worker pool for connection handling + recoverer *recovery.Recoverer + panicMetrics *recovery.PanicMetrics } func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Listener { @@ -46,6 +49,9 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage zap.Int("queue_size", queueSize), ) + panicMetrics := recovery.NewPanicMetrics(logger, nil) + recoverer := recovery.NewRecoverer(logger, panicMetrics) + return &Listener{ address: address, tlsConfig: tlsConfig, @@ -60,6 +66,8 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage stopCh: make(chan struct{}), connections: make(map[string]*Connection), workerPool: workerPool, + recoverer: recoverer, + panicMetrics: panicMetrics, } } @@ -86,6 +94,7 @@ func (l *Listener) Start() error { // acceptLoop accepts incoming connections func (l *Listener) acceptLoop() { defer l.wg.Done() + defer l.recoverer.Recover("acceptLoop") for { select { @@ -114,12 +123,20 @@ func (l *Listener) acceptLoop() { } l.wg.Add(1) - submitted := l.workerPool.Submit(func() { - l.handleConnection(conn) - }) + submitted := l.workerPool.Submit(l.recoverer.WrapGoroutine( + fmt.Sprintf("handleConnection-%s", conn.RemoteAddr().String()), + func() { + l.handleConnection(conn) + }, + )) if !submitted { - go l.handleConnection(conn) + l.recoverer.SafeGo( + fmt.Sprintf("handleConnection-fallback-%s", conn.RemoteAddr().String()), + func() { + l.handleConnection(conn) + }, + ) } } } @@ -128,6 +145,12 @@ func (l *Listener) acceptLoop() { func (l *Listener) handleConnection(netConn net.Conn) { defer l.wg.Done() defer netConn.Close() + defer l.recoverer.RecoverWithCallback("handleConnection", func(p interface{}) { + connID := netConn.RemoteAddr().String() + l.connMu.Lock() + delete(l.connections, connID) + l.connMu.Unlock() + }) tlsConn, ok := netConn.(*tls.Conn) if !ok { diff --git a/internal/shared/recovery/metrics.go b/internal/shared/recovery/metrics.go new file mode 100644 index 0000000..4792fe1 --- /dev/null +++ b/internal/shared/recovery/metrics.go @@ -0,0 +1,110 @@ +package recovery + +import ( + "fmt" + "runtime/debug" + "sync" + "sync/atomic" + "time" + + "go.uber.org/zap" +) + +type PanicMetrics struct { + totalPanics uint64 + recentPanics []PanicRecord + mu sync.Mutex + logger *zap.Logger + alerter Alerter +} + +type PanicRecord struct { + Location string + Timestamp time.Time + Value interface{} + Stack string +} + +type Alerter interface { + SendAlert(title string, message string) +} + +func NewPanicMetrics(logger *zap.Logger, alerter Alerter) *PanicMetrics { + return &PanicMetrics{ + recentPanics: make([]PanicRecord, 0, 100), + logger: logger, + alerter: alerter, + } +} + +func (pm *PanicMetrics) RecordPanic(location string, panicValue interface{}) { + atomic.AddUint64(&pm.totalPanics, 1) + + pm.mu.Lock() + + record := PanicRecord{ + Location: location, + Timestamp: time.Now(), + Value: panicValue, + Stack: string(debug.Stack()), + } + + pm.recentPanics = append(pm.recentPanics, record) + + if len(pm.recentPanics) > 100 { + pm.recentPanics = pm.recentPanics[1:] + } + + shouldAlert := pm.shouldAlertUnlocked() + pm.mu.Unlock() + + if shouldAlert { + pm.sendAlert() + } +} + +func (pm *PanicMetrics) shouldAlertUnlocked() bool { + threshold := time.Now().Add(-5 * time.Minute) + count := 0 + + for i := len(pm.recentPanics) - 1; i >= 0; i-- { + if pm.recentPanics[i].Timestamp.After(threshold) { + count++ + } else { + break + } + } + + rate := float64(count) / 5.0 + return rate >= 2.0 +} + +func (pm *PanicMetrics) sendAlert() { + total := atomic.LoadUint64(&pm.totalPanics) + + pm.mu.Lock() + threshold := time.Now().Add(-5 * time.Minute) + count := 0 + for i := len(pm.recentPanics) - 1; i >= 0; i-- { + if pm.recentPanics[i].Timestamp.After(threshold) { + count++ + } else { + break + } + } + rate := float64(count) / 5.0 + pm.mu.Unlock() + + pm.logger.Error("ALERT: High panic rate detected", + zap.Uint64("total_panics", total), + zap.Float64("rate_per_minute", rate), + ) + + if pm.alerter != nil { + message := "High panic rate detected: %.2f panics/minute (total: %d)" + pm.alerter.SendAlert( + "Drip: High Panic Rate", + fmt.Sprintf(message, rate, total), + ) + } +} diff --git a/internal/shared/recovery/middleware.go b/internal/shared/recovery/middleware.go new file mode 100644 index 0000000..e1df81f --- /dev/null +++ b/internal/shared/recovery/middleware.go @@ -0,0 +1,79 @@ +package recovery + +import ( + "runtime/debug" + + "go.uber.org/zap" +) + +type Recoverer struct { + logger *zap.Logger + metrics MetricsCollector +} + +type MetricsCollector interface { + RecordPanic(location string, panicValue interface{}) +} + +func NewRecoverer(logger *zap.Logger, metrics MetricsCollector) *Recoverer { + return &Recoverer{ + logger: logger, + metrics: metrics, + } +} + +func (r *Recoverer) WrapGoroutine(name string, fn func()) func() { + return func() { + defer func() { + if p := recover(); p != nil { + r.logger.Error("goroutine panic recovered", + zap.String("goroutine", name), + zap.Any("panic", p), + zap.ByteString("stack", debug.Stack()), + ) + + if r.metrics != nil { + r.metrics.RecordPanic(name, p) + } + } + }() + + fn() + } +} + +func (r *Recoverer) SafeGo(name string, fn func()) { + go r.WrapGoroutine(name, fn)() +} + +func (r *Recoverer) Recover(location string) { + if p := recover(); p != nil { + r.logger.Error("panic recovered", + zap.String("location", location), + zap.Any("panic", p), + zap.ByteString("stack", debug.Stack()), + ) + + if r.metrics != nil { + r.metrics.RecordPanic(location, p) + } + } +} + +func (r *Recoverer) RecoverWithCallback(location string, callback func(panicValue interface{})) { + if p := recover(); p != nil { + r.logger.Error("panic recovered with callback", + zap.String("location", location), + zap.Any("panic", p), + zap.ByteString("stack", debug.Stack()), + ) + + if r.metrics != nil { + r.metrics.RecordPanic(location, p) + } + + if callback != nil { + callback(p) + } + } +}