Files
drip/internal/client/tcp/pool_client.go
Gouryella 6139a9c0ed feat: add transport protocol option supporting TCP and WebSocket connections
Added --transport parameter to allow users to select transport protocol type:
- auto: automatically choose based on server address (default)
- tcp: direct TLS 1.3 connection
- wss: WebSocket over TLS (CDN-friendly)

Also updated client connector to support WebSocket transport, and added server-side discovery endpoint to query supported transport protocols.
2026-01-14 12:49:08 +08:00

639 lines
14 KiB
Go

package tcp
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
"drip/internal/shared/constants"
"drip/internal/shared/mux"
"drip/internal/shared/protocol"
"drip/internal/shared/stats"
"drip/internal/shared/wsutil"
"drip/pkg/config"
)
// PoolClient manages a pool of yamux sessions for tunnel connections.
type PoolClient struct {
serverAddr string
tlsConfig *tls.Config
token string
tunnelType protocol.TunnelType
localHost string
localPort int
subdomain string
assignedURL string
tunnelID string
minSessions int
maxSessions int
initialSessions int
stats *stats.TrafficStats
httpClient *http.Client
latencyCallback atomic.Value // LatencyCallback
latencyNanos atomic.Int64
ctx context.Context
cancel context.CancelFunc
stopCh chan struct{}
doneCh chan struct{}
once sync.Once
wg sync.WaitGroup
closed atomic.Bool
primary *sessionHandle
mu sync.RWMutex
dataSessions map[string]*sessionHandle
desiredTotal int
lastScale time.Time
logger *zap.Logger
allowIPs []string
denyIPs []string
authPass string
// Transport protocol selection
transport TransportType
insecure bool
}
// NewPoolClient creates a new pool client.
func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
// Parse server address to get host for TLS config
serverAddr := cfg.ServerAddr
host := serverAddr
// Handle wss:// prefix
if strings.HasPrefix(serverAddr, "wss://") {
if u, err := url.Parse(serverAddr); err == nil {
host = u.Host
// Normalize server address for internal use
if u.Port() == "" {
host = u.Host + ":443"
}
serverAddr = host
}
}
// Extract hostname without port for TLS
hostOnly, _, _ := net.SplitHostPort(host)
if hostOnly == "" {
hostOnly = host
}
var tlsConfig *tls.Config
if cfg.Insecure {
tlsConfig = config.GetClientTLSConfigInsecure()
} else {
tlsConfig = config.GetClientTLSConfig(hostOnly)
}
localHost := cfg.LocalHost
if localHost == "" {
localHost = "127.0.0.1"
}
tunnelType := cfg.TunnelType
if tunnelType == "" {
tunnelType = protocol.TunnelTypeTCP
}
numCPU := runtime.NumCPU()
minSessions := cfg.PoolMin
if minSessions <= 0 {
minSessions = 2
}
maxSessions := cfg.PoolMax
if maxSessions <= 0 {
maxSessions = max(numCPU*16, minSessions)
}
if maxSessions < minSessions {
maxSessions = minSessions
}
initialSessions := cfg.PoolSize
if initialSessions <= 0 {
initialSessions = 4
}
initialSessions = min(max(initialSessions, minSessions), maxSessions)
// Determine transport type
transport := cfg.Transport
if transport == "" {
transport = TransportAuto
}
ctx, cancel := context.WithCancel(context.Background())
c := &PoolClient{
serverAddr: serverAddr,
tlsConfig: tlsConfig,
token: cfg.Token,
tunnelType: tunnelType,
localHost: localHost,
localPort: cfg.LocalPort,
subdomain: cfg.Subdomain,
minSessions: minSessions,
maxSessions: maxSessions,
initialSessions: initialSessions,
stats: stats.NewTrafficStats(),
ctx: ctx,
cancel: cancel,
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
dataSessions: make(map[string]*sessionHandle),
logger: logger,
allowIPs: cfg.AllowIPs,
denyIPs: cfg.DenyIPs,
authPass: cfg.AuthPass,
transport: transport,
insecure: cfg.Insecure,
}
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
c.httpClient = newLocalHTTPClient(tunnelType)
}
c.latencyCallback.Store(LatencyCallback(func(time.Duration) {}))
return c
}
// Connect establishes the primary connection and starts background workers.
func (c *PoolClient) Connect() error {
primaryConn, err := c.dial()
if err != nil {
return err
}
maxData := max(c.maxSessions-1, 0)
req := protocol.RegisterRequest{
Token: c.token,
CustomSubdomain: c.subdomain,
TunnelType: c.tunnelType,
LocalPort: c.localPort,
ConnectionType: "primary",
PoolCapabilities: &protocol.PoolCapabilities{
MaxDataConns: maxData,
Version: 1,
},
}
if len(c.allowIPs) > 0 || len(c.denyIPs) > 0 {
req.IPAccess = &protocol.IPAccessControl{
AllowIPs: c.allowIPs,
DenyIPs: c.denyIPs,
}
}
if c.authPass != "" {
req.ProxyAuth = &protocol.ProxyAuth{
Enabled: true,
Password: c.authPass,
}
}
payload, err := json.Marshal(req)
if err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to marshal request: %w", err)
}
if err := protocol.WriteFrame(primaryConn, protocol.NewFrame(protocol.FrameTypeRegister, payload)); err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to send registration: %w", err)
}
_ = primaryConn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
ack, err := protocol.ReadFrame(primaryConn)
if err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to read register ack: %w", err)
}
defer ack.Release()
_ = primaryConn.SetReadDeadline(time.Time{})
if ack.Type == protocol.FrameTypeError {
var errMsg protocol.ErrorMessage
if e := json.Unmarshal(ack.Payload, &errMsg); e == nil {
_ = primaryConn.Close()
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
}
_ = primaryConn.Close()
return fmt.Errorf("registration error")
}
if ack.Type != protocol.FrameTypeRegisterAck {
_ = primaryConn.Close()
return fmt.Errorf("unexpected register ack frame: %s", ack.Type)
}
var resp protocol.RegisterResponse
if err := json.Unmarshal(ack.Payload, &resp); err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to parse register response: %w", err)
}
c.assignedURL = resp.URL
c.subdomain = resp.Subdomain
if resp.SupportsDataConn && resp.TunnelID != "" {
c.tunnelID = resp.TunnelID
}
yamuxCfg := mux.NewClientConfig()
session, err := yamux.Server(primaryConn, yamuxCfg)
if err != nil {
_ = primaryConn.Close()
return fmt.Errorf("failed to init yamux session: %w", err)
}
primary := &sessionHandle{
id: "primary",
conn: primaryConn,
session: session,
}
primary.touch()
c.primary = primary
c.wg.Add(1)
go func() {
defer c.wg.Done()
<-c.stopCh
}()
c.wg.Add(1)
go c.acceptLoop(primary, true)
c.wg.Add(1)
go c.sessionWatcher(primary, true)
c.wg.Add(1)
go c.pingLoop(primary)
if c.tunnelID != "" {
c.mu.Lock()
c.desiredTotal = c.initialSessions
c.mu.Unlock()
c.warmupSessions()
c.wg.Add(1)
go c.scalerLoop()
}
go func() {
c.wg.Wait()
close(c.doneCh)
}()
return nil
}
func (c *PoolClient) dialTLS() (net.Conn, error) {
dialer := &net.Dialer{Timeout: 10 * time.Second}
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to connect: %w", err)
}
state := conn.ConnectionState()
if state.Version != tls.VersionTLS13 {
_ = conn.Close()
return nil, fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
}
if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetReadBuffer(256 * 1024)
_ = tcpConn.SetWriteBuffer(256 * 1024)
}
return conn, nil
}
// serverCapabilities holds the discovered server capabilities
type serverCapabilities struct {
Transports []string `json:"transports"`
Preferred string `json:"preferred"`
}
// dial selects the appropriate transport and establishes a connection
func (c *PoolClient) dial() (net.Conn, error) {
switch c.transport {
case TransportWebSocket:
return c.dialWebSocket()
case TransportTCP:
// User explicitly requested TCP, verify server supports it
caps := c.discoverServerCapabilities()
if caps != nil && len(caps.Transports) > 0 {
tcpSupported := false
for _, t := range caps.Transports {
if t == "tcp" {
tcpSupported = true
break
}
}
if !tcpSupported {
return nil, fmt.Errorf("server only supports %v transport(s), but --transport tcp was specified. Use --transport wss instead", caps.Transports)
}
}
return c.dialTLS()
default: // TransportAuto
// Check if server address indicates WebSocket
if strings.HasPrefix(c.serverAddr, "wss://") {
return c.dialWebSocket()
}
// Query server for preferred transport
caps := c.discoverServerCapabilities()
if caps != nil && caps.Preferred == "wss" {
return c.dialWebSocket()
}
// Default to TCP
return c.dialTLS()
}
}
// discoverServerCapabilities queries the server for its capabilities
func (c *PoolClient) discoverServerCapabilities() *serverCapabilities {
host, port, err := net.SplitHostPort(c.serverAddr)
if err != nil {
host = c.serverAddr
port = "443"
}
discoverURL := fmt.Sprintf("https://%s:%s/_drip/discover", host, port)
client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
TLSClientConfig: c.tlsConfig,
},
}
resp, err := client.Get(discoverURL)
if err != nil {
c.logger.Debug("Failed to discover server capabilities",
zap.Error(err),
)
return nil
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil
}
var caps serverCapabilities
if err := json.NewDecoder(resp.Body).Decode(&caps); err != nil {
return nil
}
c.logger.Debug("Discovered server capabilities",
zap.Strings("transports", caps.Transports),
zap.String("preferred", caps.Preferred),
)
return &caps
}
// dialWebSocket establishes a WebSocket connection to the server over TLS
func (c *PoolClient) dialWebSocket() (net.Conn, error) {
// Build WebSocket URL
host, port, err := net.SplitHostPort(c.serverAddr)
if err != nil {
// No port specified, use default
host = c.serverAddr
port = "443"
}
wsURL := fmt.Sprintf("wss://%s:%s/_drip/ws", host, port)
c.logger.Debug("Connecting via WebSocket over TLS",
zap.String("url", wsURL),
)
dialer := websocket.Dialer{
TLSClientConfig: c.tlsConfig,
HandshakeTimeout: 10 * time.Second,
ReadBufferSize: 256 * 1024,
WriteBufferSize: 256 * 1024,
}
// Add authorization header if token is set
header := http.Header{}
if c.token != "" {
header.Set("Authorization", "Bearer "+c.token)
}
ws, resp, err := dialer.Dial(wsURL, header)
if err != nil {
if resp != nil {
return nil, fmt.Errorf("WebSocket dial failed (status %d): %w", resp.StatusCode, err)
}
return nil, fmt.Errorf("WebSocket dial failed: %w", err)
}
// Wrap WebSocket as net.Conn with ping loop for CDN keep-alive
conn := wsutil.NewConnWithPing(ws, 30*time.Second)
c.logger.Debug("WebSocket connection established",
zap.String("remote_addr", ws.RemoteAddr().String()),
)
return conn, nil
}
func (c *PoolClient) acceptLoop(h *sessionHandle, isPrimary bool) {
defer c.wg.Done()
for {
select {
case <-c.stopCh:
return
default:
}
stream, err := h.session.Accept()
if err != nil {
if c.IsClosed() || isExpectedCloseError(err) {
return
}
if isPrimary {
c.logger.Debug("Primary session accept failed", zap.Error(err))
_ = c.Close()
return
}
c.logger.Debug("Data session accept failed", zap.String("session_id", h.id), zap.Error(err))
c.removeDataSession(h.id)
return
}
h.active.Add(1)
h.touch()
c.stats.AddRequest()
c.stats.IncActiveConnections()
c.wg.Add(1)
go c.handleStream(h, stream)
}
}
func (c *PoolClient) sessionWatcher(h *sessionHandle, isPrimary bool) {
defer c.wg.Done()
select {
case <-c.stopCh:
return
case <-h.session.CloseChan():
if isPrimary {
_ = c.Close()
return
}
c.removeDataSession(h.id)
}
}
func (c *PoolClient) pingLoop(h *sessionHandle) {
defer c.wg.Done()
const maxConsecutiveFailures = 3
ticker := time.NewTicker(constants.HeartbeatInterval)
defer ticker.Stop()
consecutiveFailures := 0
for {
select {
case <-c.stopCh:
return
case <-ticker.C:
}
if h.session == nil || h.session.IsClosed() {
return
}
latency, err := h.session.Ping()
if err != nil {
consecutiveFailures++
c.logger.Debug("Ping failed",
zap.String("session_id", h.id),
zap.Int("consecutive_failures", consecutiveFailures),
zap.Error(err),
)
if consecutiveFailures >= maxConsecutiveFailures {
c.logger.Warn("Session ping failed too many times, closing",
zap.String("session_id", h.id),
zap.Int("failures", consecutiveFailures),
)
if h.id == "primary" {
_ = c.Close()
return
}
c.removeDataSession(h.id)
return
}
continue
}
consecutiveFailures = 0
h.touch()
c.latencyNanos.Store(int64(latency))
if cb, ok := c.latencyCallback.Load().(LatencyCallback); ok && cb != nil {
cb(latency)
}
}
}
// Close shuts down the client and all sessions.
func (c *PoolClient) Close() error {
var closeErr error
c.once.Do(func() {
c.closed.Store(true)
close(c.stopCh)
if c.cancel != nil {
c.cancel()
}
var data []*sessionHandle
var primary *sessionHandle
c.mu.Lock()
for _, h := range c.dataSessions {
data = append(data, h)
}
c.dataSessions = make(map[string]*sessionHandle)
primary = c.primary
c.primary = nil
c.mu.Unlock()
for _, h := range data {
if h == nil {
continue
}
if h.session != nil {
_ = h.session.Close()
}
if h.conn != nil {
_ = h.conn.Close()
}
}
if primary != nil {
if primary.session != nil {
closeErr = primary.session.Close()
}
if primary.conn != nil {
_ = primary.conn.Close()
}
}
})
return closeErr
}
func (c *PoolClient) Wait() { <-c.doneCh }
func (c *PoolClient) GetURL() string { return c.assignedURL }
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
func (c *PoolClient) SetLatencyCallback(cb LatencyCallback) {
if cb == nil {
cb = func(time.Duration) {}
}
c.latencyCallback.Store(cb)
}