From 89f67ab1450df22e563b84d33ae7c229cffb0356 Mon Sep 17 00:00:00 2001 From: Gouryella Date: Sun, 15 Feb 2026 02:39:50 +0800 Subject: [PATCH 1/2] feat(client): Add bandwidth limit function support - Implement client bandwidth limitation parameter --bandwidth, supporting 1M, 1MB, 1G and other formats - Added parseBandwidth function to parse bandwidth values and verify them - Added bandwidth limit option in HTTP, HTTPS, TCP commands - Pass bandwidth configuration to the server through protocol - Add relevant test cases to verify the bandwidth analysis function feat(server): implements server-side bandwidth limitation function - Add bandwidth limitation logic in connection processing, using token bucket algorithm - Implement an effective rate limiting strategy that minimizes the bandwidth of the client and server - Added QoS limiter and restricted connection wrapper - Integrated bandwidth throttling in HTTP and WebSocket proxies - Added global bandwidth limit and burst multiplier settings in server configuration docs: Updated documentation to describe bandwidth limiting functionality - Add 2025-02-14 version update instructions in README and README_CN - Add bandwidth limit function description and usage examples - Provide client and server configuration examples and parameter descriptions --- README.md | 16 ++ README_CN.md | 16 ++ go.mod | 1 + go.sum | 2 + internal/client/cli/http.go | 55 ++++- internal/client/cli/http_test.go | 59 +++++ internal/client/cli/https.go | 16 +- internal/client/cli/server.go | 18 ++ internal/client/cli/start.go | 6 + internal/client/cli/tcp.go | 8 + internal/client/cli/tunnel_helpers.go | 3 + internal/client/tcp/connector.go | 3 + internal/client/tcp/pool_client.go | 12 + internal/server/proxy/handler.go | 10 +- internal/server/proxy/websocket_handler.go | 10 +- internal/server/tcp/bandwidth_test.go | 93 ++++++++ internal/server/tcp/connection.go | 44 ++++ internal/server/tcp/listener.go | 16 ++ internal/server/tcp/proxy.go | 16 +- internal/server/tcp/tunnel.go | 3 + internal/server/tunnel/connection.go | 29 +++ internal/server/tunnel/connection_test.go | 59 +++++ internal/shared/protocol/messages.go | 2 + internal/shared/qos/conn.go | 112 +++++++++ internal/shared/qos/conn_test.go | 254 +++++++++++++++++++++ internal/shared/qos/limiter.go | 34 +++ internal/shared/qos/limiter_test.go | 172 ++++++++++++++ pkg/config/client_config.go | 1 + pkg/config/config.go | 4 + pkg/config/config_test.go | 63 +++++ 30 files changed, 1132 insertions(+), 5 deletions(-) create mode 100644 internal/client/cli/http_test.go create mode 100644 internal/server/tcp/bandwidth_test.go create mode 100644 internal/server/tunnel/connection_test.go create mode 100644 internal/shared/qos/conn.go create mode 100644 internal/shared/qos/conn_test.go create mode 100644 internal/shared/qos/limiter.go create mode 100644 internal/shared/qos/limiter_test.go create mode 100644 pkg/config/config_test.go diff --git a/README.md b/README.md index 3f36f99..9406daa 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,22 @@ ## Recent Changes +### 2025-02-14 + +- **Bandwidth Limiting (QoS)** - Per-tunnel bandwidth control with token bucket algorithm, server enforces `min(client, server)` as effective limit +- **Transport Protocol Control** - Support independent configuration for service domain and tunnel domain + +```bash +# Client: limit to 1MB/s +drip http 3000 --bandwidth 1M +``` + +```yaml +# Server: global limit (config.yaml) +bandwidth: 10M +burst_multiplier: 2.5 +``` + ### 2025-01-29 - **Bearer Token Authentication** - Added bearer token authentication support for tunnel access control diff --git a/README_CN.md b/README_CN.md index 24d1d91..14aa7c1 100644 --- a/README_CN.md +++ b/README_CN.md @@ -35,6 +35,22 @@ ## 最近更新 +### 2025-02-14 + +- **带宽限速 (QoS)** - 支持按隧道粒度进行带宽控制,使用令牌桶算法,服务端按 `min(client, server)` 作为实际生效限速 +- **传输协议控制** - 支持服务域名与隧道域名的独立配置 + +```bash +# Client: limit to 1MB/s +drip http 3000 --bandwidth 1M +``` + +```yaml +# Server: global limit (config.yaml) +bandwidth: 10M +burst_multiplier: 2.5 +``` + ### 2025-01-29 - **Bearer Token 认证** - 新增 Bearer Token 认证支持,用于隧道访问控制 diff --git a/go.mod b/go.mod index 9eabd79..e245534 100644 --- a/go.mod +++ b/go.mod @@ -42,5 +42,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/text v0.33.0 // indirect + golang.org/x/time v0.14.0 // indirect google.golang.org/protobuf v1.36.11 // indirect ) diff --git a/go.sum b/go.sum index a2e8541..1da54f4 100644 --- a/go.sum +++ b/go.sum @@ -95,6 +95,8 @@ golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/client/cli/http.go b/internal/client/cli/http.go index cac5c66..894d1f3 100644 --- a/internal/client/cli/http.go +++ b/internal/client/cli/http.go @@ -21,6 +21,7 @@ var ( authPass string authBearer string transport string + bandwidth string ) var httpCmd = &cobra.Command{ @@ -37,6 +38,7 @@ Example: drip http 3000 --auth secret Enable proxy authentication with password drip http 3000 --auth-bearer sk-xxx Enable proxy authentication with bearer token drip http 3000 --transport wss Use WebSocket over TLS (CDN-friendly) + drip http 3000 --bandwidth 1M Limit bandwidth to 1 MB/s Configuration: First time: Run 'drip config init' to save server and token @@ -45,7 +47,13 @@ Configuration: Transport options: auto - Automatically select based on server address (default) tcp - Direct TLS 1.3 connection - wss - WebSocket over TLS (works through CDN like Cloudflare)`, + wss - WebSocket over TLS (works through CDN like Cloudflare) + +Bandwidth format: + 1K, 1KB - 1 kilobyte per second (1024 bytes/s) + 1M, 1MB - 1 megabyte per second (1048576 bytes/s) + 1G, 1GB - 1 gigabyte per second + 1024 - 1024 bytes per second (raw number)`, Args: cobra.ExactArgs(1), RunE: runHTTP, SilenceUsage: true, @@ -61,6 +69,7 @@ func init() { httpCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication") httpCmd.Flags().StringVar(&authBearer, "auth-bearer", "", "Bearer token for proxy authentication") httpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)") + httpCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)") httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process") httpCmd.Flags().MarkHidden("daemon-child") rootCmd.AddCommand(httpCmd) @@ -85,6 +94,11 @@ func runHTTP(_ *cobra.Command, args []string) error { return err } + bw, err := parseBandwidth(bandwidth) + if err != nil { + return err + } + connConfig := &tcp.ConnectorConfig{ ServerAddr: serverAddr, Token: token, @@ -98,6 +112,7 @@ func runHTTP(_ *cobra.Command, args []string) error { AuthPass: authPass, AuthBearer: authBearer, Transport: parseTransport(transport), + Bandwidth: bw, } var daemon *DaemonInfo @@ -118,3 +133,41 @@ func parseTransport(s string) tcp.TransportType { return tcp.TransportAuto } } + +func parseBandwidth(s string) (int64, error) { + if s == "" { + return 0, nil + } + + s = strings.TrimSpace(strings.ToUpper(s)) + if s == "" { + return 0, nil + } + + var multiplier int64 = 1 + switch { + case strings.HasSuffix(s, "GB") || strings.HasSuffix(s, "G"): + multiplier = 1024 * 1024 * 1024 + s = strings.TrimSuffix(strings.TrimSuffix(s, "GB"), "G") + case strings.HasSuffix(s, "MB") || strings.HasSuffix(s, "M"): + multiplier = 1024 * 1024 + s = strings.TrimSuffix(strings.TrimSuffix(s, "MB"), "M") + case strings.HasSuffix(s, "KB") || strings.HasSuffix(s, "K"): + multiplier = 1024 + s = strings.TrimSuffix(strings.TrimSuffix(s, "KB"), "K") + case strings.HasSuffix(s, "B"): + s = strings.TrimSuffix(s, "B") + } + + val, err := strconv.ParseInt(s, 10, 64) + if err != nil || val < 0 { + return 0, fmt.Errorf("invalid bandwidth value: %q (use format like 1M, 500K, 1G)", s) + } + + result := val * multiplier + if val > 0 && result/multiplier != val { + return 0, fmt.Errorf("bandwidth value overflow: %q", s) + } + + return result, nil +} diff --git a/internal/client/cli/http_test.go b/internal/client/cli/http_test.go new file mode 100644 index 0000000..a76646c --- /dev/null +++ b/internal/client/cli/http_test.go @@ -0,0 +1,59 @@ +package cli + +import ( + "testing" +) + +func TestParseBandwidth(t *testing.T) { + tests := []struct { + input string + want int64 + wantErr bool + }{ + {"", 0, false}, + {"0", 0, false}, + {"1024", 1024, false}, + {"1K", 1024, false}, + {"1KB", 1024, false}, + {"1k", 1024, false}, + {"1M", 1024 * 1024, false}, + {"1MB", 1024 * 1024, false}, + {"1m", 1024 * 1024, false}, + {"10M", 10 * 1024 * 1024, false}, + {"1G", 1024 * 1024 * 1024, false}, + {"1GB", 1024 * 1024 * 1024, false}, + {"500K", 500 * 1024, false}, + {"100M", 100 * 1024 * 1024, false}, + {" 1M ", 1024 * 1024, false}, + {"1B", 1, false}, + {"100B", 100, false}, + {"invalid", 0, true}, + {"abc", 0, true}, + {"-1M", 0, true}, + {"-100", 0, true}, + {"1.5M", 0, true}, + {"M", 0, true}, + {"K", 0, true}, + {"9223372036854775807K", 0, true}, + {"9999999999999999999G", 0, true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got, err := parseBandwidth(tt.input) + if tt.wantErr { + if err == nil { + t.Errorf("parseBandwidth(%q) = %d, want error", tt.input, got) + } + return + } + if err != nil { + t.Errorf("parseBandwidth(%q) unexpected error: %v", tt.input, err) + return + } + if got != tt.want { + t.Errorf("parseBandwidth(%q) = %d, want %d", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/client/cli/https.go b/internal/client/cli/https.go index 1760db5..eea8bcc 100644 --- a/internal/client/cli/https.go +++ b/internal/client/cli/https.go @@ -24,6 +24,7 @@ Example: drip https 443 --auth secret Enable proxy authentication with password drip https 443 --auth-bearer sk-xxx Enable proxy authentication with bearer token drip https 443 --transport wss Use WebSocket over TLS (CDN-friendly) + drip https 443 --bandwidth 1M Limit bandwidth to 1 MB/s Configuration: First time: Run 'drip config init' to save server and token @@ -32,7 +33,13 @@ Configuration: Transport options: auto - Automatically select based on server address (default) tcp - Direct TLS 1.3 connection - wss - WebSocket over TLS (works through CDN like Cloudflare)`, + wss - WebSocket over TLS (works through CDN like Cloudflare) + +Bandwidth format: + 1K, 1KB - 1 kilobyte per second (1024 bytes/s) + 1M, 1MB - 1 megabyte per second (1048576 bytes/s) + 1G, 1GB - 1 gigabyte per second + 1024 - 1024 bytes per second (raw number)`, Args: cobra.ExactArgs(1), RunE: runHTTPS, SilenceUsage: true, @@ -48,6 +55,7 @@ func init() { httpsCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication") httpsCmd.Flags().StringVar(&authBearer, "auth-bearer", "", "Bearer token for proxy authentication") httpsCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)") + httpsCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)") httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process") httpsCmd.Flags().MarkHidden("daemon-child") rootCmd.AddCommand(httpsCmd) @@ -72,6 +80,11 @@ func runHTTPS(_ *cobra.Command, args []string) error { return err } + bw, err := parseBandwidth(bandwidth) + if err != nil { + return err + } + connConfig := &tcp.ConnectorConfig{ ServerAddr: serverAddr, Token: token, @@ -85,6 +98,7 @@ func runHTTPS(_ *cobra.Command, args []string) error { AuthPass: authPass, AuthBearer: authBearer, Transport: parseTransport(transport), + Bandwidth: bw, } var daemon *DaemonInfo diff --git a/internal/client/cli/server.go b/internal/client/cli/server.go index 9d16621..94b40cc 100644 --- a/internal/client/cli/server.go +++ b/internal/client/cli/server.go @@ -313,6 +313,24 @@ func runServer(cmd *cobra.Command, _ []string) error { listener.SetAllowedTransports(cfg.AllowedTransports) listener.SetAllowedTunnelTypes(cfg.AllowedTunnelTypes) + bandwidth, err := parseBandwidth(cfg.Bandwidth) + if err != nil { + logger.Fatal("Invalid bandwidth configuration", zap.Error(err)) + } + burstMultiplier := cfg.BurstMultiplier + if burstMultiplier <= 0 { + burstMultiplier = 2.0 + } + listener.SetBandwidth(bandwidth) + listener.SetBurstMultiplier(burstMultiplier) + if bandwidth > 0 { + logger.Info("Bandwidth limit configured", + zap.String("bandwidth", cfg.Bandwidth), + zap.Int64("bandwidth_bytes_sec", bandwidth), + zap.Float64("burst_multiplier", burstMultiplier), + ) + } + if err := listener.Start(); err != nil { logger.Fatal("Failed to start TCP listener", zap.Error(err)) } diff --git a/internal/client/cli/start.go b/internal/client/cli/start.go index 3ade7fa..350dc8f 100644 --- a/internal/client/cli/start.go +++ b/internal/client/cli/start.go @@ -242,6 +242,7 @@ func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp AuthPass: t.Auth, AuthBearer: t.AuthBearer, Transport: transport, + Bandwidth: parseBandwidthOrZero(t.Bandwidth), } } @@ -251,3 +252,8 @@ func getAddress(t *config.TunnelConfig) string { } return "127.0.0.1" } + +func parseBandwidthOrZero(s string) int64 { + bw, _ := parseBandwidth(s) + return bw +} diff --git a/internal/client/cli/tcp.go b/internal/client/cli/tcp.go index 2ab643e..8ea48c8 100644 --- a/internal/client/cli/tcp.go +++ b/internal/client/cli/tcp.go @@ -24,6 +24,7 @@ Example: drip tcp 22 --allow-ip 10.0.0.1 Allow single IP drip tcp 22 --deny-ip 1.2.3.4 Block specific IP drip tcp 22 --transport wss Use WebSocket over TLS (CDN-friendly) + drip tcp 22 --bandwidth 1M Limit bandwidth to 1 MB/s Supported Services: - Databases: PostgreSQL (5432), MySQL (3306), Redis (6379), MongoDB (27017) @@ -54,6 +55,7 @@ func init() { tcpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)") tcpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)") tcpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)") + tcpCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)") tcpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process") tcpCmd.Flags().MarkHidden("daemon-child") rootCmd.AddCommand(tcpCmd) @@ -74,6 +76,11 @@ func runTCP(_ *cobra.Command, args []string) error { return err } + bw, err := parseBandwidth(bandwidth) + if err != nil { + return err + } + connConfig := &tcp.ConnectorConfig{ ServerAddr: serverAddr, Token: token, @@ -85,6 +92,7 @@ func runTCP(_ *cobra.Command, args []string) error { AllowIPs: allowIPs, DenyIPs: denyIPs, Transport: parseTransport(transport), + Bandwidth: bw, } var daemon *DaemonInfo diff --git a/internal/client/cli/tunnel_helpers.go b/internal/client/cli/tunnel_helpers.go index 4863ec2..25c503f 100644 --- a/internal/client/cli/tunnel_helpers.go +++ b/internal/client/cli/tunnel_helpers.go @@ -30,6 +30,9 @@ func buildDaemonArgs(tunnelType string, args []string, subdomain string, localAd if authBearer != "" { daemonArgs = append(daemonArgs, "--auth-bearer", authBearer) } + if bandwidth != "" { + daemonArgs = append(daemonArgs, "--bandwidth", bandwidth) + } if insecure { daemonArgs = append(daemonArgs, "--insecure") } diff --git a/internal/client/tcp/connector.go b/internal/client/tcp/connector.go index 961aa2c..38afd57 100644 --- a/internal/client/tcp/connector.go +++ b/internal/client/tcp/connector.go @@ -46,6 +46,9 @@ type ConnectorConfig struct { // Transport protocol selection Transport TransportType + + // Bandwidth limit (bytes/sec), 0 = unlimited + Bandwidth int64 } type TunnelClient interface { diff --git a/internal/client/tcp/pool_client.go b/internal/client/tcp/pool_client.go index b913604..d131315 100644 --- a/internal/client/tcp/pool_client.go +++ b/internal/client/tcp/pool_client.go @@ -81,6 +81,9 @@ type PoolClient struct { // Session scaler scaler *SessionScaler + + // Bandwidth limit requested from server (bytes/sec), 0 = unlimited + bandwidth int64 } // NewPoolClient creates a new pool client. @@ -178,6 +181,7 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient { transport: transport, insecure: cfg.Insecure, dialer: NewConnectionDialer(serverAddr, tlsConfig, cfg.Token, transport, logger), + bandwidth: cfg.Bandwidth, } if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS { @@ -229,6 +233,10 @@ func (c *PoolClient) Connect() error { } } + if c.bandwidth > 0 { + req.Bandwidth = c.bandwidth + } + payload, err := json.Marshal(req) if err != nil { _ = primaryConn.Close() @@ -275,6 +283,10 @@ func (c *PoolClient) Connect() error { c.tunnelID = resp.TunnelID } + if resp.Bandwidth > 0 { + c.bandwidth = resp.Bandwidth + } + yamuxCfg := mux.NewClientConfig() session, err := yamux.Server(primaryConn, yamuxCfg) diff --git a/internal/server/proxy/handler.go b/internal/server/proxy/handler.go index 386932f..9b26a7d 100644 --- a/internal/server/proxy/handler.go +++ b/internal/server/proxy/handler.go @@ -20,6 +20,7 @@ import ( "drip/internal/shared/netutil" "drip/internal/shared/pool" "drip/internal/shared/protocol" + "drip/internal/shared/qos" ) // bufio.Reader pool to reduce allocations on hot path @@ -247,7 +248,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { tconn.IncActiveConnections() defer tconn.DecActiveConnections() - countingStream := netutil.NewCountingConn(stream, + var limitedStream net.Conn = stream + if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() { + if l, ok := limiter.(*qos.Limiter); ok { + limitedStream = qos.NewLimitedConn(r.Context(), stream, l) + } + } + + countingStream := netutil.NewCountingConn(limitedStream, tconn.AddBytesOut, tconn.AddBytesIn, ) diff --git a/internal/server/proxy/websocket_handler.go b/internal/server/proxy/websocket_handler.go index 41ec952..16ccf32 100644 --- a/internal/server/proxy/websocket_handler.go +++ b/internal/server/proxy/websocket_handler.go @@ -14,6 +14,7 @@ import ( "drip/internal/shared/httputil" "drip/internal/shared/netutil" "drip/internal/shared/protocol" + "drip/internal/shared/qos" "drip/internal/shared/wsutil" ) @@ -58,6 +59,13 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn return } + var limitedStream net.Conn = stream + if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() { + if l, ok := limiter.(*qos.Limiter); ok { + limitedStream = qos.NewLimitedConn(context.Background(), stream, l) + } + } + go func() { defer stream.Close() defer clientConn.Close() @@ -71,7 +79,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn } } - _ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW, + _ = netutil.PipeWithCallbacks(context.Background(), limitedStream, clientRW, func(n int64) { tconn.AddBytesOut(n) }, func(n int64) { tconn.AddBytesIn(n) }, ) diff --git a/internal/server/tcp/bandwidth_test.go b/internal/server/tcp/bandwidth_test.go new file mode 100644 index 0000000..5d2a6c2 --- /dev/null +++ b/internal/server/tcp/bandwidth_test.go @@ -0,0 +1,93 @@ +package tcp + +import ( + "testing" +) + +func TestEffectiveBandwidthSelection(t *testing.T) { + tests := []struct { + name string + serverBW int64 + clientBW int64 + wantEffective int64 + }{ + {"server only", 1024 * 1024, 0, 1024 * 1024}, + {"client only", 0, 512 * 1024, 512 * 1024}, + {"both unlimited", 0, 0, 0}, + {"client lower than server", 10 * 1024 * 1024, 1 * 1024 * 1024, 1 * 1024 * 1024}, + {"client higher than server - server wins", 1 * 1024 * 1024, 10 * 1024 * 1024, 1 * 1024 * 1024}, + {"client equal to server", 5 * 1024 * 1024, 5 * 1024 * 1024, 5 * 1024 * 1024}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + effectiveBandwidth := tt.serverBW + if tt.clientBW > 0 { + if effectiveBandwidth == 0 || tt.clientBW < effectiveBandwidth { + effectiveBandwidth = tt.clientBW + } + } + + if effectiveBandwidth != tt.wantEffective { + t.Errorf("effectiveBandwidth = %d, want %d", effectiveBandwidth, tt.wantEffective) + } + }) + } +} + +func TestConnectionSetBandwidthConfig(t *testing.T) { + tests := []struct { + name string + bandwidth int64 + burstMultiplier float64 + wantBandwidth int64 + wantMultiplier float64 + }{ + {"1MB/s with 2x burst", 1024 * 1024, 2.0, 1024 * 1024, 2.0}, + {"default multiplier when 0", 1024 * 1024, 0, 1024 * 1024, 2.0}, + {"default multiplier when negative", 1024 * 1024, -1.0, 1024 * 1024, 2.0}, + {"unlimited bandwidth", 0, 2.5, 0, 2.5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conn := &Connection{} + conn.SetBandwidthConfig(tt.bandwidth, tt.burstMultiplier) + + if conn.bandwidth != tt.wantBandwidth { + t.Errorf("bandwidth = %v, want %v", conn.bandwidth, tt.wantBandwidth) + } + if conn.burstMultiplier != tt.wantMultiplier { + t.Errorf("burstMultiplier = %v, want %v", conn.burstMultiplier, tt.wantMultiplier) + } + }) + } +} + +func TestListenerBandwidthConfig(t *testing.T) { + tests := []struct { + name string + bandwidth int64 + burstMultiplier float64 + wantBandwidth int64 + wantMultiplier float64 + }{ + {"set bandwidth and multiplier", 1024 * 1024, 2.5, 1024 * 1024, 2.5}, + {"default multiplier", 1024 * 1024, 0, 1024 * 1024, 2.0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &Listener{} + l.SetBandwidth(tt.bandwidth) + l.SetBurstMultiplier(tt.burstMultiplier) + + if l.bandwidth != tt.wantBandwidth { + t.Errorf("bandwidth = %v, want %v", l.bandwidth, tt.wantBandwidth) + } + if l.burstMultiplier != tt.wantMultiplier { + t.Errorf("burstMultiplier = %v, want %v", l.burstMultiplier, tt.wantMultiplier) + } + }) + } +} diff --git a/internal/server/tcp/connection.go b/internal/server/tcp/connection.go index ebbf59d..4a16942 100644 --- a/internal/server/tcp/connection.go +++ b/internal/server/tcp/connection.go @@ -19,6 +19,7 @@ import ( "drip/internal/shared/constants" "drip/internal/shared/httputil" "drip/internal/shared/protocol" + "drip/internal/shared/qos" "go.uber.org/zap" ) @@ -69,6 +70,8 @@ type Connection struct { // Server capabilities allowedTunnelTypes []string allowedTransports []string + bandwidth int64 + burstMultiplier float64 } // NewConnection creates a new connection handler @@ -231,11 +234,44 @@ func (c *Connection) Handle() error { ) } + // Configure bandwidth limiting + effectiveBandwidth := c.bandwidth + if req.Bandwidth > 0 { + if effectiveBandwidth == 0 || req.Bandwidth < effectiveBandwidth { + effectiveBandwidth = req.Bandwidth + } + } + if effectiveBandwidth > 0 { + burstMultiplier := c.burstMultiplier + if burstMultiplier <= 0 { + burstMultiplier = 2.0 + } + c.tunnelConn.SetBandwidthWithBurst(effectiveBandwidth, burstMultiplier) + + limiter := qos.NewLimiter(qos.Config{ + Bandwidth: effectiveBandwidth, + Burst: int(float64(effectiveBandwidth) * burstMultiplier), + }) + c.tunnelConn.SetLimiter(limiter) + + source := "server" + if req.Bandwidth > 0 && (c.bandwidth == 0 || req.Bandwidth < c.bandwidth) { + source = "client" + } + c.logger.Info("Bandwidth limit configured", + zap.String("subdomain", c.subdomain), + zap.Int64("bandwidth_bytes_sec", effectiveBandwidth), + zap.Float64("burst_multiplier", burstMultiplier), + zap.String("source", source), + ) + } + // Build and send registration response resp, err := regHandler.BuildRegistrationResponse(result) if err != nil { return fmt.Errorf("failed to build registration response: %w", err) } + resp.Bandwidth = c.tunnelConn.GetBandwidth() if err := regHandler.SendRegistrationResponse(c.conn, resp); err != nil { return fmt.Errorf("failed to send registration ack: %w", err) @@ -483,3 +519,11 @@ func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool { } return false } + +func (c *Connection) SetBandwidthConfig(bandwidth int64, burstMultiplier float64) { + c.bandwidth = bandwidth + if burstMultiplier <= 0 { + burstMultiplier = 2.0 + } + c.burstMultiplier = burstMultiplier +} diff --git a/internal/server/tcp/listener.go b/internal/server/tcp/listener.go index 93b50c8..c3cbf46 100644 --- a/internal/server/tcp/listener.go +++ b/internal/server/tcp/listener.go @@ -60,6 +60,8 @@ type Listener struct { // Server capabilities allowedTransports []string allowedTunnelTypes []string + bandwidth int64 + burstMultiplier float64 } func NewListener(cfg ListenerConfig) *Listener { @@ -298,6 +300,7 @@ func (l *Listener) handleConnection(netConn net.Conn) { }) conn.SetAllowedTunnelTypes(l.allowedTunnelTypes) conn.SetAllowedTransports(l.allowedTransports) + conn.SetBandwidthConfig(l.bandwidth, l.burstMultiplier) connID := netConn.RemoteAddr().String() l.connMu.Lock() @@ -420,6 +423,8 @@ func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) { HTTPListener: l.httpListener, }) tcpConn.SetAllowedTunnelTypes(l.allowedTunnelTypes) + tcpConn.SetAllowedTransports(l.allowedTransports) + tcpConn.SetBandwidthConfig(l.bandwidth, l.burstMultiplier) l.connMu.Lock() l.connections[connID] = tcpConn @@ -471,6 +476,17 @@ func (l *Listener) SetAllowedTunnelTypes(types []string) { l.allowedTunnelTypes = types } +func (l *Listener) SetBandwidth(bandwidth int64) { + l.bandwidth = bandwidth +} + +func (l *Listener) SetBurstMultiplier(multiplier float64) { + if multiplier <= 0 { + multiplier = 2.0 + } + l.burstMultiplier = multiplier +} + // IsTransportAllowed checks if a transport is allowed func (l *Listener) IsTransportAllowed(transport string) bool { if len(l.allowedTransports) == 0 { diff --git a/internal/server/tcp/proxy.go b/internal/server/tcp/proxy.go index 18bb68c..d054971 100644 --- a/internal/server/tcp/proxy.go +++ b/internal/server/tcp/proxy.go @@ -10,6 +10,7 @@ import ( "drip/internal/shared/netutil" "drip/internal/shared/pool" + "drip/internal/shared/qos" "go.uber.org/zap" ) @@ -34,6 +35,7 @@ type Proxy struct { cancel context.CancelFunc checkIPAccess func(ip string) bool + limiter interface{ IsLimited() bool } } type trafficStats interface { @@ -73,6 +75,11 @@ func (p *Proxy) SetIPAccessCheck(check func(ip string) bool) { p.checkIPAccess = check } +// SetLimiter sets the bandwidth limiter for this proxy. +func (p *Proxy) SetLimiter(limiter interface{ IsLimited() bool }) { + p.limiter = limiter +} + func (p *Proxy) Start() error { addr := fmt.Sprintf("0.0.0.0:%d", p.port) @@ -240,10 +247,17 @@ func (p *Proxy) handleConn(conn net.Conn) { defer stream.Close() + var limitedStream net.Conn = stream + if p.limiter != nil && p.limiter.IsLimited() { + if l, ok := p.limiter.(*qos.Limiter); ok { + limitedStream = qos.NewLimitedConn(p.ctx, stream, l) + } + } + _ = netutil.PipeWithCallbacksAndBufferSize( p.ctx, conn, - stream, + limitedStream, pool.SizeLarge, func(n int64) { if p.stats != nil { diff --git a/internal/server/tcp/tunnel.go b/internal/server/tcp/tunnel.go index 0d9f721..0f901ce 100644 --- a/internal/server/tcp/tunnel.go +++ b/internal/server/tcp/tunnel.go @@ -52,6 +52,9 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error { if c.tunnelConn != nil && c.tunnelConn.HasIPAccessControl() { c.proxy.SetIPAccessCheck(c.tunnelConn.IsIPAllowed) } + if c.tunnelConn != nil { + c.proxy.SetLimiter(c.tunnelConn.GetLimiter()) + } // Update lifecycle manager with proxy if c.lifecycleManager != nil { diff --git a/internal/server/tunnel/connection.go b/internal/server/tunnel/connection.go index 6bbddbc..0221af1 100644 --- a/internal/server/tunnel/connection.go +++ b/internal/server/tunnel/connection.go @@ -32,6 +32,10 @@ type Connection struct { ipAccessChecker *netutil.IPAccessChecker proxyAuth *protocol.ProxyAuth + + bandwidth int64 + burstMultiplier float64 + limiter interface{ IsLimited() bool } } func NewConnection(subdomain string, conn *websocket.Conn, logger *zap.Logger) *Connection { @@ -214,6 +218,31 @@ func (c *Connection) ValidateProxyAuth(password string) bool { return auth.Password == password } +func (c *Connection) SetBandwidthWithBurst(bandwidth int64, burstMultiplier float64) { + c.mu.Lock() + defer c.mu.Unlock() + c.bandwidth = bandwidth + c.burstMultiplier = burstMultiplier +} + +func (c *Connection) GetBandwidth() int64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.bandwidth +} + +func (c *Connection) SetLimiter(limiter interface{ IsLimited() bool }) { + c.mu.Lock() + defer c.mu.Unlock() + c.limiter = limiter +} + +func (c *Connection) GetLimiter() interface{ IsLimited() bool } { + c.mu.RLock() + defer c.mu.RUnlock() + return c.limiter +} + func (c *Connection) StartWritePump() { if c.Conn == nil { go func() { diff --git a/internal/server/tunnel/connection_test.go b/internal/server/tunnel/connection_test.go new file mode 100644 index 0000000..68d7e01 --- /dev/null +++ b/internal/server/tunnel/connection_test.go @@ -0,0 +1,59 @@ +package tunnel + +import ( + "testing" + + "drip/internal/shared/qos" + + "go.uber.org/zap" +) + +func TestConnectionBandwidthWithBurst(t *testing.T) { + logger := zap.NewNop() + + tests := []struct { + name string + bandwidth int64 + burstMultiplier float64 + wantBandwidth int64 + }{ + {"1MB/s with 2x burst", 1024 * 1024, 2.0, 1024 * 1024}, + {"500KB/s with 3x burst", 500 * 1024, 3.0, 500 * 1024}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conn := NewConnection("test-subdomain", nil, logger) + conn.SetBandwidthWithBurst(tt.bandwidth, tt.burstMultiplier) + + if conn.GetBandwidth() != tt.wantBandwidth { + t.Errorf("GetBandwidth() = %v, want %v", conn.GetBandwidth(), tt.wantBandwidth) + } + + burst := int(float64(tt.bandwidth) * tt.burstMultiplier) + limiter := qos.NewLimiter(qos.Config{Bandwidth: tt.bandwidth, Burst: burst}) + conn.SetLimiter(limiter) + + got := conn.GetLimiter() + if got == nil { + t.Fatal("GetLimiter() should not be nil") + } + if !got.IsLimited() { + t.Error("Limiter should be limited") + } + }) + } +} + +func TestConnectionBandwidthUnlimited(t *testing.T) { + logger := zap.NewNop() + conn := NewConnection("test-subdomain", nil, logger) + + if conn.GetBandwidth() != 0 { + t.Errorf("Default bandwidth should be 0, got %v", conn.GetBandwidth()) + } + + if conn.GetLimiter() != nil { + t.Error("Default limiter should be nil") + } +} diff --git a/internal/shared/protocol/messages.go b/internal/shared/protocol/messages.go index d82947a..3f37d42 100644 --- a/internal/shared/protocol/messages.go +++ b/internal/shared/protocol/messages.go @@ -29,6 +29,7 @@ type RegisterRequest struct { PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"` IPAccess *IPAccessControl `json:"ip_access,omitempty"` ProxyAuth *ProxyAuth `json:"proxy_auth,omitempty"` + Bandwidth int64 `json:"bandwidth,omitempty"` } type RegisterResponse struct { @@ -39,6 +40,7 @@ type RegisterResponse struct { TunnelID string `json:"tunnel_id,omitempty"` SupportsDataConn bool `json:"supports_data_conn,omitempty"` RecommendedConns int `json:"recommended_conns,omitempty"` + Bandwidth int64 `json:"bandwidth,omitempty"` } type DataConnectRequest struct { diff --git a/internal/shared/qos/conn.go b/internal/shared/qos/conn.go new file mode 100644 index 0000000..6bc4bbc --- /dev/null +++ b/internal/shared/qos/conn.go @@ -0,0 +1,112 @@ +package qos + +import ( + "context" + "io" + "net" +) + +type LimitedConn struct { + net.Conn + limiter *Limiter + ctx context.Context +} + +func NewLimitedConn(ctx context.Context, conn net.Conn, limiter *Limiter) *LimitedConn { + return &LimitedConn{ + Conn: conn, + limiter: limiter, + ctx: ctx, + } +} + +func (c *LimitedConn) Read(b []byte) (n int, err error) { + if c.limiter == nil || !c.limiter.IsLimited() { + return c.Conn.Read(b) + } + + burst := c.limiter.RateLimiter().Burst() + if len(b) > burst { + b = b[:burst] + } + + n, err = c.Conn.Read(b) + if n > 0 { + if waitErr := c.limiter.RateLimiter().WaitN(c.ctx, n); waitErr != nil { + if err == nil { + err = waitErr + } + } + } + return n, err +} + +func (c *LimitedConn) Write(b []byte) (n int, err error) { + if c.limiter == nil || !c.limiter.IsLimited() { + return c.Conn.Write(b) + } + + burst := c.limiter.RateLimiter().Burst() + total := 0 + + for len(b) > 0 { + chunk := min(len(b), burst) + + if err := c.limiter.RateLimiter().WaitN(c.ctx, chunk); err != nil { + return total, err + } + + nw, err := c.Conn.Write(b[:chunk]) + total += nw + if err != nil { + return total, err + } + b = b[chunk:] + } + + return total, nil +} + +func (c *LimitedConn) ReadFrom(r io.Reader) (n int64, err error) { + buf := make([]byte, 32*1024) + for { + nr, er := r.Read(buf) + if nr > 0 { + nw, ew := c.Write(buf[:nr]) + n += int64(nw) + if ew != nil { + err = ew + break + } + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + return n, err +} + +func (c *LimitedConn) WriteTo(w io.Writer) (n int64, err error) { + buf := make([]byte, 32*1024) + for { + nr, er := c.Read(buf) + if nr > 0 { + nw, ew := w.Write(buf[:nr]) + n += int64(nw) + if ew != nil { + err = ew + break + } + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + return n, err +} diff --git a/internal/shared/qos/conn_test.go b/internal/shared/qos/conn_test.go new file mode 100644 index 0000000..aff4cf7 --- /dev/null +++ b/internal/shared/qos/conn_test.go @@ -0,0 +1,254 @@ +package qos + +import ( + "bytes" + "context" + "errors" + "io" + "net" + "sync" + "testing" + "time" +) + +type errorAfterConn struct { + mockConn + writeLimit int + written int +} + +func (c *errorAfterConn) Write(b []byte) (int, error) { + c.mu.Lock() + defer c.mu.Unlock() + remaining := c.writeLimit - c.written + if remaining <= 0 { + return 0, errors.New("write error") + } + if len(b) > remaining { + b = b[:remaining] + } + c.writeBuf = append(c.writeBuf, b...) + c.written += len(b) + return len(b), nil +} + +func TestWriteLargerThanBurst(t *testing.T) { + bandwidth := int64(10 * 1024) + burst := 1024 + limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst}) + + conn := newMockConn(nil) + lc := NewLimitedConn(context.Background(), conn, limiter) + + data := make([]byte, 5*1024) + for i := range data { + data[i] = byte(i % 256) + } + + n, err := lc.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + if n != len(data) { + t.Errorf("Write returned %d, want %d", n, len(data)) + } + + conn.mu.Lock() + defer conn.mu.Unlock() + if !bytes.Equal(conn.writeBuf, data) { + t.Error("Written data does not match input") + } +} + +func TestWriteZeroLength(t *testing.T) { + limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024}) + conn := newMockConn(nil) + lc := NewLimitedConn(context.Background(), conn, limiter) + + n, err := lc.Write(nil) + if err != nil { + t.Fatalf("Write(nil) failed: %v", err) + } + if n != 0 { + t.Errorf("Write(nil) returned %d, want 0", n) + } +} + +func TestWriteContextCancelDuringChunking(t *testing.T) { + limiter := NewLimiter(Config{Bandwidth: 100, Burst: 100}) + conn := newMockConn(nil) + ctx, cancel := context.WithCancel(context.Background()) + lc := NewLimitedConn(ctx, conn, limiter) + + _, err := lc.Write(make([]byte, 100)) + if err != nil { + t.Fatalf("First write failed: %v", err) + } + + cancel() + + _, err = lc.Write(make([]byte, 200)) + if err == nil { + t.Error("Write should fail after context cancellation") + } +} + +func TestReadCappedToBurst(t *testing.T) { + burst := 512 + limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst}) + + data := make([]byte, 4096) + for i := range data { + data[i] = byte(i % 256) + } + conn := newMockConn(data) + lc := NewLimitedConn(context.Background(), conn, limiter) + + buf := make([]byte, 4096) + n, err := lc.Read(buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + if n > burst { + t.Errorf("Read returned %d bytes, should be capped at burst=%d", n, burst) + } + if !bytes.Equal(buf[:n], data[:n]) { + t.Error("Read data mismatch") + } +} + +func TestReadEOF(t *testing.T) { + limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024}) + conn := newMockConn([]byte{}) + lc := NewLimitedConn(context.Background(), conn, limiter) + + buf := make([]byte, 100) + _, err := lc.Read(buf) + if err != io.EOF { + t.Errorf("Expected io.EOF, got %v", err) + } +} + +func TestNilLimiter(t *testing.T) { + conn := newMockConn([]byte("hello")) + lc := NewLimitedConn(context.Background(), conn, nil) + + buf := make([]byte, 10) + n, err := lc.Read(buf) + if err != nil { + t.Fatalf("Read with nil limiter failed: %v", err) + } + if n != 5 { + t.Errorf("Read returned %d, want 5", n) + } + + n, err = lc.Write([]byte("world")) + if err != nil { + t.Fatalf("Write with nil limiter failed: %v", err) + } + if n != 5 { + t.Errorf("Write returned %d, want 5", n) + } +} + +func TestConcurrentReadWrite(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + limiter := NewLimiter(Config{Bandwidth: 100 * 1024, Burst: 100 * 1024}) + lc := NewLimitedConn(context.Background(), serverConn, limiter) + + dataSize := 50 * 1024 + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + data := make([]byte, dataSize) + for i := range data { + data[i] = 0xAA + } + lc.Write(data) + }() + + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, dataSize) + total := 0 + for total < dataSize { + n, err := clientConn.Read(buf[total:]) + if err != nil { + return + } + total += n + } + }() + + wg.Wait() +} + +func TestReadFromBasic(t *testing.T) { + limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024}) + conn := newMockConn(nil) + lc := NewLimitedConn(context.Background(), conn, limiter) + + src := bytes.NewReader(make([]byte, 50*1024)) + n, err := lc.ReadFrom(src) + if err != nil { + t.Fatalf("ReadFrom failed: %v", err) + } + if n != 50*1024 { + t.Errorf("ReadFrom transferred %d bytes, want %d", n, 50*1024) + } + + conn.mu.Lock() + defer conn.mu.Unlock() + if len(conn.writeBuf) != 50*1024 { + t.Errorf("Underlying conn received %d bytes, want %d", len(conn.writeBuf), 50*1024) + } +} + +func TestWriteToBasic(t *testing.T) { + data := make([]byte, 50*1024) + for i := range data { + data[i] = byte(i % 256) + } + limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024}) + conn := newMockConn(data) + lc := NewLimitedConn(context.Background(), conn, limiter) + + var buf bytes.Buffer + n, err := lc.WriteTo(&buf) + if err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + if n != int64(len(data)) { + t.Errorf("WriteTo transferred %d bytes, want %d", n, len(data)) + } + if !bytes.Equal(buf.Bytes(), data) { + t.Error("WriteTo data mismatch") + } +} + +func TestUnlimitedWrite(t *testing.T) { + limiter := NewLimiter(Config{Bandwidth: 0}) + conn := newMockConn(nil) + lc := NewLimitedConn(context.Background(), conn, limiter) + + data := make([]byte, 1024*1024) + start := time.Now() + n, err := lc.Write(data) + dur := time.Since(start) + + if err != nil { + t.Fatalf("Write failed: %v", err) + } + if n != len(data) { + t.Errorf("Write returned %d, want %d", n, len(data)) + } + if dur > 50*time.Millisecond { + t.Errorf("Unlimited write took too long: %v", dur) + } +} diff --git a/internal/shared/qos/limiter.go b/internal/shared/qos/limiter.go new file mode 100644 index 0000000..f5803fd --- /dev/null +++ b/internal/shared/qos/limiter.go @@ -0,0 +1,34 @@ +package qos + +import ( + "golang.org/x/time/rate" +) + +type Config struct { + Bandwidth int64 + Burst int +} + +type Limiter struct { + limiter *rate.Limiter +} + +func NewLimiter(cfg Config) *Limiter { + l := &Limiter{} + if cfg.Bandwidth > 0 { + burst := cfg.Burst + if burst <= 0 { + burst = int(cfg.Bandwidth * 2) + } + l.limiter = rate.NewLimiter(rate.Limit(cfg.Bandwidth), burst) + } + return l +} + +func (l *Limiter) RateLimiter() *rate.Limiter { + return l.limiter +} + +func (l *Limiter) IsLimited() bool { + return l.limiter != nil +} diff --git a/internal/shared/qos/limiter_test.go b/internal/shared/qos/limiter_test.go new file mode 100644 index 0000000..4034215 --- /dev/null +++ b/internal/shared/qos/limiter_test.go @@ -0,0 +1,172 @@ +package qos + +import ( + "context" + "io" + "net" + "sync" + "testing" + "time" +) + +func TestNewLimiter(t *testing.T) { + tests := []struct { + name string + cfg Config + wantLimit bool + wantBurst int + }{ + { + name: "unlimited when bandwidth is 0", + cfg: Config{Bandwidth: 0}, + wantLimit: false, + }, + { + name: "limited with default burst (2x)", + cfg: Config{Bandwidth: 1024}, + wantLimit: true, + wantBurst: 2048, + }, + { + name: "limited with custom burst", + cfg: Config{Bandwidth: 1024, Burst: 4096}, + wantLimit: true, + wantBurst: 4096, + }, + { + name: "1MB/s with 2x burst", + cfg: Config{Bandwidth: 1024 * 1024}, + wantLimit: true, + wantBurst: 2 * 1024 * 1024, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := NewLimiter(tt.cfg) + if l.IsLimited() != tt.wantLimit { + t.Errorf("IsLimited() = %v, want %v", l.IsLimited(), tt.wantLimit) + } + if tt.wantLimit { + if l.RateLimiter() == nil { + t.Error("RateLimiter() should not be nil when limited") + } + if l.RateLimiter().Burst() != tt.wantBurst { + t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst) + } + } + }) + } +} + +func TestLimiterBandwidthEnforcement(t *testing.T) { + bandwidth := int64(10 * 1024) + burst := int(bandwidth * 2) + + l := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst}) + + if !l.IsLimited() { + t.Fatal("Limiter should be limited") + } + + ctx := context.Background() + + start := time.Now() + err := l.RateLimiter().WaitN(ctx, burst) + if err != nil { + t.Fatalf("WaitN failed: %v", err) + } + burstDuration := time.Since(start) + if burstDuration > 100*time.Millisecond { + t.Errorf("Burst should be instant, took %v", burstDuration) + } + + start = time.Now() + err = l.RateLimiter().WaitN(ctx, int(bandwidth)) + if err != nil { + t.Fatalf("WaitN failed: %v", err) + } + limitedDuration := time.Since(start) + + if limitedDuration < 800*time.Millisecond { + t.Errorf("Rate limiting not working, took only %v for 1 second worth of data", limitedDuration) + } + if limitedDuration > 1500*time.Millisecond { + t.Errorf("Rate limiting too slow, took %v for 1 second worth of data", limitedDuration) + } +} + +type mockConn struct { + readBuf []byte + readPos int + writeBuf []byte + mu sync.Mutex +} + +func newMockConn(data []byte) *mockConn { + return &mockConn{readBuf: data} +} + +func (c *mockConn) Read(b []byte) (n int, err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.readPos >= len(c.readBuf) { + return 0, io.EOF + } + n = copy(b, c.readBuf[c.readPos:]) + c.readPos += n + return n, nil +} + +func (c *mockConn) Write(b []byte) (n int, err error) { + c.mu.Lock() + defer c.mu.Unlock() + c.writeBuf = append(c.writeBuf, b...) + return len(b), nil +} + +func (c *mockConn) Close() error { return nil } +func (c *mockConn) LocalAddr() net.Addr { return nil } +func (c *mockConn) RemoteAddr() net.Addr { return nil } +func (c *mockConn) SetDeadline(t time.Time) error { return nil } +func (c *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (c *mockConn) SetWriteDeadline(t time.Time) error { return nil } + +func TestBurstMultiplier(t *testing.T) { + tests := []struct { + name string + bandwidth int64 + multiplier float64 + wantBurst int + }{ + { + name: "2x multiplier", + bandwidth: 1024 * 1024, + multiplier: 2.0, + wantBurst: 2 * 1024 * 1024, + }, + { + name: "1x multiplier (no extra burst)", + bandwidth: 1024 * 1024, + multiplier: 1.0, + wantBurst: 1024 * 1024, + }, + { + name: "3x multiplier", + bandwidth: 500 * 1024, + multiplier: 3.0, + wantBurst: 3 * 500 * 1024, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + burst := int(float64(tt.bandwidth) * tt.multiplier) + l := NewLimiter(Config{Bandwidth: tt.bandwidth, Burst: burst}) + + if l.RateLimiter().Burst() != tt.wantBurst { + t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst) + } + }) + } +} diff --git a/pkg/config/client_config.go b/pkg/config/client_config.go index 944dba5..d655423 100644 --- a/pkg/config/client_config.go +++ b/pkg/config/client_config.go @@ -22,6 +22,7 @@ type TunnelConfig struct { DenyIPs []string `yaml:"deny_ips,omitempty"` // Denied IPs/CIDRs Auth string `yaml:"auth,omitempty"` // Proxy authentication password (http/https only) AuthBearer string `yaml:"auth_bearer,omitempty"` // Proxy authentication bearer token (http/https only) + Bandwidth string `yaml:"bandwidth,omitempty"` // Bandwidth limit (e.g., 1M, 500K, 1G) } // Validate checks if the tunnel configuration is valid diff --git a/pkg/config/config.go b/pkg/config/config.go index 45a6664..a0f7713 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -41,6 +41,10 @@ type ServerConfig struct { // Allowed tunnel types: "http", "https", "tcp" (default: all) AllowedTunnelTypes []string `yaml:"tunnel_types"` + + // Bandwidth limiting + Bandwidth string `yaml:"bandwidth,omitempty"` + BurstMultiplier float64 `yaml:"burst_multiplier,omitempty"` } // Validate checks if the server configuration is valid diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 0000000..09defdd --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,63 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestServerConfigBandwidth(t *testing.T) { + tests := []struct { + name string + yaml string + wantBandwidth string + wantMultiplier float64 + }{ + { + name: "bandwidth 1M with 2.5x burst", + yaml: ` +port: 8443 +domain: example.com +tcp_port_min: 10000 +tcp_port_max: 20000 +bandwidth: 1M +burst_multiplier: 2.5 +`, + wantBandwidth: "1M", + wantMultiplier: 2.5, + }, + { + name: "no bandwidth limit", + yaml: ` +port: 8443 +domain: example.com +tcp_port_min: 10000 +tcp_port_max: 20000 +`, + wantBandwidth: "", + wantMultiplier: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte(tt.yaml), 0600); err != nil { + t.Fatalf("Failed to write config file: %v", err) + } + + cfg, err := LoadServerConfig(configPath) + if err != nil { + t.Fatalf("LoadServerConfig failed: %v", err) + } + + if cfg.Bandwidth != tt.wantBandwidth { + t.Errorf("Bandwidth = %q, want %q", cfg.Bandwidth, tt.wantBandwidth) + } + if cfg.BurstMultiplier != tt.wantMultiplier { + t.Errorf("BurstMultiplier = %v, want %v", cfg.BurstMultiplier, tt.wantMultiplier) + } + }) + } +} From 6f1f4da5d9e39f508bff6ce158f0c391a87e6439 Mon Sep 17 00:00:00 2001 From: Gouryella Date: Sun, 15 Feb 2026 03:04:15 +0800 Subject: [PATCH 2/2] feat(client): Add tunnel bandwidth verification and error handling fix(server): Improve burst value calculation of bandwidth limiter --- internal/client/cli/start.go | 37 ++++++++++++++++++++++++------- internal/server/tcp/connection.go | 30 ++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 9 deletions(-) diff --git a/internal/client/cli/start.go b/internal/client/cli/start.go index 350dc8f..0444e47 100644 --- a/internal/client/cli/start.go +++ b/internal/client/cli/start.go @@ -106,6 +106,12 @@ func runStart(_ *cobra.Command, args []string) error { return fmt.Errorf("no tunnels to start") } + for _, t := range tunnelsToStart { + if err := validateTunnelBandwidth(t); err != nil { + return err + } + } + // Start tunnels if len(tunnelsToStart) == 1 { return startSingleTunnel(cfg, tunnelsToStart[0]) @@ -127,7 +133,10 @@ func formatTunnelInfo(t *config.TunnelConfig) string { } func startSingleTunnel(cfg *config.ClientConfig, t *config.TunnelConfig) error { - connConfig := buildConnectorConfig(cfg, t) + connConfig, err := buildConnectorConfig(cfg, t) + if err != nil { + return err + } fmt.Printf("Starting tunnel '%s' (%s %s:%d)\n", t.Name, t.Type, getAddress(t), t.Port) @@ -164,7 +173,11 @@ func startMultipleTunnels(cfg *config.ClientConfig, tunnels []*config.TunnelConf go func(tunnel *config.TunnelConfig) { defer wg.Done() - connConfig := buildConnectorConfig(cfg, tunnel) + connConfig, err := buildConnectorConfig(cfg, tunnel) + if err != nil { + errChan <- err + return + } fmt.Printf(" Starting %s (%s %s:%d)...\n", tunnel.Name, tunnel.Type, getAddress(tunnel), tunnel.Port) client := tcp.NewTunnelClient(connConfig, logger) @@ -212,7 +225,12 @@ func startMultipleTunnels(cfg *config.ClientConfig, tunnels []*config.TunnelConf return nil } -func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp.ConnectorConfig { +func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) (*tcp.ConnectorConfig, error) { + bw, err := parseBandwidth(t.Bandwidth) + if err != nil { + return nil, fmt.Errorf("invalid bandwidth for tunnel '%s': %w", t.Name, err) + } + tunnelType := protocol.TunnelTypeHTTP switch t.Type { case "https": @@ -242,8 +260,8 @@ func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp AuthPass: t.Auth, AuthBearer: t.AuthBearer, Transport: transport, - Bandwidth: parseBandwidthOrZero(t.Bandwidth), - } + Bandwidth: bw, + }, nil } func getAddress(t *config.TunnelConfig) string { @@ -253,7 +271,10 @@ func getAddress(t *config.TunnelConfig) string { return "127.0.0.1" } -func parseBandwidthOrZero(s string) int64 { - bw, _ := parseBandwidth(s) - return bw +func validateTunnelBandwidth(t *config.TunnelConfig) error { + _, err := parseBandwidth(t.Bandwidth) + if err != nil { + return fmt.Errorf("invalid bandwidth for tunnel '%s': %w", t.Name, err) + } + return nil } diff --git a/internal/server/tcp/connection.go b/internal/server/tcp/connection.go index 4a16942..c069981 100644 --- a/internal/server/tcp/connection.go +++ b/internal/server/tcp/connection.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "math" "net" "net/http" "strconv" @@ -247,10 +248,11 @@ func (c *Connection) Handle() error { burstMultiplier = 2.0 } c.tunnelConn.SetBandwidthWithBurst(effectiveBandwidth, burstMultiplier) + burst := limiterBurst(effectiveBandwidth, burstMultiplier) limiter := qos.NewLimiter(qos.Config{ Bandwidth: effectiveBandwidth, - Burst: int(float64(effectiveBandwidth) * burstMultiplier), + Burst: burst, }) c.tunnelConn.SetLimiter(limiter) @@ -262,6 +264,7 @@ func (c *Connection) Handle() error { zap.String("subdomain", c.subdomain), zap.Int64("bandwidth_bytes_sec", effectiveBandwidth), zap.Float64("burst_multiplier", burstMultiplier), + zap.Int("burst_bytes", burst), zap.String("source", source), ) } @@ -527,3 +530,28 @@ func (c *Connection) SetBandwidthConfig(bandwidth int64, burstMultiplier float64 } c.burstMultiplier = burstMultiplier } + +func limiterBurst(bandwidth int64, burstMultiplier float64) int { + if bandwidth <= 0 { + return 0 + } + + if burstMultiplier <= 0 || math.IsNaN(burstMultiplier) || math.IsInf(burstMultiplier, 0) { + burstMultiplier = 2.0 + } + + maxBurst := int64(^uint(0) >> 1) + rawBurst := float64(bandwidth) * burstMultiplier + if math.IsNaN(rawBurst) || rawBurst <= 0 { + return 1 + } + if rawBurst >= float64(maxBurst) { + return int(maxBurst) + } + + burst := int(rawBurst) + if burst <= 0 { + return 1 + } + return burst +}