Files
drip/internal/client/tcp/connector.go
Gouryella f6f2c6fd5b feat (recovery): Introduced panic recovery and monitoring mechanisms.
A new recovery package has been added, containing Recoverer and PanicMetrics, for capturing panics in goroutines.
It records stack trace information and provides statistical metrics. This mechanism is also integrated into the TCP connector and listener.
Enhance service stability and observability.
2025-12-10 16:10:26 +08:00

443 lines
10 KiB
Go

package tcp
import (
"crypto/tls"
"fmt"
"net"
json "github.com/goccy/go-json"
"sync"
"time"
"drip/internal/shared/constants"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/recovery"
"drip/pkg/config"
"go.uber.org/zap"
)
// LatencyCallback is called when latency is measured
type LatencyCallback func(latency time.Duration)
// Connector manages the TCP connection to the server
type Connector struct {
serverAddr string
tlsConfig *tls.Config
token string
tunnelType protocol.TunnelType
localHost string
localPort int
subdomain string
conn net.Conn
logger *zap.Logger
stopCh chan struct{}
once sync.Once
registered bool
assignedURL string
frameHandler *FrameHandler
frameWriter *protocol.FrameWriter
latencyCallback LatencyCallback
heartbeatSentAt time.Time
heartbeatMu sync.Mutex
lastLatency time.Duration
handlerWg sync.WaitGroup // Tracks active data frame handlers
closed bool
closedMu sync.RWMutex
// Worker pool for handling data frames
dataFrameQueue chan *protocol.Frame
workerCount int
recoverer *recovery.Recoverer
panicMetrics *recovery.PanicMetrics
}
// ConnectorConfig holds connector configuration
type ConnectorConfig struct {
ServerAddr string
Token string
TunnelType protocol.TunnelType
LocalHost string // Local host address (default: 127.0.0.1)
LocalPort int
Subdomain string // Optional custom subdomain
Insecure bool // Skip TLS verification (testing only)
}
// NewConnector creates a new connector
func NewConnector(cfg *ConnectorConfig, logger *zap.Logger) *Connector {
var tlsConfig *tls.Config
if cfg.Insecure {
tlsConfig = config.GetClientTLSConfigInsecure()
} else {
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
tlsConfig = config.GetClientTLSConfig(host)
}
localHost := cfg.LocalHost
if localHost == "" {
localHost = "127.0.0.1"
}
numCPU := pool.NumCPU()
workerCount := max(numCPU+numCPU/2, 4)
panicMetrics := recovery.NewPanicMetrics(logger, nil)
recoverer := recovery.NewRecoverer(logger, panicMetrics)
return &Connector{
serverAddr: cfg.ServerAddr,
tlsConfig: tlsConfig,
token: cfg.Token,
tunnelType: cfg.TunnelType,
localHost: localHost,
localPort: cfg.LocalPort,
subdomain: cfg.Subdomain,
logger: logger,
stopCh: make(chan struct{}),
dataFrameQueue: make(chan *protocol.Frame, workerCount*100),
workerCount: workerCount,
recoverer: recoverer,
panicMetrics: panicMetrics,
}
}
// Connect connects to the server and registers the tunnel
func (c *Connector) Connect() error {
c.logger.Info("Connecting to server",
zap.String("server", c.serverAddr),
zap.String("tunnel_type", string(c.tunnelType)),
zap.String("local_host", c.localHost),
zap.Int("local_port", c.localPort),
)
dialer := &net.Dialer{
Timeout: 10 * time.Second,
}
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
c.conn = conn
state := conn.ConnectionState()
if state.Version != tls.VersionTLS13 {
conn.Close()
return fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
}
c.logger.Info("TLS connection established",
zap.String("cipher_suite", tls.CipherSuiteName(state.CipherSuite)),
)
if err := c.register(); err != nil {
conn.Close()
return fmt.Errorf("registration failed: %w", err)
}
c.frameWriter = protocol.NewFrameWriter(c.conn)
bufferPool := pool.NewBufferPool()
c.frameHandler = NewFrameHandler(
c.conn,
c.frameWriter,
c.localHost,
c.localPort,
c.tunnelType,
c.logger,
c.IsClosed,
bufferPool,
)
c.frameWriter.EnableHeartbeat(constants.HeartbeatInterval, c.createHeartbeatFrame)
for i := 0; i < c.workerCount; i++ {
c.handlerWg.Add(1)
go c.dataFrameWorker(i)
}
go c.frameHandler.WarmupConnectionPool(3)
go c.handleFrames()
return nil
}
// register sends registration request and waits for acknowledgment
func (c *Connector) register() error {
req := protocol.RegisterRequest{
Token: c.token,
CustomSubdomain: c.subdomain,
TunnelType: c.tunnelType,
LocalPort: c.localPort,
}
payload, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
regFrame := protocol.NewFrame(protocol.FrameTypeRegister, payload)
err = protocol.WriteFrame(c.conn, regFrame)
if err != nil {
return fmt.Errorf("failed to send registration: %w", err)
}
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
ackFrame, err := protocol.ReadFrame(c.conn)
if err != nil {
return fmt.Errorf("failed to read ack: %w", err)
}
defer ackFrame.Release()
c.conn.SetReadDeadline(time.Time{})
if ackFrame.Type == protocol.FrameTypeError {
var errMsg protocol.ErrorMessage
if err := json.Unmarshal(ackFrame.Payload, &errMsg); err == nil {
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
}
return fmt.Errorf("registration error")
}
if ackFrame.Type != protocol.FrameTypeRegisterAck {
return fmt.Errorf("unexpected frame type: %s", ackFrame.Type)
}
var resp protocol.RegisterResponse
if err := json.Unmarshal(ackFrame.Payload, &resp); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
c.registered = true
c.assignedURL = resp.URL
c.subdomain = resp.Subdomain
c.logger.Info("Tunnel registered successfully",
zap.String("subdomain", resp.Subdomain),
zap.String("url", resp.URL),
zap.Int("remote_port", resp.Port),
)
return nil
}
func (c *Connector) dataFrameWorker(workerID int) {
defer c.handlerWg.Done()
defer c.recoverer.Recover(fmt.Sprintf("dataFrameWorker-%d", workerID))
for {
select {
case frame, ok := <-c.dataFrameQueue:
if !ok {
return
}
func() {
defer c.recoverer.RecoverWithCallback("handleDataFrame", func(p interface{}) {
if frame != nil {
frame.Release()
}
})
if err := c.frameHandler.HandleDataFrame(frame); err != nil {
c.logger.Error("Failed to handle data frame",
zap.Int("worker_id", workerID),
zap.Error(err))
}
frame.Release()
}()
case <-c.stopCh:
return
}
}
}
// handleFrames handles incoming frames from server
func (c *Connector) handleFrames() {
defer c.Close()
defer c.recoverer.Recover("handleFrames")
for {
select {
case <-c.stopCh:
return
default:
}
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(c.conn)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
c.logger.Warn("Read timeout")
return
}
select {
case <-c.stopCh:
return
default:
c.logger.Error("Failed to read frame", zap.Error(err))
return
}
}
switch frame.Type {
case protocol.FrameTypeHeartbeatAck:
c.heartbeatMu.Lock()
if !c.heartbeatSentAt.IsZero() {
latency := time.Since(c.heartbeatSentAt)
c.lastLatency = latency
c.heartbeatMu.Unlock()
c.logger.Debug("Received heartbeat ack", zap.Duration("latency", latency))
if c.latencyCallback != nil {
c.latencyCallback(latency)
}
} else {
c.heartbeatMu.Unlock()
c.logger.Debug("Received heartbeat ack")
}
frame.Release()
case protocol.FrameTypeData:
select {
case c.dataFrameQueue <- frame:
case <-c.stopCh:
frame.Release()
return
default:
c.logger.Warn("Data frame queue full, dropping frame")
frame.Release()
}
case protocol.FrameTypeClose:
frame.Release()
c.logger.Info("Server requested close")
return
case protocol.FrameTypeError:
var errMsg protocol.ErrorMessage
if err := json.Unmarshal(frame.Payload, &errMsg); err == nil {
c.logger.Error("Received error from server",
zap.String("code", errMsg.Code),
zap.String("message", errMsg.Message),
)
}
frame.Release()
return
default:
frame.Release()
c.logger.Warn("Unexpected frame type",
zap.String("type", frame.Type.String()),
)
}
}
}
func (c *Connector) createHeartbeatFrame() *protocol.Frame {
c.closedMu.RLock()
if c.closed {
c.closedMu.RUnlock()
return nil
}
c.closedMu.RUnlock()
c.heartbeatMu.Lock()
c.heartbeatSentAt = time.Now()
c.heartbeatMu.Unlock()
return protocol.NewFrame(protocol.FrameTypeHeartbeat, nil)
}
// SendFrame sends a frame to the server
func (c *Connector) SendFrame(frame *protocol.Frame) error {
if !c.registered {
return fmt.Errorf("not registered")
}
return c.frameWriter.WriteFrame(frame)
}
func (c *Connector) Close() error {
c.once.Do(func() {
c.closedMu.Lock()
c.closed = true
c.closedMu.Unlock()
close(c.stopCh)
close(c.dataFrameQueue)
done := make(chan struct{})
go func() {
c.handlerWg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(3 * time.Second):
c.logger.Warn("Force closing: some handlers are still active")
}
if c.conn != nil {
closeFrame := protocol.NewFrame(protocol.FrameTypeClose, nil)
if c.frameWriter != nil {
c.frameWriter.WriteFrame(closeFrame)
c.frameWriter.Close()
} else {
protocol.WriteFrame(c.conn, closeFrame)
}
c.conn.Close()
}
c.logger.Info("Connector closed")
})
return nil
}
// Wait blocks until connection is closed
func (c *Connector) Wait() {
<-c.stopCh
}
// GetURL returns the assigned tunnel URL
func (c *Connector) GetURL() string {
return c.assignedURL
}
// GetSubdomain returns the assigned subdomain
func (c *Connector) GetSubdomain() string {
return c.subdomain
}
// SetLatencyCallback sets the callback for latency updates
func (c *Connector) SetLatencyCallback(cb LatencyCallback) {
c.latencyCallback = cb
}
// GetLatency returns the last measured latency
func (c *Connector) GetLatency() time.Duration {
c.heartbeatMu.Lock()
defer c.heartbeatMu.Unlock()
return c.lastLatency
}
// GetStats returns the traffic stats from the frame handler
func (c *Connector) GetStats() *TrafficStats {
if c.frameHandler != nil {
return c.frameHandler.GetStats()
}
return nil
}
// IsClosed returns whether the connector has been closed
func (c *Connector) IsClosed() bool {
c.closedMu.RLock()
defer c.closedMu.RUnlock()
return c.closed
}