feat(cli): Add bandwidth limit function support

Added bandwidth limiting functionality, allowing users to limit the bandwidth of tunnel connections via the --bandwidth parameter.
Supported formats include: 1K/1KB (kilobytes), 1M/1MB (megabytes), 1G/1GB (gigabytes) or
Raw number (bytes).
This commit is contained in:
Gouryella
2026-02-14 14:20:21 +08:00
parent 3872bd9326
commit f90df37d7c
28 changed files with 2115 additions and 291 deletions

View File

@@ -24,6 +24,7 @@ import (
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/qos"
"drip/internal/shared/wsutil"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -171,24 +172,22 @@ func (h *Handler) SetAllowedTunnelTypes(types []string) {
// IsTransportAllowed checks if a transport is allowed
func (h *Handler) IsTransportAllowed(transport string) bool {
if len(h.allowedTransports) == 0 {
return true
}
for _, t := range h.allowedTransports {
if strings.EqualFold(t, transport) {
return true
}
}
return false
return containsFold(h.allowedTransports, transport)
}
// IsTunnelTypeAllowed checks if a tunnel type is allowed
func (h *Handler) IsTunnelTypeAllowed(tunnelType string) bool {
if len(h.allowedTunnelTypes) == 0 {
return containsFold(h.allowedTunnelTypes, tunnelType)
}
// containsFold returns true if the slice is empty (allow all) or contains the
// value in a case-insensitive comparison.
func containsFold(allowed []string, value string) bool {
if len(allowed) == 0 {
return true
}
for _, t := range h.allowedTunnelTypes {
if strings.EqualFold(t, tunnelType) {
for _, a := range allowed {
if strings.EqualFold(a, value) {
return true
}
}
@@ -299,7 +298,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
tconn.IncActiveConnections()
defer tconn.DecActiveConnections()
countingStream := netutil.NewCountingConn(stream,
var limitedStream net.Conn = stream
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
limitedStream = qos.NewLimitedConn(r.Context(), stream, limiter)
}
countingStream := netutil.NewCountingConn(limitedStream,
tconn.AddBytesOut,
tconn.AddBytesIn,
)
@@ -428,6 +432,11 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
return
}
var limitedStream net.Conn = stream
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
limitedStream = qos.NewLimitedConn(context.Background(), stream, limiter)
}
go func() {
defer stream.Close()
defer clientConn.Close()
@@ -441,7 +450,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
}
}
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
_ = netutil.PipeWithCallbacks(context.Background(), limitedStream, clientRW,
func(n int64) { tconn.AddBytesOut(n) },
func(n int64) { tconn.AddBytesIn(n) },
)

View File

@@ -0,0 +1,168 @@
package tcp
import (
"testing"
)
func TestEffectiveBandwidthSelection(t *testing.T) {
tests := []struct {
name string
serverBW int64
clientBW int64
wantEffective int64
}{
{
name: "server only",
serverBW: 1024 * 1024,
clientBW: 0,
wantEffective: 1024 * 1024,
},
{
name: "client only",
serverBW: 0,
clientBW: 512 * 1024,
wantEffective: 512 * 1024,
},
{
name: "both unlimited",
serverBW: 0,
clientBW: 0,
wantEffective: 0,
},
{
name: "client lower than server",
serverBW: 10 * 1024 * 1024,
clientBW: 1 * 1024 * 1024,
wantEffective: 1 * 1024 * 1024,
},
{
name: "client higher than server - server wins",
serverBW: 1 * 1024 * 1024,
clientBW: 10 * 1024 * 1024,
wantEffective: 1 * 1024 * 1024,
},
{
name: "client equal to server",
serverBW: 5 * 1024 * 1024,
clientBW: 5 * 1024 * 1024,
wantEffective: 5 * 1024 * 1024,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
effectiveBandwidth := tt.serverBW
if tt.clientBW > 0 {
if effectiveBandwidth == 0 || tt.clientBW < effectiveBandwidth {
effectiveBandwidth = tt.clientBW
}
}
if effectiveBandwidth != tt.wantEffective {
t.Errorf("effectiveBandwidth = %d, want %d", effectiveBandwidth, tt.wantEffective)
}
})
}
}
func TestConnectionSetBandwidthConfig(t *testing.T) {
tests := []struct {
name string
bandwidth int64
burstMultiplier float64
wantBandwidth int64
wantMultiplier float64
}{
{
name: "1MB/s with 2x burst",
bandwidth: 1024 * 1024,
burstMultiplier: 2.0,
wantBandwidth: 1024 * 1024,
wantMultiplier: 2.0,
},
{
name: "1MB/s with 2.5x burst",
bandwidth: 1024 * 1024,
burstMultiplier: 2.5,
wantBandwidth: 1024 * 1024,
wantMultiplier: 2.5,
},
{
name: "default multiplier when 0",
bandwidth: 1024 * 1024,
burstMultiplier: 0,
wantBandwidth: 1024 * 1024,
wantMultiplier: 2.0,
},
{
name: "default multiplier when negative",
bandwidth: 1024 * 1024,
burstMultiplier: -1.0,
wantBandwidth: 1024 * 1024,
wantMultiplier: 2.0,
},
{
name: "unlimited bandwidth",
bandwidth: 0,
burstMultiplier: 2.5,
wantBandwidth: 0,
wantMultiplier: 2.5,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conn := &Connection{}
conn.SetBandwidthConfig(tt.bandwidth, tt.burstMultiplier)
if conn.bandwidth != tt.wantBandwidth {
t.Errorf("bandwidth = %v, want %v", conn.bandwidth, tt.wantBandwidth)
}
if conn.burstMultiplier != tt.wantMultiplier {
t.Errorf("burstMultiplier = %v, want %v", conn.burstMultiplier, tt.wantMultiplier)
}
})
}
}
func TestListenerBandwidthConfig(t *testing.T) {
tests := []struct {
name string
bandwidth int64
burstMultiplier float64
wantBandwidth int64
wantMultiplier float64
}{
{
name: "set bandwidth and multiplier",
bandwidth: 1024 * 1024,
burstMultiplier: 2.5,
wantBandwidth: 1024 * 1024,
wantMultiplier: 2.5,
},
{
name: "default multiplier",
bandwidth: 1024 * 1024,
burstMultiplier: 0,
wantBandwidth: 1024 * 1024,
wantMultiplier: 2.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := &Listener{}
l.SetBandwidth(tt.bandwidth)
l.SetBurstMultiplier(tt.burstMultiplier)
if l.bandwidth != tt.wantBandwidth {
t.Errorf("bandwidth = %v, want %v", l.bandwidth, tt.wantBandwidth)
}
if l.burstMultiplier != tt.wantMultiplier {
t.Errorf("burstMultiplier = %v, want %v", l.burstMultiplier, tt.wantMultiplier)
}
})
}
}

View File

@@ -59,12 +59,12 @@ type Connection struct {
httpListener *connQueueListener
handedOff bool
// Server capabilities
allowedTunnelTypes []string
allowedTransports []string
bandwidth int64
burstMultiplier float64
}
// NewConnection creates a new connection handler
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection {
ctx, cancel := context.WithCancel(context.Background())
c := &Connection{
@@ -99,22 +99,11 @@ func (c *Connection) Handle() error {
return fmt.Errorf("failed to peek connection: %w", err)
}
peekStr := string(peek)
httpMethods := []string{"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC"}
isHTTP := false
for _, method := range httpMethods {
if strings.HasPrefix(peekStr, method) {
isHTTP = true
break
}
}
if isHTTP {
if isHTTPMethod(string(peek)) {
c.logger.Info("Detected HTTP request on TCP port, handling as HTTP")
return c.handleHTTPRequest(reader)
}
// Check if TCP transport is allowed (only for Drip protocol connections, not HTTP)
if !c.isTransportAllowed("tcp") {
c.logger.Warn("TCP transport not allowed, rejecting Drip protocol connection")
return fmt.Errorf("TCP transport not allowed")
@@ -142,7 +131,6 @@ func (c *Connection) Handle() error {
c.tunnelType = req.TunnelType
// Check if tunnel type is allowed
if !c.isTunnelTypeAllowed(string(req.TunnelType)) {
c.sendError("tunnel_type_not_allowed", fmt.Sprintf("Tunnel type '%s' is not allowed on this server", req.TunnelType))
return fmt.Errorf("tunnel type not allowed: %s", req.TunnelType)
@@ -197,7 +185,6 @@ func (c *Connection) Handle() error {
c.tunnelConn.Conn = nil
c.tunnelConn.SetTunnelType(req.TunnelType)
c.tunnelType = req.TunnelType
if req.IPAccess != nil && (len(req.IPAccess.AllowIPs) > 0 || len(req.IPAccess.DenyIPs) > 0) {
c.tunnelConn.SetIPAccessControl(req.IPAccess.AllowIPs, req.IPAccess.DenyIPs)
@@ -215,6 +202,31 @@ func (c *Connection) Handle() error {
)
}
effectiveBandwidth := c.bandwidth
if req.Bandwidth > 0 {
if effectiveBandwidth == 0 || req.Bandwidth < effectiveBandwidth {
effectiveBandwidth = req.Bandwidth
}
}
if effectiveBandwidth > 0 {
burstMultiplier := c.burstMultiplier
if burstMultiplier <= 0 {
burstMultiplier = 2.0
}
c.tunnelConn.SetBandwidthWithBurst(effectiveBandwidth, burstMultiplier)
source := "server"
if req.Bandwidth > 0 && (c.bandwidth == 0 || req.Bandwidth < c.bandwidth) {
source = "client"
}
c.logger.Info("Bandwidth limit configured",
zap.String("subdomain", subdomain),
zap.Int64("bandwidth_bytes_sec", effectiveBandwidth),
zap.Float64("burst_multiplier", burstMultiplier),
zap.String("source", source),
)
}
c.logger.Info("Tunnel registered",
zap.String("subdomain", subdomain),
zap.String("tunnel_type", string(req.TunnelType)),
@@ -258,6 +270,7 @@ func (c *Connection) Handle() error {
TunnelID: tunnelID,
SupportsDataConn: supportsDataConn,
RecommendedConns: recommendedConns,
Bandwidth: c.tunnelConn.GetBandwidth(),
}
respData, err := json.Marshal(resp)
@@ -389,7 +402,6 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
zap.String("host", req.Host),
)
// Get writer from pool to reduce GC pressure
pooledWriter := bufioWriterPool.Get().(*bufio.Writer)
pooledWriter.Reset(c.conn)
@@ -405,30 +417,17 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
c.logger.Debug("Failed to flush HTTP response", zap.Error(err))
}
// Return writer to pool
pooledWriter.Reset(nil) // Clear reference to connection
pooledWriter.Reset(nil)
bufioWriterPool.Put(pooledWriter)
// Keep TCP_NODELAY enabled for low latency HTTP responses
// (removed the toggle that was disabling it)
c.logger.Debug("HTTP request processing completed",
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
)
shouldClose := false
if req.Close {
shouldClose = true
} else if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
if req.Header.Get("Connection") != "keep-alive" {
shouldClose = true
}
}
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
shouldClose = true
}
shouldClose := req.Close ||
(req.ProtoMajor == 1 && req.ProtoMinor == 0 && req.Header.Get("Connection") != "keep-alive") ||
(respWriter.headerWritten && respWriter.header.Get("Connection") == "close")
if shouldClose {
c.logger.Debug("Closing connection as requested by client or server")
@@ -636,7 +635,7 @@ func (w *httpResponseWriter) WriteHeader(statusCode int) {
}
w.writer.WriteString("HTTP/1.1 ")
w.writer.WriteString(fmt.Sprintf("%d", statusCode))
w.writer.WriteString(strconv.Itoa(statusCode))
w.writer.WriteByte(' ')
w.writer.WriteString(statusText)
w.writer.WriteString("\r\n")
@@ -755,6 +754,14 @@ func isTimeoutError(err error) bool {
return strings.Contains(err.Error(), "i/o timeout")
}
func isHTTPMethod(peek string) bool {
switch peek {
case "GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC":
return true
}
return false
}
func (c *Connection) sendDataConnectError(code, message string) {
resp := protocol.DataConnectResponse{
Accepted: false,
@@ -769,38 +776,40 @@ func (c *Connection) sendDataConnectError(code, message string) {
_ = protocol.WriteFrame(c.conn, frame)
}
// SetAllowedTunnelTypes sets the allowed tunnel types for this connection
func (c *Connection) SetAllowedTunnelTypes(types []string) {
c.allowedTunnelTypes = types
}
// SetAllowedTransports sets the allowed transports for this connection
func (c *Connection) SetAllowedTransports(transports []string) {
c.allowedTransports = transports
}
// isTransportAllowed checks if a transport is allowed
func (c *Connection) isTransportAllowed(transport string) bool {
if len(c.allowedTransports) == 0 {
return containsFold(c.allowedTransports, transport)
}
func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool {
return containsFold(c.allowedTunnelTypes, tunnelType)
}
// containsFold returns true if the slice is empty (allow all) or contains the
// value case-insensitively.
func containsFold(allowed []string, value string) bool {
if len(allowed) == 0 {
return true
}
for _, t := range c.allowedTransports {
if strings.EqualFold(t, transport) {
for _, a := range allowed {
if strings.EqualFold(a, value) {
return true
}
}
return false
}
// isTunnelTypeAllowed checks if a tunnel type is allowed
func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool {
if len(c.allowedTunnelTypes) == 0 {
return true // Allow all by default
func (c *Connection) SetBandwidthConfig(bandwidth int64, burstMultiplier float64) {
c.bandwidth = bandwidth
if burstMultiplier <= 0 {
burstMultiplier = 2.0
}
for _, t := range c.allowedTunnelTypes {
if strings.EqualFold(t, tunnelType) {
return true
}
}
return false
c.burstMultiplier = burstMultiplier
}

View File

@@ -43,9 +43,10 @@ type Listener struct {
httpServer *http.Server
httpListener *connQueueListener
// Server capabilities
allowedTransports []string
allowedTunnelTypes []string
bandwidth int64
burstMultiplier float64
}
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler) *Listener {
@@ -63,7 +64,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
panicMetrics := recovery.NewPanicMetrics(logger, nil)
recoverer := recovery.NewRecoverer(logger, panicMetrics)
// Initialize worker pool metrics
metrics.WorkerPoolSize.Set(float64(workers))
l := &Listener{
@@ -85,7 +85,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
groupManager: NewConnectionGroupManager(logger),
}
// Set up WebSocket connection handler if httpHandler supports it
if h, ok := httpHandler.(*proxy.Handler); ok {
h.SetWSConnectionHandler(l)
h.SetPublicPort(publicPort)
@@ -97,7 +96,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
func (l *Listener) Start() error {
var err error
// Support both TLS and plain TCP modes
if l.tlsConfig != nil {
l.listener, err = tls.Listen("tcp", l.address, l.tlsConfig)
if err != nil {
@@ -269,57 +267,13 @@ func (l *Listener) handleConnection(netConn net.Conn) {
)
}
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
conn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
conn.SetAllowedTransports(l.allowedTransports)
conn := l.newConfiguredConnection(netConn)
connID := netConn.RemoteAddr().String()
l.connMu.Lock()
l.connections[connID] = conn
l.connMu.Unlock()
l.trackConnection(connID, conn)
defer l.untrackConnection(connID, conn, netConn)
// Update connection metrics
metrics.TotalConnections.Inc()
metrics.ActiveConnections.Inc()
defer func() {
l.connMu.Lock()
delete(l.connections, connID)
l.connMu.Unlock()
metrics.ActiveConnections.Dec()
if !conn.IsHandedOff() {
netConn.Close()
}
}()
if err := conn.Handle(); err != nil {
errStr := err.Error()
if strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") {
return
}
if strings.Contains(errStr, "payload too large") ||
strings.Contains(errStr, "failed to read registration frame") ||
strings.Contains(errStr, "expected register frame") ||
strings.Contains(errStr, "failed to parse registration request") ||
strings.Contains(errStr, "failed to parse HTTP request") {
l.logger.Warn("Protocol validation failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
} else {
l.logger.Error("Connection handling failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
}
}
l.logConnectionError(conn.Handle(), connID, "Connection")
}
func (l *Listener) Stop() error {
@@ -372,7 +326,6 @@ func (l *Listener) GetActiveConnections() int {
return len(l.connections)
}
// HandleWSConnection implements proxy.WSConnectionHandler for WebSocket tunnel connections
func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
l.wg.Add(1)
defer l.wg.Done()
@@ -386,77 +339,103 @@ func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
zap.String("remote_addr", connID),
)
// Create connection handler (no TLS verification needed - already done by HTTP server)
tcpConn := NewConnection(conn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
tcpConn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
tcpConn := l.newConfiguredConnection(conn)
l.connMu.Lock()
l.connections[connID] = tcpConn
l.connMu.Unlock()
l.trackConnection(connID, tcpConn)
defer l.untrackConnection(connID, tcpConn, conn)
metrics.TotalConnections.Inc()
metrics.ActiveConnections.Inc()
defer func() {
l.connMu.Lock()
delete(l.connections, connID)
l.connMu.Unlock()
metrics.ActiveConnections.Dec()
if !tcpConn.IsHandedOff() {
conn.Close()
}
}()
if err := tcpConn.Handle(); err != nil {
errStr := err.Error()
if strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "websocket: close") {
return
}
if strings.Contains(errStr, "payload too large") ||
strings.Contains(errStr, "failed to read registration frame") ||
strings.Contains(errStr, "expected register frame") ||
strings.Contains(errStr, "failed to parse registration request") ||
strings.Contains(errStr, "tunnel type not allowed") {
l.logger.Warn("WebSocket tunnel protocol validation failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
} else {
l.logger.Error("WebSocket tunnel connection handling failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
}
}
l.logConnectionError(tcpConn.Handle(), connID, "WebSocket tunnel")
}
// SetAllowedTransports sets the allowed transport protocols
func (l *Listener) SetAllowedTransports(transports []string) {
l.allowedTransports = transports
}
// SetAllowedTunnelTypes sets the allowed tunnel types
func (l *Listener) SetAllowedTunnelTypes(types []string) {
l.allowedTunnelTypes = types
}
// IsTransportAllowed checks if a transport is allowed
func (l *Listener) IsTransportAllowed(transport string) bool {
if len(l.allowedTransports) == 0 {
return true
return containsFold(l.allowedTransports, transport)
}
func (l *Listener) SetBurstMultiplier(multiplier float64) {
if multiplier <= 0 {
multiplier = 2.0
}
for _, t := range l.allowedTransports {
if strings.EqualFold(t, transport) {
return true
l.burstMultiplier = multiplier
}
func (l *Listener) SetBandwidth(bandwidth int64) {
l.bandwidth = bandwidth
}
func (l *Listener) newConfiguredConnection(conn net.Conn) *Connection {
c := NewConnection(conn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
c.SetAllowedTunnelTypes(l.allowedTunnelTypes)
c.SetAllowedTransports(l.allowedTransports)
c.SetBandwidthConfig(l.bandwidth, l.burstMultiplier)
return c
}
func (l *Listener) trackConnection(connID string, conn *Connection) {
l.connMu.Lock()
l.connections[connID] = conn
l.connMu.Unlock()
metrics.TotalConnections.Inc()
metrics.ActiveConnections.Inc()
}
func (l *Listener) untrackConnection(connID string, conn *Connection, netConn net.Conn) {
l.connMu.Lock()
delete(l.connections, connID)
l.connMu.Unlock()
metrics.ActiveConnections.Dec()
if !conn.IsHandedOff() {
netConn.Close()
}
}
// logConnectionError classifies and logs a connection handling error.
// Transient network errors are silently ignored, protocol errors are warned,
// and everything else is logged as an error.
func (l *Listener) logConnectionError(err error, connID, label string) {
if err == nil {
return
}
errStr := err.Error()
// Transient / expected disconnects — ignore silently
for _, substr := range []string{
"EOF", "connection reset by peer", "broken pipe",
"connection refused", "websocket: close",
} {
if strings.Contains(errStr, substr) {
return
}
}
return false
// Protocol-level validation failures — warn
for _, substr := range []string{
"payload too large", "failed to read registration frame",
"expected register frame", "failed to parse registration request",
"failed to parse HTTP request", "tunnel type not allowed",
} {
if strings.Contains(errStr, substr) {
l.logger.Warn(label+" protocol validation failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
return
}
}
l.logger.Error(label+" handling failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
}

View File

@@ -10,6 +10,7 @@ import (
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/qos"
"go.uber.org/zap"
)
@@ -34,6 +35,7 @@ type Proxy struct {
cancel context.CancelFunc
checkIPAccess func(ip string) bool
limiter *qos.Limiter
}
type trafficStats interface {
@@ -49,12 +51,6 @@ func NewProxy(ctx context.Context, port int, subdomain string, openStream func()
}
cctx, cancel := context.WithCancel(ctx)
const maxConcurrentConnections = 10000
var sem chan struct{}
if maxConcurrentConnections > 0 {
sem = make(chan struct{}, maxConcurrentConnections)
}
return &Proxy{
port: port,
subdomain: subdomain,
@@ -62,17 +58,20 @@ func NewProxy(ctx context.Context, port int, subdomain string, openStream func()
stopCh: make(chan struct{}),
openStream: openStream,
stats: stats,
sem: sem,
sem: make(chan struct{}, 10000),
ctx: cctx,
cancel: cancel,
}
}
// SetIPAccessCheck sets the IP access control check function.
func (p *Proxy) SetIPAccessCheck(check func(ip string) bool) {
p.checkIPAccess = check
}
func (p *Proxy) SetLimiter(limiter *qos.Limiter) {
p.limiter = limiter
}
func (p *Proxy) Start() error {
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
@@ -174,13 +173,11 @@ func (p *Proxy) handleConn(conn net.Conn) {
}
}
if p.sem != nil {
select {
case p.sem <- struct{}{}:
defer func() { <-p.sem }()
default:
return
}
select {
case p.sem <- struct{}{}:
defer func() { <-p.sem }()
default:
return
}
if p.stats != nil {
@@ -243,7 +240,7 @@ func (p *Proxy) handleConn(conn net.Conn) {
_ = netutil.PipeWithCallbacksAndBufferSize(
p.ctx,
conn,
stream,
qos.NewLimitedConn(p.ctx, stream, p.limiter),
pool.SizeLarge,
func(n int64) {
if p.stats != nil {

View File

@@ -19,19 +19,17 @@ func (c *bufferedConn) Read(p []byte) (int, error) {
return c.reader.Read(p)
}
func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
// Public server acts as yamux Client, client connector acts as yamux Server.
// initMuxSession creates a yamux session over the buffered connection and
// returns the openStream function (possibly group-aware).
func (c *Connection) initMuxSession(reader *bufio.Reader) (func() (net.Conn, error), *yamux.Session, error) {
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
// Use optimized mux config for server
cfg := mux.NewServerConfig()
session, err := yamux.Client(bc, cfg)
session, err := yamux.Client(bc, mux.NewServerConfig())
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
return nil, nil, fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
@@ -43,10 +41,22 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
}
}
return openStream, session, nil
}
func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
openStream, session, err := c.initMuxSession(reader)
if err != nil {
return err
}
c.proxy = NewProxy(c.ctx, c.port, c.subdomain, openStream, c.tunnelConn, c.logger)
if c.tunnelConn != nil && c.tunnelConn.HasIPAccessControl() {
c.proxy.SetIPAccessCheck(c.tunnelConn.IsIPAllowed)
}
if c.tunnelConn != nil {
c.proxy.SetLimiter(c.tunnelConn.GetLimiter())
}
if err := c.proxy.Start(); err != nil {
return fmt.Errorf("failed to start tcp proxy: %w", err)
@@ -61,27 +71,9 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
}
func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
// Public server acts as yamux Client, client connector acts as yamux Server.
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
// Use optimized mux config for server
cfg := mux.NewServerConfig()
session, err := yamux.Client(bc, cfg)
openStream, session, err := c.initMuxSession(reader)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
openStream := session.Open
if c.groupManager != nil {
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
group.AddSession("primary", session)
openStream = group.OpenStream
}
return err
}
if c.tunnelConn != nil {

View File

@@ -9,6 +9,7 @@ import (
"drip/internal/server/metrics"
"drip/internal/shared/netutil"
"drip/internal/shared/protocol"
"drip/internal/shared/qos"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
@@ -32,6 +33,9 @@ type Connection struct {
ipAccessChecker *netutil.IPAccessChecker
proxyAuth *protocol.ProxyAuth
bandwidth int64
limiter *qos.Limiter
}
func NewConnection(subdomain string, conn *websocket.Conn, logger *zap.Logger) *Connection {
@@ -214,6 +218,34 @@ func (c *Connection) ValidateProxyAuth(password string) bool {
return auth.Password == password
}
func (c *Connection) SetBandwidth(bandwidth int64) {
c.SetBandwidthWithBurst(bandwidth, 2.0)
}
func (c *Connection) SetBandwidthWithBurst(bandwidth int64, burstMultiplier float64) {
c.mu.Lock()
defer c.mu.Unlock()
c.bandwidth = bandwidth
if bandwidth > 0 {
burst := int(float64(bandwidth) * burstMultiplier)
c.limiter = qos.NewLimiter(qos.Config{Bandwidth: bandwidth, Burst: burst})
} else {
c.limiter = nil
}
}
func (c *Connection) GetBandwidth() int64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.bandwidth
}
func (c *Connection) GetLimiter() *qos.Limiter {
c.mu.RLock()
defer c.mu.RUnlock()
return c.limiter
}
func (c *Connection) StartWritePump() {
if c.Conn == nil {
go func() {

View File

@@ -0,0 +1,117 @@
package tunnel
import (
"testing"
"go.uber.org/zap"
)
func TestConnectionBandwidthWithBurst(t *testing.T) {
logger := zap.NewNop()
tests := []struct {
name string
bandwidth int64
burstMultiplier float64
wantBandwidth int64
wantBurst int
}{
{
name: "1MB/s with 2x burst",
bandwidth: 1024 * 1024,
burstMultiplier: 2.0,
wantBandwidth: 1024 * 1024,
wantBurst: 2 * 1024 * 1024,
},
{
name: "1MB/s with 2.5x burst",
bandwidth: 1024 * 1024,
burstMultiplier: 2.5,
wantBandwidth: 1024 * 1024,
wantBurst: int(float64(1024*1024) * 2.5),
},
{
name: "500KB/s with 3x burst",
bandwidth: 500 * 1024,
burstMultiplier: 3.0,
wantBandwidth: 500 * 1024,
wantBurst: 3 * 500 * 1024,
},
{
name: "10MB/s with 1.5x burst",
bandwidth: 10 * 1024 * 1024,
burstMultiplier: 1.5,
wantBandwidth: 10 * 1024 * 1024,
wantBurst: int(float64(10*1024*1024) * 1.5),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conn := NewConnection("test-subdomain", nil, logger)
conn.SetBandwidthWithBurst(tt.bandwidth, tt.burstMultiplier)
if conn.GetBandwidth() != tt.wantBandwidth {
t.Errorf("GetBandwidth() = %v, want %v", conn.GetBandwidth(), tt.wantBandwidth)
}
limiter := conn.GetLimiter()
if limiter == nil {
t.Fatal("GetLimiter() should not be nil")
}
if !limiter.IsLimited() {
t.Error("Limiter should be limited")
}
if limiter.RateLimiter().Burst() != tt.wantBurst {
t.Errorf("Burst() = %v, want %v", limiter.RateLimiter().Burst(), tt.wantBurst)
}
})
}
}
func TestConnectionBandwidthUnlimited(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test-subdomain", nil, logger)
if conn.GetBandwidth() != 0 {
t.Errorf("Default bandwidth should be 0, got %v", conn.GetBandwidth())
}
if conn.GetLimiter() != nil {
t.Error("Default limiter should be nil")
}
conn.SetBandwidth(0)
if conn.GetLimiter() != nil {
t.Error("Limiter should be nil when bandwidth is 0")
}
conn.SetBandwidthWithBurst(0, 2.0)
if conn.GetLimiter() != nil {
t.Error("Limiter should be nil when bandwidth is 0")
}
}
func TestConnectionSetBandwidth(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test-subdomain", nil, logger)
conn.SetBandwidth(1024 * 1024)
if conn.GetBandwidth() != 1024*1024 {
t.Errorf("GetBandwidth() = %v, want %v", conn.GetBandwidth(), 1024*1024)
}
limiter := conn.GetLimiter()
if limiter == nil {
t.Fatal("GetLimiter() should not be nil")
}
expectedBurst := 2 * 1024 * 1024
if limiter.RateLimiter().Burst() != expectedBurst {
t.Errorf("Burst() = %v, want %v", limiter.RateLimiter().Burst(), expectedBurst)
}
}