Files
drip/internal/client/tcp/pool_client.go
Gouryella 0c19c3300c feat(tunnel): switch to yamux stream proxying and connection pooling
- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server
- Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly
- Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
2025-12-13 18:03:44 +08:00

451 lines
9.7 KiB
Go

package tcp
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"runtime"
"sync"
"sync/atomic"
"time"
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
"drip/internal/shared/constants"
"drip/internal/shared/protocol"
"drip/internal/shared/stats"
"drip/pkg/config"
)
// PoolClient manages a pool of yamux sessions for tunnel connections.
type PoolClient struct {
serverAddr string
tlsConfig *tls.Config
token string
tunnelType protocol.TunnelType
localHost string
localPort int
subdomain string
assignedURL string
tunnelID string
minSessions int
maxSessions int
initialSessions int
stats *stats.TrafficStats
httpClient *http.Client
latencyCallback atomic.Value // LatencyCallback
latencyNanos atomic.Int64
ctx context.Context
cancel context.CancelFunc
stopCh chan struct{}
doneCh chan struct{}
once sync.Once
wg sync.WaitGroup
closed atomic.Bool
primary *sessionHandle
mu sync.RWMutex
dataSessions map[string]*sessionHandle
desiredTotal int
lastScale time.Time
logger *zap.Logger
}
// NewPoolClient creates a new pool client.
func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
var tlsConfig *tls.Config
if cfg.Insecure {
tlsConfig = config.GetClientTLSConfigInsecure()
} else {
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
tlsConfig = config.GetClientTLSConfig(host)
}
localHost := cfg.LocalHost
if localHost == "" {
localHost = "127.0.0.1"
}
tunnelType := cfg.TunnelType
if tunnelType == "" {
tunnelType = protocol.TunnelTypeTCP
}
numCPU := runtime.NumCPU()
minSessions := cfg.PoolMin
if minSessions <= 0 {
minSessions = 2
}
maxSessions := cfg.PoolMax
if maxSessions <= 0 {
maxSessions = max(numCPU*16, minSessions)
}
if maxSessions < minSessions {
maxSessions = minSessions
}
initialSessions := cfg.PoolSize
if initialSessions <= 0 {
initialSessions = 4
}
initialSessions = min(max(initialSessions, minSessions), maxSessions)
ctx, cancel := context.WithCancel(context.Background())
c := &PoolClient{
serverAddr: cfg.ServerAddr,
tlsConfig: tlsConfig,
token: cfg.Token,
tunnelType: tunnelType,
localHost: localHost,
localPort: cfg.LocalPort,
subdomain: cfg.Subdomain,
minSessions: minSessions,
maxSessions: maxSessions,
initialSessions: initialSessions,
stats: stats.NewTrafficStats(),
ctx: ctx,
cancel: cancel,
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
dataSessions: make(map[string]*sessionHandle),
logger: logger,
}
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
c.httpClient = newLocalHTTPClient(tunnelType)
}
c.latencyCallback.Store(LatencyCallback(func(time.Duration) {}))
return c
}
// Connect establishes the primary connection and starts background workers.
func (c *PoolClient) Connect() error {
primaryConn, err := c.dialTLS()
if err != nil {
return err
}
maxData := max(c.maxSessions-1, 0)
req := protocol.RegisterRequest{
Token: c.token,
CustomSubdomain: c.subdomain,
TunnelType: c.tunnelType,
LocalPort: c.localPort,
ConnectionType: "primary",
PoolCapabilities: &protocol.PoolCapabilities{
MaxDataConns: maxData,
Version: 1,
},
}
payload, err := json.Marshal(req)
if err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to marshal request: %w", err)
}
if err := protocol.WriteFrame(primaryConn, protocol.NewFrame(protocol.FrameTypeRegister, payload)); err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to send registration: %w", err)
}
_ = primaryConn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
ack, err := protocol.ReadFrame(primaryConn)
if err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to read register ack: %w", err)
}
defer ack.Release()
_ = primaryConn.SetReadDeadline(time.Time{})
if ack.Type == protocol.FrameTypeError {
var errMsg protocol.ErrorMessage
if e := json.Unmarshal(ack.Payload, &errMsg); e == nil {
_ = primaryConn.Close()
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
}
_ = primaryConn.Close()
return fmt.Errorf("registration error")
}
if ack.Type != protocol.FrameTypeRegisterAck {
_ = primaryConn.Close()
return fmt.Errorf("unexpected register ack frame: %s", ack.Type)
}
var resp protocol.RegisterResponse
if err := json.Unmarshal(ack.Payload, &resp); err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to parse register response: %w", err)
}
c.assignedURL = resp.URL
c.subdomain = resp.Subdomain
if resp.SupportsDataConn && resp.TunnelID != "" {
c.tunnelID = resp.TunnelID
}
yamuxCfg := yamux.DefaultConfig()
yamuxCfg.EnableKeepAlive = false
yamuxCfg.LogOutput = io.Discard
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Server(primaryConn, yamuxCfg)
if err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to init yamux session: %w", err)
}
primary := &sessionHandle{
id: "primary",
conn: primaryConn,
session: session,
}
primary.touch()
c.primary = primary
c.wg.Add(1)
go func() {
defer c.wg.Done()
<-c.stopCh
}()
c.wg.Add(1)
go c.acceptLoop(primary, true)
c.wg.Add(1)
go c.sessionWatcher(primary, true)
c.wg.Add(1)
go c.pingLoop(primary)
if c.tunnelID != "" {
c.mu.Lock()
c.desiredTotal = c.initialSessions
c.mu.Unlock()
c.ensureSessions()
c.wg.Add(1)
go c.scalerLoop()
}
go func() {
c.wg.Wait()
close(c.doneCh)
}()
return nil
}
func (c *PoolClient) dialTLS() (net.Conn, error) {
dialer := &net.Dialer{Timeout: 10 * time.Second}
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to connect: %w", err)
}
state := conn.ConnectionState()
if state.Version != tls.VersionTLS13 {
_ = conn.Close()
return nil, fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
}
if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetReadBuffer(256 * 1024)
_ = tcpConn.SetWriteBuffer(256 * 1024)
}
return conn, nil
}
func (c *PoolClient) acceptLoop(h *sessionHandle, isPrimary bool) {
defer c.wg.Done()
for {
select {
case <-c.stopCh:
return
default:
}
stream, err := h.session.Accept()
if err != nil {
if c.IsClosed() || isExpectedCloseError(err) {
return
}
if isPrimary {
c.logger.Debug("Primary session accept failed", zap.Error(err))
_ = c.Close()
return
}
c.logger.Debug("Data session accept failed", zap.String("session_id", h.id), zap.Error(err))
c.removeDataSession(h.id)
return
}
h.active.Add(1)
h.touch()
c.stats.AddRequest()
c.stats.IncActiveConnections()
c.wg.Add(1)
go c.handleStream(h, stream)
}
}
func (c *PoolClient) sessionWatcher(h *sessionHandle, isPrimary bool) {
defer c.wg.Done()
select {
case <-c.stopCh:
return
case <-h.session.CloseChan():
if isPrimary {
_ = c.Close()
return
}
c.removeDataSession(h.id)
}
}
func (c *PoolClient) pingLoop(h *sessionHandle) {
defer c.wg.Done()
const maxConsecutiveFailures = 3
ticker := time.NewTicker(constants.HeartbeatInterval)
defer ticker.Stop()
consecutiveFailures := 0
for {
select {
case <-c.stopCh:
return
case <-ticker.C:
}
if h.session == nil || h.session.IsClosed() {
return
}
latency, err := h.session.Ping()
if err != nil {
consecutiveFailures++
c.logger.Debug("Ping failed",
zap.String("session_id", h.id),
zap.Int("consecutive_failures", consecutiveFailures),
zap.Error(err),
)
if consecutiveFailures >= maxConsecutiveFailures {
c.logger.Warn("Session ping failed too many times, closing",
zap.String("session_id", h.id),
zap.Int("failures", consecutiveFailures),
)
if h.id == "primary" {
_ = c.Close()
return
}
c.removeDataSession(h.id)
return
}
continue
}
consecutiveFailures = 0
h.touch()
c.latencyNanos.Store(int64(latency))
if cb, ok := c.latencyCallback.Load().(LatencyCallback); ok && cb != nil {
cb(latency)
}
}
}
// Close shuts down the client and all sessions.
func (c *PoolClient) Close() error {
var closeErr error
c.once.Do(func() {
c.closed.Store(true)
close(c.stopCh)
if c.cancel != nil {
c.cancel()
}
var data []*sessionHandle
var primary *sessionHandle
c.mu.Lock()
for _, h := range c.dataSessions {
data = append(data, h)
}
c.dataSessions = make(map[string]*sessionHandle)
primary = c.primary
c.primary = nil
c.mu.Unlock()
for _, h := range data {
if h == nil {
continue
}
if h.session != nil {
_ = h.session.Close()
}
if h.conn != nil {
_ = h.conn.Close()
}
}
if primary != nil {
if primary.session != nil {
closeErr = primary.session.Close()
}
if primary.conn != nil {
_ = primary.conn.Close()
}
}
})
return closeErr
}
func (c *PoolClient) Wait() { <-c.doneCh }
func (c *PoolClient) GetURL() string { return c.assignedURL }
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
func (c *PoolClient) SetLatencyCallback(cb LatencyCallback) {
if cb == nil {
cb = func(time.Duration) {}
}
c.latencyCallback.Store(cb)
}