perf(core): Optimizes performance configuration and resource management

- Removed the manual performance optimization configuration in main.go and replaced it with a new tuning module.
- Add patterned GC tuning in server.go and tunnel_runner.go
- Updated yamux configuration to a unified optimized configuration to improve throughput.
- Implement connection pool preheating function to eliminate cold start delay.
- Optimize session selection using a min-heap, reducing the time complexity from O(n) to O(log n).
- Add a bufio.Reader pool and a buffer pool to reduce memory allocation.
- Implement a fragmented lock manager to improve performance under high concurrency.
- Adjust heartbeat and timeout configurations to suit high-throughput scenarios
BREAKING CHANGE: Manual GC tuning configuration has been removed; automatic tuning mode is now used.
This commit is contained in:
Gouryella
2025-12-23 11:16:12 +08:00
parent 0cff316334
commit 88e4525bf6
22 changed files with 662 additions and 272 deletions

2
.gitignore vendored
View File

@@ -52,5 +52,5 @@ temp/
certs/
.drip-server.env
benchmark-results/
drip
drip-linux-amd64
./drip

View File

@@ -3,8 +3,6 @@ package main
import (
"fmt"
"os"
"runtime"
"runtime/debug"
"drip/internal/client/cli"
)
@@ -16,9 +14,6 @@ var (
)
func main() {
// Performance optimizations
setupPerformanceOptimizations()
cli.SetVersion(Version, GitCommit, BuildTime)
if err := cli.Execute(); err != nil {
@@ -26,19 +21,3 @@ func main() {
os.Exit(1)
}
}
// setupPerformanceOptimizations configures runtime settings for optimal performance
func setupPerformanceOptimizations() {
// Set GOMAXPROCS to use all available CPU cores
numCPU := runtime.NumCPU()
runtime.GOMAXPROCS(numCPU)
// Reduce GC frequency for high-throughput scenarios
// Default is 100, setting to 200 reduces GC overhead at cost of more memory
// This is beneficial since we now use buffer pools (less garbage)
debug.SetGCPercent(200)
// Set memory limit to prevent OOM (adjust based on your server)
// This is a soft limit - Go will try to stay under this
debug.SetMemoryLimit(8 * 1024 * 1024 * 1024) // 8GB limit
}

View File

@@ -13,6 +13,7 @@ import (
"drip/internal/server/tcp"
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
"drip/internal/shared/tuning"
"drip/internal/shared/utils"
"drip/pkg/config"
"github.com/spf13/cobra"
@@ -60,6 +61,9 @@ func init() {
}
func runServer(_ *cobra.Command, _ []string) error {
// Apply server-mode GC tuning (high throughput, more memory)
tuning.ApplyMode(tuning.ModeServer)
if serverTLSCert == "" {
return fmt.Errorf("TLS certificate path is required (use --tls-cert flag or DRIP_TLS_CERT environment variable)")
}

View File

@@ -9,8 +9,10 @@ import (
"time"
"drip/internal/client/tcp"
"drip/internal/shared/tuning"
"drip/internal/shared/ui"
"drip/internal/shared/utils"
"go.uber.org/zap"
)
@@ -20,6 +22,8 @@ const (
)
func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) error {
tuning.ApplyMode(tuning.ModeClient)
if err := utils.InitLogger(verbose); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"runtime"
@@ -17,6 +16,7 @@ import (
"go.uber.org/zap"
"drip/internal/shared/constants"
"drip/internal/shared/mux"
"drip/internal/shared/protocol"
"drip/internal/shared/stats"
"drip/pkg/config"
@@ -202,10 +202,7 @@ func (c *PoolClient) Connect() error {
c.tunnelID = resp.TunnelID
}
yamuxCfg := yamux.DefaultConfig()
yamuxCfg.EnableKeepAlive = false
yamuxCfg.LogOutput = io.Discard
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
yamuxCfg := mux.NewClientConfig()
session, err := yamux.Server(primaryConn, yamuxCfg)
if err != nil {
@@ -241,7 +238,7 @@ func (c *PoolClient) Connect() error {
c.desiredTotal = c.initialSessions
c.mu.Unlock()
c.ensureSessions()
c.warmupSessions()
c.wg.Add(1)
go c.scalerLoop()

View File

@@ -2,7 +2,6 @@ package tcp
import (
"fmt"
"io"
"net"
"sync"
"sync/atomic"
@@ -11,10 +10,12 @@ import (
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"drip/internal/shared/constants"
"drip/internal/shared/mux"
"drip/internal/shared/protocol"
)
var dataConnCounter atomic.Uint64
// sessionHandle wraps a yamux session with metadata.
type sessionHandle struct {
id string
@@ -37,6 +38,36 @@ func (h *sessionHandle) lastActiveTime() time.Time {
return time.Unix(0, n)
}
// warmupSessions pre-creates initial sessions in parallel to eliminate cold-start latency.
func (c *PoolClient) warmupSessions() {
if c.IsClosed() || c.tunnelID == "" {
return
}
c.mu.RLock()
desired := c.desiredTotal
c.mu.RUnlock()
current := c.sessionCount()
toCreate := desired - current
if toCreate <= 0 {
return
}
var wg sync.WaitGroup
for i := 0; i < toCreate; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = c.addDataSession()
}()
}
wg.Wait()
// Brief wait for server to register all sessions
time.Sleep(100 * time.Millisecond)
}
// scalerLoop monitors load and adjusts session count.
func (c *PoolClient) scalerLoop() {
defer c.wg.Done()
@@ -162,7 +193,7 @@ func (c *PoolClient) addDataSession() error {
return err
}
connID := fmt.Sprintf("data-%d", time.Now().UnixNano())
connID := fmt.Sprintf("data-%d", dataConnCounter.Add(1))
req := protocol.DataConnectRequest{
TunnelID: c.tunnelID,
@@ -214,10 +245,7 @@ func (c *PoolClient) addDataSession() error {
return fmt.Errorf("data connection rejected: %s", resp.Message)
}
yamuxCfg := yamux.DefaultConfig()
yamuxCfg.EnableKeepAlive = false
yamuxCfg.LogOutput = io.Discard
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
yamuxCfg := mux.NewClientConfig()
session, err := yamux.Server(conn, yamuxCfg)
if err != nil {

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"
json "github.com/goccy/go-json"
@@ -16,11 +17,19 @@ import (
"drip/internal/server/tunnel"
"drip/internal/shared/httputil"
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// bufio.Reader pool to reduce allocations on hot path
var bufioReaderPool = sync.Pool{
New: func() interface{} {
return bufio.NewReaderSize(nil, 32*1024)
},
}
const openStreamTimeout = 3 * time.Second
type Handler struct {
@@ -104,13 +113,19 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r)
reader := bufioReaderPool.Get().(*bufio.Reader)
reader.Reset(countingStream)
resp, err := http.ReadResponse(reader, r)
if err != nil {
bufioReaderPool.Put(reader)
w.Header().Set("Connection", "close")
http.Error(w, "Read response failed", http.StatusBadGateway)
return
}
defer resp.Body.Close()
defer func() {
resp.Body.Close()
bufioReaderPool.Put(reader)
}()
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
@@ -147,7 +162,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}()
_, _ = io.Copy(w, resp.Body)
// Use pooled buffer for zero-copy optimization
buf := pool.GetBuffer(pool.SizeLarge)
_, _ = io.CopyBuffer(w, resp.Body, (*buf)[:])
pool.PutBuffer(buf)
close(done)
stream.Close()
}

View File

@@ -18,6 +18,7 @@ import (
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
"drip/internal/shared/mux"
"drip/internal/shared/protocol"
"go.uber.org/zap"
@@ -686,10 +687,8 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
// Use optimized mux config for server
cfg := mux.NewServerConfig()
session, err := yamux.Client(bc, cfg)
if err != nil {

View File

@@ -1,10 +1,10 @@
package tcp
import (
"container/heap"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/yamux"
@@ -15,6 +15,54 @@ import (
"go.uber.org/zap"
)
// sessionEntry represents a session with its current stream count for heap operations
type sessionEntry struct {
id string
session *yamux.Session
streams int
heapIdx int // index in the heap, managed by heap.Interface
}
// sessionHeap implements heap.Interface for O(log n) session selection
type sessionHeap []*sessionEntry
func (h sessionHeap) Len() int { return len(h) }
func (h sessionHeap) Less(i, j int) bool {
// Min-heap: session with fewer streams has higher priority
return h[i].streams < h[j].streams
}
func (h sessionHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
h[i].heapIdx = i
h[j].heapIdx = j
}
func (h *sessionHeap) Push(x interface{}) {
entry := x.(*sessionEntry)
entry.heapIdx = len(*h)
*h = append(*h, entry)
}
func (h *sessionHeap) Pop() interface{} {
old := *h
n := len(old)
entry := old[n-1]
old[n-1] = nil // avoid memory leak
entry.heapIdx = -1
*h = old[0 : n-1]
return entry
}
// sessionHeapPool reuses heap slices to reduce allocations
var sessionHeapPool = sync.Pool{
New: func() interface{} {
h := make(sessionHeap, 0, 16)
return &h
},
}
type ConnectionGroup struct {
TunnelID string
Subdomain string
@@ -60,6 +108,12 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
const maxConsecutiveFailures = 3
failureCount := make(map[string]int)
type sessionSnapshot struct {
id string
session *yamux.Session
}
sessions := make([]sessionSnapshot, 0, 16)
for {
select {
case <-g.stopCh:
@@ -67,26 +121,25 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
case <-ticker.C:
}
sessions = sessions[:0]
g.mu.RLock()
sessions := make(map[string]*yamux.Session, len(g.Sessions))
for id, s := range g.Sessions {
sessions[id] = s
sessions = append(sessions, sessionSnapshot{id: id, session: s})
}
g.mu.RUnlock()
for id, session := range sessions {
if session == nil || session.IsClosed() {
g.RemoveSession(id)
delete(failureCount, id)
for _, snap := range sessions {
if snap.session == nil || snap.session.IsClosed() {
g.RemoveSession(snap.id)
delete(failureCount, snap.id)
continue
}
// Ping with timeout
done := make(chan error, 1)
go func(s *yamux.Session) {
_, err := s.Ping()
done <- err
}(session)
}(snap.session)
var err error
select {
@@ -98,31 +151,29 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
}
if err != nil {
failureCount[id]++
failureCount[snap.id]++
g.logger.Debug("Session ping failed",
zap.String("session_id", id),
zap.Int("consecutive_failures", failureCount[id]),
zap.String("session_id", snap.id),
zap.Int("consecutive_failures", failureCount[snap.id]),
zap.Error(err),
)
if failureCount[id] >= maxConsecutiveFailures {
if failureCount[snap.id] >= maxConsecutiveFailures {
g.logger.Warn("Session ping failed too many times, removing",
zap.String("session_id", id),
zap.Int("failures", failureCount[id]),
zap.String("session_id", snap.id),
zap.Int("failures", failureCount[snap.id]),
)
g.RemoveSession(id)
delete(failureCount, id)
g.RemoveSession(snap.id)
delete(failureCount, snap.id)
}
} else {
// Reset on success
failureCount[id] = 0
failureCount[snap.id] = 0
g.mu.Lock()
g.LastActivity = time.Now()
g.mu.Unlock()
}
}
// Check if all sessions are gone
g.mu.RLock()
sessionCount := len(g.Sessions)
g.mu.RUnlock()
@@ -214,6 +265,7 @@ func (g *ConnectionGroup) SessionCount() int {
return len(g.Sessions)
}
// OpenStream opens a new stream using a min-heap for O(log n) session selection.
func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
const (
maxStreamsPerSession = 256
@@ -230,61 +282,33 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
default:
}
// Prefer data sessions for data-plane traffic; keep the primary session
// as control-plane (client ping/latency), and only fall back to primary
// when no data session exists.
sessions := g.sessionsSnapshot(false)
if len(sessions) == 0 {
sessions = g.sessionsSnapshot(true)
h := g.buildSessionHeap(false)
if h.Len() == 0 {
h = g.buildSessionHeap(true)
}
if len(sessions) == 0 {
if h.Len() == 0 {
return nil, net.ErrClosed
}
tried := make([]bool, len(sessions))
anyUnderCap := false
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
for h.Len() > 0 {
entry := heap.Pop(h).(*sessionEntry)
session := entry.session
for range sessions {
bestIdx := -1
minStreams := int(^uint(0) >> 1)
for i := 0; i < len(sessions); i++ {
idx := (start + i) % len(sessions)
if tried[idx] {
continue
}
session := sessions[idx]
if session == nil || session.IsClosed() {
tried[idx] = true
continue
}
n := session.NumStreams()
if n >= maxStreamsPerSession {
continue
}
anyUnderCap = true
if n < minStreams {
minStreams = n
bestIdx = idx
}
}
if bestIdx == -1 {
break
}
tried[bestIdx] = true
session := sessions[bestIdx]
if session == nil || session.IsClosed() {
continue
}
currentStreams := session.NumStreams()
if currentStreams >= maxStreamsPerSession {
continue
}
anyUnderCap = true
stream, err := session.Open()
if err == nil {
*h = (*h)[:0]
sessionHeapPool.Put(h)
return stream, nil
}
lastErr = err
@@ -294,6 +318,9 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
}
}
*h = (*h)[:0]
sessionHeapPool.Put(h)
if !anyUnderCap {
lastErr = fmt.Errorf("all sessions are at stream capacity (%d)", maxStreamsPerSession)
}
@@ -313,31 +340,59 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
return nil, lastErr
}
func (g *ConnectionGroup) selectSession() *yamux.Session {
sessions := g.sessionsSnapshot(false)
if len(sessions) == 0 {
sessions = g.sessionsSnapshot(true)
}
if len(sessions) == 0 {
return nil
// buildSessionHeap creates a min-heap of sessions ordered by stream count.
func (g *ConnectionGroup) buildSessionHeap(includePrimary bool) *sessionHeap {
g.mu.RLock()
defer g.mu.RUnlock()
if len(g.Sessions) == 0 {
h := sessionHeapPool.Get().(*sessionHeap)
return h
}
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
minStreams := int(^uint(0) >> 1)
var best *yamux.Session
h := sessionHeapPool.Get().(*sessionHeap)
*h = (*h)[:0]
for i := 0; i < len(sessions); i++ {
session := sessions[(start+i)%len(sessions)]
for id, session := range g.Sessions {
if session == nil || session.IsClosed() {
continue
}
if n := session.NumStreams(); n < minStreams {
minStreams = n
best = session
if id == "primary" && !includePrimary {
continue
}
*h = append(*h, &sessionEntry{
id: id,
session: session,
streams: session.NumStreams(),
})
}
return best
heap.Init(h)
return h
}
func (g *ConnectionGroup) selectSession() *yamux.Session {
h := g.buildSessionHeap(false)
if h.Len() == 0 {
sessionHeapPool.Put(h)
h = g.buildSessionHeap(true)
}
if h.Len() == 0 {
sessionHeapPool.Put(h)
return nil
}
entry := heap.Pop(h).(*sessionEntry)
session := entry.session
*h = (*h)[:0]
sessionHeapPool.Put(h)
if session == nil || session.IsClosed() {
return nil
}
return session
}
func (g *ConnectionGroup) sessionsSnapshot(includePrimary bool) []*yamux.Session {

View File

@@ -3,12 +3,11 @@ package tcp
import (
"bufio"
"fmt"
"io"
"net"
"github.com/hashicorp/yamux"
"drip/internal/shared/constants"
"drip/internal/shared/mux"
)
type bufferedConn struct {
@@ -27,10 +26,8 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
// Use optimized mux config for server
cfg := mux.NewServerConfig()
session, err := yamux.Client(bc, cfg)
if err != nil {
@@ -66,10 +63,8 @@ func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
// Use optimized mux config for server
cfg := mux.NewServerConfig()
session, err := yamux.Client(bc, cfg)
if err != nil {

View File

@@ -115,8 +115,7 @@ func (c *Connection) GetTunnelType() protocol.TunnelType {
return c.tunnelType
}
// SetOpenStream registers a yamux stream opener for this tunnel.
// It is used by the HTTP proxy to forward each request over a mux stream.
// SetOpenStream registers a stream opener for this tunnel.
func (c *Connection) SetOpenStream(open func() (net.Conn, error)) {
c.mu.Lock()
c.openStream = open
@@ -178,17 +177,11 @@ func (c *Connection) GetActiveConnections() int64 {
// StartWritePump starts the write pump for sending messages
func (c *Connection) StartWritePump() {
// Skip write pump for TCP-only connections (no WebSocket)
if c.Conn == nil {
c.logger.Debug("Skipping WritePump for TCP connection",
zap.String("subdomain", c.Subdomain),
)
// Still need to drain SendCh to prevent blocking
go func() {
for {
select {
case <-c.SendCh:
// Discard messages for TCP mode
case <-c.CloseCh:
return
}

View File

@@ -2,7 +2,9 @@ package tunnel
import (
"errors"
"hash/fnv"
"sync"
"sync/atomic"
"time"
"drip/internal/shared/utils"
@@ -12,10 +14,14 @@ import (
// Manager limits
const (
DefaultMaxTunnels = 1000 // Maximum total tunnels
DefaultMaxTunnelsPerIP = 10 // Maximum tunnels per IP
DefaultRateLimit = 10 // Registrations per IP per minute
DefaultRateLimitWindow = 1 * time.Minute // Rate limit window
DefaultMaxTunnels = 1000 // Maximum total tunnels
DefaultMaxTunnelsPerIP = 10 // Maximum tunnels per IP
DefaultRateLimit = 10 // Registrations per IP per minute
DefaultRateLimitWindow = 1 * time.Minute // Rate limit window
// numShards is the number of shards for lock distribution
// Using 32 shards reduces lock contention by ~32x under high concurrency
numShards = 32
)
var (
@@ -30,12 +36,17 @@ type rateLimitEntry struct {
windowEnd time.Time
}
// Manager manages all active tunnel connections
type Manager struct {
tunnels map[string]*Connection // subdomain -> connection
// shard holds a subset of tunnels with its own lock
type shard struct {
tunnels map[string]*Connection
used map[string]bool
mu sync.RWMutex
used map[string]bool // track used subdomains
logger *zap.Logger
}
// Manager manages all active tunnel connections with sharded locking
type Manager struct {
shards [numShards]shard
logger *zap.Logger
// Limits
maxTunnels int
@@ -43,7 +54,11 @@ type Manager struct {
rateLimit int
rateLimitWindow time.Duration
// Per-IP tracking
// Global counters (atomic for lock-free reads)
tunnelCount atomic.Int64
// Per-IP tracking (requires separate lock as it spans shards)
ipMu sync.RWMutex
tunnelsByIP map[string]int // IP -> tunnel count
rateLimits map[string]*rateLimitEntry // IP -> rate limit entry
@@ -94,11 +109,10 @@ func NewManagerWithConfig(logger *zap.Logger, cfg ManagerConfig) *Manager {
zap.Int("max_per_ip", cfg.MaxTunnelsPerIP),
zap.Int("rate_limit", cfg.RateLimit),
zap.Duration("rate_window", cfg.RateLimitWindow),
zap.Int("num_shards", numShards),
)
return &Manager{
tunnels: make(map[string]*Connection),
used: make(map[string]bool),
m := &Manager{
logger: logger,
maxTunnels: cfg.MaxTunnels,
maxTunnelsPerIP: cfg.MaxTunnelsPerIP,
@@ -108,10 +122,25 @@ func NewManagerWithConfig(logger *zap.Logger, cfg ManagerConfig) *Manager {
rateLimits: make(map[string]*rateLimitEntry),
stopCh: make(chan struct{}),
}
// Initialize all shards
for i := 0; i < numShards; i++ {
m.shards[i].tunnels = make(map[string]*Connection)
m.shards[i].used = make(map[string]bool)
}
return m
}
// checkRateLimit checks if the IP has exceeded rate limit
func (m *Manager) checkRateLimit(ip string) bool {
// getShard returns the shard for a given subdomain using FNV-1a hash
func (m *Manager) getShard(subdomain string) *shard {
h := fnv.New32a()
h.Write([]byte(subdomain))
return &m.shards[h.Sum32()%numShards]
}
// checkRateLimit checks if the IP has exceeded rate limit (caller must hold ipMu)
func (m *Manager) checkRateLimitLocked(ip string) bool {
now := time.Now()
entry, exists := m.rateLimits[ip]
@@ -139,22 +168,20 @@ func (m *Manager) Register(conn *websocket.Conn, customSubdomain string) (string
// RegisterWithIP registers a new tunnel with IP tracking
func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, remoteIP string) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Check total tunnel limit
if len(m.tunnels) >= m.maxTunnels {
// Check global limits first (lock-free read)
if m.tunnelCount.Load() >= int64(m.maxTunnels) {
m.logger.Warn("Maximum tunnel limit reached",
zap.Int("current", len(m.tunnels)),
zap.Int64("current", m.tunnelCount.Load()),
zap.Int("max", m.maxTunnels),
)
return "", ErrTooManyTunnels
}
// Check per-IP limits if IP is provided
// Check per-IP limits
if remoteIP != "" {
// Check rate limit
if !m.checkRateLimit(remoteIP) {
m.ipMu.Lock()
if !m.checkRateLimitLocked(remoteIP) {
m.ipMu.Unlock()
m.logger.Warn("Rate limit exceeded",
zap.String("ip", remoteIP),
zap.Int("limit", m.rateLimit),
@@ -162,8 +189,8 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
return "", ErrRateLimitExceeded
}
// Check per-IP tunnel limit
if m.tunnelsByIP[remoteIP] >= m.maxTunnelsPerIP {
m.ipMu.Unlock()
m.logger.Warn("Per-IP tunnel limit reached",
zap.String("ip", remoteIP),
zap.Int("current", m.tunnelsByIP[remoteIP]),
@@ -171,6 +198,7 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
)
return "", ErrTooManyPerIP
}
m.ipMu.Unlock()
}
var subdomain string
@@ -183,33 +211,56 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
if utils.IsReserved(customSubdomain) {
return "", ErrReservedSubdomain
}
if m.used[customSubdomain] {
// Check if subdomain is taken in its shard
s := m.getShard(customSubdomain)
s.mu.Lock()
if s.used[customSubdomain] {
s.mu.Unlock()
return "", ErrSubdomainTaken
}
subdomain = customSubdomain
// Register in shard
tc := NewConnection(subdomain, conn, m.logger)
tc.remoteIP = remoteIP
s.tunnels[subdomain] = tc
s.used[subdomain] = true
s.mu.Unlock()
} else {
// Generate unique random subdomain
subdomain = m.generateUniqueSubdomain()
s := m.getShard(subdomain)
s.mu.Lock()
tc := NewConnection(subdomain, conn, m.logger)
tc.remoteIP = remoteIP
s.tunnels[subdomain] = tc
s.used[subdomain] = true
s.mu.Unlock()
}
// Create connection
tc := NewConnection(subdomain, conn, m.logger)
tc.remoteIP = remoteIP // Track IP for cleanup
m.tunnels[subdomain] = tc
m.used[subdomain] = true
// Update per-IP counter
// Update counters
m.tunnelCount.Add(1)
if remoteIP != "" {
m.ipMu.Lock()
m.tunnelsByIP[remoteIP]++
m.ipMu.Unlock()
}
// Start write pump in background
go tc.StartWritePump()
// Get connection and start write pump
s := m.getShard(subdomain)
s.mu.RLock()
tc := s.tunnels[subdomain]
s.mu.RUnlock()
if tc != nil {
go tc.StartWritePump()
}
m.logger.Info("Tunnel registered",
zap.String("subdomain", subdomain),
zap.String("ip", remoteIP),
zap.Int("total_tunnels", len(m.tunnels)),
zap.Int64("total_tunnels", m.tunnelCount.Load()),
)
return subdomain, nil
@@ -217,101 +268,129 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
// Unregister removes a tunnel connection
func (m *Manager) Unregister(subdomain string) {
m.mu.Lock()
defer m.mu.Unlock()
s := m.getShard(subdomain)
s.mu.Lock()
if tc, ok := m.tunnels[subdomain]; ok {
// Decrement per-IP counter
if tc.remoteIP != "" && m.tunnelsByIP[tc.remoteIP] > 0 {
m.tunnelsByIP[tc.remoteIP]--
if m.tunnelsByIP[tc.remoteIP] == 0 {
delete(m.tunnelsByIP, tc.remoteIP)
tc, ok := s.tunnels[subdomain]
if !ok {
s.mu.Unlock()
return
}
remoteIP := tc.remoteIP
tc.Close()
delete(s.tunnels, subdomain)
delete(s.used, subdomain)
s.mu.Unlock()
// Update counters
m.tunnelCount.Add(-1)
if remoteIP != "" {
m.ipMu.Lock()
if m.tunnelsByIP[remoteIP] > 0 {
m.tunnelsByIP[remoteIP]--
if m.tunnelsByIP[remoteIP] == 0 {
delete(m.tunnelsByIP, remoteIP)
}
}
tc.Close()
delete(m.tunnels, subdomain)
delete(m.used, subdomain)
m.logger.Info("Tunnel unregistered",
zap.String("subdomain", subdomain),
zap.Int("total_tunnels", len(m.tunnels)),
)
m.ipMu.Unlock()
}
m.logger.Info("Tunnel unregistered",
zap.String("subdomain", subdomain),
zap.Int64("total_tunnels", m.tunnelCount.Load()),
)
}
// Get retrieves a tunnel connection by subdomain
func (m *Manager) Get(subdomain string) (*Connection, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
tc, ok := m.tunnels[subdomain]
s := m.getShard(subdomain)
s.mu.RLock()
tc, ok := s.tunnels[subdomain]
s.mu.RUnlock()
return tc, ok
}
// List returns all active tunnel connections
func (m *Manager) List() []*Connection {
m.mu.RLock()
defer m.mu.RUnlock()
// Pre-allocate with approximate capacity
connections := make([]*Connection, 0, m.tunnelCount.Load())
connections := make([]*Connection, 0, len(m.tunnels))
for _, tc := range m.tunnels {
connections = append(connections, tc)
for i := 0; i < numShards; i++ {
s := &m.shards[i]
s.mu.RLock()
for _, tc := range s.tunnels {
connections = append(connections, tc)
}
s.mu.RUnlock()
}
return connections
}
// Count returns the number of active tunnels
func (m *Manager) Count() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.tunnels)
return int(m.tunnelCount.Load())
}
// CleanupStale removes stale connections that haven't been active
func (m *Manager) CleanupStale(timeout time.Duration) int {
m.mu.Lock()
defer m.mu.Unlock()
totalCleaned := 0
staleSubdomains := []string{}
// Clean up each shard independently
for i := 0; i < numShards; i++ {
s := &m.shards[i]
s.mu.Lock()
for subdomain, tc := range m.tunnels {
if !tc.IsAlive(timeout) {
staleSubdomains = append(staleSubdomains, subdomain)
var staleSubdomains []string
for subdomain, tc := range s.tunnels {
if !tc.IsAlive(timeout) {
staleSubdomains = append(staleSubdomains, subdomain)
}
}
}
for _, subdomain := range staleSubdomains {
if tc, ok := m.tunnels[subdomain]; ok {
// Decrement per-IP counter
if tc.remoteIP != "" && m.tunnelsByIP[tc.remoteIP] > 0 {
m.tunnelsByIP[tc.remoteIP]--
if m.tunnelsByIP[tc.remoteIP] == 0 {
delete(m.tunnelsByIP, tc.remoteIP)
for _, subdomain := range staleSubdomains {
if tc, ok := s.tunnels[subdomain]; ok {
remoteIP := tc.remoteIP
tc.Close()
delete(s.tunnels, subdomain)
delete(s.used, subdomain)
// Update counters
m.tunnelCount.Add(-1)
if remoteIP != "" {
m.ipMu.Lock()
if m.tunnelsByIP[remoteIP] > 0 {
m.tunnelsByIP[remoteIP]--
if m.tunnelsByIP[remoteIP] == 0 {
delete(m.tunnelsByIP, remoteIP)
}
}
m.ipMu.Unlock()
}
}
tc.Close()
delete(m.tunnels, subdomain)
delete(m.used, subdomain)
}
totalCleaned += len(staleSubdomains)
s.mu.Unlock()
}
// Cleanup expired rate limit entries
m.ipMu.Lock()
now := time.Now()
for ip, entry := range m.rateLimits {
if now.After(entry.windowEnd) {
delete(m.rateLimits, ip)
}
}
m.ipMu.Unlock()
if len(staleSubdomains) > 0 {
if totalCleaned > 0 {
m.logger.Info("Cleaned up stale tunnels",
zap.Int("count", len(staleSubdomains)),
zap.Int("count", totalCleaned),
)
}
return len(staleSubdomains)
return totalCleaned
}
// StartCleanupTask starts a background task to clean up stale connections
@@ -336,7 +415,16 @@ func (m *Manager) generateUniqueSubdomain() string {
for i := 0; i < maxAttempts; i++ {
subdomain := utils.GenerateSubdomain(6)
if !m.used[subdomain] && !utils.IsReserved(subdomain) {
if utils.IsReserved(subdomain) {
continue
}
s := m.getShard(subdomain)
s.mu.RLock()
taken := s.used[subdomain]
s.mu.RUnlock()
if !taken {
return subdomain
}
}
@@ -350,17 +438,21 @@ func (m *Manager) Shutdown() {
// Signal cleanup goroutine to stop
close(m.stopCh)
m.mu.Lock()
defer m.mu.Unlock()
m.logger.Info("Shutting down tunnel manager",
zap.Int("active_tunnels", len(m.tunnels)),
zap.Int64("active_tunnels", m.tunnelCount.Load()),
)
for _, tc := range m.tunnels {
tc.Close()
// Close all tunnels in each shard
for i := 0; i < numShards; i++ {
s := &m.shards[i]
s.mu.Lock()
for _, tc := range s.tunnels {
tc.Close()
}
s.tunnels = make(map[string]*Connection)
s.used = make(map[string]bool)
s.mu.Unlock()
}
m.tunnels = make(map[string]*Connection)
m.used = make(map[string]bool)
m.tunnelCount.Store(0)
}

View File

@@ -9,9 +9,33 @@ const (
// DefaultWSPort is the default WebSocket port
DefaultWSPort = 8080
// ==================== Yamux Configuration ====================
// These settings are tuned for high-throughput tunnel scenarios
// YamuxAcceptBacklog controls how many incoming streams can be queued
// before yamux starts blocking stream opens under load.
YamuxAcceptBacklog = 4096
// Increased from default 256 to handle burst traffic.
YamuxAcceptBacklog = 8192
// YamuxMaxStreamWindowSize is the maximum window size for a stream.
// Larger windows allow more data in-flight, improving throughput on high-latency links.
// Default is 256KB, increased to 512KB for better throughput.
YamuxMaxStreamWindowSize = 512 * 1024
// YamuxStreamOpenTimeout is how long to wait for a stream open to complete.
YamuxStreamOpenTimeout = 10 * time.Second
// YamuxStreamCloseTimeout is how long to wait for a stream close to complete.
YamuxStreamCloseTimeout = 5 * time.Minute
// YamuxKeepAliveInterval is how often to send keep-alive pings.
// Set higher than HeartbeatInterval to avoid redundant pings.
YamuxKeepAliveInterval = 15 * time.Second
// YamuxConnectionWriteTimeout is the timeout for writing to the underlying connection.
YamuxConnectionWriteTimeout = 10 * time.Second
// ==================== Heartbeat Configuration ====================
// HeartbeatInterval is how often clients send heartbeat messages
HeartbeatInterval = 2 * time.Second
@@ -19,9 +43,13 @@ const (
// HeartbeatTimeout is how long the server waits before considering a connection dead
HeartbeatTimeout = 6 * time.Second
// ==================== Request/Response Timeouts ====================
// RequestTimeout is the maximum time to wait for a response from the client
RequestTimeout = 30 * time.Second
// ==================== Reconnection Configuration ====================
// ReconnectBaseDelay is the initial delay for reconnection attempts
ReconnectBaseDelay = 1 * time.Second
@@ -31,9 +59,12 @@ const (
// MaxReconnectAttempts is the maximum number of reconnection attempts (0 = infinite)
MaxReconnectAttempts = 0
// ==================== TCP Port Allocation ====================
// DefaultTCPPortMin/Max define the default allocation range for TCP tunnels
DefaultTCPPortMin = 20000
DefaultTCPPortMax = 40000
// DefaultDomain is the default domain for tunnels
DefaultDomain = "tunnel.localhost"
)

View File

@@ -0,0 +1,35 @@
package mux
import (
"io"
"github.com/hashicorp/yamux"
"drip/internal/shared/constants"
)
// NewOptimizedConfig returns a multiplexer config optimized for tunnel scenarios.
func NewOptimizedConfig() *yamux.Config {
cfg := yamux.DefaultConfig()
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
cfg.MaxStreamWindowSize = constants.YamuxMaxStreamWindowSize
cfg.StreamOpenTimeout = constants.YamuxStreamOpenTimeout
cfg.StreamCloseTimeout = constants.YamuxStreamCloseTimeout
cfg.ConnectionWriteTimeout = constants.YamuxConnectionWriteTimeout
cfg.EnableKeepAlive = true
cfg.KeepAliveInterval = constants.YamuxKeepAliveInterval
cfg.LogOutput = io.Discard
return cfg
}
// NewServerConfig returns a multiplexer config for server-side use.
func NewServerConfig() *yamux.Config {
return NewOptimizedConfig()
}
// NewClientConfig returns a multiplexer config for client-side use.
func NewClientConfig() *yamux.Config {
cfg := NewOptimizedConfig()
cfg.EnableKeepAlive = false
return cfg
}

View File

@@ -2,14 +2,12 @@ package netutil
import "net"
// CountingConn wraps a net.Conn to count bytes read/written.
type CountingConn struct {
net.Conn
OnRead func(int64)
OnWrite func(int64)
}
// NewCountingConn creates a new CountingConn.
func NewCountingConn(conn net.Conn, onRead, onWrite func(int64)) *CountingConn {
return &CountingConn{
Conn: conn,

View File

@@ -3,15 +3,17 @@ package pool
import "sync"
const (
SizeSmall = 4 * 1024 // 4KB
SizeMedium = 32 * 1024 // 32KB
SizeLarge = 256 * 1024 // 256KB
SizeSmall = 4 * 1024 // 4KB - HTTP headers, small messages
SizeMedium = 32 * 1024 // 32KB - HTTP request/response bodies
SizeLarge = 256 * 1024 // 256KB - Data pipe, file transfers
SizeXLarge = 1024 * 1024 // 1MB - Large file transfers, bulk data
)
type BufferPool struct {
small sync.Pool
medium sync.Pool
large sync.Pool
xlarge sync.Pool
}
func NewBufferPool() *BufferPool {
@@ -34,6 +36,12 @@ func NewBufferPool() *BufferPool {
return &b
},
},
xlarge: sync.Pool{
New: func() interface{} {
b := make([]byte, SizeXLarge)
return &b
},
},
}
}
@@ -43,8 +51,10 @@ func (p *BufferPool) Get(size int) *[]byte {
return p.small.Get().(*[]byte)
case size <= SizeMedium:
return p.medium.Get().(*[]byte)
default:
case size <= SizeLarge:
return p.large.Get().(*[]byte)
default:
return p.xlarge.Get().(*[]byte)
}
}
@@ -63,7 +73,24 @@ func (p *BufferPool) Put(buf *[]byte) {
p.medium.Put(buf)
case SizeLarge:
p.large.Put(buf)
case SizeXLarge:
p.xlarge.Put(buf)
}
// Note: buffers with non-standard sizes are not pooled (let GC handle them)
}
// GetXLarge returns a 1MB buffer for bulk data transfers
func (p *BufferPool) GetXLarge() *[]byte {
return p.xlarge.Get().(*[]byte)
}
// PutXLarge returns a 1MB buffer to the pool
func (p *BufferPool) PutXLarge(buf *[]byte) {
if buf == nil || cap(*buf) != SizeXLarge {
return
}
*buf = (*buf)[:cap(*buf)]
p.xlarge.Put(buf)
}
var globalBufferPool = NewBufferPool()
@@ -75,3 +102,13 @@ func GetBuffer(size int) *[]byte {
func PutBuffer(buf *[]byte) {
globalBufferPool.Put(buf)
}
// GetXLargeBuffer returns a 1MB buffer from the global pool
func GetXLargeBuffer() *[]byte {
return globalBufferPool.GetXLarge()
}
// PutXLargeBuffer returns a 1MB buffer to the global pool
func PutXLargeBuffer(buf *[]byte) {
globalBufferPool.PutXLarge(buf)
}

View File

@@ -6,31 +6,24 @@ import (
"time"
)
// TrafficStats tracks traffic statistics for a tunnel connection
type TrafficStats struct {
// Total bytes
totalBytesIn int64
totalBytesOut int64
// Request counts
totalRequests int64
activeConnections int64
// For speed calculation
lastBytesIn int64
lastBytesOut int64
lastTime time.Time
speedMu sync.Mutex
// Current speed (bytes per second)
speedIn int64
speedOut int64
// Start time
startTime time.Time
}
// NewTrafficStats creates a new traffic stats tracker
func NewTrafficStats() *TrafficStats {
now := time.Now()
return &TrafficStats{
@@ -39,17 +32,14 @@ func NewTrafficStats() *TrafficStats {
}
}
// AddBytesIn adds incoming bytes to the counter
func (s *TrafficStats) AddBytesIn(n int64) {
atomic.AddInt64(&s.totalBytesIn, n)
}
// AddBytesOut adds outgoing bytes to the counter
func (s *TrafficStats) AddBytesOut(n int64) {
atomic.AddInt64(&s.totalBytesOut, n)
}
// AddRequest increments the request counter
func (s *TrafficStats) AddRequest() {
atomic.AddInt64(&s.totalRequests, 1)
}
@@ -59,23 +49,25 @@ func (s *TrafficStats) IncActiveConnections() {
}
func (s *TrafficStats) DecActiveConnections() {
v := atomic.AddInt64(&s.activeConnections, -1)
if v < 0 {
atomic.StoreInt64(&s.activeConnections, 0)
for {
old := atomic.LoadInt64(&s.activeConnections)
if old <= 0 {
return
}
if atomic.CompareAndSwapInt64(&s.activeConnections, old, old-1) {
return
}
}
}
// GetTotalBytesIn returns total incoming bytes
func (s *TrafficStats) GetTotalBytesIn() int64 {
return atomic.LoadInt64(&s.totalBytesIn)
}
// GetTotalBytesOut returns total outgoing bytes
func (s *TrafficStats) GetTotalBytesOut() int64 {
return atomic.LoadInt64(&s.totalBytesOut)
}
// GetTotalRequests returns total request count
func (s *TrafficStats) GetTotalRequests() int64 {
return atomic.LoadInt64(&s.totalRequests)
}
@@ -84,21 +76,16 @@ func (s *TrafficStats) GetActiveConnections() int64 {
return atomic.LoadInt64(&s.activeConnections)
}
// GetTotalBytes returns total bytes (in + out)
func (s *TrafficStats) GetTotalBytes() int64 {
return s.GetTotalBytesIn() + s.GetTotalBytesOut()
}
// UpdateSpeed calculates current transfer speed
// Should be called periodically (e.g., every second)
func (s *TrafficStats) UpdateSpeed() {
s.speedMu.Lock()
defer s.speedMu.Unlock()
now := time.Now()
elapsed := now.Sub(s.lastTime).Seconds()
// Require minimum interval of 100ms to avoid division issues
if elapsed < 0.1 {
return
}
@@ -109,18 +96,15 @@ func (s *TrafficStats) UpdateSpeed() {
deltaIn := currentIn - s.lastBytesIn
deltaOut := currentOut - s.lastBytesOut
// Calculate instantaneous speed
if deltaIn > 0 {
s.speedIn = int64(float64(deltaIn) / elapsed)
} else {
// No new bytes - set speed to 0 immediately
s.speedIn = 0
}
if deltaOut > 0 {
s.speedOut = int64(float64(deltaOut) / elapsed)
} else {
// No new bytes - set speed to 0 immediately
s.speedOut = 0
}
@@ -129,38 +113,33 @@ func (s *TrafficStats) UpdateSpeed() {
s.lastTime = now
}
// GetSpeedIn returns current incoming speed in bytes per second
func (s *TrafficStats) GetSpeedIn() int64 {
s.speedMu.Lock()
defer s.speedMu.Unlock()
return s.speedIn
}
// GetSpeedOut returns current outgoing speed in bytes per second
func (s *TrafficStats) GetSpeedOut() int64 {
s.speedMu.Lock()
defer s.speedMu.Unlock()
return s.speedOut
}
// GetUptime returns how long the connection has been active
func (s *TrafficStats) GetUptime() time.Duration {
return time.Since(s.startTime)
}
// Snapshot returns a snapshot of all stats
type Snapshot struct {
TotalBytesIn int64
TotalBytesOut int64
TotalBytes int64
TotalRequests int64
ActiveConnections int64
SpeedIn int64 // bytes per second
SpeedOut int64 // bytes per second
SpeedIn int64
SpeedOut int64
Uptime time.Duration
}
// GetSnapshot returns a snapshot of current stats
func (s *TrafficStats) GetSnapshot() Snapshot {
s.speedMu.Lock()
speedIn := s.speedIn

View File

@@ -0,0 +1,61 @@
package tuning
import (
"runtime"
"runtime/debug"
)
type Mode int
const (
ModeClient Mode = iota
ModeServer
)
type Config struct {
GCPercent int
MemoryLimit int64
}
func DefaultClientConfig() Config {
total := int64(getSystemTotalMemory())
limit := total / 4
if limit < 64*1024*1024 {
limit = 64 * 1024 * 1024
}
return Config{
GCPercent: 100,
MemoryLimit: limit,
}
}
func DefaultServerConfig() Config {
total := int64(getSystemTotalMemory())
limit := total * 3 / 4
if limit < 128*1024*1024 {
limit = 128 * 1024 * 1024
}
return Config{
GCPercent: 200,
MemoryLimit: limit,
}
}
func Apply(cfg Config) {
runtime.GOMAXPROCS(runtime.NumCPU())
if cfg.GCPercent > 0 {
debug.SetGCPercent(cfg.GCPercent)
}
if cfg.MemoryLimit > 0 {
debug.SetMemoryLimit(cfg.MemoryLimit)
}
}
func ApplyMode(mode Mode) {
switch mode {
case ModeClient:
Apply(DefaultClientConfig())
case ModeServer:
Apply(DefaultServerConfig())
}
}

View File

@@ -0,0 +1,28 @@
//go:build darwin
package tuning
import (
"syscall"
"unsafe"
)
func getSystemTotalMemory() uint64 {
mib := [2]int32{6, 24} // CTL_HW, HW_MEMSIZE
var value uint64
size := unsafe.Sizeof(value)
_, _, errno := syscall.Syscall6(
syscall.SYS___SYSCTL,
uintptr(unsafe.Pointer(&mib[0])),
2,
uintptr(unsafe.Pointer(&value)),
uintptr(unsafe.Pointer(&size)),
0,
0,
)
if errno != 0 {
return 1024 * 1024 * 1024
}
return value
}

View File

@@ -0,0 +1,13 @@
//go:build linux
package tuning
import "syscall"
func getSystemTotalMemory() uint64 {
var info syscall.Sysinfo_t
if err := syscall.Sysinfo(&info); err == nil {
return info.Totalram * uint64(info.Unit)
}
return 1024 * 1024 * 1024
}

View File

@@ -0,0 +1,7 @@
//go:build !linux && !darwin && !windows
package tuning
func getSystemTotalMemory() uint64 {
return 1024 * 1024 * 1024 // 1GB fallback
}

View File

@@ -0,0 +1,36 @@
//go:build windows
package tuning
import (
"syscall"
"unsafe"
)
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
globalMemoryStatusEx = kernel32.NewProc("GlobalMemoryStatusEx")
)
type memoryStatusEx struct {
dwLength uint32
dwMemoryLoad uint32
ullTotalPhys uint64
ullAvailPhys uint64
ullTotalPageFile uint64
ullAvailPageFile uint64
ullTotalVirtual uint64
ullAvailVirtual uint64
ullAvailExtendedVirtual uint64
}
func getSystemTotalMemory() uint64 {
var mem memoryStatusEx
mem.dwLength = uint32(unsafe.Sizeof(mem))
ret, _, _ := globalMemoryStatusEx.Call(uintptr(unsafe.Pointer(&mem)))
if ret == 0 {
return 1024 * 1024 * 1024
}
return mem.ullTotalPhys
}