feat(cli): Add bandwidth limit function support

Added bandwidth limiting functionality, allowing users to limit the bandwidth of tunnel connections via the --bandwidth parameter.
Supported formats include: 1K/1KB (kilobytes), 1M/1MB (megabytes), 1G/1GB (gigabytes) or
Raw number (bytes).
This commit is contained in:
Gouryella
2026-02-14 14:20:21 +08:00
parent 3872bd9326
commit f90df37d7c
28 changed files with 2115 additions and 291 deletions

View File

@@ -24,6 +24,7 @@ import (
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/qos"
"drip/internal/shared/wsutil"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -171,24 +172,22 @@ func (h *Handler) SetAllowedTunnelTypes(types []string) {
// IsTransportAllowed checks if a transport is allowed
func (h *Handler) IsTransportAllowed(transport string) bool {
if len(h.allowedTransports) == 0 {
return true
}
for _, t := range h.allowedTransports {
if strings.EqualFold(t, transport) {
return true
}
}
return false
return containsFold(h.allowedTransports, transport)
}
// IsTunnelTypeAllowed checks if a tunnel type is allowed
func (h *Handler) IsTunnelTypeAllowed(tunnelType string) bool {
if len(h.allowedTunnelTypes) == 0 {
return containsFold(h.allowedTunnelTypes, tunnelType)
}
// containsFold returns true if the slice is empty (allow all) or contains the
// value in a case-insensitive comparison.
func containsFold(allowed []string, value string) bool {
if len(allowed) == 0 {
return true
}
for _, t := range h.allowedTunnelTypes {
if strings.EqualFold(t, tunnelType) {
for _, a := range allowed {
if strings.EqualFold(a, value) {
return true
}
}
@@ -299,7 +298,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
tconn.IncActiveConnections()
defer tconn.DecActiveConnections()
countingStream := netutil.NewCountingConn(stream,
var limitedStream net.Conn = stream
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
limitedStream = qos.NewLimitedConn(r.Context(), stream, limiter)
}
countingStream := netutil.NewCountingConn(limitedStream,
tconn.AddBytesOut,
tconn.AddBytesIn,
)
@@ -428,6 +432,11 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
return
}
var limitedStream net.Conn = stream
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
limitedStream = qos.NewLimitedConn(context.Background(), stream, limiter)
}
go func() {
defer stream.Close()
defer clientConn.Close()
@@ -441,7 +450,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
}
}
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
_ = netutil.PipeWithCallbacks(context.Background(), limitedStream, clientRW,
func(n int64) { tconn.AddBytesOut(n) },
func(n int64) { tconn.AddBytesIn(n) },
)