mirror of
https://github.com/Gouryella/drip.git
synced 2026-03-01 15:52:32 +00:00
feat(cli): Add bandwidth limit function support
Added bandwidth limiting functionality, allowing users to limit the bandwidth of tunnel connections via the --bandwidth parameter. Supported formats include: 1K/1KB (kilobytes), 1M/1MB (megabytes), 1G/1GB (gigabytes) or Raw number (bytes).
This commit is contained in:
@@ -24,6 +24,7 @@ import (
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/qos"
|
||||
"drip/internal/shared/wsutil"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
@@ -171,24 +172,22 @@ func (h *Handler) SetAllowedTunnelTypes(types []string) {
|
||||
|
||||
// IsTransportAllowed checks if a transport is allowed
|
||||
func (h *Handler) IsTransportAllowed(transport string) bool {
|
||||
if len(h.allowedTransports) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, t := range h.allowedTransports {
|
||||
if strings.EqualFold(t, transport) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return containsFold(h.allowedTransports, transport)
|
||||
}
|
||||
|
||||
// IsTunnelTypeAllowed checks if a tunnel type is allowed
|
||||
func (h *Handler) IsTunnelTypeAllowed(tunnelType string) bool {
|
||||
if len(h.allowedTunnelTypes) == 0 {
|
||||
return containsFold(h.allowedTunnelTypes, tunnelType)
|
||||
}
|
||||
|
||||
// containsFold returns true if the slice is empty (allow all) or contains the
|
||||
// value in a case-insensitive comparison.
|
||||
func containsFold(allowed []string, value string) bool {
|
||||
if len(allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, t := range h.allowedTunnelTypes {
|
||||
if strings.EqualFold(t, tunnelType) {
|
||||
for _, a := range allowed {
|
||||
if strings.EqualFold(a, value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -299,7 +298,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
tconn.IncActiveConnections()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
countingStream := netutil.NewCountingConn(stream,
|
||||
var limitedStream net.Conn = stream
|
||||
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
|
||||
limitedStream = qos.NewLimitedConn(r.Context(), stream, limiter)
|
||||
}
|
||||
|
||||
countingStream := netutil.NewCountingConn(limitedStream,
|
||||
tconn.AddBytesOut,
|
||||
tconn.AddBytesIn,
|
||||
)
|
||||
@@ -428,6 +432,11 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
|
||||
return
|
||||
}
|
||||
|
||||
var limitedStream net.Conn = stream
|
||||
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
|
||||
limitedStream = qos.NewLimitedConn(context.Background(), stream, limiter)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
defer clientConn.Close()
|
||||
@@ -441,7 +450,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
|
||||
}
|
||||
}
|
||||
|
||||
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
|
||||
_ = netutil.PipeWithCallbacks(context.Background(), limitedStream, clientRW,
|
||||
func(n int64) { tconn.AddBytesOut(n) },
|
||||
func(n int64) { tconn.AddBytesIn(n) },
|
||||
)
|
||||
|
||||
168
internal/server/tcp/bandwidth_test.go
Normal file
168
internal/server/tcp/bandwidth_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEffectiveBandwidthSelection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverBW int64
|
||||
clientBW int64
|
||||
wantEffective int64
|
||||
}{
|
||||
{
|
||||
name: "server only",
|
||||
serverBW: 1024 * 1024,
|
||||
clientBW: 0,
|
||||
wantEffective: 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "client only",
|
||||
serverBW: 0,
|
||||
clientBW: 512 * 1024,
|
||||
wantEffective: 512 * 1024,
|
||||
},
|
||||
{
|
||||
name: "both unlimited",
|
||||
serverBW: 0,
|
||||
clientBW: 0,
|
||||
wantEffective: 0,
|
||||
},
|
||||
{
|
||||
name: "client lower than server",
|
||||
serverBW: 10 * 1024 * 1024,
|
||||
clientBW: 1 * 1024 * 1024,
|
||||
wantEffective: 1 * 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "client higher than server - server wins",
|
||||
serverBW: 1 * 1024 * 1024,
|
||||
clientBW: 10 * 1024 * 1024,
|
||||
wantEffective: 1 * 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "client equal to server",
|
||||
serverBW: 5 * 1024 * 1024,
|
||||
clientBW: 5 * 1024 * 1024,
|
||||
wantEffective: 5 * 1024 * 1024,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
effectiveBandwidth := tt.serverBW
|
||||
if tt.clientBW > 0 {
|
||||
if effectiveBandwidth == 0 || tt.clientBW < effectiveBandwidth {
|
||||
effectiveBandwidth = tt.clientBW
|
||||
}
|
||||
}
|
||||
|
||||
if effectiveBandwidth != tt.wantEffective {
|
||||
t.Errorf("effectiveBandwidth = %d, want %d", effectiveBandwidth, tt.wantEffective)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionSetBandwidthConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
wantBandwidth int64
|
||||
wantMultiplier float64
|
||||
}{
|
||||
{
|
||||
name: "1MB/s with 2x burst",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 2.0,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantMultiplier: 2.0,
|
||||
},
|
||||
{
|
||||
name: "1MB/s with 2.5x burst",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 2.5,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantMultiplier: 2.5,
|
||||
},
|
||||
{
|
||||
name: "default multiplier when 0",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 0,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantMultiplier: 2.0,
|
||||
},
|
||||
{
|
||||
name: "default multiplier when negative",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: -1.0,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantMultiplier: 2.0,
|
||||
},
|
||||
{
|
||||
name: "unlimited bandwidth",
|
||||
bandwidth: 0,
|
||||
burstMultiplier: 2.5,
|
||||
wantBandwidth: 0,
|
||||
wantMultiplier: 2.5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
conn := &Connection{}
|
||||
conn.SetBandwidthConfig(tt.bandwidth, tt.burstMultiplier)
|
||||
|
||||
if conn.bandwidth != tt.wantBandwidth {
|
||||
t.Errorf("bandwidth = %v, want %v", conn.bandwidth, tt.wantBandwidth)
|
||||
}
|
||||
|
||||
if conn.burstMultiplier != tt.wantMultiplier {
|
||||
t.Errorf("burstMultiplier = %v, want %v", conn.burstMultiplier, tt.wantMultiplier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenerBandwidthConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
wantBandwidth int64
|
||||
wantMultiplier float64
|
||||
}{
|
||||
{
|
||||
name: "set bandwidth and multiplier",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 2.5,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantMultiplier: 2.5,
|
||||
},
|
||||
{
|
||||
name: "default multiplier",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 0,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantMultiplier: 2.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l := &Listener{}
|
||||
l.SetBandwidth(tt.bandwidth)
|
||||
l.SetBurstMultiplier(tt.burstMultiplier)
|
||||
|
||||
if l.bandwidth != tt.wantBandwidth {
|
||||
t.Errorf("bandwidth = %v, want %v", l.bandwidth, tt.wantBandwidth)
|
||||
}
|
||||
|
||||
if l.burstMultiplier != tt.wantMultiplier {
|
||||
t.Errorf("burstMultiplier = %v, want %v", l.burstMultiplier, tt.wantMultiplier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -59,12 +59,12 @@ type Connection struct {
|
||||
httpListener *connQueueListener
|
||||
handedOff bool
|
||||
|
||||
// Server capabilities
|
||||
allowedTunnelTypes []string
|
||||
allowedTransports []string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
}
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
c := &Connection{
|
||||
@@ -99,22 +99,11 @@ func (c *Connection) Handle() error {
|
||||
return fmt.Errorf("failed to peek connection: %w", err)
|
||||
}
|
||||
|
||||
peekStr := string(peek)
|
||||
httpMethods := []string{"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC"}
|
||||
isHTTP := false
|
||||
for _, method := range httpMethods {
|
||||
if strings.HasPrefix(peekStr, method) {
|
||||
isHTTP = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isHTTP {
|
||||
if isHTTPMethod(string(peek)) {
|
||||
c.logger.Info("Detected HTTP request on TCP port, handling as HTTP")
|
||||
return c.handleHTTPRequest(reader)
|
||||
}
|
||||
|
||||
// Check if TCP transport is allowed (only for Drip protocol connections, not HTTP)
|
||||
if !c.isTransportAllowed("tcp") {
|
||||
c.logger.Warn("TCP transport not allowed, rejecting Drip protocol connection")
|
||||
return fmt.Errorf("TCP transport not allowed")
|
||||
@@ -142,7 +131,6 @@ func (c *Connection) Handle() error {
|
||||
|
||||
c.tunnelType = req.TunnelType
|
||||
|
||||
// Check if tunnel type is allowed
|
||||
if !c.isTunnelTypeAllowed(string(req.TunnelType)) {
|
||||
c.sendError("tunnel_type_not_allowed", fmt.Sprintf("Tunnel type '%s' is not allowed on this server", req.TunnelType))
|
||||
return fmt.Errorf("tunnel type not allowed: %s", req.TunnelType)
|
||||
@@ -197,7 +185,6 @@ func (c *Connection) Handle() error {
|
||||
|
||||
c.tunnelConn.Conn = nil
|
||||
c.tunnelConn.SetTunnelType(req.TunnelType)
|
||||
c.tunnelType = req.TunnelType
|
||||
|
||||
if req.IPAccess != nil && (len(req.IPAccess.AllowIPs) > 0 || len(req.IPAccess.DenyIPs) > 0) {
|
||||
c.tunnelConn.SetIPAccessControl(req.IPAccess.AllowIPs, req.IPAccess.DenyIPs)
|
||||
@@ -215,6 +202,31 @@ func (c *Connection) Handle() error {
|
||||
)
|
||||
}
|
||||
|
||||
effectiveBandwidth := c.bandwidth
|
||||
if req.Bandwidth > 0 {
|
||||
if effectiveBandwidth == 0 || req.Bandwidth < effectiveBandwidth {
|
||||
effectiveBandwidth = req.Bandwidth
|
||||
}
|
||||
}
|
||||
if effectiveBandwidth > 0 {
|
||||
burstMultiplier := c.burstMultiplier
|
||||
if burstMultiplier <= 0 {
|
||||
burstMultiplier = 2.0
|
||||
}
|
||||
c.tunnelConn.SetBandwidthWithBurst(effectiveBandwidth, burstMultiplier)
|
||||
|
||||
source := "server"
|
||||
if req.Bandwidth > 0 && (c.bandwidth == 0 || req.Bandwidth < c.bandwidth) {
|
||||
source = "client"
|
||||
}
|
||||
c.logger.Info("Bandwidth limit configured",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.Int64("bandwidth_bytes_sec", effectiveBandwidth),
|
||||
zap.Float64("burst_multiplier", burstMultiplier),
|
||||
zap.String("source", source),
|
||||
)
|
||||
}
|
||||
|
||||
c.logger.Info("Tunnel registered",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.String("tunnel_type", string(req.TunnelType)),
|
||||
@@ -258,6 +270,7 @@ func (c *Connection) Handle() error {
|
||||
TunnelID: tunnelID,
|
||||
SupportsDataConn: supportsDataConn,
|
||||
RecommendedConns: recommendedConns,
|
||||
Bandwidth: c.tunnelConn.GetBandwidth(),
|
||||
}
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
@@ -389,7 +402,6 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
|
||||
zap.String("host", req.Host),
|
||||
)
|
||||
|
||||
// Get writer from pool to reduce GC pressure
|
||||
pooledWriter := bufioWriterPool.Get().(*bufio.Writer)
|
||||
pooledWriter.Reset(c.conn)
|
||||
|
||||
@@ -405,30 +417,17 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
|
||||
c.logger.Debug("Failed to flush HTTP response", zap.Error(err))
|
||||
}
|
||||
|
||||
// Return writer to pool
|
||||
pooledWriter.Reset(nil) // Clear reference to connection
|
||||
pooledWriter.Reset(nil)
|
||||
bufioWriterPool.Put(pooledWriter)
|
||||
|
||||
// Keep TCP_NODELAY enabled for low latency HTTP responses
|
||||
// (removed the toggle that was disabling it)
|
||||
|
||||
c.logger.Debug("HTTP request processing completed",
|
||||
zap.String("method", req.Method),
|
||||
zap.String("url", req.URL.String()),
|
||||
)
|
||||
|
||||
shouldClose := false
|
||||
if req.Close {
|
||||
shouldClose = true
|
||||
} else if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
|
||||
if req.Header.Get("Connection") != "keep-alive" {
|
||||
shouldClose = true
|
||||
}
|
||||
}
|
||||
|
||||
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
|
||||
shouldClose = true
|
||||
}
|
||||
shouldClose := req.Close ||
|
||||
(req.ProtoMajor == 1 && req.ProtoMinor == 0 && req.Header.Get("Connection") != "keep-alive") ||
|
||||
(respWriter.headerWritten && respWriter.header.Get("Connection") == "close")
|
||||
|
||||
if shouldClose {
|
||||
c.logger.Debug("Closing connection as requested by client or server")
|
||||
@@ -636,7 +635,7 @@ func (w *httpResponseWriter) WriteHeader(statusCode int) {
|
||||
}
|
||||
|
||||
w.writer.WriteString("HTTP/1.1 ")
|
||||
w.writer.WriteString(fmt.Sprintf("%d", statusCode))
|
||||
w.writer.WriteString(strconv.Itoa(statusCode))
|
||||
w.writer.WriteByte(' ')
|
||||
w.writer.WriteString(statusText)
|
||||
w.writer.WriteString("\r\n")
|
||||
@@ -755,6 +754,14 @@ func isTimeoutError(err error) bool {
|
||||
return strings.Contains(err.Error(), "i/o timeout")
|
||||
}
|
||||
|
||||
func isHTTPMethod(peek string) bool {
|
||||
switch peek {
|
||||
case "GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Connection) sendDataConnectError(code, message string) {
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: false,
|
||||
@@ -769,38 +776,40 @@ func (c *Connection) sendDataConnectError(code, message string) {
|
||||
_ = protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
|
||||
// SetAllowedTunnelTypes sets the allowed tunnel types for this connection
|
||||
func (c *Connection) SetAllowedTunnelTypes(types []string) {
|
||||
c.allowedTunnelTypes = types
|
||||
}
|
||||
|
||||
// SetAllowedTransports sets the allowed transports for this connection
|
||||
func (c *Connection) SetAllowedTransports(transports []string) {
|
||||
c.allowedTransports = transports
|
||||
}
|
||||
|
||||
// isTransportAllowed checks if a transport is allowed
|
||||
func (c *Connection) isTransportAllowed(transport string) bool {
|
||||
if len(c.allowedTransports) == 0 {
|
||||
return containsFold(c.allowedTransports, transport)
|
||||
}
|
||||
|
||||
func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool {
|
||||
return containsFold(c.allowedTunnelTypes, tunnelType)
|
||||
}
|
||||
|
||||
// containsFold returns true if the slice is empty (allow all) or contains the
|
||||
// value case-insensitively.
|
||||
func containsFold(allowed []string, value string) bool {
|
||||
if len(allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, t := range c.allowedTransports {
|
||||
if strings.EqualFold(t, transport) {
|
||||
for _, a := range allowed {
|
||||
if strings.EqualFold(a, value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTunnelTypeAllowed checks if a tunnel type is allowed
|
||||
func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool {
|
||||
if len(c.allowedTunnelTypes) == 0 {
|
||||
return true // Allow all by default
|
||||
func (c *Connection) SetBandwidthConfig(bandwidth int64, burstMultiplier float64) {
|
||||
c.bandwidth = bandwidth
|
||||
if burstMultiplier <= 0 {
|
||||
burstMultiplier = 2.0
|
||||
}
|
||||
for _, t := range c.allowedTunnelTypes {
|
||||
if strings.EqualFold(t, tunnelType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
c.burstMultiplier = burstMultiplier
|
||||
}
|
||||
|
||||
@@ -43,9 +43,10 @@ type Listener struct {
|
||||
httpServer *http.Server
|
||||
httpListener *connQueueListener
|
||||
|
||||
// Server capabilities
|
||||
allowedTransports []string
|
||||
allowedTunnelTypes []string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
}
|
||||
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler) *Listener {
|
||||
@@ -63,7 +64,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
panicMetrics := recovery.NewPanicMetrics(logger, nil)
|
||||
recoverer := recovery.NewRecoverer(logger, panicMetrics)
|
||||
|
||||
// Initialize worker pool metrics
|
||||
metrics.WorkerPoolSize.Set(float64(workers))
|
||||
|
||||
l := &Listener{
|
||||
@@ -85,7 +85,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
groupManager: NewConnectionGroupManager(logger),
|
||||
}
|
||||
|
||||
// Set up WebSocket connection handler if httpHandler supports it
|
||||
if h, ok := httpHandler.(*proxy.Handler); ok {
|
||||
h.SetWSConnectionHandler(l)
|
||||
h.SetPublicPort(publicPort)
|
||||
@@ -97,7 +96,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
func (l *Listener) Start() error {
|
||||
var err error
|
||||
|
||||
// Support both TLS and plain TCP modes
|
||||
if l.tlsConfig != nil {
|
||||
l.listener, err = tls.Listen("tcp", l.address, l.tlsConfig)
|
||||
if err != nil {
|
||||
@@ -269,57 +267,13 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
)
|
||||
}
|
||||
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
|
||||
conn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
|
||||
conn.SetAllowedTransports(l.allowedTransports)
|
||||
conn := l.newConfiguredConnection(netConn)
|
||||
|
||||
connID := netConn.RemoteAddr().String()
|
||||
l.connMu.Lock()
|
||||
l.connections[connID] = conn
|
||||
l.connMu.Unlock()
|
||||
l.trackConnection(connID, conn)
|
||||
defer l.untrackConnection(connID, conn, netConn)
|
||||
|
||||
// Update connection metrics
|
||||
metrics.TotalConnections.Inc()
|
||||
metrics.ActiveConnections.Inc()
|
||||
|
||||
defer func() {
|
||||
l.connMu.Lock()
|
||||
delete(l.connections, connID)
|
||||
l.connMu.Unlock()
|
||||
|
||||
metrics.ActiveConnections.Dec()
|
||||
|
||||
if !conn.IsHandedOff() {
|
||||
netConn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := conn.Handle(); err != nil {
|
||||
errStr := err.Error()
|
||||
|
||||
if strings.Contains(errStr, "EOF") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(errStr, "payload too large") ||
|
||||
strings.Contains(errStr, "failed to read registration frame") ||
|
||||
strings.Contains(errStr, "expected register frame") ||
|
||||
strings.Contains(errStr, "failed to parse registration request") ||
|
||||
strings.Contains(errStr, "failed to parse HTTP request") {
|
||||
l.logger.Warn("Protocol validation failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
l.logger.Error("Connection handling failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
l.logConnectionError(conn.Handle(), connID, "Connection")
|
||||
}
|
||||
|
||||
func (l *Listener) Stop() error {
|
||||
@@ -372,7 +326,6 @@ func (l *Listener) GetActiveConnections() int {
|
||||
return len(l.connections)
|
||||
}
|
||||
|
||||
// HandleWSConnection implements proxy.WSConnectionHandler for WebSocket tunnel connections
|
||||
func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
|
||||
l.wg.Add(1)
|
||||
defer l.wg.Done()
|
||||
@@ -386,77 +339,103 @@ func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
|
||||
zap.String("remote_addr", connID),
|
||||
)
|
||||
|
||||
// Create connection handler (no TLS verification needed - already done by HTTP server)
|
||||
tcpConn := NewConnection(conn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
|
||||
tcpConn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
|
||||
tcpConn := l.newConfiguredConnection(conn)
|
||||
|
||||
l.connMu.Lock()
|
||||
l.connections[connID] = tcpConn
|
||||
l.connMu.Unlock()
|
||||
l.trackConnection(connID, tcpConn)
|
||||
defer l.untrackConnection(connID, tcpConn, conn)
|
||||
|
||||
metrics.TotalConnections.Inc()
|
||||
metrics.ActiveConnections.Inc()
|
||||
|
||||
defer func() {
|
||||
l.connMu.Lock()
|
||||
delete(l.connections, connID)
|
||||
l.connMu.Unlock()
|
||||
|
||||
metrics.ActiveConnections.Dec()
|
||||
|
||||
if !tcpConn.IsHandedOff() {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := tcpConn.Handle(); err != nil {
|
||||
errStr := err.Error()
|
||||
|
||||
if strings.Contains(errStr, "EOF") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") ||
|
||||
strings.Contains(errStr, "websocket: close") {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(errStr, "payload too large") ||
|
||||
strings.Contains(errStr, "failed to read registration frame") ||
|
||||
strings.Contains(errStr, "expected register frame") ||
|
||||
strings.Contains(errStr, "failed to parse registration request") ||
|
||||
strings.Contains(errStr, "tunnel type not allowed") {
|
||||
l.logger.Warn("WebSocket tunnel protocol validation failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
l.logger.Error("WebSocket tunnel connection handling failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
l.logConnectionError(tcpConn.Handle(), connID, "WebSocket tunnel")
|
||||
}
|
||||
|
||||
// SetAllowedTransports sets the allowed transport protocols
|
||||
func (l *Listener) SetAllowedTransports(transports []string) {
|
||||
l.allowedTransports = transports
|
||||
}
|
||||
|
||||
// SetAllowedTunnelTypes sets the allowed tunnel types
|
||||
func (l *Listener) SetAllowedTunnelTypes(types []string) {
|
||||
l.allowedTunnelTypes = types
|
||||
}
|
||||
|
||||
// IsTransportAllowed checks if a transport is allowed
|
||||
func (l *Listener) IsTransportAllowed(transport string) bool {
|
||||
if len(l.allowedTransports) == 0 {
|
||||
return true
|
||||
return containsFold(l.allowedTransports, transport)
|
||||
}
|
||||
|
||||
func (l *Listener) SetBurstMultiplier(multiplier float64) {
|
||||
if multiplier <= 0 {
|
||||
multiplier = 2.0
|
||||
}
|
||||
for _, t := range l.allowedTransports {
|
||||
if strings.EqualFold(t, transport) {
|
||||
return true
|
||||
l.burstMultiplier = multiplier
|
||||
}
|
||||
|
||||
func (l *Listener) SetBandwidth(bandwidth int64) {
|
||||
l.bandwidth = bandwidth
|
||||
}
|
||||
|
||||
func (l *Listener) newConfiguredConnection(conn net.Conn) *Connection {
|
||||
c := NewConnection(conn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
|
||||
c.SetAllowedTunnelTypes(l.allowedTunnelTypes)
|
||||
c.SetAllowedTransports(l.allowedTransports)
|
||||
c.SetBandwidthConfig(l.bandwidth, l.burstMultiplier)
|
||||
return c
|
||||
}
|
||||
|
||||
func (l *Listener) trackConnection(connID string, conn *Connection) {
|
||||
l.connMu.Lock()
|
||||
l.connections[connID] = conn
|
||||
l.connMu.Unlock()
|
||||
|
||||
metrics.TotalConnections.Inc()
|
||||
metrics.ActiveConnections.Inc()
|
||||
}
|
||||
|
||||
func (l *Listener) untrackConnection(connID string, conn *Connection, netConn net.Conn) {
|
||||
l.connMu.Lock()
|
||||
delete(l.connections, connID)
|
||||
l.connMu.Unlock()
|
||||
|
||||
metrics.ActiveConnections.Dec()
|
||||
|
||||
if !conn.IsHandedOff() {
|
||||
netConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// logConnectionError classifies and logs a connection handling error.
|
||||
// Transient network errors are silently ignored, protocol errors are warned,
|
||||
// and everything else is logged as an error.
|
||||
func (l *Listener) logConnectionError(err error, connID, label string) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
|
||||
// Transient / expected disconnects — ignore silently
|
||||
for _, substr := range []string{
|
||||
"EOF", "connection reset by peer", "broken pipe",
|
||||
"connection refused", "websocket: close",
|
||||
} {
|
||||
if strings.Contains(errStr, substr) {
|
||||
return
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
// Protocol-level validation failures — warn
|
||||
for _, substr := range []string{
|
||||
"payload too large", "failed to read registration frame",
|
||||
"expected register frame", "failed to parse registration request",
|
||||
"failed to parse HTTP request", "tunnel type not allowed",
|
||||
} {
|
||||
if strings.Contains(errStr, substr) {
|
||||
l.logger.Warn(label+" protocol validation failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
l.logger.Error(label+" handling failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/qos"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -34,6 +35,7 @@ type Proxy struct {
|
||||
cancel context.CancelFunc
|
||||
|
||||
checkIPAccess func(ip string) bool
|
||||
limiter *qos.Limiter
|
||||
}
|
||||
|
||||
type trafficStats interface {
|
||||
@@ -49,12 +51,6 @@ func NewProxy(ctx context.Context, port int, subdomain string, openStream func()
|
||||
}
|
||||
cctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
const maxConcurrentConnections = 10000
|
||||
var sem chan struct{}
|
||||
if maxConcurrentConnections > 0 {
|
||||
sem = make(chan struct{}, maxConcurrentConnections)
|
||||
}
|
||||
|
||||
return &Proxy{
|
||||
port: port,
|
||||
subdomain: subdomain,
|
||||
@@ -62,17 +58,20 @@ func NewProxy(ctx context.Context, port int, subdomain string, openStream func()
|
||||
stopCh: make(chan struct{}),
|
||||
openStream: openStream,
|
||||
stats: stats,
|
||||
sem: sem,
|
||||
sem: make(chan struct{}, 10000),
|
||||
ctx: cctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// SetIPAccessCheck sets the IP access control check function.
|
||||
func (p *Proxy) SetIPAccessCheck(check func(ip string) bool) {
|
||||
p.checkIPAccess = check
|
||||
}
|
||||
|
||||
func (p *Proxy) SetLimiter(limiter *qos.Limiter) {
|
||||
p.limiter = limiter
|
||||
}
|
||||
|
||||
func (p *Proxy) Start() error {
|
||||
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
|
||||
|
||||
@@ -174,13 +173,11 @@ func (p *Proxy) handleConn(conn net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
if p.sem != nil {
|
||||
select {
|
||||
case p.sem <- struct{}{}:
|
||||
defer func() { <-p.sem }()
|
||||
default:
|
||||
return
|
||||
}
|
||||
select {
|
||||
case p.sem <- struct{}{}:
|
||||
defer func() { <-p.sem }()
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
if p.stats != nil {
|
||||
@@ -243,7 +240,7 @@ func (p *Proxy) handleConn(conn net.Conn) {
|
||||
_ = netutil.PipeWithCallbacksAndBufferSize(
|
||||
p.ctx,
|
||||
conn,
|
||||
stream,
|
||||
qos.NewLimitedConn(p.ctx, stream, p.limiter),
|
||||
pool.SizeLarge,
|
||||
func(n int64) {
|
||||
if p.stats != nil {
|
||||
|
||||
@@ -19,19 +19,17 @@ func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
// initMuxSession creates a yamux session over the buffered connection and
|
||||
// returns the openStream function (possibly group-aware).
|
||||
func (c *Connection) initMuxSession(reader *bufio.Reader) (func() (net.Conn, error), *yamux.Session, error) {
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
// Use optimized mux config for server
|
||||
cfg := mux.NewServerConfig()
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
session, err := yamux.Client(bc, mux.NewServerConfig())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
@@ -43,10 +41,22 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
}
|
||||
}
|
||||
|
||||
return openStream, session, nil
|
||||
}
|
||||
|
||||
func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
openStream, session, err := c.initMuxSession(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.proxy = NewProxy(c.ctx, c.port, c.subdomain, openStream, c.tunnelConn, c.logger)
|
||||
if c.tunnelConn != nil && c.tunnelConn.HasIPAccessControl() {
|
||||
c.proxy.SetIPAccessCheck(c.tunnelConn.IsIPAllowed)
|
||||
}
|
||||
if c.tunnelConn != nil {
|
||||
c.proxy.SetLimiter(c.tunnelConn.GetLimiter())
|
||||
}
|
||||
|
||||
if err := c.proxy.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start tcp proxy: %w", err)
|
||||
@@ -61,27 +71,9 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
}
|
||||
|
||||
func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
// Use optimized mux config for server
|
||||
cfg := mux.NewServerConfig()
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
openStream, session, err := c.initMuxSession(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
openStream := session.Open
|
||||
if c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
group.AddSession("primary", session)
|
||||
openStream = group.OpenStream
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if c.tunnelConn != nil {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"drip/internal/server/metrics"
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/qos"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -32,6 +33,9 @@ type Connection struct {
|
||||
|
||||
ipAccessChecker *netutil.IPAccessChecker
|
||||
proxyAuth *protocol.ProxyAuth
|
||||
|
||||
bandwidth int64
|
||||
limiter *qos.Limiter
|
||||
}
|
||||
|
||||
func NewConnection(subdomain string, conn *websocket.Conn, logger *zap.Logger) *Connection {
|
||||
@@ -214,6 +218,34 @@ func (c *Connection) ValidateProxyAuth(password string) bool {
|
||||
return auth.Password == password
|
||||
}
|
||||
|
||||
func (c *Connection) SetBandwidth(bandwidth int64) {
|
||||
c.SetBandwidthWithBurst(bandwidth, 2.0)
|
||||
}
|
||||
|
||||
func (c *Connection) SetBandwidthWithBurst(bandwidth int64, burstMultiplier float64) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.bandwidth = bandwidth
|
||||
if bandwidth > 0 {
|
||||
burst := int(float64(bandwidth) * burstMultiplier)
|
||||
c.limiter = qos.NewLimiter(qos.Config{Bandwidth: bandwidth, Burst: burst})
|
||||
} else {
|
||||
c.limiter = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) GetBandwidth() int64 {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.bandwidth
|
||||
}
|
||||
|
||||
func (c *Connection) GetLimiter() *qos.Limiter {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.limiter
|
||||
}
|
||||
|
||||
func (c *Connection) StartWritePump() {
|
||||
if c.Conn == nil {
|
||||
go func() {
|
||||
|
||||
117
internal/server/tunnel/connection_test.go
Normal file
117
internal/server/tunnel/connection_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConnectionBandwidthWithBurst(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
wantBandwidth int64
|
||||
wantBurst int
|
||||
}{
|
||||
{
|
||||
name: "1MB/s with 2x burst",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 2.0,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantBurst: 2 * 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "1MB/s with 2.5x burst",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 2.5,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantBurst: int(float64(1024*1024) * 2.5),
|
||||
},
|
||||
{
|
||||
name: "500KB/s with 3x burst",
|
||||
bandwidth: 500 * 1024,
|
||||
burstMultiplier: 3.0,
|
||||
wantBandwidth: 500 * 1024,
|
||||
wantBurst: 3 * 500 * 1024,
|
||||
},
|
||||
{
|
||||
name: "10MB/s with 1.5x burst",
|
||||
bandwidth: 10 * 1024 * 1024,
|
||||
burstMultiplier: 1.5,
|
||||
wantBandwidth: 10 * 1024 * 1024,
|
||||
wantBurst: int(float64(10*1024*1024) * 1.5),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
conn := NewConnection("test-subdomain", nil, logger)
|
||||
|
||||
conn.SetBandwidthWithBurst(tt.bandwidth, tt.burstMultiplier)
|
||||
|
||||
if conn.GetBandwidth() != tt.wantBandwidth {
|
||||
t.Errorf("GetBandwidth() = %v, want %v", conn.GetBandwidth(), tt.wantBandwidth)
|
||||
}
|
||||
|
||||
limiter := conn.GetLimiter()
|
||||
if limiter == nil {
|
||||
t.Fatal("GetLimiter() should not be nil")
|
||||
}
|
||||
|
||||
if !limiter.IsLimited() {
|
||||
t.Error("Limiter should be limited")
|
||||
}
|
||||
|
||||
if limiter.RateLimiter().Burst() != tt.wantBurst {
|
||||
t.Errorf("Burst() = %v, want %v", limiter.RateLimiter().Burst(), tt.wantBurst)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionBandwidthUnlimited(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test-subdomain", nil, logger)
|
||||
|
||||
if conn.GetBandwidth() != 0 {
|
||||
t.Errorf("Default bandwidth should be 0, got %v", conn.GetBandwidth())
|
||||
}
|
||||
|
||||
if conn.GetLimiter() != nil {
|
||||
t.Error("Default limiter should be nil")
|
||||
}
|
||||
|
||||
conn.SetBandwidth(0)
|
||||
if conn.GetLimiter() != nil {
|
||||
t.Error("Limiter should be nil when bandwidth is 0")
|
||||
}
|
||||
|
||||
conn.SetBandwidthWithBurst(0, 2.0)
|
||||
if conn.GetLimiter() != nil {
|
||||
t.Error("Limiter should be nil when bandwidth is 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionSetBandwidth(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test-subdomain", nil, logger)
|
||||
|
||||
conn.SetBandwidth(1024 * 1024)
|
||||
|
||||
if conn.GetBandwidth() != 1024*1024 {
|
||||
t.Errorf("GetBandwidth() = %v, want %v", conn.GetBandwidth(), 1024*1024)
|
||||
}
|
||||
|
||||
limiter := conn.GetLimiter()
|
||||
if limiter == nil {
|
||||
t.Fatal("GetLimiter() should not be nil")
|
||||
}
|
||||
|
||||
expectedBurst := 2 * 1024 * 1024
|
||||
if limiter.RateLimiter().Burst() != expectedBurst {
|
||||
t.Errorf("Burst() = %v, want %v", limiter.RateLimiter().Burst(), expectedBurst)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user