mirror of
https://github.com/Gouryella/drip.git
synced 2026-03-01 15:52:32 +00:00
feat(cli): Add bandwidth limit function support
Added bandwidth limiting functionality, allowing users to limit the bandwidth of tunnel connections via the --bandwidth parameter. Supported formats include: 1K/1KB (kilobytes), 1M/1MB (megabytes), 1G/1GB (gigabytes) or Raw number (bytes).
This commit is contained in:
@@ -24,6 +24,7 @@ import (
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/qos"
|
||||
"drip/internal/shared/wsutil"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
@@ -171,24 +172,22 @@ func (h *Handler) SetAllowedTunnelTypes(types []string) {
|
||||
|
||||
// IsTransportAllowed checks if a transport is allowed
|
||||
func (h *Handler) IsTransportAllowed(transport string) bool {
|
||||
if len(h.allowedTransports) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, t := range h.allowedTransports {
|
||||
if strings.EqualFold(t, transport) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return containsFold(h.allowedTransports, transport)
|
||||
}
|
||||
|
||||
// IsTunnelTypeAllowed checks if a tunnel type is allowed
|
||||
func (h *Handler) IsTunnelTypeAllowed(tunnelType string) bool {
|
||||
if len(h.allowedTunnelTypes) == 0 {
|
||||
return containsFold(h.allowedTunnelTypes, tunnelType)
|
||||
}
|
||||
|
||||
// containsFold returns true if the slice is empty (allow all) or contains the
|
||||
// value in a case-insensitive comparison.
|
||||
func containsFold(allowed []string, value string) bool {
|
||||
if len(allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, t := range h.allowedTunnelTypes {
|
||||
if strings.EqualFold(t, tunnelType) {
|
||||
for _, a := range allowed {
|
||||
if strings.EqualFold(a, value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -299,7 +298,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
tconn.IncActiveConnections()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
countingStream := netutil.NewCountingConn(stream,
|
||||
var limitedStream net.Conn = stream
|
||||
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
|
||||
limitedStream = qos.NewLimitedConn(r.Context(), stream, limiter)
|
||||
}
|
||||
|
||||
countingStream := netutil.NewCountingConn(limitedStream,
|
||||
tconn.AddBytesOut,
|
||||
tconn.AddBytesIn,
|
||||
)
|
||||
@@ -428,6 +432,11 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
|
||||
return
|
||||
}
|
||||
|
||||
var limitedStream net.Conn = stream
|
||||
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
|
||||
limitedStream = qos.NewLimitedConn(context.Background(), stream, limiter)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
defer clientConn.Close()
|
||||
@@ -441,7 +450,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
|
||||
}
|
||||
}
|
||||
|
||||
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
|
||||
_ = netutil.PipeWithCallbacks(context.Background(), limitedStream, clientRW,
|
||||
func(n int64) { tconn.AddBytesOut(n) },
|
||||
func(n int64) { tconn.AddBytesIn(n) },
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user