Merge pull request #12 from Gouryella/perf/pool-scaling-and-latency-fix

Perf/pool scaling and latency fix
This commit is contained in:
Gouryella
2025-12-24 10:14:01 +08:00
committed by GitHub
28 changed files with 1082 additions and 364 deletions

2
.gitignore vendored
View File

@@ -52,5 +52,5 @@ temp/
certs/
.drip-server.env
benchmark-results/
drip
drip-linux-amd64
./drip

View File

@@ -3,8 +3,6 @@ package main
import (
"fmt"
"os"
"runtime"
"runtime/debug"
"drip/internal/client/cli"
)
@@ -16,9 +14,6 @@ var (
)
func main() {
// Performance optimizations
setupPerformanceOptimizations()
cli.SetVersion(Version, GitCommit, BuildTime)
if err := cli.Execute(); err != nil {
@@ -26,19 +21,3 @@ func main() {
os.Exit(1)
}
}
// setupPerformanceOptimizations configures runtime settings for optimal performance
func setupPerformanceOptimizations() {
// Set GOMAXPROCS to use all available CPU cores
numCPU := runtime.NumCPU()
runtime.GOMAXPROCS(numCPU)
// Reduce GC frequency for high-throughput scenarios
// Default is 100, setting to 200 reduces GC overhead at cost of more memory
// This is beneficial since we now use buffer pools (less garbage)
debug.SetGCPercent(200)
// Set memory limit to prevent OOM (adjust based on your server)
// This is a soft limit - Go will try to stay under this
debug.SetMemoryLimit(8 * 1024 * 1024 * 1024) // 8GB limit
}

View File

@@ -13,6 +13,7 @@ import (
"drip/internal/server/tcp"
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
"drip/internal/shared/tuning"
"drip/internal/shared/utils"
"drip/pkg/config"
"github.com/spf13/cobra"
@@ -60,6 +61,9 @@ func init() {
}
func runServer(_ *cobra.Command, _ []string) error {
// Apply server-mode GC tuning (high throughput, more memory)
tuning.ApplyMode(tuning.ModeServer)
if serverTLSCert == "" {
return fmt.Errorf("TLS certificate path is required (use --tls-cert flag or DRIP_TLS_CERT environment variable)")
}

View File

@@ -9,8 +9,10 @@ import (
"time"
"drip/internal/client/tcp"
"drip/internal/shared/tuning"
"drip/internal/shared/ui"
"drip/internal/shared/utils"
"go.uber.org/zap"
)
@@ -20,6 +22,8 @@ const (
)
func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) error {
tuning.ApplyMode(tuning.ModeClient)
if err := utils.InitLogger(verbose); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"runtime"
@@ -17,6 +16,7 @@ import (
"go.uber.org/zap"
"drip/internal/shared/constants"
"drip/internal/shared/mux"
"drip/internal/shared/protocol"
"drip/internal/shared/stats"
"drip/pkg/config"
@@ -202,10 +202,7 @@ func (c *PoolClient) Connect() error {
c.tunnelID = resp.TunnelID
}
yamuxCfg := yamux.DefaultConfig()
yamuxCfg.EnableKeepAlive = false
yamuxCfg.LogOutput = io.Discard
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
yamuxCfg := mux.NewClientConfig()
session, err := yamux.Server(primaryConn, yamuxCfg)
if err != nil {
@@ -241,7 +238,7 @@ func (c *PoolClient) Connect() error {
c.desiredTotal = c.initialSessions
c.mu.Unlock()
c.ensureSessions()
c.warmupSessions()
c.wg.Add(1)
go c.scalerLoop()

View File

@@ -2,19 +2,20 @@ package tcp
import (
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
"drip/internal/shared/constants"
"drip/internal/shared/mux"
"drip/internal/shared/protocol"
)
var dataConnCounter atomic.Uint64
// sessionHandle wraps a yamux session with metadata.
type sessionHandle struct {
id string
@@ -37,17 +38,49 @@ func (h *sessionHandle) lastActiveTime() time.Time {
return time.Unix(0, n)
}
// warmupSessions pre-creates initial sessions in parallel to eliminate cold-start latency.
func (c *PoolClient) warmupSessions() {
if c.IsClosed() || c.tunnelID == "" {
return
}
c.mu.RLock()
desired := c.desiredTotal
c.mu.RUnlock()
current := c.sessionCount()
toCreate := desired - current
if toCreate <= 0 {
return
}
var wg sync.WaitGroup
for i := 0; i < toCreate; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = c.addDataSession()
}()
}
wg.Wait()
// Brief wait for server to register all sessions
time.Sleep(100 * time.Millisecond)
}
// scalerLoop monitors load and adjusts session count.
func (c *PoolClient) scalerLoop() {
defer c.wg.Done()
const (
checkInterval = 5 * time.Second
scaleUpCooldown = 5 * time.Second
checkInterval = 1 * time.Second
scaleUpCooldown = 1 * time.Second
scaleDownCooldown = 60 * time.Second
capacityPerSession = int64(64)
scaleUpLoad = 0.7
scaleDownLoad = 0.3
capacityPerSession = int64(256)
scaleUpLoad = 0.6
scaleDownLoad = 0.2
burstThreshold = 0.9
maxBurstAdd = 4
)
ticker := time.NewTicker(checkInterval)
@@ -76,11 +109,23 @@ func (c *PoolClient) scalerLoop() {
active := c.stats.GetActiveConnections()
load := float64(active) / float64(int64(current)*capacityPerSession)
sinceLastScale := time.Since(lastScale)
if sinceLastScale >= scaleUpCooldown && load > scaleUpLoad && desired < c.maxSessions {
if load > burstThreshold && desired < c.maxSessions {
sessionsToAdd := min(maxBurstAdd, c.maxSessions-desired)
if sessionsToAdd > 0 {
c.mu.Lock()
c.desiredTotal = min(c.desiredTotal+sessionsToAdd, c.maxSessions)
c.lastScale = time.Now()
c.mu.Unlock()
}
} else if sinceLastScale >= scaleUpCooldown && load > scaleUpLoad && desired < c.maxSessions {
sessionsToAdd := 1
if load > 0.8 {
sessionsToAdd = 2
}
c.mu.Lock()
c.desiredTotal = min(c.desiredTotal+1, c.maxSessions)
c.desiredTotal = min(c.desiredTotal+sessionsToAdd, c.maxSessions)
c.lastScale = time.Now()
c.mu.Unlock()
} else if sinceLastScale >= scaleDownCooldown && load < scaleDownLoad && desired > c.minSessions {
@@ -108,11 +153,20 @@ func (c *PoolClient) ensureSessions() {
current := c.sessionCount()
if current < desired {
for i := 0; i < desired-current; i++ {
if err := c.addDataSession(); err != nil {
c.logger.Debug("Add data session failed", zap.Error(err))
break
toAdd := desired - current
if toAdd > 1 {
var wg sync.WaitGroup
for i := 0; i < toAdd; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = c.addDataSession()
}()
}
wg.Wait()
} else {
_ = c.addDataSession()
}
return
}
@@ -139,7 +193,7 @@ func (c *PoolClient) addDataSession() error {
return err
}
connID := fmt.Sprintf("data-%d", time.Now().UnixNano())
connID := fmt.Sprintf("data-%d", dataConnCounter.Add(1))
req := protocol.DataConnectRequest{
TunnelID: c.tunnelID,
@@ -191,10 +245,7 @@ func (c *PoolClient) addDataSession() error {
return fmt.Errorf("data connection rejected: %s", resp.Message)
}
yamuxCfg := yamux.DefaultConfig()
yamuxCfg.EnableKeepAlive = false
yamuxCfg.LogOutput = io.Discard
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
yamuxCfg := mux.NewClientConfig()
session, err := yamux.Server(conn, yamuxCfg)
if err != nil {
@@ -219,6 +270,9 @@ func (c *PoolClient) addDataSession() error {
c.wg.Add(1)
go c.sessionWatcher(h, false)
c.wg.Add(1)
go c.pingLoop(h)
return nil
}
@@ -310,3 +364,39 @@ func (c *PoolClient) sessionCount() int {
}
return count
}
// SessionStats holds per-session statistics.
type SessionStats struct {
ID string
IsPrimary bool
ActiveCount int64
LastActiveAt time.Time
}
// GetSessionStats returns statistics for all sessions.
func (c *PoolClient) GetSessionStats() []SessionStats {
c.mu.RLock()
defer c.mu.RUnlock()
stats := make([]SessionStats, 0, len(c.dataSessions)+1)
if c.primary != nil {
stats = append(stats, SessionStats{
ID: c.primary.id,
IsPrimary: true,
ActiveCount: c.primary.active.Load(),
LastActiveAt: c.primary.lastActiveTime(),
})
}
for _, h := range c.dataSessions {
stats = append(stats, SessionStats{
ID: h.id,
IsPrimary: false,
ActiveCount: h.active.Load(),
LastActiveAt: h.lastActiveTime(),
})
}
return stats
}

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"
json "github.com/goccy/go-json"
@@ -16,12 +17,20 @@ import (
"drip/internal/server/tunnel"
"drip/internal/shared/httputil"
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
const openStreamTimeout = 10 * time.Second
// bufio.Reader pool to reduce allocations on hot path
var bufioReaderPool = sync.Pool{
New: func() interface{} {
return bufio.NewReaderSize(nil, 32*1024)
},
}
const openStreamTimeout = 3 * time.Second
type Handler struct {
manager *tunnel.Manager
@@ -104,13 +113,19 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r)
reader := bufioReaderPool.Get().(*bufio.Reader)
reader.Reset(countingStream)
resp, err := http.ReadResponse(reader, r)
if err != nil {
bufioReaderPool.Put(reader)
w.Header().Set("Connection", "close")
http.Error(w, "Read response failed", http.StatusBadGateway)
return
}
defer resp.Body.Close()
defer func() {
resp.Body.Close()
bufioReaderPool.Put(reader)
}()
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
@@ -147,7 +162,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}()
_, _ = io.Copy(w, resp.Body)
// Use pooled buffer for zero-copy optimization
buf := pool.GetBuffer(pool.SizeLarge)
_, _ = io.CopyBuffer(w, resp.Body, (*buf)[:])
pool.PutBuffer(buf)
close(done)
stream.Close()
}
@@ -158,10 +177,18 @@ func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, err
err error
}
ch := make(chan result, 1)
done := make(chan struct{})
defer close(done)
go func() {
s, err := tconn.OpenStream()
ch <- result{s, err}
select {
case ch <- result{s, err}:
case <-done:
if s != nil {
s.Close()
}
}
}()
select {
@@ -349,32 +376,30 @@ func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
if h.authToken != "" {
token := r.URL.Query().Get("token")
if token == "" {
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
token = strings.TrimPrefix(authHeader, "Bearer ")
}
// Only accept token via Authorization header (Bearer token)
// URL query parameters are insecure (logged, cached, visible in browser history)
var token string
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
token = strings.TrimPrefix(authHeader, "Bearer ")
}
if token != h.authToken {
http.Error(w, "Unauthorized: invalid or missing token", http.StatusUnauthorized)
w.Header().Set("WWW-Authenticate", `Bearer realm="stats"`)
http.Error(w, "Unauthorized: provide token via 'Authorization: Bearer <token>' header", http.StatusUnauthorized)
return
}
}
connections := h.manager.List()
stats := map[string]interface{}{
"total_tunnels": len(connections),
"tunnels": []map[string]interface{}{},
}
// Pre-allocate slice to avoid O(n²) reallocations
tunnelStats := make([]map[string]interface{}, 0, len(connections))
for _, conn := range connections {
if conn == nil {
continue
}
stats["tunnels"] = append(stats["tunnels"].([]map[string]interface{}), map[string]interface{}{
tunnelStats = append(tunnelStats, map[string]interface{}{
"subdomain": conn.Subdomain,
"tunnel_type": string(conn.GetTunnelType()),
"last_active": conn.LastActive.Unix(),
@@ -385,6 +410,11 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
})
}
stats := map[string]interface{}{
"total_tunnels": len(tunnelStats),
"tunnels": tunnelStats,
}
data, err := json.Marshal(stats)
if err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)

View File

@@ -18,11 +18,19 @@ import (
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
"drip/internal/shared/mux"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// bufioWriterPool reuses bufio.Writer instances to reduce GC pressure
var bufioWriterPool = sync.Pool{
New: func() interface{} {
return bufio.NewWriterSize(nil, 4096)
},
}
type Connection struct {
conn net.Conn
authToken string
@@ -344,9 +352,13 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
zap.String("host", req.Host),
)
// Get writer from pool to reduce GC pressure
pooledWriter := bufioWriterPool.Get().(*bufio.Writer)
pooledWriter.Reset(c.conn)
respWriter := &httpResponseWriter{
conn: c.conn,
writer: bufio.NewWriterSize(c.conn, 4096),
writer: pooledWriter,
header: make(http.Header),
}
@@ -356,10 +368,12 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
c.logger.Debug("Failed to flush HTTP response", zap.Error(err))
}
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
tcpConn.SetNoDelay(true)
tcpConn.SetNoDelay(false)
}
// Return writer to pool
pooledWriter.Reset(nil) // Clear reference to connection
bufioWriterPool.Put(pooledWriter)
// Keep TCP_NODELAY enabled for low latency HTTP responses
// (removed the toggle that was disabling it)
c.logger.Debug("HTTP request processing completed",
zap.String("method", req.Method),
@@ -673,10 +687,8 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
// Use optimized mux config for server
cfg := mux.NewServerConfig()
session, err := yamux.Client(bc, cfg)
if err != nil {

View File

@@ -1,10 +1,10 @@
package tcp
import (
"container/heap"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/yamux"
@@ -15,6 +15,54 @@ import (
"go.uber.org/zap"
)
// sessionEntry represents a session with its current stream count for heap operations
type sessionEntry struct {
id string
session *yamux.Session
streams int
heapIdx int // index in the heap, managed by heap.Interface
}
// sessionHeap implements heap.Interface for O(log n) session selection
type sessionHeap []*sessionEntry
func (h sessionHeap) Len() int { return len(h) }
func (h sessionHeap) Less(i, j int) bool {
// Min-heap: session with fewer streams has higher priority
return h[i].streams < h[j].streams
}
func (h sessionHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
h[i].heapIdx = i
h[j].heapIdx = j
}
func (h *sessionHeap) Push(x interface{}) {
entry := x.(*sessionEntry)
entry.heapIdx = len(*h)
*h = append(*h, entry)
}
func (h *sessionHeap) Pop() interface{} {
old := *h
n := len(old)
entry := old[n-1]
old[n-1] = nil // avoid memory leak
entry.heapIdx = -1
*h = old[0 : n-1]
return entry
}
// sessionHeapPool reuses heap slices to reduce allocations
var sessionHeapPool = sync.Pool{
New: func() interface{} {
h := make(sessionHeap, 0, 16)
return &h
},
}
type ConnectionGroup struct {
TunnelID string
Subdomain string
@@ -60,6 +108,12 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
const maxConsecutiveFailures = 3
failureCount := make(map[string]int)
type sessionSnapshot struct {
id string
session *yamux.Session
}
sessions := make([]sessionSnapshot, 0, 16)
for {
select {
case <-g.stopCh:
@@ -67,26 +121,25 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
case <-ticker.C:
}
sessions = sessions[:0]
g.mu.RLock()
sessions := make(map[string]*yamux.Session, len(g.Sessions))
for id, s := range g.Sessions {
sessions[id] = s
sessions = append(sessions, sessionSnapshot{id: id, session: s})
}
g.mu.RUnlock()
for id, session := range sessions {
if session == nil || session.IsClosed() {
g.RemoveSession(id)
delete(failureCount, id)
for _, snap := range sessions {
if snap.session == nil || snap.session.IsClosed() {
g.RemoveSession(snap.id)
delete(failureCount, snap.id)
continue
}
// Ping with timeout
done := make(chan error, 1)
go func(s *yamux.Session) {
_, err := s.Ping()
done <- err
}(session)
}(snap.session)
var err error
select {
@@ -98,31 +151,29 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
}
if err != nil {
failureCount[id]++
failureCount[snap.id]++
g.logger.Debug("Session ping failed",
zap.String("session_id", id),
zap.Int("consecutive_failures", failureCount[id]),
zap.String("session_id", snap.id),
zap.Int("consecutive_failures", failureCount[snap.id]),
zap.Error(err),
)
if failureCount[id] >= maxConsecutiveFailures {
if failureCount[snap.id] >= maxConsecutiveFailures {
g.logger.Warn("Session ping failed too many times, removing",
zap.String("session_id", id),
zap.Int("failures", failureCount[id]),
zap.String("session_id", snap.id),
zap.Int("failures", failureCount[snap.id]),
)
g.RemoveSession(id)
delete(failureCount, id)
g.RemoveSession(snap.id)
delete(failureCount, snap.id)
}
} else {
// Reset on success
failureCount[id] = 0
failureCount[snap.id] = 0
g.mu.Lock()
g.LastActivity = time.Now()
g.mu.Unlock()
}
}
// Check if all sessions are gone
g.mu.RLock()
sessionCount := len(g.Sessions)
g.mu.RUnlock()
@@ -214,11 +265,12 @@ func (g *ConnectionGroup) SessionCount() int {
return len(g.Sessions)
}
// OpenStream opens a new stream using a min-heap for O(log n) session selection.
func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
const (
maxStreamsPerSession = 256
maxRetries = 3
backoffBase = 25 * time.Millisecond
backoffBase = 5 * time.Millisecond
)
var lastErr error
@@ -230,61 +282,33 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
default:
}
// Prefer data sessions for data-plane traffic; keep the primary session
// as control-plane (client ping/latency), and only fall back to primary
// when no data session exists.
sessions := g.sessionsSnapshot(false)
if len(sessions) == 0 {
sessions = g.sessionsSnapshot(true)
h := g.buildSessionHeap(false)
if h.Len() == 0 {
h = g.buildSessionHeap(true)
}
if len(sessions) == 0 {
if h.Len() == 0 {
return nil, net.ErrClosed
}
tried := make([]bool, len(sessions))
anyUnderCap := false
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
for h.Len() > 0 {
entry := heap.Pop(h).(*sessionEntry)
session := entry.session
for range sessions {
bestIdx := -1
minStreams := int(^uint(0) >> 1)
for i := 0; i < len(sessions); i++ {
idx := (start + i) % len(sessions)
if tried[idx] {
continue
}
session := sessions[idx]
if session == nil || session.IsClosed() {
tried[idx] = true
continue
}
n := session.NumStreams()
if n >= maxStreamsPerSession {
continue
}
anyUnderCap = true
if n < minStreams {
minStreams = n
bestIdx = idx
}
}
if bestIdx == -1 {
break
}
tried[bestIdx] = true
session := sessions[bestIdx]
if session == nil || session.IsClosed() {
continue
}
currentStreams := session.NumStreams()
if currentStreams >= maxStreamsPerSession {
continue
}
anyUnderCap = true
stream, err := session.Open()
if err == nil {
*h = (*h)[:0]
sessionHeapPool.Put(h)
return stream, nil
}
lastErr = err
@@ -294,6 +318,9 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
}
}
*h = (*h)[:0]
sessionHeapPool.Put(h)
if !anyUnderCap {
lastErr = fmt.Errorf("all sessions are at stream capacity (%d)", maxStreamsPerSession)
}
@@ -313,36 +340,64 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
return nil, lastErr
}
func (g *ConnectionGroup) selectSession() *yamux.Session {
sessions := g.sessionsSnapshot(false)
if len(sessions) == 0 {
sessions = g.sessionsSnapshot(true)
}
if len(sessions) == 0 {
return nil
// buildSessionHeap creates a min-heap of sessions ordered by stream count.
func (g *ConnectionGroup) buildSessionHeap(includePrimary bool) *sessionHeap {
g.mu.RLock()
defer g.mu.RUnlock()
if len(g.Sessions) == 0 {
h := sessionHeapPool.Get().(*sessionHeap)
return h
}
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
minStreams := int(^uint(0) >> 1)
var best *yamux.Session
h := sessionHeapPool.Get().(*sessionHeap)
*h = (*h)[:0]
for i := 0; i < len(sessions); i++ {
session := sessions[(start+i)%len(sessions)]
for id, session := range g.Sessions {
if session == nil || session.IsClosed() {
continue
}
if n := session.NumStreams(); n < minStreams {
minStreams = n
best = session
if id == "primary" && !includePrimary {
continue
}
*h = append(*h, &sessionEntry{
id: id,
session: session,
streams: session.NumStreams(),
})
}
return best
heap.Init(h)
return h
}
func (g *ConnectionGroup) selectSession() *yamux.Session {
h := g.buildSessionHeap(false)
if h.Len() == 0 {
sessionHeapPool.Put(h)
h = g.buildSessionHeap(true)
}
if h.Len() == 0 {
sessionHeapPool.Put(h)
return nil
}
entry := heap.Pop(h).(*sessionEntry)
session := entry.session
*h = (*h)[:0]
sessionHeapPool.Put(h)
if session == nil || session.IsClosed() {
return nil
}
return session
}
func (g *ConnectionGroup) sessionsSnapshot(includePrimary bool) []*yamux.Session {
g.mu.Lock()
defer g.mu.Unlock()
g.mu.RLock()
defer g.mu.RUnlock()
if len(g.Sessions) == 0 {
return nil
@@ -351,7 +406,6 @@ func (g *ConnectionGroup) sessionsSnapshot(includePrimary bool) []*yamux.Session
sessions := make([]*yamux.Session, 0, len(g.Sessions))
for id, session := range g.Sessions {
if session == nil || session.IsClosed() {
delete(g.Sessions, id)
continue
}
if id == "primary" && !includePrimary {
@@ -360,10 +414,6 @@ func (g *ConnectionGroup) sessionsSnapshot(includePrimary bool) []*yamux.Session
sessions = append(sessions, session)
}
if len(sessions) > 0 {
g.LastActivity = time.Now()
}
return sessions
}

View File

@@ -92,11 +92,11 @@ func (l *Listener) Start() error {
l.httpServer = &http.Server{
Handler: l.httpHandler,
ReadHeaderTimeout: 30 * time.Second,
ReadTimeout: 0,
WriteTimeout: 0,
IdleTimeout: 120 * time.Second,
MaxHeaderBytes: 1 << 20,
ReadHeaderTimeout: 10 * time.Second, // Time to read request headers
ReadTimeout: 30 * time.Second, // Total time to read request (prevents slow-loris)
WriteTimeout: 60 * time.Second, // Time to write response (allows large responses)
IdleTimeout: 120 * time.Second, // Keep-alive timeout
MaxHeaderBytes: 1 << 18, // 256KB max header size (reduced from 1MB)
}
if err := http2.ConfigureServer(l.httpServer, &http2.Server{

View File

@@ -182,17 +182,25 @@ func (p *Proxy) handleConn(conn net.Conn) {
return
}
// Open stream with timeout to prevent goroutine leak
const openStreamTimeout = 10 * time.Second
const openStreamTimeout = 3 * time.Second
type streamResult struct {
stream net.Conn
err error
}
resultCh := make(chan streamResult, 1)
ctx, cancel := context.WithTimeout(p.ctx, openStreamTimeout)
defer cancel()
go func() {
s, err := p.openStream()
resultCh <- streamResult{s, err}
select {
case resultCh <- streamResult{s, err}:
case <-ctx.Done():
if s != nil {
s.Close()
}
}
}()
var stream net.Conn
@@ -205,7 +213,7 @@ func (p *Proxy) handleConn(conn net.Conn) {
return
}
stream = result.stream
case <-time.After(openStreamTimeout):
case <-ctx.Done():
p.logger.Debug("Open stream timeout")
return
case <-p.stopCh:

View File

@@ -3,12 +3,11 @@ package tcp
import (
"bufio"
"fmt"
"io"
"net"
"github.com/hashicorp/yamux"
"drip/internal/shared/constants"
"drip/internal/shared/mux"
)
type bufferedConn struct {
@@ -27,10 +26,8 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
// Use optimized mux config for server
cfg := mux.NewServerConfig()
session, err := yamux.Client(bc, cfg)
if err != nil {
@@ -66,10 +63,8 @@ func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
// Use optimized mux config for server
cfg := mux.NewServerConfig()
session, err := yamux.Client(bc, cfg)
if err != nil {

View File

@@ -20,9 +20,10 @@ type Connection struct {
LastActive time.Time
mu sync.RWMutex
logger *zap.Logger
closed bool
closed atomic.Bool // Use atomic for lock-free hot path checks
tunnelType protocol.TunnelType
openStream func() (net.Conn, error)
remoteIP string // Client IP for rate limiting tracking
bytesIn atomic.Int64
bytesOut atomic.Int64
@@ -38,18 +39,15 @@ func NewConnection(subdomain string, conn *websocket.Conn, logger *zap.Logger) *
CloseCh: make(chan struct{}),
LastActive: time.Now(),
logger: logger,
closed: false,
}
}
// Send sends data through the WebSocket connection
func (c *Connection) Send(data []byte) error {
c.mu.RLock()
if c.closed {
c.mu.RUnlock()
// Lock-free check using atomic - avoids RLock contention on hot path
if c.closed.Load() {
return ErrConnectionClosed
}
c.mu.RUnlock()
select {
case c.SendCh <- data:
@@ -75,14 +73,14 @@ func (c *Connection) IsAlive(timeout time.Duration) bool {
// Close closes the connection and all associated channels
func (c *Connection) Close() {
// Use atomic swap to ensure only one goroutine closes
if c.closed.Swap(true) {
return // Already closed
}
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return
}
c.closed = true
close(c.CloseCh)
close(c.SendCh)
@@ -100,9 +98,7 @@ func (c *Connection) Close() {
// IsClosed returns whether the connection is closed
func (c *Connection) IsClosed() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.closed
return c.closed.Load() // Lock-free check
}
// SetTunnelType sets the tunnel type.
@@ -119,8 +115,7 @@ func (c *Connection) GetTunnelType() protocol.TunnelType {
return c.tunnelType
}
// SetOpenStream registers a yamux stream opener for this tunnel.
// It is used by the HTTP proxy to forward each request over a mux stream.
// SetOpenStream registers a stream opener for this tunnel.
func (c *Connection) SetOpenStream(open func() (net.Conn, error)) {
c.mu.Lock()
c.openStream = open
@@ -129,12 +124,16 @@ func (c *Connection) SetOpenStream(open func() (net.Conn, error)) {
// OpenStream opens a new mux stream to the tunnel client.
func (c *Connection) OpenStream() (net.Conn, error) {
// Lock-free closed check
if c.closed.Load() {
return nil, ErrConnectionClosed
}
c.mu.RLock()
open := c.openStream
closed := c.closed
c.mu.RUnlock()
if closed || open == nil {
if open == nil {
return nil, ErrConnectionClosed
}
return open()
@@ -178,17 +177,11 @@ func (c *Connection) GetActiveConnections() int64 {
// StartWritePump starts the write pump for sending messages
func (c *Connection) StartWritePump() {
// Skip write pump for TCP-only connections (no WebSocket)
if c.Conn == nil {
c.logger.Debug("Skipping WritePump for TCP connection",
zap.String("subdomain", c.Subdomain),
)
// Still need to drain SendCh to prevent blocking
go func() {
for {
select {
case <-c.SendCh:
// Discard messages for TCP mode
case <-c.CloseCh:
return
}

View File

@@ -1,7 +1,10 @@
package tunnel
import (
"errors"
"hash/fnv"
"sync"
"sync/atomic"
"time"
"drip/internal/shared/utils"
@@ -9,59 +12,285 @@ import (
"go.uber.org/zap"
)
// Manager manages all active tunnel connections
type Manager struct {
tunnels map[string]*Connection // subdomain -> connection
mu sync.RWMutex
used map[string]bool // track used subdomains
logger *zap.Logger
// Manager limits
const (
DefaultMaxTunnels = 1000 // Maximum total tunnels
DefaultMaxTunnelsPerIP = 10 // Maximum tunnels per IP
DefaultRateLimit = 10 // Registrations per IP per minute
DefaultRateLimitWindow = 1 * time.Minute // Rate limit window
// numShards is the number of shards for lock distribution
// Using 32 shards reduces lock contention by ~32x under high concurrency
numShards = 32
)
var (
ErrTooManyTunnels = errors.New("maximum tunnel limit reached")
ErrTooManyPerIP = errors.New("maximum tunnels per IP reached")
ErrRateLimitExceeded = errors.New("rate limit exceeded, try again later")
)
// rateLimitEntry tracks registration attempts per IP
type rateLimitEntry struct {
count int
windowEnd time.Time
}
// NewManager creates a new tunnel manager
func NewManager(logger *zap.Logger) *Manager {
return &Manager{
tunnels: make(map[string]*Connection),
used: make(map[string]bool),
logger: logger,
// shard holds a subset of tunnels with its own lock
type shard struct {
tunnels map[string]*Connection
used map[string]bool
mu sync.RWMutex
}
// Manager manages all active tunnel connections with sharded locking
type Manager struct {
shards [numShards]shard
logger *zap.Logger
// Limits
maxTunnels int
maxTunnelsPerIP int
rateLimit int
rateLimitWindow time.Duration
// Global counters (atomic for lock-free reads)
tunnelCount atomic.Int64
// Per-IP tracking (requires separate lock as it spans shards)
ipMu sync.RWMutex
tunnelsByIP map[string]int // IP -> tunnel count
rateLimits map[string]*rateLimitEntry // IP -> rate limit entry
// Lifecycle
stopCh chan struct{}
}
// ManagerConfig holds configuration for the Manager
type ManagerConfig struct {
MaxTunnels int
MaxTunnelsPerIP int
RateLimit int // Registrations per IP per window
RateLimitWindow time.Duration
}
// DefaultManagerConfig returns default configuration
func DefaultManagerConfig() ManagerConfig {
return ManagerConfig{
MaxTunnels: DefaultMaxTunnels,
MaxTunnelsPerIP: DefaultMaxTunnelsPerIP,
RateLimit: DefaultRateLimit,
RateLimitWindow: DefaultRateLimitWindow,
}
}
// Register registers a new tunnel connection
// Returns the assigned subdomain and any error
// NewManager creates a new tunnel manager with default config
func NewManager(logger *zap.Logger) *Manager {
return NewManagerWithConfig(logger, DefaultManagerConfig())
}
// NewManagerWithConfig creates a new tunnel manager with custom config
func NewManagerWithConfig(logger *zap.Logger, cfg ManagerConfig) *Manager {
if cfg.MaxTunnels <= 0 {
cfg.MaxTunnels = DefaultMaxTunnels
}
if cfg.MaxTunnelsPerIP <= 0 {
cfg.MaxTunnelsPerIP = DefaultMaxTunnelsPerIP
}
if cfg.RateLimit <= 0 {
cfg.RateLimit = DefaultRateLimit
}
if cfg.RateLimitWindow <= 0 {
cfg.RateLimitWindow = DefaultRateLimitWindow
}
logger.Info("Tunnel manager configured",
zap.Int("max_tunnels", cfg.MaxTunnels),
zap.Int("max_per_ip", cfg.MaxTunnelsPerIP),
zap.Int("rate_limit", cfg.RateLimit),
zap.Duration("rate_window", cfg.RateLimitWindow),
zap.Int("num_shards", numShards),
)
m := &Manager{
logger: logger,
maxTunnels: cfg.MaxTunnels,
maxTunnelsPerIP: cfg.MaxTunnelsPerIP,
rateLimit: cfg.RateLimit,
rateLimitWindow: cfg.RateLimitWindow,
tunnelsByIP: make(map[string]int),
rateLimits: make(map[string]*rateLimitEntry),
stopCh: make(chan struct{}),
}
// Initialize all shards
for i := 0; i < numShards; i++ {
m.shards[i].tunnels = make(map[string]*Connection)
m.shards[i].used = make(map[string]bool)
}
return m
}
// getShard returns the shard for a given subdomain using FNV-1a hash
func (m *Manager) getShard(subdomain string) *shard {
h := fnv.New32a()
h.Write([]byte(subdomain))
return &m.shards[h.Sum32()%numShards]
}
// checkRateLimit checks if the IP has exceeded rate limit (caller must hold ipMu)
func (m *Manager) checkRateLimitLocked(ip string) bool {
now := time.Now()
entry, exists := m.rateLimits[ip]
if !exists || now.After(entry.windowEnd) {
// New window
m.rateLimits[ip] = &rateLimitEntry{
count: 1,
windowEnd: now.Add(m.rateLimitWindow),
}
return true
}
if entry.count >= m.rateLimit {
return false
}
entry.count++
return true
}
// Register registers a new tunnel connection with IP-based limits
func (m *Manager) Register(conn *websocket.Conn, customSubdomain string) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.RegisterWithIP(conn, customSubdomain, "")
}
// RegisterWithIP registers a new tunnel with IP tracking
func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, remoteIP string) (string, error) {
// Reserve a global slot atomically using CAS loop
for {
current := m.tunnelCount.Load()
if current >= int64(m.maxTunnels) {
m.logger.Warn("Maximum tunnel limit reached",
zap.Int64("current", current),
zap.Int("max", m.maxTunnels),
)
return "", ErrTooManyTunnels
}
if m.tunnelCount.CompareAndSwap(current, current+1) {
break
}
// CAS failed, another goroutine modified the counter, retry
}
// Rollback helper for global counter
rollbackGlobal := func() {
m.tunnelCount.Add(-1)
}
// Check per-IP limits and reserve slot atomically
if remoteIP != "" {
m.ipMu.Lock()
if !m.checkRateLimitLocked(remoteIP) {
m.ipMu.Unlock()
rollbackGlobal()
m.logger.Warn("Rate limit exceeded",
zap.String("ip", remoteIP),
zap.Int("limit", m.rateLimit),
)
return "", ErrRateLimitExceeded
}
if m.tunnelsByIP[remoteIP] >= m.maxTunnelsPerIP {
currentPerIP := m.tunnelsByIP[remoteIP]
m.ipMu.Unlock()
rollbackGlobal()
m.logger.Warn("Per-IP tunnel limit reached",
zap.String("ip", remoteIP),
zap.Int("current", currentPerIP),
zap.Int("max", m.maxTunnelsPerIP),
)
return "", ErrTooManyPerIP
}
// Reserve per-IP slot while still holding the lock
m.tunnelsByIP[remoteIP]++
m.ipMu.Unlock()
}
// Rollback helper for per-IP counter
rollbackPerIP := func() {
if remoteIP != "" {
m.ipMu.Lock()
if m.tunnelsByIP[remoteIP] > 0 {
m.tunnelsByIP[remoteIP]--
if m.tunnelsByIP[remoteIP] == 0 {
delete(m.tunnelsByIP, remoteIP)
}
}
m.ipMu.Unlock()
}
}
var subdomain string
if customSubdomain != "" {
// Validate custom subdomain
if !utils.ValidateSubdomain(customSubdomain) {
rollbackPerIP()
rollbackGlobal()
return "", ErrInvalidSubdomain
}
if utils.IsReserved(customSubdomain) {
rollbackPerIP()
rollbackGlobal()
return "", ErrReservedSubdomain
}
if m.used[customSubdomain] {
// Check if subdomain is taken in its shard
s := m.getShard(customSubdomain)
s.mu.Lock()
if s.used[customSubdomain] {
s.mu.Unlock()
rollbackPerIP()
rollbackGlobal()
return "", ErrSubdomainTaken
}
subdomain = customSubdomain
// Register in shard
tc := NewConnection(subdomain, conn, m.logger)
tc.remoteIP = remoteIP
s.tunnels[subdomain] = tc
s.used[subdomain] = true
s.mu.Unlock()
} else {
// Generate unique random subdomain
subdomain = m.generateUniqueSubdomain()
s := m.getShard(subdomain)
s.mu.Lock()
tc := NewConnection(subdomain, conn, m.logger)
tc.remoteIP = remoteIP
s.tunnels[subdomain] = tc
s.used[subdomain] = true
s.mu.Unlock()
}
// Create connection
tc := NewConnection(subdomain, conn, m.logger)
m.tunnels[subdomain] = tc
m.used[subdomain] = true
// Start write pump in background
go tc.StartWritePump()
// Get connection and start write pump
s := m.getShard(subdomain)
s.mu.RLock()
tc := s.tunnels[subdomain]
s.mu.RUnlock()
if tc != nil {
go tc.StartWritePump()
}
m.logger.Info("Tunnel registered",
zap.String("subdomain", subdomain),
zap.Int("total_tunnels", len(m.tunnels)),
zap.String("ip", remoteIP),
zap.Int64("total_tunnels", m.tunnelCount.Load()),
)
return subdomain, nil
@@ -69,85 +298,143 @@ func (m *Manager) Register(conn *websocket.Conn, customSubdomain string) (string
// Unregister removes a tunnel connection
func (m *Manager) Unregister(subdomain string) {
m.mu.Lock()
defer m.mu.Unlock()
s := m.getShard(subdomain)
s.mu.Lock()
if tc, ok := m.tunnels[subdomain]; ok {
tc.Close()
delete(m.tunnels, subdomain)
delete(m.used, subdomain)
m.logger.Info("Tunnel unregistered",
zap.String("subdomain", subdomain),
zap.Int("total_tunnels", len(m.tunnels)),
)
tc, ok := s.tunnels[subdomain]
if !ok {
s.mu.Unlock()
return
}
remoteIP := tc.remoteIP
tc.Close()
delete(s.tunnels, subdomain)
delete(s.used, subdomain)
s.mu.Unlock()
// Update counters
m.tunnelCount.Add(-1)
if remoteIP != "" {
m.ipMu.Lock()
if m.tunnelsByIP[remoteIP] > 0 {
m.tunnelsByIP[remoteIP]--
if m.tunnelsByIP[remoteIP] == 0 {
delete(m.tunnelsByIP, remoteIP)
}
}
m.ipMu.Unlock()
}
m.logger.Info("Tunnel unregistered",
zap.String("subdomain", subdomain),
zap.Int64("total_tunnels", m.tunnelCount.Load()),
)
}
// Get retrieves a tunnel connection by subdomain
func (m *Manager) Get(subdomain string) (*Connection, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
tc, ok := m.tunnels[subdomain]
s := m.getShard(subdomain)
s.mu.RLock()
tc, ok := s.tunnels[subdomain]
s.mu.RUnlock()
return tc, ok
}
// List returns all active tunnel connections
func (m *Manager) List() []*Connection {
m.mu.RLock()
defer m.mu.RUnlock()
// Pre-allocate with approximate capacity
connections := make([]*Connection, 0, m.tunnelCount.Load())
connections := make([]*Connection, 0, len(m.tunnels))
for _, tc := range m.tunnels {
connections = append(connections, tc)
for i := 0; i < numShards; i++ {
s := &m.shards[i]
s.mu.RLock()
for _, tc := range s.tunnels {
connections = append(connections, tc)
}
s.mu.RUnlock()
}
return connections
}
// Count returns the number of active tunnels
func (m *Manager) Count() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.tunnels)
return int(m.tunnelCount.Load())
}
// CleanupStale removes stale connections that haven't been active
func (m *Manager) CleanupStale(timeout time.Duration) int {
m.mu.Lock()
defer m.mu.Unlock()
totalCleaned := 0
staleSubdomains := []string{}
// Clean up each shard independently
for i := 0; i < numShards; i++ {
s := &m.shards[i]
s.mu.Lock()
for subdomain, tc := range m.tunnels {
if !tc.IsAlive(timeout) {
staleSubdomains = append(staleSubdomains, subdomain)
var staleSubdomains []string
for subdomain, tc := range s.tunnels {
if !tc.IsAlive(timeout) {
staleSubdomains = append(staleSubdomains, subdomain)
}
}
for _, subdomain := range staleSubdomains {
if tc, ok := s.tunnels[subdomain]; ok {
remoteIP := tc.remoteIP
tc.Close()
delete(s.tunnels, subdomain)
delete(s.used, subdomain)
// Update counters
m.tunnelCount.Add(-1)
if remoteIP != "" {
m.ipMu.Lock()
if m.tunnelsByIP[remoteIP] > 0 {
m.tunnelsByIP[remoteIP]--
if m.tunnelsByIP[remoteIP] == 0 {
delete(m.tunnelsByIP, remoteIP)
}
}
m.ipMu.Unlock()
}
}
}
totalCleaned += len(staleSubdomains)
s.mu.Unlock()
}
for _, subdomain := range staleSubdomains {
if tc, ok := m.tunnels[subdomain]; ok {
tc.Close()
delete(m.tunnels, subdomain)
delete(m.used, subdomain)
// Cleanup expired rate limit entries
m.ipMu.Lock()
now := time.Now()
for ip, entry := range m.rateLimits {
if now.After(entry.windowEnd) {
delete(m.rateLimits, ip)
}
}
m.ipMu.Unlock()
if len(staleSubdomains) > 0 {
if totalCleaned > 0 {
m.logger.Info("Cleaned up stale tunnels",
zap.Int("count", len(staleSubdomains)),
zap.Int("count", totalCleaned),
)
}
return len(staleSubdomains)
return totalCleaned
}
// StartCleanupTask starts a background task to clean up stale connections
func (m *Manager) StartCleanupTask(interval, timeout time.Duration) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
m.CleanupStale(timeout)
defer ticker.Stop()
for {
select {
case <-ticker.C:
m.CleanupStale(timeout)
case <-m.stopCh:
return
}
}
}()
}
@@ -158,7 +445,16 @@ func (m *Manager) generateUniqueSubdomain() string {
for i := 0; i < maxAttempts; i++ {
subdomain := utils.GenerateSubdomain(6)
if !m.used[subdomain] && !utils.IsReserved(subdomain) {
if utils.IsReserved(subdomain) {
continue
}
s := m.getShard(subdomain)
s.mu.RLock()
taken := s.used[subdomain]
s.mu.RUnlock()
if !taken {
return subdomain
}
}
@@ -169,17 +465,24 @@ func (m *Manager) generateUniqueSubdomain() string {
// Shutdown gracefully shuts down all tunnels
func (m *Manager) Shutdown() {
m.mu.Lock()
defer m.mu.Unlock()
// Signal cleanup goroutine to stop
close(m.stopCh)
m.logger.Info("Shutting down tunnel manager",
zap.Int("active_tunnels", len(m.tunnels)),
zap.Int64("active_tunnels", m.tunnelCount.Load()),
)
for _, tc := range m.tunnels {
tc.Close()
// Close all tunnels in each shard
for i := 0; i < numShards; i++ {
s := &m.shards[i]
s.mu.Lock()
for _, tc := range s.tunnels {
tc.Close()
}
s.tunnels = make(map[string]*Connection)
s.used = make(map[string]bool)
s.mu.Unlock()
}
m.tunnels = make(map[string]*Connection)
m.used = make(map[string]bool)
m.tunnelCount.Store(0)
}

View File

@@ -9,9 +9,33 @@ const (
// DefaultWSPort is the default WebSocket port
DefaultWSPort = 8080
// ==================== Yamux Configuration ====================
// These settings are tuned for high-throughput tunnel scenarios
// YamuxAcceptBacklog controls how many incoming streams can be queued
// before yamux starts blocking stream opens under load.
YamuxAcceptBacklog = 4096
// Increased from default 256 to handle burst traffic.
YamuxAcceptBacklog = 8192
// YamuxMaxStreamWindowSize is the maximum window size for a stream.
// Larger windows allow more data in-flight, improving throughput on high-latency links.
// Default is 256KB, increased to 512KB for better throughput.
YamuxMaxStreamWindowSize = 512 * 1024
// YamuxStreamOpenTimeout is how long to wait for a stream open to complete.
YamuxStreamOpenTimeout = 10 * time.Second
// YamuxStreamCloseTimeout is how long to wait for a stream close to complete.
YamuxStreamCloseTimeout = 5 * time.Minute
// YamuxKeepAliveInterval is how often to send keep-alive pings.
// Set higher than HeartbeatInterval to avoid redundant pings.
YamuxKeepAliveInterval = 15 * time.Second
// YamuxConnectionWriteTimeout is the timeout for writing to the underlying connection.
YamuxConnectionWriteTimeout = 10 * time.Second
// ==================== Heartbeat Configuration ====================
// HeartbeatInterval is how often clients send heartbeat messages
HeartbeatInterval = 2 * time.Second
@@ -19,9 +43,13 @@ const (
// HeartbeatTimeout is how long the server waits before considering a connection dead
HeartbeatTimeout = 6 * time.Second
// ==================== Request/Response Timeouts ====================
// RequestTimeout is the maximum time to wait for a response from the client
RequestTimeout = 30 * time.Second
// ==================== Reconnection Configuration ====================
// ReconnectBaseDelay is the initial delay for reconnection attempts
ReconnectBaseDelay = 1 * time.Second
@@ -31,9 +59,12 @@ const (
// MaxReconnectAttempts is the maximum number of reconnection attempts (0 = infinite)
MaxReconnectAttempts = 0
// ==================== TCP Port Allocation ====================
// DefaultTCPPortMin/Max define the default allocation range for TCP tunnels
DefaultTCPPortMin = 20000
DefaultTCPPortMax = 40000
// DefaultDomain is the default domain for tunnels
DefaultDomain = "tunnel.localhost"
)

View File

@@ -0,0 +1,35 @@
package mux
import (
"io"
"github.com/hashicorp/yamux"
"drip/internal/shared/constants"
)
// NewOptimizedConfig returns a multiplexer config optimized for tunnel scenarios.
func NewOptimizedConfig() *yamux.Config {
cfg := yamux.DefaultConfig()
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
cfg.MaxStreamWindowSize = constants.YamuxMaxStreamWindowSize
cfg.StreamOpenTimeout = constants.YamuxStreamOpenTimeout
cfg.StreamCloseTimeout = constants.YamuxStreamCloseTimeout
cfg.ConnectionWriteTimeout = constants.YamuxConnectionWriteTimeout
cfg.EnableKeepAlive = true
cfg.KeepAliveInterval = constants.YamuxKeepAliveInterval
cfg.LogOutput = io.Discard
return cfg
}
// NewServerConfig returns a multiplexer config for server-side use.
func NewServerConfig() *yamux.Config {
return NewOptimizedConfig()
}
// NewClientConfig returns a multiplexer config for client-side use.
func NewClientConfig() *yamux.Config {
cfg := NewOptimizedConfig()
cfg.EnableKeepAlive = false
return cfg
}

View File

@@ -2,14 +2,12 @@ package netutil
import "net"
// CountingConn wraps a net.Conn to count bytes read/written.
type CountingConn struct {
net.Conn
OnRead func(int64)
OnWrite func(int64)
}
// NewCountingConn creates a new CountingConn.
func NewCountingConn(conn net.Conn, onRead, onWrite func(int64)) *CountingConn {
return &CountingConn{
Conn: conn,

View File

@@ -3,15 +3,17 @@ package pool
import "sync"
const (
SizeSmall = 4 * 1024 // 4KB
SizeMedium = 32 * 1024 // 32KB
SizeLarge = 256 * 1024 // 256KB
SizeSmall = 4 * 1024 // 4KB - HTTP headers, small messages
SizeMedium = 32 * 1024 // 32KB - HTTP request/response bodies
SizeLarge = 256 * 1024 // 256KB - Data pipe, file transfers
SizeXLarge = 1024 * 1024 // 1MB - Large file transfers, bulk data
)
type BufferPool struct {
small sync.Pool
medium sync.Pool
large sync.Pool
xlarge sync.Pool
}
func NewBufferPool() *BufferPool {
@@ -34,6 +36,12 @@ func NewBufferPool() *BufferPool {
return &b
},
},
xlarge: sync.Pool{
New: func() interface{} {
b := make([]byte, SizeXLarge)
return &b
},
},
}
}
@@ -43,8 +51,10 @@ func (p *BufferPool) Get(size int) *[]byte {
return p.small.Get().(*[]byte)
case size <= SizeMedium:
return p.medium.Get().(*[]byte)
default:
case size <= SizeLarge:
return p.large.Get().(*[]byte)
default:
return p.xlarge.Get().(*[]byte)
}
}
@@ -63,7 +73,24 @@ func (p *BufferPool) Put(buf *[]byte) {
p.medium.Put(buf)
case SizeLarge:
p.large.Put(buf)
case SizeXLarge:
p.xlarge.Put(buf)
}
// Note: buffers with non-standard sizes are not pooled (let GC handle them)
}
// GetXLarge returns a 1MB buffer for bulk data transfers
func (p *BufferPool) GetXLarge() *[]byte {
return p.xlarge.Get().(*[]byte)
}
// PutXLarge returns a 1MB buffer to the pool
func (p *BufferPool) PutXLarge(buf *[]byte) {
if buf == nil || cap(*buf) != SizeXLarge {
return
}
*buf = (*buf)[:cap(*buf)]
p.xlarge.Put(buf)
}
var globalBufferPool = NewBufferPool()
@@ -75,3 +102,13 @@ func GetBuffer(size int) *[]byte {
func PutBuffer(buf *[]byte) {
globalBufferPool.Put(buf)
}
// GetXLargeBuffer returns a 1MB buffer from the global pool
func GetXLargeBuffer() *[]byte {
return globalBufferPool.GetXLarge()
}
// PutXLargeBuffer returns a 1MB buffer to the global pool
func PutXLargeBuffer(buf *[]byte) {
globalBufferPool.PutXLarge(buf)
}

View File

@@ -11,7 +11,9 @@ import (
const (
FrameHeaderSize = 5
MaxFrameSize = 10 * 1024 * 1024
// MaxFrameSize limits payload size to prevent memory exhaustion attacks.
// 1MB is sufficient for most HTTP requests/responses while limiting DoS impact.
MaxFrameSize = 1 * 1024 * 1024 // 1MB (reduced from 10MB)
)
// FrameType defines the type of frame
@@ -88,8 +90,9 @@ func WriteFrame(w io.Writer, frame *Frame) error {
}
func ReadFrame(r io.Reader) (*Frame, error) {
header := make([]byte, FrameHeaderSize)
if _, err := io.ReadFull(r, header); err != nil {
// Use stack-allocated array to avoid heap allocation
var header [FrameHeaderSize]byte
if _, err := io.ReadFull(r, header[:]); err != nil {
return nil, fmt.Errorf("failed to read frame header: %w", err)
}

View File

@@ -6,31 +6,24 @@ import (
"time"
)
// TrafficStats tracks traffic statistics for a tunnel connection
type TrafficStats struct {
// Total bytes
totalBytesIn int64
totalBytesOut int64
// Request counts
totalRequests int64
activeConnections int64
// For speed calculation
lastBytesIn int64
lastBytesOut int64
lastTime time.Time
speedMu sync.Mutex
// Current speed (bytes per second)
speedIn int64
speedOut int64
// Start time
startTime time.Time
}
// NewTrafficStats creates a new traffic stats tracker
func NewTrafficStats() *TrafficStats {
now := time.Now()
return &TrafficStats{
@@ -39,17 +32,14 @@ func NewTrafficStats() *TrafficStats {
}
}
// AddBytesIn adds incoming bytes to the counter
func (s *TrafficStats) AddBytesIn(n int64) {
atomic.AddInt64(&s.totalBytesIn, n)
}
// AddBytesOut adds outgoing bytes to the counter
func (s *TrafficStats) AddBytesOut(n int64) {
atomic.AddInt64(&s.totalBytesOut, n)
}
// AddRequest increments the request counter
func (s *TrafficStats) AddRequest() {
atomic.AddInt64(&s.totalRequests, 1)
}
@@ -59,23 +49,25 @@ func (s *TrafficStats) IncActiveConnections() {
}
func (s *TrafficStats) DecActiveConnections() {
v := atomic.AddInt64(&s.activeConnections, -1)
if v < 0 {
atomic.StoreInt64(&s.activeConnections, 0)
for {
old := atomic.LoadInt64(&s.activeConnections)
if old <= 0 {
return
}
if atomic.CompareAndSwapInt64(&s.activeConnections, old, old-1) {
return
}
}
}
// GetTotalBytesIn returns total incoming bytes
func (s *TrafficStats) GetTotalBytesIn() int64 {
return atomic.LoadInt64(&s.totalBytesIn)
}
// GetTotalBytesOut returns total outgoing bytes
func (s *TrafficStats) GetTotalBytesOut() int64 {
return atomic.LoadInt64(&s.totalBytesOut)
}
// GetTotalRequests returns total request count
func (s *TrafficStats) GetTotalRequests() int64 {
return atomic.LoadInt64(&s.totalRequests)
}
@@ -84,21 +76,16 @@ func (s *TrafficStats) GetActiveConnections() int64 {
return atomic.LoadInt64(&s.activeConnections)
}
// GetTotalBytes returns total bytes (in + out)
func (s *TrafficStats) GetTotalBytes() int64 {
return s.GetTotalBytesIn() + s.GetTotalBytesOut()
}
// UpdateSpeed calculates current transfer speed
// Should be called periodically (e.g., every second)
func (s *TrafficStats) UpdateSpeed() {
s.speedMu.Lock()
defer s.speedMu.Unlock()
now := time.Now()
elapsed := now.Sub(s.lastTime).Seconds()
// Require minimum interval of 100ms to avoid division issues
if elapsed < 0.1 {
return
}
@@ -109,18 +96,15 @@ func (s *TrafficStats) UpdateSpeed() {
deltaIn := currentIn - s.lastBytesIn
deltaOut := currentOut - s.lastBytesOut
// Calculate instantaneous speed
if deltaIn > 0 {
s.speedIn = int64(float64(deltaIn) / elapsed)
} else {
// No new bytes - set speed to 0 immediately
s.speedIn = 0
}
if deltaOut > 0 {
s.speedOut = int64(float64(deltaOut) / elapsed)
} else {
// No new bytes - set speed to 0 immediately
s.speedOut = 0
}
@@ -129,38 +113,33 @@ func (s *TrafficStats) UpdateSpeed() {
s.lastTime = now
}
// GetSpeedIn returns current incoming speed in bytes per second
func (s *TrafficStats) GetSpeedIn() int64 {
s.speedMu.Lock()
defer s.speedMu.Unlock()
return s.speedIn
}
// GetSpeedOut returns current outgoing speed in bytes per second
func (s *TrafficStats) GetSpeedOut() int64 {
s.speedMu.Lock()
defer s.speedMu.Unlock()
return s.speedOut
}
// GetUptime returns how long the connection has been active
func (s *TrafficStats) GetUptime() time.Duration {
return time.Since(s.startTime)
}
// Snapshot returns a snapshot of all stats
type Snapshot struct {
TotalBytesIn int64
TotalBytesOut int64
TotalBytes int64
TotalRequests int64
ActiveConnections int64
SpeedIn int64 // bytes per second
SpeedOut int64 // bytes per second
SpeedIn int64
SpeedOut int64
Uptime time.Duration
}
// GetSnapshot returns a snapshot of current stats
func (s *TrafficStats) GetSnapshot() Snapshot {
s.speedMu.Lock()
speedIn := s.speedIn

View File

@@ -0,0 +1,61 @@
package tuning
import (
"runtime"
"runtime/debug"
)
type Mode int
const (
ModeClient Mode = iota
ModeServer
)
type Config struct {
GCPercent int
MemoryLimit int64
}
func DefaultClientConfig() Config {
total := int64(getSystemTotalMemory())
limit := total / 4
if limit < 64*1024*1024 {
limit = 64 * 1024 * 1024
}
return Config{
GCPercent: 100,
MemoryLimit: limit,
}
}
func DefaultServerConfig() Config {
total := int64(getSystemTotalMemory())
limit := total * 3 / 4
if limit < 128*1024*1024 {
limit = 128 * 1024 * 1024
}
return Config{
GCPercent: 200,
MemoryLimit: limit,
}
}
func Apply(cfg Config) {
runtime.GOMAXPROCS(runtime.NumCPU())
if cfg.GCPercent > 0 {
debug.SetGCPercent(cfg.GCPercent)
}
if cfg.MemoryLimit > 0 {
debug.SetMemoryLimit(cfg.MemoryLimit)
}
}
func ApplyMode(mode Mode) {
switch mode {
case ModeClient:
Apply(DefaultClientConfig())
case ModeServer:
Apply(DefaultServerConfig())
}
}

View File

@@ -0,0 +1,28 @@
//go:build darwin
package tuning
import (
"syscall"
"unsafe"
)
func getSystemTotalMemory() uint64 {
mib := [2]int32{6, 24} // CTL_HW, HW_MEMSIZE
var value uint64
size := unsafe.Sizeof(value)
_, _, errno := syscall.Syscall6(
syscall.SYS___SYSCTL,
uintptr(unsafe.Pointer(&mib[0])),
2,
uintptr(unsafe.Pointer(&value)),
uintptr(unsafe.Pointer(&size)),
0,
0,
)
if errno != 0 {
return 1024 * 1024 * 1024
}
return value
}

View File

@@ -0,0 +1,13 @@
//go:build linux
package tuning
import "syscall"
func getSystemTotalMemory() uint64 {
var info syscall.Sysinfo_t
if err := syscall.Sysinfo(&info); err == nil {
return info.Totalram * uint64(info.Unit)
}
return 1024 * 1024 * 1024
}

View File

@@ -0,0 +1,7 @@
//go:build !linux && !darwin && !windows
package tuning
func getSystemTotalMemory() uint64 {
return 1024 * 1024 * 1024 // 1GB fallback
}

View File

@@ -0,0 +1,36 @@
//go:build windows
package tuning
import (
"syscall"
"unsafe"
)
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
globalMemoryStatusEx = kernel32.NewProc("GlobalMemoryStatusEx")
)
type memoryStatusEx struct {
dwLength uint32
dwMemoryLoad uint32
ullTotalPhys uint64
ullAvailPhys uint64
ullTotalPageFile uint64
ullAvailPageFile uint64
ullTotalVirtual uint64
ullAvailVirtual uint64
ullAvailExtendedVirtual uint64
}
func getSystemTotalMemory() uint64 {
var mem memoryStatusEx
mem.dwLength = uint32(unsafe.Sizeof(mem))
ret, _, _ := globalMemoryStatusEx.Call(uintptr(unsafe.Pointer(&mem)))
if ret == 0 {
return 1024 * 1024 * 1024
}
return mem.ullTotalPhys
}

View File

@@ -3,29 +3,34 @@ package utils
import (
"crypto/rand"
"encoding/hex"
"time"
"fmt"
)
// GenerateID generates a random unique ID
// GenerateID generates a cryptographically secure random unique ID (32 hex chars).
// Panics if crypto/rand fails - this indicates a critical system issue.
func GenerateID() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
// Fallback to timestamp-based ID if crypto/rand fails
return generateFallbackID()
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}
// GenerateShortID generates a shorter random ID (8 chars)
// GenerateShortID generates a cryptographically secure shorter random ID (8 hex chars).
// Panics if crypto/rand fails.
func GenerateShortID() string {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
return generateFallbackID()[:8]
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}
func generateFallbackID() string {
// Simple fallback using timestamp
return hex.EncodeToString([]byte(time.Now().String()))
// TryGenerateID returns an ID or error (for cases where panic is not desired).
func TryGenerateID() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("crypto/rand failed: %w", err)
}
return hex.EncodeToString(b), nil
}

View File

@@ -66,7 +66,7 @@ cleanup() {
fi
# Extra cleanup: ensure ports are released
pkill -f "python.*${HTTP_TEST_PORT}" 2>/dev/null || true
pkill -f "test-server.*${HTTP_TEST_PORT}" 2>/dev/null || true
pkill -f "drip server.*${DRIP_SERVER_PORT}" 2>/dev/null || true
pkill -f "drip http ${HTTP_TEST_PORT}" 2>/dev/null || true
@@ -86,8 +86,8 @@ check_dependencies() {
missing="${missing}\n - wrk (brew install wrk)"
fi
if ! command -v python3 &> /dev/null; then
missing="${missing}\n - python3"
if ! command -v go &> /dev/null; then
missing="${missing}\n - go (https://go.dev/dl/)"
fi
if ! command -v openssl &> /dev/null; then
@@ -157,42 +157,28 @@ wait_for_port() {
return 0
}
# Start HTTP test server
# Start HTTP test server (high-performance Go server)
start_http_server() {
log_step "Starting HTTP test server (port $HTTP_TEST_PORT)..."
# Create simple test server
cat > "${LOG_DIR}/test-server.py" << 'EOF'
import http.server
import socketserver
import json
from datetime import datetime
import sys
# Build Go test server
local test_server_dir="scripts/test/test-server"
local test_server_bin="${LOG_DIR}/test-server"
PORT = int(sys.argv[1]) if len(sys.argv) > 1 else 3000
if [ ! -d "$test_server_dir" ]; then
log_error "Test server source not found: $test_server_dir"
exit 1
fi
class TestHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
response = {
"status": "ok",
"timestamp": datetime.now().isoformat(),
"message": "Test server response"
}
log_info "Building Go test server..."
if ! go build -o "$test_server_bin" "./$test_server_dir" > "${LOG_DIR}/build.log" 2>&1; then
log_error "Failed to build test server"
cat "${LOG_DIR}/build.log"
exit 1
fi
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(response).encode())
def log_message(self, format, *args):
pass # Silent logging
with socketserver.TCPServer(("", PORT), TestHandler) as httpd:
print(f"Server started on port {PORT}", flush=True)
httpd.serve_forever()
EOF
python3 "${LOG_DIR}/test-server.py" "$HTTP_TEST_PORT" \
# Start the Go test server
"$test_server_bin" -port "$HTTP_TEST_PORT" \
> "${LOG_DIR}/http-server.log" 2>&1 &
local pid=$!
echo "$pid" >> "$PIDS_FILE"
@@ -201,6 +187,7 @@ EOF
log_info "✓ HTTP test server started (PID: $pid)"
else
log_error "HTTP test server failed to start"
cat "${LOG_DIR}/http-server.log"
exit 1
fi
}
@@ -475,12 +462,12 @@ main() {
exit 1
fi
# Warm up
# Warm up (sequential to ensure connection pool is ready)
log_info "Warming up tunnel (5s)..."
for _ in {1..5}; do
curl -sk "$TUNNEL_URL" > /dev/null 2>&1 || true
sleep 1
for _ in {1..20}; do
curl -sk --max-time 2 "$TUNNEL_URL" > /dev/null 2>&1 || true
done
sleep 2
# Run tests
run_performance_tests "$TUNNEL_URL"

View File

@@ -0,0 +1,33 @@
// High-performance test HTTP server for benchmarking
package main
import (
"flag"
"fmt"
"log"
"net/http"
"runtime"
)
func main() {
port := flag.Int("port", 3000, "Port to listen on")
flag.Parse()
runtime.GOMAXPROCS(runtime.NumCPU())
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "OK")
})
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"status":"ok"}`)
})
addr := fmt.Sprintf(":%d", *port)
log.Printf("Test server listening on %s", addr)
log.Fatal(http.ListenAndServe(addr, nil))
}