feat(cli): Add bandwidth limit function support

Added bandwidth limiting functionality, allowing users to limit the bandwidth of tunnel connections via the --bandwidth parameter.
Supported formats include: 1K/1KB (kilobytes), 1M/1MB (megabytes), 1G/1GB (gigabytes) or
Raw number (bytes).
This commit is contained in:
Gouryella
2026-02-14 14:20:21 +08:00
parent 3872bd9326
commit f90df37d7c
28 changed files with 2115 additions and 291 deletions

1
go.mod
View File

@@ -42,5 +42,6 @@ require (
go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/text v0.32.0 // indirect
golang.org/x/time v0.14.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
)

2
go.sum
View File

@@ -95,6 +95,8 @@ golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -151,6 +151,9 @@ func runConfigShow(_ *cobra.Command, _ []string) error {
if t.Transport != "" {
fmt.Printf(" transport=%s", t.Transport)
}
if t.Bandwidth != "" {
fmt.Printf(" bandwidth=%s", t.Bandwidth)
}
if len(t.AllowIPs) > 0 {
fmt.Printf(" allow=%s", strings.Join(t.AllowIPs, ","))
}

View File

@@ -20,6 +20,7 @@ var (
denyIPs []string
authPass string
transport string
bandwidth string
)
var httpCmd = &cobra.Command{
@@ -35,6 +36,7 @@ Example:
drip http 3000 --deny-ip 1.2.3.4 Block specific IP
drip http 3000 --auth secret Enable proxy authentication with password
drip http 3000 --transport wss Use WebSocket over TLS (CDN-friendly)
drip http 3000 --bandwidth 1M Limit bandwidth to 1 MB/s
Configuration:
First time: Run 'drip config init' to save server and token
@@ -43,7 +45,13 @@ Configuration:
Transport options:
auto - Automatically select based on server address (default)
tcp - Direct TLS 1.3 connection
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
wss - WebSocket over TLS (works through CDN like Cloudflare)
Bandwidth format:
1K, 1KB - 1 kilobyte per second (1024 bytes/s)
1M, 1MB - 1 megabyte per second (1048576 bytes/s)
1G, 1GB - 1 gigabyte per second
1024 - 1024 bytes per second (raw number)`,
Args: cobra.ExactArgs(1),
RunE: runHTTP,
}
@@ -56,6 +64,7 @@ func init() {
httpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
httpCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
httpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
httpCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)")
httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(httpCmd)
@@ -76,6 +85,11 @@ func runHTTP(_ *cobra.Command, args []string) error {
return err
}
bw, err := parseBandwidth(bandwidth)
if err != nil {
return err
}
connConfig := &tcp.ConnectorConfig{
ServerAddr: serverAddr,
Token: token,
@@ -88,6 +102,7 @@ func runHTTP(_ *cobra.Command, args []string) error {
DenyIPs: denyIPs,
AuthPass: authPass,
Transport: parseTransport(transport),
Bandwidth: bw,
}
var daemon *DaemonInfo
@@ -108,3 +123,36 @@ func parseTransport(s string) tcp.TransportType {
return tcp.TransportAuto
}
}
func parseBandwidth(s string) (int64, error) {
if s == "" {
return 0, nil
}
s = strings.TrimSpace(strings.ToUpper(s))
if s == "" {
return 0, nil
}
var multiplier int64 = 1
switch {
case strings.HasSuffix(s, "GB") || strings.HasSuffix(s, "G"):
multiplier = 1024 * 1024 * 1024
s = strings.TrimSuffix(strings.TrimSuffix(s, "GB"), "G")
case strings.HasSuffix(s, "MB") || strings.HasSuffix(s, "M"):
multiplier = 1024 * 1024
s = strings.TrimSuffix(strings.TrimSuffix(s, "MB"), "M")
case strings.HasSuffix(s, "KB") || strings.HasSuffix(s, "K"):
multiplier = 1024
s = strings.TrimSuffix(strings.TrimSuffix(s, "KB"), "K")
case strings.HasSuffix(s, "B"):
s = strings.TrimSuffix(s, "B")
}
val, err := strconv.ParseInt(s, 10, 64)
if err != nil || val < 0 {
return 0, fmt.Errorf("invalid bandwidth value: %q (use format like 1M, 500K, 1G)", s)
}
return val * multiplier, nil
}

View File

@@ -0,0 +1,62 @@
package cli
import (
"testing"
)
func TestParseBandwidth(t *testing.T) {
tests := []struct {
input string
want int64
wantErr bool
}{
// Valid cases
{"", 0, false},
{"0", 0, false},
{"1024", 1024, false},
{"1K", 1024, false},
{"1KB", 1024, false},
{"1k", 1024, false},
{"1M", 1024 * 1024, false},
{"1MB", 1024 * 1024, false},
{"1m", 1024 * 1024, false},
{"10M", 10 * 1024 * 1024, false},
{"1G", 1024 * 1024 * 1024, false},
{"1GB", 1024 * 1024 * 1024, false},
{"500K", 500 * 1024, false},
{"100M", 100 * 1024 * 1024, false},
{" 1M ", 1024 * 1024, false},
{"1B", 1, false},
{"100B", 100, false},
// Error cases
{"invalid", 0, true},
{"abc", 0, true},
{"-1M", 0, true},
{"-100", 0, true},
{"1.5M", 0, true},
{"M", 0, true},
{"K", 0, true},
{"1T", 0, true},
{"1KM", 0, true},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got, err := parseBandwidth(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("parseBandwidth(%q) = %d, want error", tt.input, got)
}
return
}
if err != nil {
t.Errorf("parseBandwidth(%q) unexpected error: %v", tt.input, err)
return
}
if got != tt.want {
t.Errorf("parseBandwidth(%q) = %d, want %d", tt.input, got, tt.want)
}
})
}
}

View File

@@ -23,6 +23,7 @@ Example:
drip https 443 --deny-ip 1.2.3.4 Block specific IP
drip https 443 --auth secret Enable proxy authentication with password
drip https 443 --transport wss Use WebSocket over TLS (CDN-friendly)
drip https 443 --bandwidth 1M Limit bandwidth to 1 MB/s
Configuration:
First time: Run 'drip config init' to save server and token
@@ -31,7 +32,13 @@ Configuration:
Transport options:
auto - Automatically select based on server address (default)
tcp - Direct TLS 1.3 connection
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
wss - WebSocket over TLS (works through CDN like Cloudflare)
Bandwidth format:
1K, 1KB - 1 kilobyte per second (1024 bytes/s)
1M, 1MB - 1 megabyte per second (1048576 bytes/s)
1G, 1GB - 1 gigabyte per second
1024 - 1024 bytes per second (raw number)`,
Args: cobra.ExactArgs(1),
RunE: runHTTPS,
}
@@ -44,6 +51,7 @@ func init() {
httpsCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
httpsCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
httpsCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
httpsCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)")
httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpsCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(httpsCmd)
@@ -64,6 +72,11 @@ func runHTTPS(_ *cobra.Command, args []string) error {
return err
}
bw, err := parseBandwidth(bandwidth)
if err != nil {
return err
}
connConfig := &tcp.ConnectorConfig{
ServerAddr: serverAddr,
Token: token,
@@ -76,6 +89,7 @@ func runHTTPS(_ *cobra.Command, args []string) error {
DenyIPs: denyIPs,
AuthPass: authPass,
Transport: parseTransport(transport),
Bandwidth: bw,
}
var daemon *DaemonInfo

View File

@@ -22,21 +22,21 @@ import (
)
var (
serverPort int
serverPublicPort int
serverDomain string
serverTunnelDomain string
serverAuthToken string
serverMetricsToken string
serverDebug bool
serverTCPPortMin int
serverTCPPortMax int
serverTLSCert string
serverTLSKey string
serverPprofPort int
serverTransports string
serverTunnelTypes string
serverConfigFile string
serverPort int
serverPublicPort int
serverDomain string
serverTunnelDomain string
serverAuthToken string
serverMetricsToken string
serverDebug bool
serverTCPPortMin int
serverTCPPortMax int
serverTLSCert string
serverTLSKey string
serverPprofPort int
serverTransports string
serverTunnelTypes string
serverConfigFile string
)
var serverCmd = &cobra.Command{
@@ -293,6 +293,24 @@ func runServer(cmd *cobra.Command, _ []string) error {
listener.SetAllowedTransports(cfg.AllowedTransports)
listener.SetAllowedTunnelTypes(cfg.AllowedTunnelTypes)
bandwidth, err := parseBandwidth(cfg.Bandwidth)
if err != nil {
logger.Fatal("Invalid bandwidth configuration", zap.Error(err))
}
burstMultiplier := cfg.BurstMultiplier
if burstMultiplier <= 0 {
burstMultiplier = 2.0
}
listener.SetBandwidth(bandwidth)
listener.SetBurstMultiplier(burstMultiplier)
if bandwidth > 0 {
logger.Info("Bandwidth limit configured",
zap.String("bandwidth", cfg.Bandwidth),
zap.Int64("bandwidth_bytes_sec", bandwidth),
zap.Float64("burst_multiplier", burstMultiplier),
)
}
if err := listener.Start(); err != nil {
logger.Fatal("Failed to start TCP listener", zap.Error(err))
}
@@ -326,7 +344,6 @@ func runServer(cmd *cobra.Command, _ []string) error {
return nil
}
// getEnvInt returns the environment variable value as int, or defaultVal if not set
func getEnvInt(key string, defaultVal int) int {
if val := os.Getenv(key); val != "" {
if i, err := strconv.Atoi(val); err == nil {
@@ -336,7 +353,6 @@ func getEnvInt(key string, defaultVal int) int {
return defaultVal
}
// getEnvString returns the environment variable value, or defaultVal if not set
func getEnvString(key string, defaultVal string) string {
if val := os.Getenv(key); val != "" {
return val
@@ -344,7 +360,6 @@ func getEnvString(key string, defaultVal string) string {
return defaultVal
}
// parseCommaSeparated splits a comma-separated string into a slice
func parseCommaSeparated(s string) []string {
if s == "" {
return nil

View File

@@ -125,7 +125,10 @@ func formatTunnelInfo(t *config.TunnelConfig) string {
}
func startSingleTunnel(cfg *config.ClientConfig, t *config.TunnelConfig) error {
connConfig := buildConnectorConfig(cfg, t)
connConfig, err := buildConnectorConfig(cfg, t)
if err != nil {
return err
}
fmt.Printf("Starting tunnel '%s' (%s %s:%d)\n", t.Name, t.Type, getAddress(t), t.Port)
@@ -162,7 +165,11 @@ func startMultipleTunnels(cfg *config.ClientConfig, tunnels []*config.TunnelConf
go func(tunnel *config.TunnelConfig) {
defer wg.Done()
connConfig := buildConnectorConfig(cfg, tunnel)
connConfig, err := buildConnectorConfig(cfg, tunnel)
if err != nil {
errChan <- fmt.Errorf("tunnel '%s': %w", tunnel.Name, err)
return
}
fmt.Printf(" Starting %s (%s %s:%d)...\n", tunnel.Name, tunnel.Type, getAddress(tunnel), tunnel.Port)
client := tcp.NewTunnelClient(connConfig, logger)
@@ -210,7 +217,7 @@ func startMultipleTunnels(cfg *config.ClientConfig, tunnels []*config.TunnelConf
return nil
}
func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp.ConnectorConfig {
func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) (*tcp.ConnectorConfig, error) {
tunnelType := protocol.TunnelTypeHTTP
switch t.Type {
case "https":
@@ -219,12 +226,11 @@ func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp
tunnelType = protocol.TunnelTypeTCP
}
transport := tcp.TransportAuto
switch strings.ToLower(t.Transport) {
case "tcp", "tls":
transport = tcp.TransportTCP
case "wss":
transport = tcp.TransportWebSocket
transport := parseTransport(t.Transport)
bw, err := parseBandwidth(t.Bandwidth)
if err != nil {
return nil, err
}
return &tcp.ConnectorConfig{
@@ -239,7 +245,8 @@ func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp
DenyIPs: t.DenyIPs,
AuthPass: t.Auth,
Transport: transport,
}
Bandwidth: bw,
}, nil
}
func getAddress(t *config.TunnelConfig) string {

View File

@@ -24,6 +24,7 @@ Example:
drip tcp 22 --allow-ip 10.0.0.1 Allow single IP
drip tcp 22 --deny-ip 1.2.3.4 Block specific IP
drip tcp 22 --transport wss Use WebSocket over TLS (CDN-friendly)
drip tcp 22 --bandwidth 1M Limit bandwidth to 1 MB/s
Supported Services:
- Databases: PostgreSQL (5432), MySQL (3306), Redis (6379), MongoDB (27017)
@@ -39,6 +40,12 @@ Transport options:
tcp - Direct TLS 1.3 connection
wss - WebSocket over TLS (works through CDN like Cloudflare)
Bandwidth format:
1K, 1KB - 1 kilobyte per second (1024 bytes/s)
1M, 1MB - 1 megabyte per second (1048576 bytes/s)
1G, 1GB - 1 gigabyte per second
1024 - 1024 bytes per second (raw number)
Note: TCP tunnels require dynamic port allocation on the server.
When using CDN (--transport wss), the server must still expose the allocated port directly.`,
Args: cobra.ExactArgs(1),
@@ -52,6 +59,7 @@ func init() {
tcpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
tcpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
tcpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
tcpCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)")
tcpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
tcpCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(tcpCmd)
@@ -72,6 +80,11 @@ func runTCP(_ *cobra.Command, args []string) error {
return err
}
bw, err := parseBandwidth(bandwidth)
if err != nil {
return err
}
connConfig := &tcp.ConnectorConfig{
ServerAddr: serverAddr,
Token: token,
@@ -83,6 +96,7 @@ func runTCP(_ *cobra.Command, args []string) error {
AllowIPs: allowIPs,
DenyIPs: denyIPs,
Transport: parseTransport(transport),
Bandwidth: bw,
}
var daemon *DaemonInfo

View File

@@ -45,6 +45,9 @@ type ConnectorConfig struct {
// Transport protocol selection
Transport TransportType
// Bandwidth limiting (bytes/sec), 0 = unlimited
Bandwidth int64
}
type TunnelClient interface {

View File

@@ -13,8 +13,8 @@ import (
"sync/atomic"
"time"
"github.com/gorilla/websocket"
json "github.com/goccy/go-json"
"github.com/gorilla/websocket"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
@@ -26,7 +26,6 @@ import (
"drip/pkg/config"
)
// PoolClient manages a pool of yamux sessions for tunnel connections.
type PoolClient struct {
serverAddr string
tlsConfig *tls.Config
@@ -76,11 +75,12 @@ type PoolClient struct {
// Transport protocol selection
transport TransportType
insecure bool
// Bandwidth limit requested from server (bytes/sec), 0 = unlimited
bandwidth int64
}
// NewPoolClient creates a new pool client.
func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
// Parse server address to get host for TLS config
serverAddr := cfg.ServerAddr
host := serverAddr
@@ -96,7 +96,6 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
}
}
// Extract hostname without port for TLS
hostOnly, _, _ := net.SplitHostPort(host)
if hostOnly == "" {
hostOnly = host
@@ -140,7 +139,6 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
}
initialSessions = min(max(initialSessions, minSessions), maxSessions)
// Determine transport type
transport := cfg.Transport
if transport == "" {
transport = TransportAuto
@@ -171,6 +169,7 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
authPass: cfg.AuthPass,
transport: transport,
insecure: cfg.Insecure,
bandwidth: cfg.Bandwidth,
}
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
@@ -181,7 +180,6 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
return c
}
// Connect establishes the primary connection and starts background workers.
func (c *PoolClient) Connect() error {
primaryConn, err := c.dial()
if err != nil {
@@ -215,6 +213,10 @@ func (c *PoolClient) Connect() error {
}
}
if c.bandwidth > 0 {
req.Bandwidth = c.bandwidth
}
payload, err := json.Marshal(req)
if err != nil {
_ = primaryConn.Close()
@@ -261,6 +263,10 @@ func (c *PoolClient) Connect() error {
c.tunnelID = resp.TunnelID
}
if resp.Bandwidth > 0 {
c.bandwidth = resp.Bandwidth
}
yamuxCfg := mux.NewClientConfig()
session, err := yamux.Server(primaryConn, yamuxCfg)
@@ -335,13 +341,11 @@ func (c *PoolClient) dialTLS() (net.Conn, error) {
return conn, nil
}
// serverCapabilities holds the discovered server capabilities
type serverCapabilities struct {
Transports []string `json:"transports"`
Preferred string `json:"preferred"`
}
// dial selects the appropriate transport and establishes a connection
func (c *PoolClient) dial() (net.Conn, error) {
switch c.transport {
case TransportWebSocket:
@@ -377,7 +381,6 @@ func (c *PoolClient) dial() (net.Conn, error) {
}
}
// discoverServerCapabilities queries the server for its capabilities
func (c *PoolClient) discoverServerCapabilities() *serverCapabilities {
host, port, err := net.SplitHostPort(c.serverAddr)
if err != nil {
@@ -420,9 +423,7 @@ func (c *PoolClient) discoverServerCapabilities() *serverCapabilities {
return &caps
}
// dialWebSocket establishes a WebSocket connection to the server over TLS
func (c *PoolClient) dialWebSocket() (net.Conn, error) {
// Build WebSocket URL
host, port, err := net.SplitHostPort(c.serverAddr)
if err != nil {
// No port specified, use default
@@ -574,7 +575,6 @@ func (c *PoolClient) pingLoop(h *sessionHandle) {
}
}
// Close shuts down the client and all sessions.
func (c *PoolClient) Close() error {
var closeErr error
@@ -623,12 +623,12 @@ func (c *PoolClient) Close() error {
return closeErr
}
func (c *PoolClient) Wait() { <-c.doneCh }
func (c *PoolClient) GetURL() string { return c.assignedURL }
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
func (c *PoolClient) Wait() { <-c.doneCh }
func (c *PoolClient) GetURL() string { return c.assignedURL }
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
func (c *PoolClient) SetLatencyCallback(cb LatencyCallback) {
if cb == nil {

View File

@@ -24,6 +24,7 @@ import (
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/qos"
"drip/internal/shared/wsutil"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -171,24 +172,22 @@ func (h *Handler) SetAllowedTunnelTypes(types []string) {
// IsTransportAllowed checks if a transport is allowed
func (h *Handler) IsTransportAllowed(transport string) bool {
if len(h.allowedTransports) == 0 {
return true
}
for _, t := range h.allowedTransports {
if strings.EqualFold(t, transport) {
return true
}
}
return false
return containsFold(h.allowedTransports, transport)
}
// IsTunnelTypeAllowed checks if a tunnel type is allowed
func (h *Handler) IsTunnelTypeAllowed(tunnelType string) bool {
if len(h.allowedTunnelTypes) == 0 {
return containsFold(h.allowedTunnelTypes, tunnelType)
}
// containsFold returns true if the slice is empty (allow all) or contains the
// value in a case-insensitive comparison.
func containsFold(allowed []string, value string) bool {
if len(allowed) == 0 {
return true
}
for _, t := range h.allowedTunnelTypes {
if strings.EqualFold(t, tunnelType) {
for _, a := range allowed {
if strings.EqualFold(a, value) {
return true
}
}
@@ -299,7 +298,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
tconn.IncActiveConnections()
defer tconn.DecActiveConnections()
countingStream := netutil.NewCountingConn(stream,
var limitedStream net.Conn = stream
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
limitedStream = qos.NewLimitedConn(r.Context(), stream, limiter)
}
countingStream := netutil.NewCountingConn(limitedStream,
tconn.AddBytesOut,
tconn.AddBytesIn,
)
@@ -428,6 +432,11 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
return
}
var limitedStream net.Conn = stream
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
limitedStream = qos.NewLimitedConn(context.Background(), stream, limiter)
}
go func() {
defer stream.Close()
defer clientConn.Close()
@@ -441,7 +450,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
}
}
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
_ = netutil.PipeWithCallbacks(context.Background(), limitedStream, clientRW,
func(n int64) { tconn.AddBytesOut(n) },
func(n int64) { tconn.AddBytesIn(n) },
)

View File

@@ -0,0 +1,168 @@
package tcp
import (
"testing"
)
func TestEffectiveBandwidthSelection(t *testing.T) {
tests := []struct {
name string
serverBW int64
clientBW int64
wantEffective int64
}{
{
name: "server only",
serverBW: 1024 * 1024,
clientBW: 0,
wantEffective: 1024 * 1024,
},
{
name: "client only",
serverBW: 0,
clientBW: 512 * 1024,
wantEffective: 512 * 1024,
},
{
name: "both unlimited",
serverBW: 0,
clientBW: 0,
wantEffective: 0,
},
{
name: "client lower than server",
serverBW: 10 * 1024 * 1024,
clientBW: 1 * 1024 * 1024,
wantEffective: 1 * 1024 * 1024,
},
{
name: "client higher than server - server wins",
serverBW: 1 * 1024 * 1024,
clientBW: 10 * 1024 * 1024,
wantEffective: 1 * 1024 * 1024,
},
{
name: "client equal to server",
serverBW: 5 * 1024 * 1024,
clientBW: 5 * 1024 * 1024,
wantEffective: 5 * 1024 * 1024,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
effectiveBandwidth := tt.serverBW
if tt.clientBW > 0 {
if effectiveBandwidth == 0 || tt.clientBW < effectiveBandwidth {
effectiveBandwidth = tt.clientBW
}
}
if effectiveBandwidth != tt.wantEffective {
t.Errorf("effectiveBandwidth = %d, want %d", effectiveBandwidth, tt.wantEffective)
}
})
}
}
func TestConnectionSetBandwidthConfig(t *testing.T) {
tests := []struct {
name string
bandwidth int64
burstMultiplier float64
wantBandwidth int64
wantMultiplier float64
}{
{
name: "1MB/s with 2x burst",
bandwidth: 1024 * 1024,
burstMultiplier: 2.0,
wantBandwidth: 1024 * 1024,
wantMultiplier: 2.0,
},
{
name: "1MB/s with 2.5x burst",
bandwidth: 1024 * 1024,
burstMultiplier: 2.5,
wantBandwidth: 1024 * 1024,
wantMultiplier: 2.5,
},
{
name: "default multiplier when 0",
bandwidth: 1024 * 1024,
burstMultiplier: 0,
wantBandwidth: 1024 * 1024,
wantMultiplier: 2.0,
},
{
name: "default multiplier when negative",
bandwidth: 1024 * 1024,
burstMultiplier: -1.0,
wantBandwidth: 1024 * 1024,
wantMultiplier: 2.0,
},
{
name: "unlimited bandwidth",
bandwidth: 0,
burstMultiplier: 2.5,
wantBandwidth: 0,
wantMultiplier: 2.5,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conn := &Connection{}
conn.SetBandwidthConfig(tt.bandwidth, tt.burstMultiplier)
if conn.bandwidth != tt.wantBandwidth {
t.Errorf("bandwidth = %v, want %v", conn.bandwidth, tt.wantBandwidth)
}
if conn.burstMultiplier != tt.wantMultiplier {
t.Errorf("burstMultiplier = %v, want %v", conn.burstMultiplier, tt.wantMultiplier)
}
})
}
}
func TestListenerBandwidthConfig(t *testing.T) {
tests := []struct {
name string
bandwidth int64
burstMultiplier float64
wantBandwidth int64
wantMultiplier float64
}{
{
name: "set bandwidth and multiplier",
bandwidth: 1024 * 1024,
burstMultiplier: 2.5,
wantBandwidth: 1024 * 1024,
wantMultiplier: 2.5,
},
{
name: "default multiplier",
bandwidth: 1024 * 1024,
burstMultiplier: 0,
wantBandwidth: 1024 * 1024,
wantMultiplier: 2.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := &Listener{}
l.SetBandwidth(tt.bandwidth)
l.SetBurstMultiplier(tt.burstMultiplier)
if l.bandwidth != tt.wantBandwidth {
t.Errorf("bandwidth = %v, want %v", l.bandwidth, tt.wantBandwidth)
}
if l.burstMultiplier != tt.wantMultiplier {
t.Errorf("burstMultiplier = %v, want %v", l.burstMultiplier, tt.wantMultiplier)
}
})
}
}

View File

@@ -59,12 +59,12 @@ type Connection struct {
httpListener *connQueueListener
handedOff bool
// Server capabilities
allowedTunnelTypes []string
allowedTransports []string
bandwidth int64
burstMultiplier float64
}
// NewConnection creates a new connection handler
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection {
ctx, cancel := context.WithCancel(context.Background())
c := &Connection{
@@ -99,22 +99,11 @@ func (c *Connection) Handle() error {
return fmt.Errorf("failed to peek connection: %w", err)
}
peekStr := string(peek)
httpMethods := []string{"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC"}
isHTTP := false
for _, method := range httpMethods {
if strings.HasPrefix(peekStr, method) {
isHTTP = true
break
}
}
if isHTTP {
if isHTTPMethod(string(peek)) {
c.logger.Info("Detected HTTP request on TCP port, handling as HTTP")
return c.handleHTTPRequest(reader)
}
// Check if TCP transport is allowed (only for Drip protocol connections, not HTTP)
if !c.isTransportAllowed("tcp") {
c.logger.Warn("TCP transport not allowed, rejecting Drip protocol connection")
return fmt.Errorf("TCP transport not allowed")
@@ -142,7 +131,6 @@ func (c *Connection) Handle() error {
c.tunnelType = req.TunnelType
// Check if tunnel type is allowed
if !c.isTunnelTypeAllowed(string(req.TunnelType)) {
c.sendError("tunnel_type_not_allowed", fmt.Sprintf("Tunnel type '%s' is not allowed on this server", req.TunnelType))
return fmt.Errorf("tunnel type not allowed: %s", req.TunnelType)
@@ -197,7 +185,6 @@ func (c *Connection) Handle() error {
c.tunnelConn.Conn = nil
c.tunnelConn.SetTunnelType(req.TunnelType)
c.tunnelType = req.TunnelType
if req.IPAccess != nil && (len(req.IPAccess.AllowIPs) > 0 || len(req.IPAccess.DenyIPs) > 0) {
c.tunnelConn.SetIPAccessControl(req.IPAccess.AllowIPs, req.IPAccess.DenyIPs)
@@ -215,6 +202,31 @@ func (c *Connection) Handle() error {
)
}
effectiveBandwidth := c.bandwidth
if req.Bandwidth > 0 {
if effectiveBandwidth == 0 || req.Bandwidth < effectiveBandwidth {
effectiveBandwidth = req.Bandwidth
}
}
if effectiveBandwidth > 0 {
burstMultiplier := c.burstMultiplier
if burstMultiplier <= 0 {
burstMultiplier = 2.0
}
c.tunnelConn.SetBandwidthWithBurst(effectiveBandwidth, burstMultiplier)
source := "server"
if req.Bandwidth > 0 && (c.bandwidth == 0 || req.Bandwidth < c.bandwidth) {
source = "client"
}
c.logger.Info("Bandwidth limit configured",
zap.String("subdomain", subdomain),
zap.Int64("bandwidth_bytes_sec", effectiveBandwidth),
zap.Float64("burst_multiplier", burstMultiplier),
zap.String("source", source),
)
}
c.logger.Info("Tunnel registered",
zap.String("subdomain", subdomain),
zap.String("tunnel_type", string(req.TunnelType)),
@@ -258,6 +270,7 @@ func (c *Connection) Handle() error {
TunnelID: tunnelID,
SupportsDataConn: supportsDataConn,
RecommendedConns: recommendedConns,
Bandwidth: c.tunnelConn.GetBandwidth(),
}
respData, err := json.Marshal(resp)
@@ -389,7 +402,6 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
zap.String("host", req.Host),
)
// Get writer from pool to reduce GC pressure
pooledWriter := bufioWriterPool.Get().(*bufio.Writer)
pooledWriter.Reset(c.conn)
@@ -405,30 +417,17 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
c.logger.Debug("Failed to flush HTTP response", zap.Error(err))
}
// Return writer to pool
pooledWriter.Reset(nil) // Clear reference to connection
pooledWriter.Reset(nil)
bufioWriterPool.Put(pooledWriter)
// Keep TCP_NODELAY enabled for low latency HTTP responses
// (removed the toggle that was disabling it)
c.logger.Debug("HTTP request processing completed",
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
)
shouldClose := false
if req.Close {
shouldClose = true
} else if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
if req.Header.Get("Connection") != "keep-alive" {
shouldClose = true
}
}
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
shouldClose = true
}
shouldClose := req.Close ||
(req.ProtoMajor == 1 && req.ProtoMinor == 0 && req.Header.Get("Connection") != "keep-alive") ||
(respWriter.headerWritten && respWriter.header.Get("Connection") == "close")
if shouldClose {
c.logger.Debug("Closing connection as requested by client or server")
@@ -636,7 +635,7 @@ func (w *httpResponseWriter) WriteHeader(statusCode int) {
}
w.writer.WriteString("HTTP/1.1 ")
w.writer.WriteString(fmt.Sprintf("%d", statusCode))
w.writer.WriteString(strconv.Itoa(statusCode))
w.writer.WriteByte(' ')
w.writer.WriteString(statusText)
w.writer.WriteString("\r\n")
@@ -755,6 +754,14 @@ func isTimeoutError(err error) bool {
return strings.Contains(err.Error(), "i/o timeout")
}
func isHTTPMethod(peek string) bool {
switch peek {
case "GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC":
return true
}
return false
}
func (c *Connection) sendDataConnectError(code, message string) {
resp := protocol.DataConnectResponse{
Accepted: false,
@@ -769,38 +776,40 @@ func (c *Connection) sendDataConnectError(code, message string) {
_ = protocol.WriteFrame(c.conn, frame)
}
// SetAllowedTunnelTypes sets the allowed tunnel types for this connection
func (c *Connection) SetAllowedTunnelTypes(types []string) {
c.allowedTunnelTypes = types
}
// SetAllowedTransports sets the allowed transports for this connection
func (c *Connection) SetAllowedTransports(transports []string) {
c.allowedTransports = transports
}
// isTransportAllowed checks if a transport is allowed
func (c *Connection) isTransportAllowed(transport string) bool {
if len(c.allowedTransports) == 0 {
return containsFold(c.allowedTransports, transport)
}
func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool {
return containsFold(c.allowedTunnelTypes, tunnelType)
}
// containsFold returns true if the slice is empty (allow all) or contains the
// value case-insensitively.
func containsFold(allowed []string, value string) bool {
if len(allowed) == 0 {
return true
}
for _, t := range c.allowedTransports {
if strings.EqualFold(t, transport) {
for _, a := range allowed {
if strings.EqualFold(a, value) {
return true
}
}
return false
}
// isTunnelTypeAllowed checks if a tunnel type is allowed
func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool {
if len(c.allowedTunnelTypes) == 0 {
return true // Allow all by default
func (c *Connection) SetBandwidthConfig(bandwidth int64, burstMultiplier float64) {
c.bandwidth = bandwidth
if burstMultiplier <= 0 {
burstMultiplier = 2.0
}
for _, t := range c.allowedTunnelTypes {
if strings.EqualFold(t, tunnelType) {
return true
}
}
return false
c.burstMultiplier = burstMultiplier
}

View File

@@ -43,9 +43,10 @@ type Listener struct {
httpServer *http.Server
httpListener *connQueueListener
// Server capabilities
allowedTransports []string
allowedTunnelTypes []string
bandwidth int64
burstMultiplier float64
}
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler) *Listener {
@@ -63,7 +64,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
panicMetrics := recovery.NewPanicMetrics(logger, nil)
recoverer := recovery.NewRecoverer(logger, panicMetrics)
// Initialize worker pool metrics
metrics.WorkerPoolSize.Set(float64(workers))
l := &Listener{
@@ -85,7 +85,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
groupManager: NewConnectionGroupManager(logger),
}
// Set up WebSocket connection handler if httpHandler supports it
if h, ok := httpHandler.(*proxy.Handler); ok {
h.SetWSConnectionHandler(l)
h.SetPublicPort(publicPort)
@@ -97,7 +96,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
func (l *Listener) Start() error {
var err error
// Support both TLS and plain TCP modes
if l.tlsConfig != nil {
l.listener, err = tls.Listen("tcp", l.address, l.tlsConfig)
if err != nil {
@@ -269,57 +267,13 @@ func (l *Listener) handleConnection(netConn net.Conn) {
)
}
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
conn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
conn.SetAllowedTransports(l.allowedTransports)
conn := l.newConfiguredConnection(netConn)
connID := netConn.RemoteAddr().String()
l.connMu.Lock()
l.connections[connID] = conn
l.connMu.Unlock()
l.trackConnection(connID, conn)
defer l.untrackConnection(connID, conn, netConn)
// Update connection metrics
metrics.TotalConnections.Inc()
metrics.ActiveConnections.Inc()
defer func() {
l.connMu.Lock()
delete(l.connections, connID)
l.connMu.Unlock()
metrics.ActiveConnections.Dec()
if !conn.IsHandedOff() {
netConn.Close()
}
}()
if err := conn.Handle(); err != nil {
errStr := err.Error()
if strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") {
return
}
if strings.Contains(errStr, "payload too large") ||
strings.Contains(errStr, "failed to read registration frame") ||
strings.Contains(errStr, "expected register frame") ||
strings.Contains(errStr, "failed to parse registration request") ||
strings.Contains(errStr, "failed to parse HTTP request") {
l.logger.Warn("Protocol validation failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
} else {
l.logger.Error("Connection handling failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
}
}
l.logConnectionError(conn.Handle(), connID, "Connection")
}
func (l *Listener) Stop() error {
@@ -372,7 +326,6 @@ func (l *Listener) GetActiveConnections() int {
return len(l.connections)
}
// HandleWSConnection implements proxy.WSConnectionHandler for WebSocket tunnel connections
func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
l.wg.Add(1)
defer l.wg.Done()
@@ -386,77 +339,103 @@ func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
zap.String("remote_addr", connID),
)
// Create connection handler (no TLS verification needed - already done by HTTP server)
tcpConn := NewConnection(conn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
tcpConn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
tcpConn := l.newConfiguredConnection(conn)
l.connMu.Lock()
l.connections[connID] = tcpConn
l.connMu.Unlock()
l.trackConnection(connID, tcpConn)
defer l.untrackConnection(connID, tcpConn, conn)
metrics.TotalConnections.Inc()
metrics.ActiveConnections.Inc()
defer func() {
l.connMu.Lock()
delete(l.connections, connID)
l.connMu.Unlock()
metrics.ActiveConnections.Dec()
if !tcpConn.IsHandedOff() {
conn.Close()
}
}()
if err := tcpConn.Handle(); err != nil {
errStr := err.Error()
if strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "websocket: close") {
return
}
if strings.Contains(errStr, "payload too large") ||
strings.Contains(errStr, "failed to read registration frame") ||
strings.Contains(errStr, "expected register frame") ||
strings.Contains(errStr, "failed to parse registration request") ||
strings.Contains(errStr, "tunnel type not allowed") {
l.logger.Warn("WebSocket tunnel protocol validation failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
} else {
l.logger.Error("WebSocket tunnel connection handling failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
}
}
l.logConnectionError(tcpConn.Handle(), connID, "WebSocket tunnel")
}
// SetAllowedTransports sets the allowed transport protocols
func (l *Listener) SetAllowedTransports(transports []string) {
l.allowedTransports = transports
}
// SetAllowedTunnelTypes sets the allowed tunnel types
func (l *Listener) SetAllowedTunnelTypes(types []string) {
l.allowedTunnelTypes = types
}
// IsTransportAllowed checks if a transport is allowed
func (l *Listener) IsTransportAllowed(transport string) bool {
if len(l.allowedTransports) == 0 {
return true
return containsFold(l.allowedTransports, transport)
}
func (l *Listener) SetBurstMultiplier(multiplier float64) {
if multiplier <= 0 {
multiplier = 2.0
}
for _, t := range l.allowedTransports {
if strings.EqualFold(t, transport) {
return true
l.burstMultiplier = multiplier
}
func (l *Listener) SetBandwidth(bandwidth int64) {
l.bandwidth = bandwidth
}
func (l *Listener) newConfiguredConnection(conn net.Conn) *Connection {
c := NewConnection(conn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
c.SetAllowedTunnelTypes(l.allowedTunnelTypes)
c.SetAllowedTransports(l.allowedTransports)
c.SetBandwidthConfig(l.bandwidth, l.burstMultiplier)
return c
}
func (l *Listener) trackConnection(connID string, conn *Connection) {
l.connMu.Lock()
l.connections[connID] = conn
l.connMu.Unlock()
metrics.TotalConnections.Inc()
metrics.ActiveConnections.Inc()
}
func (l *Listener) untrackConnection(connID string, conn *Connection, netConn net.Conn) {
l.connMu.Lock()
delete(l.connections, connID)
l.connMu.Unlock()
metrics.ActiveConnections.Dec()
if !conn.IsHandedOff() {
netConn.Close()
}
}
// logConnectionError classifies and logs a connection handling error.
// Transient network errors are silently ignored, protocol errors are warned,
// and everything else is logged as an error.
func (l *Listener) logConnectionError(err error, connID, label string) {
if err == nil {
return
}
errStr := err.Error()
// Transient / expected disconnects — ignore silently
for _, substr := range []string{
"EOF", "connection reset by peer", "broken pipe",
"connection refused", "websocket: close",
} {
if strings.Contains(errStr, substr) {
return
}
}
return false
// Protocol-level validation failures — warn
for _, substr := range []string{
"payload too large", "failed to read registration frame",
"expected register frame", "failed to parse registration request",
"failed to parse HTTP request", "tunnel type not allowed",
} {
if strings.Contains(errStr, substr) {
l.logger.Warn(label+" protocol validation failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
return
}
}
l.logger.Error(label+" handling failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
}

View File

@@ -10,6 +10,7 @@ import (
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/qos"
"go.uber.org/zap"
)
@@ -34,6 +35,7 @@ type Proxy struct {
cancel context.CancelFunc
checkIPAccess func(ip string) bool
limiter *qos.Limiter
}
type trafficStats interface {
@@ -49,12 +51,6 @@ func NewProxy(ctx context.Context, port int, subdomain string, openStream func()
}
cctx, cancel := context.WithCancel(ctx)
const maxConcurrentConnections = 10000
var sem chan struct{}
if maxConcurrentConnections > 0 {
sem = make(chan struct{}, maxConcurrentConnections)
}
return &Proxy{
port: port,
subdomain: subdomain,
@@ -62,17 +58,20 @@ func NewProxy(ctx context.Context, port int, subdomain string, openStream func()
stopCh: make(chan struct{}),
openStream: openStream,
stats: stats,
sem: sem,
sem: make(chan struct{}, 10000),
ctx: cctx,
cancel: cancel,
}
}
// SetIPAccessCheck sets the IP access control check function.
func (p *Proxy) SetIPAccessCheck(check func(ip string) bool) {
p.checkIPAccess = check
}
func (p *Proxy) SetLimiter(limiter *qos.Limiter) {
p.limiter = limiter
}
func (p *Proxy) Start() error {
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
@@ -174,13 +173,11 @@ func (p *Proxy) handleConn(conn net.Conn) {
}
}
if p.sem != nil {
select {
case p.sem <- struct{}{}:
defer func() { <-p.sem }()
default:
return
}
select {
case p.sem <- struct{}{}:
defer func() { <-p.sem }()
default:
return
}
if p.stats != nil {
@@ -243,7 +240,7 @@ func (p *Proxy) handleConn(conn net.Conn) {
_ = netutil.PipeWithCallbacksAndBufferSize(
p.ctx,
conn,
stream,
qos.NewLimitedConn(p.ctx, stream, p.limiter),
pool.SizeLarge,
func(n int64) {
if p.stats != nil {

View File

@@ -19,19 +19,17 @@ func (c *bufferedConn) Read(p []byte) (int, error) {
return c.reader.Read(p)
}
func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
// Public server acts as yamux Client, client connector acts as yamux Server.
// initMuxSession creates a yamux session over the buffered connection and
// returns the openStream function (possibly group-aware).
func (c *Connection) initMuxSession(reader *bufio.Reader) (func() (net.Conn, error), *yamux.Session, error) {
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
// Use optimized mux config for server
cfg := mux.NewServerConfig()
session, err := yamux.Client(bc, cfg)
session, err := yamux.Client(bc, mux.NewServerConfig())
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
return nil, nil, fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
@@ -43,10 +41,22 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
}
}
return openStream, session, nil
}
func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
openStream, session, err := c.initMuxSession(reader)
if err != nil {
return err
}
c.proxy = NewProxy(c.ctx, c.port, c.subdomain, openStream, c.tunnelConn, c.logger)
if c.tunnelConn != nil && c.tunnelConn.HasIPAccessControl() {
c.proxy.SetIPAccessCheck(c.tunnelConn.IsIPAllowed)
}
if c.tunnelConn != nil {
c.proxy.SetLimiter(c.tunnelConn.GetLimiter())
}
if err := c.proxy.Start(); err != nil {
return fmt.Errorf("failed to start tcp proxy: %w", err)
@@ -61,27 +71,9 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
}
func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
// Public server acts as yamux Client, client connector acts as yamux Server.
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
// Use optimized mux config for server
cfg := mux.NewServerConfig()
session, err := yamux.Client(bc, cfg)
openStream, session, err := c.initMuxSession(reader)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
openStream := session.Open
if c.groupManager != nil {
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
group.AddSession("primary", session)
openStream = group.OpenStream
}
return err
}
if c.tunnelConn != nil {

View File

@@ -9,6 +9,7 @@ import (
"drip/internal/server/metrics"
"drip/internal/shared/netutil"
"drip/internal/shared/protocol"
"drip/internal/shared/qos"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
@@ -32,6 +33,9 @@ type Connection struct {
ipAccessChecker *netutil.IPAccessChecker
proxyAuth *protocol.ProxyAuth
bandwidth int64
limiter *qos.Limiter
}
func NewConnection(subdomain string, conn *websocket.Conn, logger *zap.Logger) *Connection {
@@ -214,6 +218,34 @@ func (c *Connection) ValidateProxyAuth(password string) bool {
return auth.Password == password
}
func (c *Connection) SetBandwidth(bandwidth int64) {
c.SetBandwidthWithBurst(bandwidth, 2.0)
}
func (c *Connection) SetBandwidthWithBurst(bandwidth int64, burstMultiplier float64) {
c.mu.Lock()
defer c.mu.Unlock()
c.bandwidth = bandwidth
if bandwidth > 0 {
burst := int(float64(bandwidth) * burstMultiplier)
c.limiter = qos.NewLimiter(qos.Config{Bandwidth: bandwidth, Burst: burst})
} else {
c.limiter = nil
}
}
func (c *Connection) GetBandwidth() int64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.bandwidth
}
func (c *Connection) GetLimiter() *qos.Limiter {
c.mu.RLock()
defer c.mu.RUnlock()
return c.limiter
}
func (c *Connection) StartWritePump() {
if c.Conn == nil {
go func() {

View File

@@ -0,0 +1,117 @@
package tunnel
import (
"testing"
"go.uber.org/zap"
)
func TestConnectionBandwidthWithBurst(t *testing.T) {
logger := zap.NewNop()
tests := []struct {
name string
bandwidth int64
burstMultiplier float64
wantBandwidth int64
wantBurst int
}{
{
name: "1MB/s with 2x burst",
bandwidth: 1024 * 1024,
burstMultiplier: 2.0,
wantBandwidth: 1024 * 1024,
wantBurst: 2 * 1024 * 1024,
},
{
name: "1MB/s with 2.5x burst",
bandwidth: 1024 * 1024,
burstMultiplier: 2.5,
wantBandwidth: 1024 * 1024,
wantBurst: int(float64(1024*1024) * 2.5),
},
{
name: "500KB/s with 3x burst",
bandwidth: 500 * 1024,
burstMultiplier: 3.0,
wantBandwidth: 500 * 1024,
wantBurst: 3 * 500 * 1024,
},
{
name: "10MB/s with 1.5x burst",
bandwidth: 10 * 1024 * 1024,
burstMultiplier: 1.5,
wantBandwidth: 10 * 1024 * 1024,
wantBurst: int(float64(10*1024*1024) * 1.5),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conn := NewConnection("test-subdomain", nil, logger)
conn.SetBandwidthWithBurst(tt.bandwidth, tt.burstMultiplier)
if conn.GetBandwidth() != tt.wantBandwidth {
t.Errorf("GetBandwidth() = %v, want %v", conn.GetBandwidth(), tt.wantBandwidth)
}
limiter := conn.GetLimiter()
if limiter == nil {
t.Fatal("GetLimiter() should not be nil")
}
if !limiter.IsLimited() {
t.Error("Limiter should be limited")
}
if limiter.RateLimiter().Burst() != tt.wantBurst {
t.Errorf("Burst() = %v, want %v", limiter.RateLimiter().Burst(), tt.wantBurst)
}
})
}
}
func TestConnectionBandwidthUnlimited(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test-subdomain", nil, logger)
if conn.GetBandwidth() != 0 {
t.Errorf("Default bandwidth should be 0, got %v", conn.GetBandwidth())
}
if conn.GetLimiter() != nil {
t.Error("Default limiter should be nil")
}
conn.SetBandwidth(0)
if conn.GetLimiter() != nil {
t.Error("Limiter should be nil when bandwidth is 0")
}
conn.SetBandwidthWithBurst(0, 2.0)
if conn.GetLimiter() != nil {
t.Error("Limiter should be nil when bandwidth is 0")
}
}
func TestConnectionSetBandwidth(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test-subdomain", nil, logger)
conn.SetBandwidth(1024 * 1024)
if conn.GetBandwidth() != 1024*1024 {
t.Errorf("GetBandwidth() = %v, want %v", conn.GetBandwidth(), 1024*1024)
}
limiter := conn.GetLimiter()
if limiter == nil {
t.Fatal("GetLimiter() should not be nil")
}
expectedBurst := 2 * 1024 * 1024
if limiter.RateLimiter().Burst() != expectedBurst {
t.Errorf("Burst() = %v, want %v", limiter.RateLimiter().Burst(), expectedBurst)
}
}

View File

@@ -27,6 +27,7 @@ type RegisterRequest struct {
PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"`
IPAccess *IPAccessControl `json:"ip_access,omitempty"`
ProxyAuth *ProxyAuth `json:"proxy_auth,omitempty"`
Bandwidth int64 `json:"bandwidth,omitempty"` // Bandwidth limit (bytes/sec), 0 = unlimited
}
type RegisterResponse struct {
@@ -37,6 +38,7 @@ type RegisterResponse struct {
TunnelID string `json:"tunnel_id,omitempty"`
SupportsDataConn bool `json:"supports_data_conn,omitempty"`
RecommendedConns int `json:"recommended_conns,omitempty"`
Bandwidth int64 `json:"bandwidth,omitempty"` // Applied bandwidth limit (bytes/sec)
}
type DataConnectRequest struct {

112
internal/shared/qos/conn.go Normal file
View File

@@ -0,0 +1,112 @@
package qos
import (
"context"
"io"
"net"
)
type LimitedConn struct {
net.Conn
limiter *Limiter
ctx context.Context
}
func NewLimitedConn(ctx context.Context, conn net.Conn, limiter *Limiter) *LimitedConn {
return &LimitedConn{
Conn: conn,
limiter: limiter,
ctx: ctx,
}
}
func (c *LimitedConn) Read(b []byte) (n int, err error) {
if c.limiter == nil || !c.limiter.IsLimited() {
return c.Conn.Read(b)
}
burst := c.limiter.RateLimiter().Burst()
if len(b) > burst {
b = b[:burst]
}
n, err = c.Conn.Read(b)
if n > 0 {
if waitErr := c.limiter.RateLimiter().WaitN(c.ctx, n); waitErr != nil {
if err == nil {
err = waitErr
}
}
}
return n, err
}
func (c *LimitedConn) Write(b []byte) (n int, err error) {
if c.limiter == nil || !c.limiter.IsLimited() {
return c.Conn.Write(b)
}
burst := c.limiter.RateLimiter().Burst()
total := 0
for len(b) > 0 {
chunk := min(len(b), burst)
if err := c.limiter.RateLimiter().WaitN(c.ctx, chunk); err != nil {
return total, err
}
nw, err := c.Conn.Write(b[:chunk])
total += nw
if err != nil {
return total, err
}
b = b[chunk:]
}
return total, nil
}
func (c *LimitedConn) ReadFrom(r io.Reader) (n int64, err error) {
buf := make([]byte, 32*1024)
for {
nr, er := r.Read(buf)
if nr > 0 {
nw, ew := c.Write(buf[:nr])
n += int64(nw)
if ew != nil {
err = ew
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return n, err
}
func (c *LimitedConn) WriteTo(w io.Writer) (n int64, err error) {
buf := make([]byte, 32*1024)
for {
nr, er := c.Read(buf)
if nr > 0 {
nw, ew := w.Write(buf[:nr])
n += int64(nw)
if ew != nil {
err = ew
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return n, err
}

View File

@@ -0,0 +1,484 @@
package qos
import (
"bytes"
"context"
"errors"
"io"
"net"
"sync"
"testing"
"time"
)
type errorAfterConn struct {
mockConn
writeLimit int
written int
}
func (c *errorAfterConn) Write(b []byte) (int, error) {
c.mu.Lock()
defer c.mu.Unlock()
remaining := c.writeLimit - c.written
if remaining <= 0 {
return 0, errors.New("write error")
}
if len(b) > remaining {
b = b[:remaining]
}
c.writeBuf = append(c.writeBuf, b...)
c.written += len(b)
return len(b), nil
}
func TestWriteLargerThanBurst(t *testing.T) {
// 10KB/s, burst=1KB — write 5KB should be chunked into 5 pieces
bandwidth := int64(10 * 1024)
burst := 1024
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
data := make([]byte, 5*1024)
for i := range data {
data[i] = byte(i % 256)
}
n, err := lc.Write(data)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != len(data) {
t.Errorf("Write returned %d, want %d", n, len(data))
}
conn.mu.Lock()
defer conn.mu.Unlock()
if !bytes.Equal(conn.writeBuf, data) {
t.Error("Written data does not match input")
}
}
func TestWriteExactBurstSize(t *testing.T) {
burst := 2048
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
data := make([]byte, burst)
start := time.Now()
n, err := lc.Write(data)
dur := time.Since(start)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != burst {
t.Errorf("Write returned %d, want %d", n, burst)
}
// Exact burst should be instant
if dur > 100*time.Millisecond {
t.Errorf("Exact burst write should be instant, took %v", dur)
}
}
func TestWriteZeroLength(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
n, err := lc.Write(nil)
if err != nil {
t.Fatalf("Write(nil) failed: %v", err)
}
if n != 0 {
t.Errorf("Write(nil) returned %d, want 0", n)
}
n, err = lc.Write([]byte{})
if err != nil {
t.Fatalf("Write([]) failed: %v", err)
}
if n != 0 {
t.Errorf("Write([]) returned %d, want 0", n)
}
}
func TestWriteContextCancelDuringChunking(t *testing.T) {
// Very slow rate, small burst so second chunk must wait
limiter := NewLimiter(Config{Bandwidth: 100, Burst: 100})
conn := newMockConn(nil)
ctx, cancel := context.WithCancel(context.Background())
lc := NewLimitedConn(ctx, conn, limiter)
// Use up burst
_, err := lc.Write(make([]byte, 100))
if err != nil {
t.Fatalf("First write failed: %v", err)
}
cancel()
// This write needs more tokens but context is cancelled
_, err = lc.Write(make([]byte, 200))
if err == nil {
t.Error("Write should fail after context cancellation")
}
}
func TestWritePartialErrorMidChunk(t *testing.T) {
// Underlying conn fails after 500 bytes
conn := &errorAfterConn{writeLimit: 500}
limiter := NewLimiter(Config{Bandwidth: 100000, Burst: 1024})
lc := NewLimitedConn(context.Background(), conn, limiter)
data := make([]byte, 2048)
n, err := lc.Write(data)
if err == nil {
t.Error("Expected write error")
}
if n < 500 {
t.Errorf("Expected at least 500 bytes written, got %d", n)
}
if n > 1024 {
t.Errorf("Expected at most 1024 bytes written (one chunk), got %d", n)
}
}
func TestReadCappedToBurst(t *testing.T) {
burst := 512
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
// Provide 4KB of data
data := make([]byte, 4096)
for i := range data {
data[i] = byte(i % 256)
}
conn := newMockConn(data)
lc := NewLimitedConn(context.Background(), conn, limiter)
// Request 4KB read, should get at most burst (512) bytes
buf := make([]byte, 4096)
n, err := lc.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n > burst {
t.Errorf("Read returned %d bytes, should be capped at burst=%d", n, burst)
}
if !bytes.Equal(buf[:n], data[:n]) {
t.Error("Read data mismatch")
}
}
func TestReadSmallerThanBurst(t *testing.T) {
burst := 4096
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
data := make([]byte, 100)
conn := newMockConn(data)
lc := NewLimitedConn(context.Background(), conn, limiter)
buf := make([]byte, 100)
n, err := lc.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n != 100 {
t.Errorf("Read returned %d, want 100", n)
}
}
func TestReadEOF(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024})
conn := newMockConn([]byte{}) // empty
lc := NewLimitedConn(context.Background(), conn, limiter)
buf := make([]byte, 100)
_, err := lc.Read(buf)
if err != io.EOF {
t.Errorf("Expected io.EOF, got %v", err)
}
}
func TestReadContextCancel(t *testing.T) {
// Slow rate, small burst
limiter := NewLimiter(Config{Bandwidth: 100, Burst: 100})
data := make([]byte, 200)
conn := newMockConn(data)
ctx, cancel := context.WithCancel(context.Background())
lc := NewLimitedConn(ctx, conn, limiter)
// First read uses burst
buf := make([]byte, 100)
_, err := lc.Read(buf)
if err != nil {
t.Fatalf("First read failed: %v", err)
}
cancel()
// Second read should fail on WaitN
_, err = lc.Read(buf)
if err == nil {
t.Error("Read should fail after context cancellation")
}
}
func TestReadFromInterface(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
var _ io.ReaderFrom = lc
}
func TestWriteToInterface(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
data := make([]byte, 100)
conn := newMockConn(data)
lc := NewLimitedConn(context.Background(), conn, limiter)
var _ io.WriterTo = lc
}
func TestReadFromBasic(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
src := bytes.NewReader(make([]byte, 50*1024))
n, err := lc.ReadFrom(src)
if err != nil {
t.Fatalf("ReadFrom failed: %v", err)
}
if n != 50*1024 {
t.Errorf("ReadFrom transferred %d bytes, want %d", n, 50*1024)
}
conn.mu.Lock()
defer conn.mu.Unlock()
if len(conn.writeBuf) != 50*1024 {
t.Errorf("Underlying conn received %d bytes, want %d", len(conn.writeBuf), 50*1024)
}
}
func TestReadFromReaderError(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
errReader := &failingReader{failAfter: 100}
n, err := lc.ReadFrom(errReader)
if err == nil {
t.Error("Expected error from ReadFrom")
}
if n != 100 {
t.Errorf("ReadFrom transferred %d bytes before error, want 100", n)
}
}
func TestWriteToBasic(t *testing.T) {
data := make([]byte, 50*1024)
for i := range data {
data[i] = byte(i % 256)
}
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
conn := newMockConn(data)
lc := NewLimitedConn(context.Background(), conn, limiter)
var buf bytes.Buffer
n, err := lc.WriteTo(&buf)
if err != nil {
t.Fatalf("WriteTo failed: %v", err)
}
if n != int64(len(data)) {
t.Errorf("WriteTo transferred %d bytes, want %d", n, len(data))
}
if !bytes.Equal(buf.Bytes(), data) {
t.Error("WriteTo data mismatch")
}
}
func TestWriteToWriterError(t *testing.T) {
data := make([]byte, 1024)
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
conn := newMockConn(data)
lc := NewLimitedConn(context.Background(), conn, limiter)
fw := &failingWriter{failAfter: 100}
_, err := lc.WriteTo(fw)
if err == nil {
t.Error("Expected error from WriteTo")
}
}
func TestReadFromRateLimited(t *testing.T) {
// 10KB/s, 10KB burst — 20KB transfer should take ~1s
limiter := NewLimiter(Config{Bandwidth: 10 * 1024, Burst: 10 * 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
src := bytes.NewReader(make([]byte, 20*1024))
start := time.Now()
n, err := lc.ReadFrom(src)
dur := time.Since(start)
if err != nil {
t.Fatalf("ReadFrom failed: %v", err)
}
if n != 20*1024 {
t.Errorf("ReadFrom transferred %d, want %d", n, 20*1024)
}
if dur < 800*time.Millisecond {
t.Errorf("ReadFrom too fast: %v (expected ~1s for 20KB at 10KB/s with 10KB burst)", dur)
}
}
// TestIoCopyUsesReadFrom verifies io.Copy goes through our ReadFrom,
// not the underlying conn's optimized path.
func TestIoCopyUsesReadFrom(t *testing.T) {
// Use a small burst so we can detect if rate limiting is applied
limiter := NewLimiter(Config{Bandwidth: 10 * 1024, Burst: 10 * 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
src := bytes.NewReader(make([]byte, 20*1024))
start := time.Now()
n, err := io.Copy(lc, src)
dur := time.Since(start)
if err != nil {
t.Fatalf("io.Copy failed: %v", err)
}
if n != 20*1024 {
t.Errorf("io.Copy transferred %d, want %d", n, 20*1024)
}
// If io.Copy bypassed our ReadFrom, it would be instant
if dur < 800*time.Millisecond {
t.Errorf("io.Copy too fast (%v), rate limiting may be bypassed", dur)
}
}
func TestUnlimitedWrite(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 0})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
data := make([]byte, 1024*1024) // 1MB
start := time.Now()
n, err := lc.Write(data)
dur := time.Since(start)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != len(data) {
t.Errorf("Write returned %d, want %d", n, len(data))
}
if dur > 50*time.Millisecond {
t.Errorf("Unlimited write took too long: %v", dur)
}
}
func TestNilLimiter(t *testing.T) {
conn := newMockConn([]byte("hello"))
lc := NewLimitedConn(context.Background(), conn, nil)
buf := make([]byte, 10)
n, err := lc.Read(buf)
if err != nil {
t.Fatalf("Read with nil limiter failed: %v", err)
}
if n != 5 {
t.Errorf("Read returned %d, want 5", n)
}
n, err = lc.Write([]byte("world"))
if err != nil {
t.Fatalf("Write with nil limiter failed: %v", err)
}
if n != 5 {
t.Errorf("Write returned %d, want 5", n)
}
}
func TestConcurrentReadWrite(t *testing.T) {
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
limiter := NewLimiter(Config{Bandwidth: 100 * 1024, Burst: 100 * 1024})
lc := NewLimitedConn(context.Background(), serverConn, limiter)
dataSize := 50 * 1024
var wg sync.WaitGroup
// Writer
wg.Add(1)
go func() {
defer wg.Done()
data := make([]byte, dataSize)
for i := range data {
data[i] = 0xAA
}
lc.Write(data)
}()
// Reader on the other end
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, dataSize)
total := 0
for total < dataSize {
n, err := clientConn.Read(buf[total:])
if err != nil {
return
}
total += n
}
}()
wg.Wait()
}
type failingReader struct {
failAfter int
read int
}
func (r *failingReader) Read(b []byte) (int, error) {
remaining := r.failAfter - r.read
if remaining <= 0 {
return 0, errors.New("reader error")
}
n := len(b)
if n > remaining {
n = remaining
}
r.read += n
return n, nil
}
type failingWriter struct {
failAfter int
written int
}
func (w *failingWriter) Write(b []byte) (int, error) {
remaining := w.failAfter - w.written
if remaining <= 0 {
return 0, errors.New("writer error")
}
n := len(b)
if n > remaining {
w.written += remaining
return remaining, errors.New("writer error")
}
w.written += n
return n, nil
}

View File

@@ -0,0 +1,269 @@
package qos
import (
"context"
"io"
"net"
"sync"
"testing"
"time"
)
func TestEndToEndBandwidthLimiting(t *testing.T) {
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
bandwidth := int64(100 * 1024)
burstMultiplier := 2.0
burst := int(float64(bandwidth) * burstMultiplier)
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
ctx := context.Background()
limitedServerConn := NewLimitedConn(ctx, serverConn, limiter)
dataSize := 500 * 1024
testData := make([]byte, dataSize)
for i := range testData {
testData[i] = byte(i % 256)
}
var wg sync.WaitGroup
var writeErr, readErr error
var writeDuration time.Duration
receivedData := make([]byte, dataSize)
wg.Add(1)
go func() {
defer wg.Done()
start := time.Now()
chunkSize := 32 * 1024
for i := 0; i < dataSize; i += chunkSize {
end := i + chunkSize
if end > dataSize {
end = dataSize
}
_, err := limitedServerConn.Write(testData[i:end])
if err != nil {
writeErr = err
return
}
}
writeDuration = time.Since(start)
}()
wg.Add(1)
go func() {
defer wg.Done()
totalRead := 0
for totalRead < dataSize {
n, err := clientConn.Read(receivedData[totalRead:])
if err != nil {
if err != io.EOF {
readErr = err
}
return
}
totalRead += n
}
}()
wg.Wait()
if writeErr != nil {
t.Fatalf("Write error: %v", writeErr)
}
if readErr != nil {
t.Fatalf("Read error: %v", readErr)
}
for i := 0; i < dataSize; i++ {
if receivedData[i] != testData[i] {
t.Fatalf("Data mismatch at byte %d: got %d, want %d", i, receivedData[i], testData[i])
}
}
expectedMinDuration := 2500 * time.Millisecond
expectedMaxDuration := 4000 * time.Millisecond
if writeDuration < expectedMinDuration {
t.Errorf("Transfer too fast: %v (expected >= %v)", writeDuration, expectedMinDuration)
}
if writeDuration > expectedMaxDuration {
t.Errorf("Transfer too slow: %v (expected <= %v)", writeDuration, expectedMaxDuration)
}
t.Logf("Transferred %d bytes in %v (rate: %.2f KB/s)",
dataSize, writeDuration, float64(dataSize)/writeDuration.Seconds()/1024)
}
func TestBidirectionalBandwidthLimiting(t *testing.T) {
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
bandwidth := int64(50 * 1024)
burst := int(bandwidth * 2)
serverLimiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
clientLimiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
ctx := context.Background()
limitedServerConn := NewLimitedConn(ctx, serverConn, serverLimiter)
limitedClientConn := NewLimitedConn(ctx, clientConn, clientLimiter)
dataSize := 200 * 1024
serverData := make([]byte, dataSize)
clientData := make([]byte, dataSize)
for i := range serverData {
serverData[i] = byte(i % 256)
clientData[i] = byte((i + 128) % 256)
}
var wg sync.WaitGroup
receivedByClient := make([]byte, dataSize)
receivedByServer := make([]byte, dataSize)
// Server writes to client
wg.Add(1)
go func() {
defer wg.Done()
chunkSize := 16 * 1024
for i := 0; i < dataSize; i += chunkSize {
end := i + chunkSize
if end > dataSize {
end = dataSize
}
limitedServerConn.Write(serverData[i:end])
}
}()
// Client writes to server
wg.Add(1)
go func() {
defer wg.Done()
chunkSize := 16 * 1024
for i := 0; i < dataSize; i += chunkSize {
end := i + chunkSize
if end > dataSize {
end = dataSize
}
limitedClientConn.Write(clientData[i:end])
}
}()
// Client reads from server
wg.Add(1)
go func() {
defer wg.Done()
totalRead := 0
for totalRead < dataSize {
n, err := limitedClientConn.Read(receivedByClient[totalRead:])
if err != nil {
return
}
totalRead += n
}
}()
// Server reads from client
wg.Add(1)
go func() {
defer wg.Done()
totalRead := 0
for totalRead < dataSize {
n, err := limitedServerConn.Read(receivedByServer[totalRead:])
if err != nil {
return
}
totalRead += n
}
}()
wg.Wait()
for i := 0; i < dataSize; i++ {
if receivedByClient[i] != serverData[i] {
t.Fatalf("Client received wrong data at byte %d", i)
}
if receivedByServer[i] != clientData[i] {
t.Fatalf("Server received wrong data at byte %d", i)
}
}
t.Log("Bidirectional transfer completed successfully")
}
func TestBurstBehavior(t *testing.T) {
bandwidth := int64(10 * 1024)
burst := 50 * 1024
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
ctx := context.Background()
start := time.Now()
err := limiter.RateLimiter().WaitN(ctx, burst)
if err != nil {
t.Fatalf("WaitN failed: %v", err)
}
burstDuration := time.Since(start)
if burstDuration > 100*time.Millisecond {
t.Errorf("Burst should be instant, took %v", burstDuration)
}
start = time.Now()
err = limiter.RateLimiter().WaitN(ctx, 10*1024)
if err != nil {
t.Fatalf("WaitN failed: %v", err)
}
limitedDuration := time.Since(start)
if limitedDuration < 900*time.Millisecond || limitedDuration > 1200*time.Millisecond {
t.Errorf("Rate limiting not working correctly, took %v (expected ~1s)", limitedDuration)
}
t.Logf("Burst: %v, Rate-limited: %v", burstDuration, limitedDuration)
}
func TestMultipleBurstMultipliers(t *testing.T) {
tests := []struct {
name string
bandwidth int64
multiplier float64
}{
{"1x burst", 10 * 1024, 1.0},
{"1.5x burst", 10 * 1024, 1.5},
{"2x burst", 10 * 1024, 2.0},
{"2.5x burst", 10 * 1024, 2.5},
{"3x burst", 10 * 1024, 3.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
burst := int(float64(tt.bandwidth) * tt.multiplier)
limiter := NewLimiter(Config{Bandwidth: tt.bandwidth, Burst: burst})
if !limiter.IsLimited() {
t.Error("Limiter should be limited")
}
actualBurst := limiter.RateLimiter().Burst()
if actualBurst != burst {
t.Errorf("Burst = %d, want %d", actualBurst, burst)
}
ctx := context.Background()
start := time.Now()
err := limiter.RateLimiter().WaitN(ctx, burst)
if err != nil {
t.Fatalf("WaitN failed: %v", err)
}
duration := time.Since(start)
if duration > 50*time.Millisecond {
t.Errorf("Burst should be instant, took %v", duration)
}
})
}
}

View File

@@ -0,0 +1,34 @@
package qos
import (
"golang.org/x/time/rate"
)
type Config struct {
Bandwidth int64
Burst int
}
type Limiter struct {
limiter *rate.Limiter
}
func NewLimiter(cfg Config) *Limiter {
l := &Limiter{}
if cfg.Bandwidth > 0 {
burst := cfg.Burst
if burst <= 0 {
burst = int(cfg.Bandwidth * 2)
}
l.limiter = rate.NewLimiter(rate.Limit(cfg.Bandwidth), burst)
}
return l
}
func (l *Limiter) RateLimiter() *rate.Limiter {
return l.limiter
}
func (l *Limiter) IsLimited() bool {
return l.limiter != nil
}

View File

@@ -0,0 +1,313 @@
package qos
import (
"context"
"io"
"net"
"sync"
"testing"
"time"
)
func TestNewLimiter(t *testing.T) {
tests := []struct {
name string
cfg Config
wantLimit bool
wantBurst int
}{
{
name: "unlimited when bandwidth is 0",
cfg: Config{Bandwidth: 0},
wantLimit: false,
},
{
name: "limited with default burst (2x)",
cfg: Config{Bandwidth: 1024},
wantLimit: true,
wantBurst: 2048, // 2x bandwidth
},
{
name: "limited with custom burst",
cfg: Config{Bandwidth: 1024, Burst: 4096},
wantLimit: true,
wantBurst: 4096,
},
{
name: "1MB/s with 2x burst",
cfg: Config{Bandwidth: 1024 * 1024},
wantLimit: true,
wantBurst: 2 * 1024 * 1024,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := NewLimiter(tt.cfg)
if l.IsLimited() != tt.wantLimit {
t.Errorf("IsLimited() = %v, want %v", l.IsLimited(), tt.wantLimit)
}
if tt.wantLimit {
if l.RateLimiter() == nil {
t.Error("RateLimiter() should not be nil when limited")
}
if l.RateLimiter().Burst() != tt.wantBurst {
t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst)
}
}
})
}
}
func TestLimiterBandwidthEnforcement(t *testing.T) {
bandwidth := int64(10 * 1024)
burst := int(bandwidth * 2)
l := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
if !l.IsLimited() {
t.Fatal("Limiter should be limited")
}
ctx := context.Background()
start := time.Now()
err := l.RateLimiter().WaitN(ctx, burst)
if err != nil {
t.Fatalf("WaitN failed: %v", err)
}
burstDuration := time.Since(start)
if burstDuration > 100*time.Millisecond {
t.Errorf("Burst should be instant, took %v", burstDuration)
}
start = time.Now()
err = l.RateLimiter().WaitN(ctx, int(bandwidth)) // Request 1 second worth
if err != nil {
t.Fatalf("WaitN failed: %v", err)
}
limitedDuration := time.Since(start)
if limitedDuration < 800*time.Millisecond {
t.Errorf("Rate limiting not working, took only %v for 1 second worth of data", limitedDuration)
}
if limitedDuration > 1500*time.Millisecond {
t.Errorf("Rate limiting too slow, took %v for 1 second worth of data", limitedDuration)
}
}
type mockConn struct {
readBuf []byte
readPos int
writeBuf []byte
mu sync.Mutex
}
func newMockConn(data []byte) *mockConn {
return &mockConn{readBuf: data}
}
func (c *mockConn) Read(b []byte) (n int, err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.readPos >= len(c.readBuf) {
return 0, io.EOF
}
n = copy(b, c.readBuf[c.readPos:])
c.readPos += n
return n, nil
}
func (c *mockConn) Write(b []byte) (n int, err error) {
c.mu.Lock()
defer c.mu.Unlock()
c.writeBuf = append(c.writeBuf, b...)
return len(b), nil
}
func (c *mockConn) Close() error { return nil }
func (c *mockConn) LocalAddr() net.Addr { return nil }
func (c *mockConn) RemoteAddr() net.Addr { return nil }
func (c *mockConn) SetDeadline(t time.Time) error { return nil }
func (c *mockConn) SetReadDeadline(t time.Time) error { return nil }
func (c *mockConn) SetWriteDeadline(t time.Time) error { return nil }
func TestLimitedConnRead(t *testing.T) {
dataSize := 20 * 1024
testData := make([]byte, dataSize)
for i := range testData {
testData[i] = byte(i % 256)
}
// 10KB/s limit, 20KB burst
bandwidth := int64(10 * 1024)
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: int(bandwidth * 2)})
conn := newMockConn(testData)
ctx := context.Background()
limitedConn := NewLimitedConn(ctx, conn, limiter)
buf := make([]byte, dataSize)
start := time.Now()
totalRead := 0
for totalRead < dataSize {
n, err := limitedConn.Read(buf[totalRead:])
if err != nil && err != io.EOF {
t.Fatalf("Read failed: %v", err)
}
totalRead += n
if err == io.EOF {
break
}
}
duration := time.Since(start)
if totalRead != dataSize {
t.Errorf("Read %d bytes, want %d", totalRead, dataSize)
}
t.Logf("Read %d bytes in %v", totalRead, duration)
}
func TestLimitedConnWrite(t *testing.T) {
bandwidth := int64(10 * 1024)
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: int(bandwidth * 2)})
conn := newMockConn(nil)
ctx := context.Background()
limitedConn := NewLimitedConn(ctx, conn, limiter)
dataSize := 30 * 1024
testData := make([]byte, dataSize)
for i := range testData {
testData[i] = byte(i % 256)
}
start := time.Now()
chunkSize := 10 * 1024
for i := 0; i < dataSize; i += chunkSize {
end := i + chunkSize
if end > dataSize {
end = dataSize
}
n, err := limitedConn.Write(testData[i:end])
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != end-i {
t.Errorf("Write returned %d, want %d", n, end-i)
}
}
duration := time.Since(start)
// 30KB data, 10KB/s rate, 20KB burst → ~1s for remaining 10KB
if duration < 800*time.Millisecond {
t.Errorf("Write too fast, took %v for 30KB with 10KB/s limit and 20KB burst", duration)
}
t.Logf("Wrote %d bytes in %v", dataSize, duration)
}
func TestLimitedConnUnlimited(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 0})
testData := make([]byte, 100*1024)
conn := newMockConn(testData)
ctx := context.Background()
limitedConn := NewLimitedConn(ctx, conn, limiter)
buf := make([]byte, len(testData))
start := time.Now()
totalRead := 0
for totalRead < len(testData) {
n, err := limitedConn.Read(buf[totalRead:])
if err != nil && err != io.EOF {
t.Fatalf("Read failed: %v", err)
}
totalRead += n
if err == io.EOF {
break
}
}
duration := time.Since(start)
if totalRead != len(testData) {
t.Errorf("Read %d bytes, want %d", totalRead, len(testData))
}
if duration > 100*time.Millisecond {
t.Errorf("Unlimited read took too long: %v", duration)
}
}
func TestLimitedConnContextCancellation(t *testing.T) {
bandwidth := int64(100)
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: 100})
conn := newMockConn(nil)
ctx, cancel := context.WithCancel(context.Background())
limitedConn := NewLimitedConn(ctx, conn, limiter)
_, err := limitedConn.Write(make([]byte, 100))
if err != nil {
t.Fatalf("First write failed: %v", err)
}
cancel()
_, err = limitedConn.Write(make([]byte, 1000))
if err == nil {
t.Error("Write should fail after context cancellation")
}
}
func TestBurstMultiplier(t *testing.T) {
tests := []struct {
name string
bandwidth int64
multiplier float64
wantBurst int
}{
{
name: "2x multiplier",
bandwidth: 1024 * 1024, // 1MB/s
multiplier: 2.0,
wantBurst: 2 * 1024 * 1024,
},
{
name: "2.5x multiplier",
bandwidth: 1024 * 1024,
multiplier: 2.5,
wantBurst: int(float64(1024*1024) * 2.5),
},
{
name: "1x multiplier (no extra burst)",
bandwidth: 1024 * 1024,
multiplier: 1.0,
wantBurst: 1024 * 1024,
},
{
name: "3x multiplier",
bandwidth: 500 * 1024, // 500KB/s
multiplier: 3.0,
wantBurst: 3 * 500 * 1024,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
burst := int(float64(tt.bandwidth) * tt.multiplier)
l := NewLimiter(Config{Bandwidth: tt.bandwidth, Burst: burst})
if l.RateLimiter().Burst() != tt.wantBurst {
t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst)
}
})
}
}

View File

@@ -10,7 +10,6 @@ import (
"gopkg.in/yaml.v3"
)
// TunnelConfig holds configuration for a predefined tunnel
type TunnelConfig struct {
Name string `yaml:"name"` // Tunnel name (required, unique identifier)
Type string `yaml:"type"` // Tunnel type: http, https, tcp (required)
@@ -21,9 +20,9 @@ type TunnelConfig struct {
AllowIPs []string `yaml:"allow_ips,omitempty"` // Allowed IPs/CIDRs
DenyIPs []string `yaml:"deny_ips,omitempty"` // Denied IPs/CIDRs
Auth string `yaml:"auth,omitempty"` // Proxy authentication password (http/https only)
Bandwidth string `yaml:"bandwidth,omitempty"` // Bandwidth limit (e.g., 1M, 500K, 1G)
}
// Validate checks if the tunnel configuration is valid
func (t *TunnelConfig) Validate() error {
if t.Name == "" {
return fmt.Errorf("tunnel name is required")
@@ -47,7 +46,6 @@ func (t *TunnelConfig) Validate() error {
return nil
}
// ClientConfig represents the client configuration
type ClientConfig struct {
Server string `yaml:"server"` // Server address (e.g., tunnel.example.com:443)
Token string `yaml:"token"` // Authentication token
@@ -55,7 +53,6 @@ type ClientConfig struct {
Tunnels []*TunnelConfig `yaml:"tunnels,omitempty"` // Predefined tunnels
}
// Validate checks if the client configuration is valid
func (c *ClientConfig) Validate() error {
if c.Server == "" {
return fmt.Errorf("server address is required")
@@ -92,7 +89,6 @@ func (c *ClientConfig) Validate() error {
return nil
}
// GetTunnel returns a tunnel by name
func (c *ClientConfig) GetTunnel(name string) *TunnelConfig {
for _, t := range c.Tunnels {
if t.Name == name {
@@ -102,7 +98,6 @@ func (c *ClientConfig) GetTunnel(name string) *TunnelConfig {
return nil
}
// GetTunnelNames returns all tunnel names
func (c *ClientConfig) GetTunnelNames() []string {
names := make([]string, len(c.Tunnels))
for i, t := range c.Tunnels {
@@ -111,7 +106,6 @@ func (c *ClientConfig) GetTunnelNames() []string {
return names
}
// DefaultClientConfig returns the default configuration path
func DefaultClientConfigPath() string {
home, err := os.UserHomeDir()
if err != nil {
@@ -120,7 +114,6 @@ func DefaultClientConfigPath() string {
return filepath.Join(home, ".drip", "config.yaml")
}
// LoadClientConfig loads configuration from file
func LoadClientConfig(path string) (*ClientConfig, error) {
if path == "" {
path = DefaultClientConfigPath()
@@ -146,7 +139,6 @@ func LoadClientConfig(path string) (*ClientConfig, error) {
return &config, nil
}
// SaveClientConfig saves configuration to file
func SaveClientConfig(config *ClientConfig, path string) error {
if path == "" {
path = DefaultClientConfigPath()
@@ -162,7 +154,6 @@ func SaveClientConfig(config *ClientConfig, path string) error {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Write to file with secure permissions
if err := os.WriteFile(path, data, 0600); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
@@ -170,7 +161,6 @@ func SaveClientConfig(config *ClientConfig, path string) error {
return nil
}
// ConfigExists checks if config file exists
func ConfigExists(path string) bool {
if path == "" {
path = DefaultClientConfigPath()

View File

@@ -10,7 +10,6 @@ import (
"gopkg.in/yaml.v3"
)
// ServerConfig holds the server configuration
type ServerConfig struct {
Port int `yaml:"port"`
PublicPort int `yaml:"public_port"` // Port to display in URLs (for reverse proxy scenarios)
@@ -37,13 +36,12 @@ type ServerConfig struct {
PprofPort int `yaml:"pprof_port"`
// Allowed transports: "tcp", "wss", or "tcp,wss" (default: "tcp,wss")
AllowedTransports []string `yaml:"transports"`
// Allowed tunnel types: "http", "https", "tcp" (default: all)
AllowedTransports []string `yaml:"transports"`
AllowedTunnelTypes []string `yaml:"tunnel_types"`
Bandwidth string `yaml:"bandwidth,omitempty"`
BurstMultiplier float64 `yaml:"burst_multiplier,omitempty"`
}
// Validate checks if the server configuration is valid
func (c *ServerConfig) Validate() error {
// Validate port
if c.Port < 1 || c.Port > 65535 {
@@ -92,7 +90,6 @@ func (c *ServerConfig) Validate() error {
return nil
}
// LoadTLSConfig loads TLS configuration
func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
if !c.TLSEnabled {
return nil, nil
@@ -131,7 +128,6 @@ func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
return tlsConfig, nil
}
// GetClientTLSConfig returns TLS config for client connections
func GetClientTLSConfig(serverName string) *tls.Config {
return &tls.Config{
ServerName: serverName,
@@ -147,8 +143,6 @@ func GetClientTLSConfig(serverName string) *tls.Config {
}
}
// GetClientTLSConfigInsecure returns TLS config for client with InsecureSkipVerify
// WARNING: Only use for testing!
func GetClientTLSConfigInsecure() *tls.Config {
return &tls.Config{
InsecureSkipVerify: true,
@@ -164,7 +158,6 @@ func GetClientTLSConfigInsecure() *tls.Config {
}
}
// DefaultServerConfigPath returns the default server configuration path
func DefaultServerConfigPath() string {
// Check /etc/drip/config.yaml first (system-wide)
systemPath := "/etc/drip/config.yaml"
@@ -180,7 +173,6 @@ func DefaultServerConfigPath() string {
return filepath.Join(home, ".drip", "server.yaml")
}
// LoadServerConfig loads server configuration from file
func LoadServerConfig(path string) (*ServerConfig, error) {
if path == "" {
path = DefaultServerConfigPath()
@@ -202,7 +194,6 @@ func LoadServerConfig(path string) (*ServerConfig, error) {
return &config, nil
}
// SaveServerConfig saves server configuration to file
func SaveServerConfig(config *ServerConfig, path string) error {
if path == "" {
path = DefaultServerConfigPath()
@@ -225,7 +216,6 @@ func SaveServerConfig(config *ServerConfig, path string) error {
return nil
}
// ServerConfigExists checks if server config file exists
func ServerConfigExists(path string) bool {
if path == "" {
path = DefaultServerConfigPath()

158
pkg/config/config_test.go Normal file
View File

@@ -0,0 +1,158 @@
package config
import (
"os"
"path/filepath"
"strconv"
"strings"
"testing"
)
func TestServerConfigBandwidth(t *testing.T) {
tests := []struct {
name string
yaml string
wantBandwidth string
wantMultiplier float64
}{
{
name: "bandwidth 1M with 2.5x burst",
yaml: `
port: 8443
domain: example.com
tcp_port_min: 10000
tcp_port_max: 20000
bandwidth: 1M
burst_multiplier: 2.5
`,
wantBandwidth: "1M",
wantMultiplier: 2.5,
},
{
name: "bandwidth 10M with default burst",
yaml: `
port: 8443
domain: example.com
tcp_port_min: 10000
tcp_port_max: 20000
bandwidth: 10M
`,
wantBandwidth: "10M",
wantMultiplier: 0, // not set, will use default 2.0 in code
},
{
name: "no bandwidth limit",
yaml: `
port: 8443
domain: example.com
tcp_port_min: 10000
tcp_port_max: 20000
`,
wantBandwidth: "",
wantMultiplier: 0,
},
{
name: "bandwidth 500K with 3x burst",
yaml: `
port: 8443
domain: example.com
tcp_port_min: 10000
tcp_port_max: 20000
bandwidth: 500K
burst_multiplier: 3.0
`,
wantBandwidth: "500K",
wantMultiplier: 3.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte(tt.yaml), 0600); err != nil {
t.Fatalf("Failed to write config file: %v", err)
}
cfg, err := LoadServerConfig(configPath)
if err != nil {
t.Fatalf("LoadServerConfig failed: %v", err)
}
if cfg.Bandwidth != tt.wantBandwidth {
t.Errorf("Bandwidth = %q, want %q", cfg.Bandwidth, tt.wantBandwidth)
}
if cfg.BurstMultiplier != tt.wantMultiplier {
t.Errorf("BurstMultiplier = %v, want %v", cfg.BurstMultiplier, tt.wantMultiplier)
}
})
}
}
func TestParseBandwidth(t *testing.T) {
tests := []struct {
input string
want int64
}{
{"", 0},
{"0", 0},
{"1024", 1024},
{"1K", 1024},
{"1KB", 1024},
{"1k", 1024},
{"1M", 1024 * 1024},
{"1MB", 1024 * 1024},
{"1m", 1024 * 1024},
{"10M", 10 * 1024 * 1024},
{"1G", 1024 * 1024 * 1024},
{"1GB", 1024 * 1024 * 1024},
{"500K", 500 * 1024},
{"100M", 100 * 1024 * 1024},
{" 1M ", 1024 * 1024}, // with spaces
{"invalid", 0},
{"-1M", 0}, // negative
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := parseBandwidthString(tt.input)
if got != tt.want {
t.Errorf("parseBandwidthString(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func parseBandwidthString(s string) int64 {
if s == "" {
return 0
}
s = strings.TrimSpace(strings.ToUpper(s))
if s == "" {
return 0
}
var multiplier int64 = 1
switch {
case strings.HasSuffix(s, "GB") || strings.HasSuffix(s, "G"):
multiplier = 1024 * 1024 * 1024
s = strings.TrimSuffix(strings.TrimSuffix(s, "GB"), "G")
case strings.HasSuffix(s, "MB") || strings.HasSuffix(s, "M"):
multiplier = 1024 * 1024
s = strings.TrimSuffix(strings.TrimSuffix(s, "MB"), "M")
case strings.HasSuffix(s, "KB") || strings.HasSuffix(s, "K"):
multiplier = 1024
s = strings.TrimSuffix(strings.TrimSuffix(s, "KB"), "K")
case strings.HasSuffix(s, "B"):
s = strings.TrimSuffix(s, "B")
}
val, err := strconv.ParseInt(s, 10, 64)
if err != nil || val < 0 {
return 0
}
return val * multiplier
}