mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 12:53:43 +00:00
perf(core): Optimizes performance configuration and resource management
- Removed the manual performance optimization configuration in main.go and replaced it with a new tuning module. - Add patterned GC tuning in server.go and tunnel_runner.go - Updated yamux configuration to a unified optimized configuration to improve throughput. - Implement connection pool preheating function to eliminate cold start delay. - Optimize session selection using a min-heap, reducing the time complexity from O(n) to O(log n). - Add a bufio.Reader pool and a buffer pool to reduce memory allocation. - Implement a fragmented lock manager to improve performance under high concurrency. - Adjust heartbeat and timeout configurations to suit high-throughput scenarios BREAKING CHANGE: Manual GC tuning configuration has been removed; automatic tuning mode is now used.
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -52,5 +52,5 @@ temp/
|
||||
certs/
|
||||
.drip-server.env
|
||||
benchmark-results/
|
||||
drip
|
||||
drip-linux-amd64
|
||||
./drip
|
||||
@@ -3,8 +3,6 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
|
||||
"drip/internal/client/cli"
|
||||
)
|
||||
@@ -16,9 +14,6 @@ var (
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Performance optimizations
|
||||
setupPerformanceOptimizations()
|
||||
|
||||
cli.SetVersion(Version, GitCommit, BuildTime)
|
||||
|
||||
if err := cli.Execute(); err != nil {
|
||||
@@ -26,19 +21,3 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// setupPerformanceOptimizations configures runtime settings for optimal performance
|
||||
func setupPerformanceOptimizations() {
|
||||
// Set GOMAXPROCS to use all available CPU cores
|
||||
numCPU := runtime.NumCPU()
|
||||
runtime.GOMAXPROCS(numCPU)
|
||||
|
||||
// Reduce GC frequency for high-throughput scenarios
|
||||
// Default is 100, setting to 200 reduces GC overhead at cost of more memory
|
||||
// This is beneficial since we now use buffer pools (less garbage)
|
||||
debug.SetGCPercent(200)
|
||||
|
||||
// Set memory limit to prevent OOM (adjust based on your server)
|
||||
// This is a soft limit - Go will try to stay under this
|
||||
debug.SetMemoryLimit(8 * 1024 * 1024 * 1024) // 8GB limit
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"drip/internal/server/tcp"
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/tuning"
|
||||
"drip/internal/shared/utils"
|
||||
"drip/pkg/config"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -60,6 +61,9 @@ func init() {
|
||||
}
|
||||
|
||||
func runServer(_ *cobra.Command, _ []string) error {
|
||||
// Apply server-mode GC tuning (high throughput, more memory)
|
||||
tuning.ApplyMode(tuning.ModeServer)
|
||||
|
||||
if serverTLSCert == "" {
|
||||
return fmt.Errorf("TLS certificate path is required (use --tls-cert flag or DRIP_TLS_CERT environment variable)")
|
||||
}
|
||||
|
||||
@@ -9,8 +9,10 @@ import (
|
||||
"time"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/tuning"
|
||||
"drip/internal/shared/ui"
|
||||
"drip/internal/shared/utils"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -20,6 +22,8 @@ const (
|
||||
)
|
||||
|
||||
func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) error {
|
||||
tuning.ApplyMode(tuning.ModeClient)
|
||||
|
||||
if err := utils.InitLogger(verbose); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/mux"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/stats"
|
||||
"drip/pkg/config"
|
||||
@@ -202,10 +202,7 @@ func (c *PoolClient) Connect() error {
|
||||
c.tunnelID = resp.TunnelID
|
||||
}
|
||||
|
||||
yamuxCfg := yamux.DefaultConfig()
|
||||
yamuxCfg.EnableKeepAlive = false
|
||||
yamuxCfg.LogOutput = io.Discard
|
||||
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
yamuxCfg := mux.NewClientConfig()
|
||||
|
||||
session, err := yamux.Server(primaryConn, yamuxCfg)
|
||||
if err != nil {
|
||||
@@ -241,7 +238,7 @@ func (c *PoolClient) Connect() error {
|
||||
c.desiredTotal = c.initialSessions
|
||||
c.mu.Unlock()
|
||||
|
||||
c.ensureSessions()
|
||||
c.warmupSessions()
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.scalerLoop()
|
||||
|
||||
@@ -2,7 +2,6 @@ package tcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -11,10 +10,12 @@ import (
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/hashicorp/yamux"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/mux"
|
||||
"drip/internal/shared/protocol"
|
||||
)
|
||||
|
||||
var dataConnCounter atomic.Uint64
|
||||
|
||||
// sessionHandle wraps a yamux session with metadata.
|
||||
type sessionHandle struct {
|
||||
id string
|
||||
@@ -37,6 +38,36 @@ func (h *sessionHandle) lastActiveTime() time.Time {
|
||||
return time.Unix(0, n)
|
||||
}
|
||||
|
||||
// warmupSessions pre-creates initial sessions in parallel to eliminate cold-start latency.
|
||||
func (c *PoolClient) warmupSessions() {
|
||||
if c.IsClosed() || c.tunnelID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
desired := c.desiredTotal
|
||||
c.mu.RUnlock()
|
||||
|
||||
current := c.sessionCount()
|
||||
toCreate := desired - current
|
||||
if toCreate <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < toCreate; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = c.addDataSession()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Brief wait for server to register all sessions
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// scalerLoop monitors load and adjusts session count.
|
||||
func (c *PoolClient) scalerLoop() {
|
||||
defer c.wg.Done()
|
||||
@@ -162,7 +193,7 @@ func (c *PoolClient) addDataSession() error {
|
||||
return err
|
||||
}
|
||||
|
||||
connID := fmt.Sprintf("data-%d", time.Now().UnixNano())
|
||||
connID := fmt.Sprintf("data-%d", dataConnCounter.Add(1))
|
||||
|
||||
req := protocol.DataConnectRequest{
|
||||
TunnelID: c.tunnelID,
|
||||
@@ -214,10 +245,7 @@ func (c *PoolClient) addDataSession() error {
|
||||
return fmt.Errorf("data connection rejected: %s", resp.Message)
|
||||
}
|
||||
|
||||
yamuxCfg := yamux.DefaultConfig()
|
||||
yamuxCfg.EnableKeepAlive = false
|
||||
yamuxCfg.LogOutput = io.Discard
|
||||
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
yamuxCfg := mux.NewClientConfig()
|
||||
|
||||
session, err := yamux.Server(conn, yamuxCfg)
|
||||
if err != nil {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
@@ -16,11 +17,19 @@ import (
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/httputil"
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// bufio.Reader pool to reduce allocations on hot path
|
||||
var bufioReaderPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return bufio.NewReaderSize(nil, 32*1024)
|
||||
},
|
||||
}
|
||||
|
||||
const openStreamTimeout = 3 * time.Second
|
||||
|
||||
type Handler struct {
|
||||
@@ -104,13 +113,19 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r)
|
||||
reader := bufioReaderPool.Get().(*bufio.Reader)
|
||||
reader.Reset(countingStream)
|
||||
resp, err := http.ReadResponse(reader, r)
|
||||
if err != nil {
|
||||
bufioReaderPool.Put(reader)
|
||||
w.Header().Set("Connection", "close")
|
||||
http.Error(w, "Read response failed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() {
|
||||
resp.Body.Close()
|
||||
bufioReaderPool.Put(reader)
|
||||
}()
|
||||
|
||||
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
|
||||
|
||||
@@ -147,7 +162,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
// Use pooled buffer for zero-copy optimization
|
||||
buf := pool.GetBuffer(pool.SizeLarge)
|
||||
_, _ = io.CopyBuffer(w, resp.Body, (*buf)[:])
|
||||
pool.PutBuffer(buf)
|
||||
|
||||
close(done)
|
||||
stream.Close()
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/mux"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
@@ -686,10 +687,8 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
// Use optimized mux config for server
|
||||
cfg := mux.NewServerConfig()
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
@@ -15,6 +15,54 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// sessionEntry represents a session with its current stream count for heap operations
|
||||
type sessionEntry struct {
|
||||
id string
|
||||
session *yamux.Session
|
||||
streams int
|
||||
heapIdx int // index in the heap, managed by heap.Interface
|
||||
}
|
||||
|
||||
// sessionHeap implements heap.Interface for O(log n) session selection
|
||||
type sessionHeap []*sessionEntry
|
||||
|
||||
func (h sessionHeap) Len() int { return len(h) }
|
||||
|
||||
func (h sessionHeap) Less(i, j int) bool {
|
||||
// Min-heap: session with fewer streams has higher priority
|
||||
return h[i].streams < h[j].streams
|
||||
}
|
||||
|
||||
func (h sessionHeap) Swap(i, j int) {
|
||||
h[i], h[j] = h[j], h[i]
|
||||
h[i].heapIdx = i
|
||||
h[j].heapIdx = j
|
||||
}
|
||||
|
||||
func (h *sessionHeap) Push(x interface{}) {
|
||||
entry := x.(*sessionEntry)
|
||||
entry.heapIdx = len(*h)
|
||||
*h = append(*h, entry)
|
||||
}
|
||||
|
||||
func (h *sessionHeap) Pop() interface{} {
|
||||
old := *h
|
||||
n := len(old)
|
||||
entry := old[n-1]
|
||||
old[n-1] = nil // avoid memory leak
|
||||
entry.heapIdx = -1
|
||||
*h = old[0 : n-1]
|
||||
return entry
|
||||
}
|
||||
|
||||
// sessionHeapPool reuses heap slices to reduce allocations
|
||||
var sessionHeapPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
h := make(sessionHeap, 0, 16)
|
||||
return &h
|
||||
},
|
||||
}
|
||||
|
||||
type ConnectionGroup struct {
|
||||
TunnelID string
|
||||
Subdomain string
|
||||
@@ -60,6 +108,12 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
|
||||
const maxConsecutiveFailures = 3
|
||||
failureCount := make(map[string]int)
|
||||
|
||||
type sessionSnapshot struct {
|
||||
id string
|
||||
session *yamux.Session
|
||||
}
|
||||
sessions := make([]sessionSnapshot, 0, 16)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
@@ -67,26 +121,25 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
sessions = sessions[:0]
|
||||
g.mu.RLock()
|
||||
sessions := make(map[string]*yamux.Session, len(g.Sessions))
|
||||
for id, s := range g.Sessions {
|
||||
sessions[id] = s
|
||||
sessions = append(sessions, sessionSnapshot{id: id, session: s})
|
||||
}
|
||||
g.mu.RUnlock()
|
||||
|
||||
for id, session := range sessions {
|
||||
if session == nil || session.IsClosed() {
|
||||
g.RemoveSession(id)
|
||||
delete(failureCount, id)
|
||||
for _, snap := range sessions {
|
||||
if snap.session == nil || snap.session.IsClosed() {
|
||||
g.RemoveSession(snap.id)
|
||||
delete(failureCount, snap.id)
|
||||
continue
|
||||
}
|
||||
|
||||
// Ping with timeout
|
||||
done := make(chan error, 1)
|
||||
go func(s *yamux.Session) {
|
||||
_, err := s.Ping()
|
||||
done <- err
|
||||
}(session)
|
||||
}(snap.session)
|
||||
|
||||
var err error
|
||||
select {
|
||||
@@ -98,31 +151,29 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
failureCount[id]++
|
||||
failureCount[snap.id]++
|
||||
g.logger.Debug("Session ping failed",
|
||||
zap.String("session_id", id),
|
||||
zap.Int("consecutive_failures", failureCount[id]),
|
||||
zap.String("session_id", snap.id),
|
||||
zap.Int("consecutive_failures", failureCount[snap.id]),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
if failureCount[id] >= maxConsecutiveFailures {
|
||||
if failureCount[snap.id] >= maxConsecutiveFailures {
|
||||
g.logger.Warn("Session ping failed too many times, removing",
|
||||
zap.String("session_id", id),
|
||||
zap.Int("failures", failureCount[id]),
|
||||
zap.String("session_id", snap.id),
|
||||
zap.Int("failures", failureCount[snap.id]),
|
||||
)
|
||||
g.RemoveSession(id)
|
||||
delete(failureCount, id)
|
||||
g.RemoveSession(snap.id)
|
||||
delete(failureCount, snap.id)
|
||||
}
|
||||
} else {
|
||||
// Reset on success
|
||||
failureCount[id] = 0
|
||||
failureCount[snap.id] = 0
|
||||
g.mu.Lock()
|
||||
g.LastActivity = time.Now()
|
||||
g.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all sessions are gone
|
||||
g.mu.RLock()
|
||||
sessionCount := len(g.Sessions)
|
||||
g.mu.RUnlock()
|
||||
@@ -214,6 +265,7 @@ func (g *ConnectionGroup) SessionCount() int {
|
||||
return len(g.Sessions)
|
||||
}
|
||||
|
||||
// OpenStream opens a new stream using a min-heap for O(log n) session selection.
|
||||
func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
|
||||
const (
|
||||
maxStreamsPerSession = 256
|
||||
@@ -230,61 +282,33 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
|
||||
default:
|
||||
}
|
||||
|
||||
// Prefer data sessions for data-plane traffic; keep the primary session
|
||||
// as control-plane (client ping/latency), and only fall back to primary
|
||||
// when no data session exists.
|
||||
sessions := g.sessionsSnapshot(false)
|
||||
if len(sessions) == 0 {
|
||||
sessions = g.sessionsSnapshot(true)
|
||||
h := g.buildSessionHeap(false)
|
||||
if h.Len() == 0 {
|
||||
h = g.buildSessionHeap(true)
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
if h.Len() == 0 {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
|
||||
tried := make([]bool, len(sessions))
|
||||
anyUnderCap := false
|
||||
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
|
||||
for h.Len() > 0 {
|
||||
entry := heap.Pop(h).(*sessionEntry)
|
||||
session := entry.session
|
||||
|
||||
for range sessions {
|
||||
bestIdx := -1
|
||||
minStreams := int(^uint(0) >> 1)
|
||||
|
||||
for i := 0; i < len(sessions); i++ {
|
||||
idx := (start + i) % len(sessions)
|
||||
if tried[idx] {
|
||||
continue
|
||||
}
|
||||
|
||||
session := sessions[idx]
|
||||
if session == nil || session.IsClosed() {
|
||||
tried[idx] = true
|
||||
continue
|
||||
}
|
||||
|
||||
n := session.NumStreams()
|
||||
if n >= maxStreamsPerSession {
|
||||
continue
|
||||
}
|
||||
anyUnderCap = true
|
||||
|
||||
if n < minStreams {
|
||||
minStreams = n
|
||||
bestIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
if bestIdx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
tried[bestIdx] = true
|
||||
session := sessions[bestIdx]
|
||||
if session == nil || session.IsClosed() {
|
||||
continue
|
||||
}
|
||||
|
||||
currentStreams := session.NumStreams()
|
||||
if currentStreams >= maxStreamsPerSession {
|
||||
continue
|
||||
}
|
||||
anyUnderCap = true
|
||||
|
||||
stream, err := session.Open()
|
||||
if err == nil {
|
||||
*h = (*h)[:0]
|
||||
sessionHeapPool.Put(h)
|
||||
return stream, nil
|
||||
}
|
||||
lastErr = err
|
||||
@@ -294,6 +318,9 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
|
||||
}
|
||||
}
|
||||
|
||||
*h = (*h)[:0]
|
||||
sessionHeapPool.Put(h)
|
||||
|
||||
if !anyUnderCap {
|
||||
lastErr = fmt.Errorf("all sessions are at stream capacity (%d)", maxStreamsPerSession)
|
||||
}
|
||||
@@ -313,31 +340,59 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) selectSession() *yamux.Session {
|
||||
sessions := g.sessionsSnapshot(false)
|
||||
if len(sessions) == 0 {
|
||||
sessions = g.sessionsSnapshot(true)
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
return nil
|
||||
// buildSessionHeap creates a min-heap of sessions ordered by stream count.
|
||||
func (g *ConnectionGroup) buildSessionHeap(includePrimary bool) *sessionHeap {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
|
||||
if len(g.Sessions) == 0 {
|
||||
h := sessionHeapPool.Get().(*sessionHeap)
|
||||
return h
|
||||
}
|
||||
|
||||
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
|
||||
minStreams := int(^uint(0) >> 1)
|
||||
var best *yamux.Session
|
||||
h := sessionHeapPool.Get().(*sessionHeap)
|
||||
*h = (*h)[:0]
|
||||
|
||||
for i := 0; i < len(sessions); i++ {
|
||||
session := sessions[(start+i)%len(sessions)]
|
||||
for id, session := range g.Sessions {
|
||||
if session == nil || session.IsClosed() {
|
||||
continue
|
||||
}
|
||||
if n := session.NumStreams(); n < minStreams {
|
||||
minStreams = n
|
||||
best = session
|
||||
if id == "primary" && !includePrimary {
|
||||
continue
|
||||
}
|
||||
|
||||
*h = append(*h, &sessionEntry{
|
||||
id: id,
|
||||
session: session,
|
||||
streams: session.NumStreams(),
|
||||
})
|
||||
}
|
||||
|
||||
return best
|
||||
heap.Init(h)
|
||||
return h
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) selectSession() *yamux.Session {
|
||||
h := g.buildSessionHeap(false)
|
||||
if h.Len() == 0 {
|
||||
sessionHeapPool.Put(h)
|
||||
h = g.buildSessionHeap(true)
|
||||
}
|
||||
if h.Len() == 0 {
|
||||
sessionHeapPool.Put(h)
|
||||
return nil
|
||||
}
|
||||
|
||||
entry := heap.Pop(h).(*sessionEntry)
|
||||
session := entry.session
|
||||
|
||||
*h = (*h)[:0]
|
||||
sessionHeapPool.Put(h)
|
||||
|
||||
if session == nil || session.IsClosed() {
|
||||
return nil
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) sessionsSnapshot(includePrimary bool) []*yamux.Session {
|
||||
|
||||
@@ -3,12 +3,11 @@ package tcp
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/mux"
|
||||
)
|
||||
|
||||
type bufferedConn struct {
|
||||
@@ -27,10 +26,8 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
// Use optimized mux config for server
|
||||
cfg := mux.NewServerConfig()
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
@@ -66,10 +63,8 @@ func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
// Use optimized mux config for server
|
||||
cfg := mux.NewServerConfig()
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
|
||||
@@ -115,8 +115,7 @@ func (c *Connection) GetTunnelType() protocol.TunnelType {
|
||||
return c.tunnelType
|
||||
}
|
||||
|
||||
// SetOpenStream registers a yamux stream opener for this tunnel.
|
||||
// It is used by the HTTP proxy to forward each request over a mux stream.
|
||||
// SetOpenStream registers a stream opener for this tunnel.
|
||||
func (c *Connection) SetOpenStream(open func() (net.Conn, error)) {
|
||||
c.mu.Lock()
|
||||
c.openStream = open
|
||||
@@ -178,17 +177,11 @@ func (c *Connection) GetActiveConnections() int64 {
|
||||
|
||||
// StartWritePump starts the write pump for sending messages
|
||||
func (c *Connection) StartWritePump() {
|
||||
// Skip write pump for TCP-only connections (no WebSocket)
|
||||
if c.Conn == nil {
|
||||
c.logger.Debug("Skipping WritePump for TCP connection",
|
||||
zap.String("subdomain", c.Subdomain),
|
||||
)
|
||||
// Still need to drain SendCh to prevent blocking
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.SendCh:
|
||||
// Discard messages for TCP mode
|
||||
case <-c.CloseCh:
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package tunnel
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"hash/fnv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/utils"
|
||||
@@ -12,10 +14,14 @@ import (
|
||||
|
||||
// Manager limits
|
||||
const (
|
||||
DefaultMaxTunnels = 1000 // Maximum total tunnels
|
||||
DefaultMaxTunnelsPerIP = 10 // Maximum tunnels per IP
|
||||
DefaultRateLimit = 10 // Registrations per IP per minute
|
||||
DefaultRateLimitWindow = 1 * time.Minute // Rate limit window
|
||||
DefaultMaxTunnels = 1000 // Maximum total tunnels
|
||||
DefaultMaxTunnelsPerIP = 10 // Maximum tunnels per IP
|
||||
DefaultRateLimit = 10 // Registrations per IP per minute
|
||||
DefaultRateLimitWindow = 1 * time.Minute // Rate limit window
|
||||
|
||||
// numShards is the number of shards for lock distribution
|
||||
// Using 32 shards reduces lock contention by ~32x under high concurrency
|
||||
numShards = 32
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -30,12 +36,17 @@ type rateLimitEntry struct {
|
||||
windowEnd time.Time
|
||||
}
|
||||
|
||||
// Manager manages all active tunnel connections
|
||||
type Manager struct {
|
||||
tunnels map[string]*Connection // subdomain -> connection
|
||||
// shard holds a subset of tunnels with its own lock
|
||||
type shard struct {
|
||||
tunnels map[string]*Connection
|
||||
used map[string]bool
|
||||
mu sync.RWMutex
|
||||
used map[string]bool // track used subdomains
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// Manager manages all active tunnel connections with sharded locking
|
||||
type Manager struct {
|
||||
shards [numShards]shard
|
||||
logger *zap.Logger
|
||||
|
||||
// Limits
|
||||
maxTunnels int
|
||||
@@ -43,7 +54,11 @@ type Manager struct {
|
||||
rateLimit int
|
||||
rateLimitWindow time.Duration
|
||||
|
||||
// Per-IP tracking
|
||||
// Global counters (atomic for lock-free reads)
|
||||
tunnelCount atomic.Int64
|
||||
|
||||
// Per-IP tracking (requires separate lock as it spans shards)
|
||||
ipMu sync.RWMutex
|
||||
tunnelsByIP map[string]int // IP -> tunnel count
|
||||
rateLimits map[string]*rateLimitEntry // IP -> rate limit entry
|
||||
|
||||
@@ -94,11 +109,10 @@ func NewManagerWithConfig(logger *zap.Logger, cfg ManagerConfig) *Manager {
|
||||
zap.Int("max_per_ip", cfg.MaxTunnelsPerIP),
|
||||
zap.Int("rate_limit", cfg.RateLimit),
|
||||
zap.Duration("rate_window", cfg.RateLimitWindow),
|
||||
zap.Int("num_shards", numShards),
|
||||
)
|
||||
|
||||
return &Manager{
|
||||
tunnels: make(map[string]*Connection),
|
||||
used: make(map[string]bool),
|
||||
m := &Manager{
|
||||
logger: logger,
|
||||
maxTunnels: cfg.MaxTunnels,
|
||||
maxTunnelsPerIP: cfg.MaxTunnelsPerIP,
|
||||
@@ -108,10 +122,25 @@ func NewManagerWithConfig(logger *zap.Logger, cfg ManagerConfig) *Manager {
|
||||
rateLimits: make(map[string]*rateLimitEntry),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Initialize all shards
|
||||
for i := 0; i < numShards; i++ {
|
||||
m.shards[i].tunnels = make(map[string]*Connection)
|
||||
m.shards[i].used = make(map[string]bool)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// checkRateLimit checks if the IP has exceeded rate limit
|
||||
func (m *Manager) checkRateLimit(ip string) bool {
|
||||
// getShard returns the shard for a given subdomain using FNV-1a hash
|
||||
func (m *Manager) getShard(subdomain string) *shard {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(subdomain))
|
||||
return &m.shards[h.Sum32()%numShards]
|
||||
}
|
||||
|
||||
// checkRateLimit checks if the IP has exceeded rate limit (caller must hold ipMu)
|
||||
func (m *Manager) checkRateLimitLocked(ip string) bool {
|
||||
now := time.Now()
|
||||
entry, exists := m.rateLimits[ip]
|
||||
|
||||
@@ -139,22 +168,20 @@ func (m *Manager) Register(conn *websocket.Conn, customSubdomain string) (string
|
||||
|
||||
// RegisterWithIP registers a new tunnel with IP tracking
|
||||
func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, remoteIP string) (string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Check total tunnel limit
|
||||
if len(m.tunnels) >= m.maxTunnels {
|
||||
// Check global limits first (lock-free read)
|
||||
if m.tunnelCount.Load() >= int64(m.maxTunnels) {
|
||||
m.logger.Warn("Maximum tunnel limit reached",
|
||||
zap.Int("current", len(m.tunnels)),
|
||||
zap.Int64("current", m.tunnelCount.Load()),
|
||||
zap.Int("max", m.maxTunnels),
|
||||
)
|
||||
return "", ErrTooManyTunnels
|
||||
}
|
||||
|
||||
// Check per-IP limits if IP is provided
|
||||
// Check per-IP limits
|
||||
if remoteIP != "" {
|
||||
// Check rate limit
|
||||
if !m.checkRateLimit(remoteIP) {
|
||||
m.ipMu.Lock()
|
||||
if !m.checkRateLimitLocked(remoteIP) {
|
||||
m.ipMu.Unlock()
|
||||
m.logger.Warn("Rate limit exceeded",
|
||||
zap.String("ip", remoteIP),
|
||||
zap.Int("limit", m.rateLimit),
|
||||
@@ -162,8 +189,8 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
|
||||
return "", ErrRateLimitExceeded
|
||||
}
|
||||
|
||||
// Check per-IP tunnel limit
|
||||
if m.tunnelsByIP[remoteIP] >= m.maxTunnelsPerIP {
|
||||
m.ipMu.Unlock()
|
||||
m.logger.Warn("Per-IP tunnel limit reached",
|
||||
zap.String("ip", remoteIP),
|
||||
zap.Int("current", m.tunnelsByIP[remoteIP]),
|
||||
@@ -171,6 +198,7 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
|
||||
)
|
||||
return "", ErrTooManyPerIP
|
||||
}
|
||||
m.ipMu.Unlock()
|
||||
}
|
||||
|
||||
var subdomain string
|
||||
@@ -183,33 +211,56 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
|
||||
if utils.IsReserved(customSubdomain) {
|
||||
return "", ErrReservedSubdomain
|
||||
}
|
||||
if m.used[customSubdomain] {
|
||||
|
||||
// Check if subdomain is taken in its shard
|
||||
s := m.getShard(customSubdomain)
|
||||
s.mu.Lock()
|
||||
if s.used[customSubdomain] {
|
||||
s.mu.Unlock()
|
||||
return "", ErrSubdomainTaken
|
||||
}
|
||||
subdomain = customSubdomain
|
||||
|
||||
// Register in shard
|
||||
tc := NewConnection(subdomain, conn, m.logger)
|
||||
tc.remoteIP = remoteIP
|
||||
s.tunnels[subdomain] = tc
|
||||
s.used[subdomain] = true
|
||||
s.mu.Unlock()
|
||||
} else {
|
||||
// Generate unique random subdomain
|
||||
subdomain = m.generateUniqueSubdomain()
|
||||
|
||||
s := m.getShard(subdomain)
|
||||
s.mu.Lock()
|
||||
tc := NewConnection(subdomain, conn, m.logger)
|
||||
tc.remoteIP = remoteIP
|
||||
s.tunnels[subdomain] = tc
|
||||
s.used[subdomain] = true
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Create connection
|
||||
tc := NewConnection(subdomain, conn, m.logger)
|
||||
tc.remoteIP = remoteIP // Track IP for cleanup
|
||||
m.tunnels[subdomain] = tc
|
||||
m.used[subdomain] = true
|
||||
|
||||
// Update per-IP counter
|
||||
// Update counters
|
||||
m.tunnelCount.Add(1)
|
||||
if remoteIP != "" {
|
||||
m.ipMu.Lock()
|
||||
m.tunnelsByIP[remoteIP]++
|
||||
m.ipMu.Unlock()
|
||||
}
|
||||
|
||||
// Start write pump in background
|
||||
go tc.StartWritePump()
|
||||
// Get connection and start write pump
|
||||
s := m.getShard(subdomain)
|
||||
s.mu.RLock()
|
||||
tc := s.tunnels[subdomain]
|
||||
s.mu.RUnlock()
|
||||
if tc != nil {
|
||||
go tc.StartWritePump()
|
||||
}
|
||||
|
||||
m.logger.Info("Tunnel registered",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.String("ip", remoteIP),
|
||||
zap.Int("total_tunnels", len(m.tunnels)),
|
||||
zap.Int64("total_tunnels", m.tunnelCount.Load()),
|
||||
)
|
||||
|
||||
return subdomain, nil
|
||||
@@ -217,101 +268,129 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
|
||||
|
||||
// Unregister removes a tunnel connection
|
||||
func (m *Manager) Unregister(subdomain string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
s := m.getShard(subdomain)
|
||||
s.mu.Lock()
|
||||
|
||||
if tc, ok := m.tunnels[subdomain]; ok {
|
||||
// Decrement per-IP counter
|
||||
if tc.remoteIP != "" && m.tunnelsByIP[tc.remoteIP] > 0 {
|
||||
m.tunnelsByIP[tc.remoteIP]--
|
||||
if m.tunnelsByIP[tc.remoteIP] == 0 {
|
||||
delete(m.tunnelsByIP, tc.remoteIP)
|
||||
tc, ok := s.tunnels[subdomain]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
remoteIP := tc.remoteIP
|
||||
tc.Close()
|
||||
delete(s.tunnels, subdomain)
|
||||
delete(s.used, subdomain)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Update counters
|
||||
m.tunnelCount.Add(-1)
|
||||
if remoteIP != "" {
|
||||
m.ipMu.Lock()
|
||||
if m.tunnelsByIP[remoteIP] > 0 {
|
||||
m.tunnelsByIP[remoteIP]--
|
||||
if m.tunnelsByIP[remoteIP] == 0 {
|
||||
delete(m.tunnelsByIP, remoteIP)
|
||||
}
|
||||
}
|
||||
|
||||
tc.Close()
|
||||
delete(m.tunnels, subdomain)
|
||||
delete(m.used, subdomain)
|
||||
|
||||
m.logger.Info("Tunnel unregistered",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.Int("total_tunnels", len(m.tunnels)),
|
||||
)
|
||||
m.ipMu.Unlock()
|
||||
}
|
||||
|
||||
m.logger.Info("Tunnel unregistered",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.Int64("total_tunnels", m.tunnelCount.Load()),
|
||||
)
|
||||
}
|
||||
|
||||
// Get retrieves a tunnel connection by subdomain
|
||||
func (m *Manager) Get(subdomain string) (*Connection, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
tc, ok := m.tunnels[subdomain]
|
||||
s := m.getShard(subdomain)
|
||||
s.mu.RLock()
|
||||
tc, ok := s.tunnels[subdomain]
|
||||
s.mu.RUnlock()
|
||||
return tc, ok
|
||||
}
|
||||
|
||||
// List returns all active tunnel connections
|
||||
func (m *Manager) List() []*Connection {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
// Pre-allocate with approximate capacity
|
||||
connections := make([]*Connection, 0, m.tunnelCount.Load())
|
||||
|
||||
connections := make([]*Connection, 0, len(m.tunnels))
|
||||
for _, tc := range m.tunnels {
|
||||
connections = append(connections, tc)
|
||||
for i := 0; i < numShards; i++ {
|
||||
s := &m.shards[i]
|
||||
s.mu.RLock()
|
||||
for _, tc := range s.tunnels {
|
||||
connections = append(connections, tc)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
}
|
||||
|
||||
return connections
|
||||
}
|
||||
|
||||
// Count returns the number of active tunnels
|
||||
func (m *Manager) Count() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.tunnels)
|
||||
return int(m.tunnelCount.Load())
|
||||
}
|
||||
|
||||
// CleanupStale removes stale connections that haven't been active
|
||||
func (m *Manager) CleanupStale(timeout time.Duration) int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
totalCleaned := 0
|
||||
|
||||
staleSubdomains := []string{}
|
||||
// Clean up each shard independently
|
||||
for i := 0; i < numShards; i++ {
|
||||
s := &m.shards[i]
|
||||
s.mu.Lock()
|
||||
|
||||
for subdomain, tc := range m.tunnels {
|
||||
if !tc.IsAlive(timeout) {
|
||||
staleSubdomains = append(staleSubdomains, subdomain)
|
||||
var staleSubdomains []string
|
||||
for subdomain, tc := range s.tunnels {
|
||||
if !tc.IsAlive(timeout) {
|
||||
staleSubdomains = append(staleSubdomains, subdomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, subdomain := range staleSubdomains {
|
||||
if tc, ok := m.tunnels[subdomain]; ok {
|
||||
// Decrement per-IP counter
|
||||
if tc.remoteIP != "" && m.tunnelsByIP[tc.remoteIP] > 0 {
|
||||
m.tunnelsByIP[tc.remoteIP]--
|
||||
if m.tunnelsByIP[tc.remoteIP] == 0 {
|
||||
delete(m.tunnelsByIP, tc.remoteIP)
|
||||
for _, subdomain := range staleSubdomains {
|
||||
if tc, ok := s.tunnels[subdomain]; ok {
|
||||
remoteIP := tc.remoteIP
|
||||
tc.Close()
|
||||
delete(s.tunnels, subdomain)
|
||||
delete(s.used, subdomain)
|
||||
|
||||
// Update counters
|
||||
m.tunnelCount.Add(-1)
|
||||
if remoteIP != "" {
|
||||
m.ipMu.Lock()
|
||||
if m.tunnelsByIP[remoteIP] > 0 {
|
||||
m.tunnelsByIP[remoteIP]--
|
||||
if m.tunnelsByIP[remoteIP] == 0 {
|
||||
delete(m.tunnelsByIP, remoteIP)
|
||||
}
|
||||
}
|
||||
m.ipMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
tc.Close()
|
||||
delete(m.tunnels, subdomain)
|
||||
delete(m.used, subdomain)
|
||||
}
|
||||
totalCleaned += len(staleSubdomains)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Cleanup expired rate limit entries
|
||||
m.ipMu.Lock()
|
||||
now := time.Now()
|
||||
for ip, entry := range m.rateLimits {
|
||||
if now.After(entry.windowEnd) {
|
||||
delete(m.rateLimits, ip)
|
||||
}
|
||||
}
|
||||
m.ipMu.Unlock()
|
||||
|
||||
if len(staleSubdomains) > 0 {
|
||||
if totalCleaned > 0 {
|
||||
m.logger.Info("Cleaned up stale tunnels",
|
||||
zap.Int("count", len(staleSubdomains)),
|
||||
zap.Int("count", totalCleaned),
|
||||
)
|
||||
}
|
||||
|
||||
return len(staleSubdomains)
|
||||
return totalCleaned
|
||||
}
|
||||
|
||||
// StartCleanupTask starts a background task to clean up stale connections
|
||||
@@ -336,7 +415,16 @@ func (m *Manager) generateUniqueSubdomain() string {
|
||||
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
subdomain := utils.GenerateSubdomain(6)
|
||||
if !m.used[subdomain] && !utils.IsReserved(subdomain) {
|
||||
if utils.IsReserved(subdomain) {
|
||||
continue
|
||||
}
|
||||
|
||||
s := m.getShard(subdomain)
|
||||
s.mu.RLock()
|
||||
taken := s.used[subdomain]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !taken {
|
||||
return subdomain
|
||||
}
|
||||
}
|
||||
@@ -350,17 +438,21 @@ func (m *Manager) Shutdown() {
|
||||
// Signal cleanup goroutine to stop
|
||||
close(m.stopCh)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.logger.Info("Shutting down tunnel manager",
|
||||
zap.Int("active_tunnels", len(m.tunnels)),
|
||||
zap.Int64("active_tunnels", m.tunnelCount.Load()),
|
||||
)
|
||||
|
||||
for _, tc := range m.tunnels {
|
||||
tc.Close()
|
||||
// Close all tunnels in each shard
|
||||
for i := 0; i < numShards; i++ {
|
||||
s := &m.shards[i]
|
||||
s.mu.Lock()
|
||||
for _, tc := range s.tunnels {
|
||||
tc.Close()
|
||||
}
|
||||
s.tunnels = make(map[string]*Connection)
|
||||
s.used = make(map[string]bool)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
m.tunnels = make(map[string]*Connection)
|
||||
m.used = make(map[string]bool)
|
||||
m.tunnelCount.Store(0)
|
||||
}
|
||||
|
||||
@@ -9,9 +9,33 @@ const (
|
||||
// DefaultWSPort is the default WebSocket port
|
||||
DefaultWSPort = 8080
|
||||
|
||||
// ==================== Yamux Configuration ====================
|
||||
// These settings are tuned for high-throughput tunnel scenarios
|
||||
|
||||
// YamuxAcceptBacklog controls how many incoming streams can be queued
|
||||
// before yamux starts blocking stream opens under load.
|
||||
YamuxAcceptBacklog = 4096
|
||||
// Increased from default 256 to handle burst traffic.
|
||||
YamuxAcceptBacklog = 8192
|
||||
|
||||
// YamuxMaxStreamWindowSize is the maximum window size for a stream.
|
||||
// Larger windows allow more data in-flight, improving throughput on high-latency links.
|
||||
// Default is 256KB, increased to 512KB for better throughput.
|
||||
YamuxMaxStreamWindowSize = 512 * 1024
|
||||
|
||||
// YamuxStreamOpenTimeout is how long to wait for a stream open to complete.
|
||||
YamuxStreamOpenTimeout = 10 * time.Second
|
||||
|
||||
// YamuxStreamCloseTimeout is how long to wait for a stream close to complete.
|
||||
YamuxStreamCloseTimeout = 5 * time.Minute
|
||||
|
||||
// YamuxKeepAliveInterval is how often to send keep-alive pings.
|
||||
// Set higher than HeartbeatInterval to avoid redundant pings.
|
||||
YamuxKeepAliveInterval = 15 * time.Second
|
||||
|
||||
// YamuxConnectionWriteTimeout is the timeout for writing to the underlying connection.
|
||||
YamuxConnectionWriteTimeout = 10 * time.Second
|
||||
|
||||
// ==================== Heartbeat Configuration ====================
|
||||
|
||||
// HeartbeatInterval is how often clients send heartbeat messages
|
||||
HeartbeatInterval = 2 * time.Second
|
||||
@@ -19,9 +43,13 @@ const (
|
||||
// HeartbeatTimeout is how long the server waits before considering a connection dead
|
||||
HeartbeatTimeout = 6 * time.Second
|
||||
|
||||
// ==================== Request/Response Timeouts ====================
|
||||
|
||||
// RequestTimeout is the maximum time to wait for a response from the client
|
||||
RequestTimeout = 30 * time.Second
|
||||
|
||||
// ==================== Reconnection Configuration ====================
|
||||
|
||||
// ReconnectBaseDelay is the initial delay for reconnection attempts
|
||||
ReconnectBaseDelay = 1 * time.Second
|
||||
|
||||
@@ -31,9 +59,12 @@ const (
|
||||
// MaxReconnectAttempts is the maximum number of reconnection attempts (0 = infinite)
|
||||
MaxReconnectAttempts = 0
|
||||
|
||||
// ==================== TCP Port Allocation ====================
|
||||
|
||||
// DefaultTCPPortMin/Max define the default allocation range for TCP tunnels
|
||||
DefaultTCPPortMin = 20000
|
||||
DefaultTCPPortMax = 40000
|
||||
|
||||
// DefaultDomain is the default domain for tunnels
|
||||
DefaultDomain = "tunnel.localhost"
|
||||
)
|
||||
|
||||
35
internal/shared/mux/config.go
Normal file
35
internal/shared/mux/config.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
)
|
||||
|
||||
// NewOptimizedConfig returns a multiplexer config optimized for tunnel scenarios.
|
||||
func NewOptimizedConfig() *yamux.Config {
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
cfg.MaxStreamWindowSize = constants.YamuxMaxStreamWindowSize
|
||||
cfg.StreamOpenTimeout = constants.YamuxStreamOpenTimeout
|
||||
cfg.StreamCloseTimeout = constants.YamuxStreamCloseTimeout
|
||||
cfg.ConnectionWriteTimeout = constants.YamuxConnectionWriteTimeout
|
||||
cfg.EnableKeepAlive = true
|
||||
cfg.KeepAliveInterval = constants.YamuxKeepAliveInterval
|
||||
cfg.LogOutput = io.Discard
|
||||
return cfg
|
||||
}
|
||||
|
||||
// NewServerConfig returns a multiplexer config for server-side use.
|
||||
func NewServerConfig() *yamux.Config {
|
||||
return NewOptimizedConfig()
|
||||
}
|
||||
|
||||
// NewClientConfig returns a multiplexer config for client-side use.
|
||||
func NewClientConfig() *yamux.Config {
|
||||
cfg := NewOptimizedConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
return cfg
|
||||
}
|
||||
@@ -2,14 +2,12 @@ package netutil
|
||||
|
||||
import "net"
|
||||
|
||||
// CountingConn wraps a net.Conn to count bytes read/written.
|
||||
type CountingConn struct {
|
||||
net.Conn
|
||||
OnRead func(int64)
|
||||
OnWrite func(int64)
|
||||
}
|
||||
|
||||
// NewCountingConn creates a new CountingConn.
|
||||
func NewCountingConn(conn net.Conn, onRead, onWrite func(int64)) *CountingConn {
|
||||
return &CountingConn{
|
||||
Conn: conn,
|
||||
|
||||
@@ -3,15 +3,17 @@ package pool
|
||||
import "sync"
|
||||
|
||||
const (
|
||||
SizeSmall = 4 * 1024 // 4KB
|
||||
SizeMedium = 32 * 1024 // 32KB
|
||||
SizeLarge = 256 * 1024 // 256KB
|
||||
SizeSmall = 4 * 1024 // 4KB - HTTP headers, small messages
|
||||
SizeMedium = 32 * 1024 // 32KB - HTTP request/response bodies
|
||||
SizeLarge = 256 * 1024 // 256KB - Data pipe, file transfers
|
||||
SizeXLarge = 1024 * 1024 // 1MB - Large file transfers, bulk data
|
||||
)
|
||||
|
||||
type BufferPool struct {
|
||||
small sync.Pool
|
||||
medium sync.Pool
|
||||
large sync.Pool
|
||||
xlarge sync.Pool
|
||||
}
|
||||
|
||||
func NewBufferPool() *BufferPool {
|
||||
@@ -34,6 +36,12 @@ func NewBufferPool() *BufferPool {
|
||||
return &b
|
||||
},
|
||||
},
|
||||
xlarge: sync.Pool{
|
||||
New: func() interface{} {
|
||||
b := make([]byte, SizeXLarge)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,8 +51,10 @@ func (p *BufferPool) Get(size int) *[]byte {
|
||||
return p.small.Get().(*[]byte)
|
||||
case size <= SizeMedium:
|
||||
return p.medium.Get().(*[]byte)
|
||||
default:
|
||||
case size <= SizeLarge:
|
||||
return p.large.Get().(*[]byte)
|
||||
default:
|
||||
return p.xlarge.Get().(*[]byte)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,7 +73,24 @@ func (p *BufferPool) Put(buf *[]byte) {
|
||||
p.medium.Put(buf)
|
||||
case SizeLarge:
|
||||
p.large.Put(buf)
|
||||
case SizeXLarge:
|
||||
p.xlarge.Put(buf)
|
||||
}
|
||||
// Note: buffers with non-standard sizes are not pooled (let GC handle them)
|
||||
}
|
||||
|
||||
// GetXLarge returns a 1MB buffer for bulk data transfers
|
||||
func (p *BufferPool) GetXLarge() *[]byte {
|
||||
return p.xlarge.Get().(*[]byte)
|
||||
}
|
||||
|
||||
// PutXLarge returns a 1MB buffer to the pool
|
||||
func (p *BufferPool) PutXLarge(buf *[]byte) {
|
||||
if buf == nil || cap(*buf) != SizeXLarge {
|
||||
return
|
||||
}
|
||||
*buf = (*buf)[:cap(*buf)]
|
||||
p.xlarge.Put(buf)
|
||||
}
|
||||
|
||||
var globalBufferPool = NewBufferPool()
|
||||
@@ -75,3 +102,13 @@ func GetBuffer(size int) *[]byte {
|
||||
func PutBuffer(buf *[]byte) {
|
||||
globalBufferPool.Put(buf)
|
||||
}
|
||||
|
||||
// GetXLargeBuffer returns a 1MB buffer from the global pool
|
||||
func GetXLargeBuffer() *[]byte {
|
||||
return globalBufferPool.GetXLarge()
|
||||
}
|
||||
|
||||
// PutXLargeBuffer returns a 1MB buffer to the global pool
|
||||
func PutXLargeBuffer(buf *[]byte) {
|
||||
globalBufferPool.PutXLarge(buf)
|
||||
}
|
||||
|
||||
@@ -6,31 +6,24 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// TrafficStats tracks traffic statistics for a tunnel connection
|
||||
type TrafficStats struct {
|
||||
// Total bytes
|
||||
totalBytesIn int64
|
||||
totalBytesOut int64
|
||||
|
||||
// Request counts
|
||||
totalRequests int64
|
||||
activeConnections int64
|
||||
|
||||
// For speed calculation
|
||||
lastBytesIn int64
|
||||
lastBytesOut int64
|
||||
lastTime time.Time
|
||||
speedMu sync.Mutex
|
||||
|
||||
// Current speed (bytes per second)
|
||||
speedIn int64
|
||||
speedOut int64
|
||||
|
||||
// Start time
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// NewTrafficStats creates a new traffic stats tracker
|
||||
func NewTrafficStats() *TrafficStats {
|
||||
now := time.Now()
|
||||
return &TrafficStats{
|
||||
@@ -39,17 +32,14 @@ func NewTrafficStats() *TrafficStats {
|
||||
}
|
||||
}
|
||||
|
||||
// AddBytesIn adds incoming bytes to the counter
|
||||
func (s *TrafficStats) AddBytesIn(n int64) {
|
||||
atomic.AddInt64(&s.totalBytesIn, n)
|
||||
}
|
||||
|
||||
// AddBytesOut adds outgoing bytes to the counter
|
||||
func (s *TrafficStats) AddBytesOut(n int64) {
|
||||
atomic.AddInt64(&s.totalBytesOut, n)
|
||||
}
|
||||
|
||||
// AddRequest increments the request counter
|
||||
func (s *TrafficStats) AddRequest() {
|
||||
atomic.AddInt64(&s.totalRequests, 1)
|
||||
}
|
||||
@@ -59,23 +49,25 @@ func (s *TrafficStats) IncActiveConnections() {
|
||||
}
|
||||
|
||||
func (s *TrafficStats) DecActiveConnections() {
|
||||
v := atomic.AddInt64(&s.activeConnections, -1)
|
||||
if v < 0 {
|
||||
atomic.StoreInt64(&s.activeConnections, 0)
|
||||
for {
|
||||
old := atomic.LoadInt64(&s.activeConnections)
|
||||
if old <= 0 {
|
||||
return
|
||||
}
|
||||
if atomic.CompareAndSwapInt64(&s.activeConnections, old, old-1) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetTotalBytesIn returns total incoming bytes
|
||||
func (s *TrafficStats) GetTotalBytesIn() int64 {
|
||||
return atomic.LoadInt64(&s.totalBytesIn)
|
||||
}
|
||||
|
||||
// GetTotalBytesOut returns total outgoing bytes
|
||||
func (s *TrafficStats) GetTotalBytesOut() int64 {
|
||||
return atomic.LoadInt64(&s.totalBytesOut)
|
||||
}
|
||||
|
||||
// GetTotalRequests returns total request count
|
||||
func (s *TrafficStats) GetTotalRequests() int64 {
|
||||
return atomic.LoadInt64(&s.totalRequests)
|
||||
}
|
||||
@@ -84,21 +76,16 @@ func (s *TrafficStats) GetActiveConnections() int64 {
|
||||
return atomic.LoadInt64(&s.activeConnections)
|
||||
}
|
||||
|
||||
// GetTotalBytes returns total bytes (in + out)
|
||||
func (s *TrafficStats) GetTotalBytes() int64 {
|
||||
return s.GetTotalBytesIn() + s.GetTotalBytesOut()
|
||||
}
|
||||
|
||||
// UpdateSpeed calculates current transfer speed
|
||||
// Should be called periodically (e.g., every second)
|
||||
func (s *TrafficStats) UpdateSpeed() {
|
||||
s.speedMu.Lock()
|
||||
defer s.speedMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(s.lastTime).Seconds()
|
||||
|
||||
// Require minimum interval of 100ms to avoid division issues
|
||||
if elapsed < 0.1 {
|
||||
return
|
||||
}
|
||||
@@ -109,18 +96,15 @@ func (s *TrafficStats) UpdateSpeed() {
|
||||
deltaIn := currentIn - s.lastBytesIn
|
||||
deltaOut := currentOut - s.lastBytesOut
|
||||
|
||||
// Calculate instantaneous speed
|
||||
if deltaIn > 0 {
|
||||
s.speedIn = int64(float64(deltaIn) / elapsed)
|
||||
} else {
|
||||
// No new bytes - set speed to 0 immediately
|
||||
s.speedIn = 0
|
||||
}
|
||||
|
||||
if deltaOut > 0 {
|
||||
s.speedOut = int64(float64(deltaOut) / elapsed)
|
||||
} else {
|
||||
// No new bytes - set speed to 0 immediately
|
||||
s.speedOut = 0
|
||||
}
|
||||
|
||||
@@ -129,38 +113,33 @@ func (s *TrafficStats) UpdateSpeed() {
|
||||
s.lastTime = now
|
||||
}
|
||||
|
||||
// GetSpeedIn returns current incoming speed in bytes per second
|
||||
func (s *TrafficStats) GetSpeedIn() int64 {
|
||||
s.speedMu.Lock()
|
||||
defer s.speedMu.Unlock()
|
||||
return s.speedIn
|
||||
}
|
||||
|
||||
// GetSpeedOut returns current outgoing speed in bytes per second
|
||||
func (s *TrafficStats) GetSpeedOut() int64 {
|
||||
s.speedMu.Lock()
|
||||
defer s.speedMu.Unlock()
|
||||
return s.speedOut
|
||||
}
|
||||
|
||||
// GetUptime returns how long the connection has been active
|
||||
func (s *TrafficStats) GetUptime() time.Duration {
|
||||
return time.Since(s.startTime)
|
||||
}
|
||||
|
||||
// Snapshot returns a snapshot of all stats
|
||||
type Snapshot struct {
|
||||
TotalBytesIn int64
|
||||
TotalBytesOut int64
|
||||
TotalBytes int64
|
||||
TotalRequests int64
|
||||
ActiveConnections int64
|
||||
SpeedIn int64 // bytes per second
|
||||
SpeedOut int64 // bytes per second
|
||||
SpeedIn int64
|
||||
SpeedOut int64
|
||||
Uptime time.Duration
|
||||
}
|
||||
|
||||
// GetSnapshot returns a snapshot of current stats
|
||||
func (s *TrafficStats) GetSnapshot() Snapshot {
|
||||
s.speedMu.Lock()
|
||||
speedIn := s.speedIn
|
||||
|
||||
61
internal/shared/tuning/gc.go
Normal file
61
internal/shared/tuning/gc.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package tuning
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
type Mode int
|
||||
|
||||
const (
|
||||
ModeClient Mode = iota
|
||||
ModeServer
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
GCPercent int
|
||||
MemoryLimit int64
|
||||
}
|
||||
|
||||
func DefaultClientConfig() Config {
|
||||
total := int64(getSystemTotalMemory())
|
||||
limit := total / 4
|
||||
if limit < 64*1024*1024 {
|
||||
limit = 64 * 1024 * 1024
|
||||
}
|
||||
return Config{
|
||||
GCPercent: 100,
|
||||
MemoryLimit: limit,
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultServerConfig() Config {
|
||||
total := int64(getSystemTotalMemory())
|
||||
limit := total * 3 / 4
|
||||
if limit < 128*1024*1024 {
|
||||
limit = 128 * 1024 * 1024
|
||||
}
|
||||
return Config{
|
||||
GCPercent: 200,
|
||||
MemoryLimit: limit,
|
||||
}
|
||||
}
|
||||
|
||||
func Apply(cfg Config) {
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
if cfg.GCPercent > 0 {
|
||||
debug.SetGCPercent(cfg.GCPercent)
|
||||
}
|
||||
if cfg.MemoryLimit > 0 {
|
||||
debug.SetMemoryLimit(cfg.MemoryLimit)
|
||||
}
|
||||
}
|
||||
|
||||
func ApplyMode(mode Mode) {
|
||||
switch mode {
|
||||
case ModeClient:
|
||||
Apply(DefaultClientConfig())
|
||||
case ModeServer:
|
||||
Apply(DefaultServerConfig())
|
||||
}
|
||||
}
|
||||
28
internal/shared/tuning/mem_darwin.go
Normal file
28
internal/shared/tuning/mem_darwin.go
Normal file
@@ -0,0 +1,28 @@
|
||||
//go:build darwin
|
||||
|
||||
package tuning
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func getSystemTotalMemory() uint64 {
|
||||
mib := [2]int32{6, 24} // CTL_HW, HW_MEMSIZE
|
||||
var value uint64
|
||||
size := unsafe.Sizeof(value)
|
||||
|
||||
_, _, errno := syscall.Syscall6(
|
||||
syscall.SYS___SYSCTL,
|
||||
uintptr(unsafe.Pointer(&mib[0])),
|
||||
2,
|
||||
uintptr(unsafe.Pointer(&value)),
|
||||
uintptr(unsafe.Pointer(&size)),
|
||||
0,
|
||||
0,
|
||||
)
|
||||
if errno != 0 {
|
||||
return 1024 * 1024 * 1024
|
||||
}
|
||||
return value
|
||||
}
|
||||
13
internal/shared/tuning/mem_linux.go
Normal file
13
internal/shared/tuning/mem_linux.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build linux
|
||||
|
||||
package tuning
|
||||
|
||||
import "syscall"
|
||||
|
||||
func getSystemTotalMemory() uint64 {
|
||||
var info syscall.Sysinfo_t
|
||||
if err := syscall.Sysinfo(&info); err == nil {
|
||||
return info.Totalram * uint64(info.Unit)
|
||||
}
|
||||
return 1024 * 1024 * 1024
|
||||
}
|
||||
7
internal/shared/tuning/mem_other.go
Normal file
7
internal/shared/tuning/mem_other.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !linux && !darwin && !windows
|
||||
|
||||
package tuning
|
||||
|
||||
func getSystemTotalMemory() uint64 {
|
||||
return 1024 * 1024 * 1024 // 1GB fallback
|
||||
}
|
||||
36
internal/shared/tuning/mem_windows.go
Normal file
36
internal/shared/tuning/mem_windows.go
Normal file
@@ -0,0 +1,36 @@
|
||||
//go:build windows
|
||||
|
||||
package tuning
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
globalMemoryStatusEx = kernel32.NewProc("GlobalMemoryStatusEx")
|
||||
)
|
||||
|
||||
type memoryStatusEx struct {
|
||||
dwLength uint32
|
||||
dwMemoryLoad uint32
|
||||
ullTotalPhys uint64
|
||||
ullAvailPhys uint64
|
||||
ullTotalPageFile uint64
|
||||
ullAvailPageFile uint64
|
||||
ullTotalVirtual uint64
|
||||
ullAvailVirtual uint64
|
||||
ullAvailExtendedVirtual uint64
|
||||
}
|
||||
|
||||
func getSystemTotalMemory() uint64 {
|
||||
var mem memoryStatusEx
|
||||
mem.dwLength = uint32(unsafe.Sizeof(mem))
|
||||
|
||||
ret, _, _ := globalMemoryStatusEx.Call(uintptr(unsafe.Pointer(&mem)))
|
||||
if ret == 0 {
|
||||
return 1024 * 1024 * 1024
|
||||
}
|
||||
return mem.ullTotalPhys
|
||||
}
|
||||
Reference in New Issue
Block a user