mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-27 14:50:52 +00:00
- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server - Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly - Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
451 lines
9.7 KiB
Go
451 lines
9.7 KiB
Go
package tcp
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
json "github.com/goccy/go-json"
|
|
"github.com/hashicorp/yamux"
|
|
"go.uber.org/zap"
|
|
|
|
"drip/internal/shared/constants"
|
|
"drip/internal/shared/protocol"
|
|
"drip/internal/shared/stats"
|
|
"drip/pkg/config"
|
|
)
|
|
|
|
// PoolClient manages a pool of yamux sessions for tunnel connections.
|
|
type PoolClient struct {
|
|
serverAddr string
|
|
tlsConfig *tls.Config
|
|
token string
|
|
tunnelType protocol.TunnelType
|
|
localHost string
|
|
localPort int
|
|
subdomain string
|
|
|
|
assignedURL string
|
|
tunnelID string
|
|
|
|
minSessions int
|
|
maxSessions int
|
|
initialSessions int
|
|
|
|
stats *stats.TrafficStats
|
|
|
|
httpClient *http.Client
|
|
|
|
latencyCallback atomic.Value // LatencyCallback
|
|
latencyNanos atomic.Int64
|
|
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
|
|
stopCh chan struct{}
|
|
doneCh chan struct{}
|
|
once sync.Once
|
|
wg sync.WaitGroup
|
|
closed atomic.Bool
|
|
|
|
primary *sessionHandle
|
|
|
|
mu sync.RWMutex
|
|
dataSessions map[string]*sessionHandle
|
|
desiredTotal int
|
|
lastScale time.Time
|
|
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewPoolClient creates a new pool client.
|
|
func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
|
var tlsConfig *tls.Config
|
|
if cfg.Insecure {
|
|
tlsConfig = config.GetClientTLSConfigInsecure()
|
|
} else {
|
|
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
|
|
tlsConfig = config.GetClientTLSConfig(host)
|
|
}
|
|
|
|
localHost := cfg.LocalHost
|
|
if localHost == "" {
|
|
localHost = "127.0.0.1"
|
|
}
|
|
|
|
tunnelType := cfg.TunnelType
|
|
if tunnelType == "" {
|
|
tunnelType = protocol.TunnelTypeTCP
|
|
}
|
|
|
|
numCPU := runtime.NumCPU()
|
|
|
|
minSessions := cfg.PoolMin
|
|
if minSessions <= 0 {
|
|
minSessions = 2
|
|
}
|
|
|
|
maxSessions := cfg.PoolMax
|
|
if maxSessions <= 0 {
|
|
maxSessions = max(numCPU*16, minSessions)
|
|
}
|
|
if maxSessions < minSessions {
|
|
maxSessions = minSessions
|
|
}
|
|
|
|
initialSessions := cfg.PoolSize
|
|
if initialSessions <= 0 {
|
|
initialSessions = 4
|
|
}
|
|
initialSessions = min(max(initialSessions, minSessions), maxSessions)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
c := &PoolClient{
|
|
serverAddr: cfg.ServerAddr,
|
|
tlsConfig: tlsConfig,
|
|
token: cfg.Token,
|
|
tunnelType: tunnelType,
|
|
localHost: localHost,
|
|
localPort: cfg.LocalPort,
|
|
subdomain: cfg.Subdomain,
|
|
minSessions: minSessions,
|
|
maxSessions: maxSessions,
|
|
initialSessions: initialSessions,
|
|
stats: stats.NewTrafficStats(),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
stopCh: make(chan struct{}),
|
|
doneCh: make(chan struct{}),
|
|
dataSessions: make(map[string]*sessionHandle),
|
|
logger: logger,
|
|
}
|
|
|
|
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
|
|
c.httpClient = newLocalHTTPClient(tunnelType)
|
|
}
|
|
|
|
c.latencyCallback.Store(LatencyCallback(func(time.Duration) {}))
|
|
return c
|
|
}
|
|
|
|
// Connect establishes the primary connection and starts background workers.
|
|
func (c *PoolClient) Connect() error {
|
|
primaryConn, err := c.dialTLS()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
maxData := max(c.maxSessions-1, 0)
|
|
req := protocol.RegisterRequest{
|
|
Token: c.token,
|
|
CustomSubdomain: c.subdomain,
|
|
TunnelType: c.tunnelType,
|
|
LocalPort: c.localPort,
|
|
ConnectionType: "primary",
|
|
PoolCapabilities: &protocol.PoolCapabilities{
|
|
MaxDataConns: maxData,
|
|
Version: 1,
|
|
},
|
|
}
|
|
|
|
payload, err := json.Marshal(req)
|
|
if err != nil {
|
|
_ = primaryConn.Close()
|
|
return fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
if err := protocol.WriteFrame(primaryConn, protocol.NewFrame(protocol.FrameTypeRegister, payload)); err != nil {
|
|
_ = primaryConn.Close()
|
|
return fmt.Errorf("failed to send registration: %w", err)
|
|
}
|
|
|
|
_ = primaryConn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
|
ack, err := protocol.ReadFrame(primaryConn)
|
|
if err != nil {
|
|
_ = primaryConn.Close()
|
|
return fmt.Errorf("failed to read register ack: %w", err)
|
|
}
|
|
defer ack.Release()
|
|
_ = primaryConn.SetReadDeadline(time.Time{})
|
|
|
|
if ack.Type == protocol.FrameTypeError {
|
|
var errMsg protocol.ErrorMessage
|
|
if e := json.Unmarshal(ack.Payload, &errMsg); e == nil {
|
|
_ = primaryConn.Close()
|
|
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
|
|
}
|
|
_ = primaryConn.Close()
|
|
return fmt.Errorf("registration error")
|
|
}
|
|
if ack.Type != protocol.FrameTypeRegisterAck {
|
|
_ = primaryConn.Close()
|
|
return fmt.Errorf("unexpected register ack frame: %s", ack.Type)
|
|
}
|
|
|
|
var resp protocol.RegisterResponse
|
|
if err := json.Unmarshal(ack.Payload, &resp); err != nil {
|
|
_ = primaryConn.Close()
|
|
return fmt.Errorf("failed to parse register response: %w", err)
|
|
}
|
|
|
|
c.assignedURL = resp.URL
|
|
c.subdomain = resp.Subdomain
|
|
if resp.SupportsDataConn && resp.TunnelID != "" {
|
|
c.tunnelID = resp.TunnelID
|
|
}
|
|
|
|
yamuxCfg := yamux.DefaultConfig()
|
|
yamuxCfg.EnableKeepAlive = false
|
|
yamuxCfg.LogOutput = io.Discard
|
|
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
|
|
|
session, err := yamux.Server(primaryConn, yamuxCfg)
|
|
if err != nil {
|
|
_ = primaryConn.Close()
|
|
return fmt.Errorf("failed to init yamux session: %w", err)
|
|
}
|
|
|
|
primary := &sessionHandle{
|
|
id: "primary",
|
|
conn: primaryConn,
|
|
session: session,
|
|
}
|
|
primary.touch()
|
|
c.primary = primary
|
|
|
|
c.wg.Add(1)
|
|
go func() {
|
|
defer c.wg.Done()
|
|
<-c.stopCh
|
|
}()
|
|
|
|
c.wg.Add(1)
|
|
go c.acceptLoop(primary, true)
|
|
|
|
c.wg.Add(1)
|
|
go c.sessionWatcher(primary, true)
|
|
|
|
c.wg.Add(1)
|
|
go c.pingLoop(primary)
|
|
|
|
if c.tunnelID != "" {
|
|
c.mu.Lock()
|
|
c.desiredTotal = c.initialSessions
|
|
c.mu.Unlock()
|
|
|
|
c.ensureSessions()
|
|
|
|
c.wg.Add(1)
|
|
go c.scalerLoop()
|
|
}
|
|
|
|
go func() {
|
|
c.wg.Wait()
|
|
close(c.doneCh)
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *PoolClient) dialTLS() (net.Conn, error) {
|
|
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
|
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect: %w", err)
|
|
}
|
|
|
|
state := conn.ConnectionState()
|
|
if state.Version != tls.VersionTLS13 {
|
|
_ = conn.Close()
|
|
return nil, fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
|
|
}
|
|
|
|
if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
|
|
_ = tcpConn.SetNoDelay(true)
|
|
_ = tcpConn.SetKeepAlive(true)
|
|
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
|
_ = tcpConn.SetReadBuffer(256 * 1024)
|
|
_ = tcpConn.SetWriteBuffer(256 * 1024)
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
func (c *PoolClient) acceptLoop(h *sessionHandle, isPrimary bool) {
|
|
defer c.wg.Done()
|
|
|
|
for {
|
|
select {
|
|
case <-c.stopCh:
|
|
return
|
|
default:
|
|
}
|
|
|
|
stream, err := h.session.Accept()
|
|
if err != nil {
|
|
if c.IsClosed() || isExpectedCloseError(err) {
|
|
return
|
|
}
|
|
if isPrimary {
|
|
c.logger.Debug("Primary session accept failed", zap.Error(err))
|
|
_ = c.Close()
|
|
return
|
|
}
|
|
|
|
c.logger.Debug("Data session accept failed", zap.String("session_id", h.id), zap.Error(err))
|
|
c.removeDataSession(h.id)
|
|
return
|
|
}
|
|
|
|
h.active.Add(1)
|
|
h.touch()
|
|
|
|
c.stats.AddRequest()
|
|
c.stats.IncActiveConnections()
|
|
|
|
c.wg.Add(1)
|
|
go c.handleStream(h, stream)
|
|
}
|
|
}
|
|
|
|
func (c *PoolClient) sessionWatcher(h *sessionHandle, isPrimary bool) {
|
|
defer c.wg.Done()
|
|
|
|
select {
|
|
case <-c.stopCh:
|
|
return
|
|
case <-h.session.CloseChan():
|
|
if isPrimary {
|
|
_ = c.Close()
|
|
return
|
|
}
|
|
c.removeDataSession(h.id)
|
|
}
|
|
}
|
|
|
|
func (c *PoolClient) pingLoop(h *sessionHandle) {
|
|
defer c.wg.Done()
|
|
|
|
const maxConsecutiveFailures = 3
|
|
|
|
ticker := time.NewTicker(constants.HeartbeatInterval)
|
|
defer ticker.Stop()
|
|
|
|
consecutiveFailures := 0
|
|
|
|
for {
|
|
select {
|
|
case <-c.stopCh:
|
|
return
|
|
case <-ticker.C:
|
|
}
|
|
|
|
if h.session == nil || h.session.IsClosed() {
|
|
return
|
|
}
|
|
|
|
latency, err := h.session.Ping()
|
|
if err != nil {
|
|
consecutiveFailures++
|
|
c.logger.Debug("Ping failed",
|
|
zap.String("session_id", h.id),
|
|
zap.Int("consecutive_failures", consecutiveFailures),
|
|
zap.Error(err),
|
|
)
|
|
|
|
if consecutiveFailures >= maxConsecutiveFailures {
|
|
c.logger.Warn("Session ping failed too many times, closing",
|
|
zap.String("session_id", h.id),
|
|
zap.Int("failures", consecutiveFailures),
|
|
)
|
|
if h.id == "primary" {
|
|
_ = c.Close()
|
|
return
|
|
}
|
|
c.removeDataSession(h.id)
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
|
|
consecutiveFailures = 0
|
|
h.touch()
|
|
|
|
c.latencyNanos.Store(int64(latency))
|
|
if cb, ok := c.latencyCallback.Load().(LatencyCallback); ok && cb != nil {
|
|
cb(latency)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Close shuts down the client and all sessions.
|
|
func (c *PoolClient) Close() error {
|
|
var closeErr error
|
|
|
|
c.once.Do(func() {
|
|
c.closed.Store(true)
|
|
close(c.stopCh)
|
|
|
|
if c.cancel != nil {
|
|
c.cancel()
|
|
}
|
|
|
|
var data []*sessionHandle
|
|
var primary *sessionHandle
|
|
|
|
c.mu.Lock()
|
|
for _, h := range c.dataSessions {
|
|
data = append(data, h)
|
|
}
|
|
c.dataSessions = make(map[string]*sessionHandle)
|
|
primary = c.primary
|
|
c.primary = nil
|
|
c.mu.Unlock()
|
|
|
|
for _, h := range data {
|
|
if h == nil {
|
|
continue
|
|
}
|
|
if h.session != nil {
|
|
_ = h.session.Close()
|
|
}
|
|
if h.conn != nil {
|
|
_ = h.conn.Close()
|
|
}
|
|
}
|
|
|
|
if primary != nil {
|
|
if primary.session != nil {
|
|
closeErr = primary.session.Close()
|
|
}
|
|
if primary.conn != nil {
|
|
_ = primary.conn.Close()
|
|
}
|
|
}
|
|
})
|
|
|
|
return closeErr
|
|
}
|
|
|
|
func (c *PoolClient) Wait() { <-c.doneCh }
|
|
func (c *PoolClient) GetURL() string { return c.assignedURL }
|
|
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
|
|
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
|
|
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
|
|
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
|
|
|
|
func (c *PoolClient) SetLatencyCallback(cb LatencyCallback) {
|
|
if cb == nil {
|
|
cb = func(time.Duration) {}
|
|
}
|
|
c.latencyCallback.Store(cb)
|
|
}
|