mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 12:53:43 +00:00
Merge pull request #23 from Gouryella/feat/qos-bandwidth-limiting-v2
feat(client): Add bandwidth limit function support
This commit is contained in:
16
README.md
16
README.md
@@ -35,6 +35,22 @@
|
||||
|
||||
## Recent Changes
|
||||
|
||||
### 2025-02-14
|
||||
|
||||
- **Bandwidth Limiting (QoS)** - Per-tunnel bandwidth control with token bucket algorithm, server enforces `min(client, server)` as effective limit
|
||||
- **Transport Protocol Control** - Support independent configuration for service domain and tunnel domain
|
||||
|
||||
```bash
|
||||
# Client: limit to 1MB/s
|
||||
drip http 3000 --bandwidth 1M
|
||||
```
|
||||
|
||||
```yaml
|
||||
# Server: global limit (config.yaml)
|
||||
bandwidth: 10M
|
||||
burst_multiplier: 2.5
|
||||
```
|
||||
|
||||
### 2025-01-29
|
||||
|
||||
- **Bearer Token Authentication** - Added bearer token authentication support for tunnel access control
|
||||
|
||||
16
README_CN.md
16
README_CN.md
@@ -35,6 +35,22 @@
|
||||
|
||||
## 最近更新
|
||||
|
||||
### 2025-02-14
|
||||
|
||||
- **带宽限速 (QoS)** - 支持按隧道粒度进行带宽控制,使用令牌桶算法,服务端按 `min(client, server)` 作为实际生效限速
|
||||
- **传输协议控制** - 支持服务域名与隧道域名的独立配置
|
||||
|
||||
```bash
|
||||
# Client: limit to 1MB/s
|
||||
drip http 3000 --bandwidth 1M
|
||||
```
|
||||
|
||||
```yaml
|
||||
# Server: global limit (config.yaml)
|
||||
bandwidth: 10M
|
||||
burst_multiplier: 2.5
|
||||
```
|
||||
|
||||
### 2025-01-29
|
||||
|
||||
- **Bearer Token 认证** - 新增 Bearer Token 认证支持,用于隧道访问控制
|
||||
|
||||
1
go.mod
1
go.mod
@@ -42,5 +42,6 @@ require (
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/text v0.33.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
)
|
||||
|
||||
2
go.sum
2
go.sum
@@ -95,6 +95,8 @@ golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
@@ -21,6 +21,7 @@ var (
|
||||
authPass string
|
||||
authBearer string
|
||||
transport string
|
||||
bandwidth string
|
||||
)
|
||||
|
||||
var httpCmd = &cobra.Command{
|
||||
@@ -37,6 +38,7 @@ Example:
|
||||
drip http 3000 --auth secret Enable proxy authentication with password
|
||||
drip http 3000 --auth-bearer sk-xxx Enable proxy authentication with bearer token
|
||||
drip http 3000 --transport wss Use WebSocket over TLS (CDN-friendly)
|
||||
drip http 3000 --bandwidth 1M Limit bandwidth to 1 MB/s
|
||||
|
||||
Configuration:
|
||||
First time: Run 'drip config init' to save server and token
|
||||
@@ -45,7 +47,13 @@ Configuration:
|
||||
Transport options:
|
||||
auto - Automatically select based on server address (default)
|
||||
tcp - Direct TLS 1.3 connection
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)
|
||||
|
||||
Bandwidth format:
|
||||
1K, 1KB - 1 kilobyte per second (1024 bytes/s)
|
||||
1M, 1MB - 1 megabyte per second (1048576 bytes/s)
|
||||
1G, 1GB - 1 gigabyte per second
|
||||
1024 - 1024 bytes per second (raw number)`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runHTTP,
|
||||
SilenceUsage: true,
|
||||
@@ -61,6 +69,7 @@ func init() {
|
||||
httpCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
|
||||
httpCmd.Flags().StringVar(&authBearer, "auth-bearer", "", "Bearer token for proxy authentication")
|
||||
httpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
|
||||
httpCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)")
|
||||
httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(httpCmd)
|
||||
@@ -85,6 +94,11 @@ func runHTTP(_ *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
bw, err := parseBandwidth(bandwidth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
@@ -98,6 +112,7 @@ func runHTTP(_ *cobra.Command, args []string) error {
|
||||
AuthPass: authPass,
|
||||
AuthBearer: authBearer,
|
||||
Transport: parseTransport(transport),
|
||||
Bandwidth: bw,
|
||||
}
|
||||
|
||||
var daemon *DaemonInfo
|
||||
@@ -118,3 +133,41 @@ func parseTransport(s string) tcp.TransportType {
|
||||
return tcp.TransportAuto
|
||||
}
|
||||
}
|
||||
|
||||
func parseBandwidth(s string) (int64, error) {
|
||||
if s == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
s = strings.TrimSpace(strings.ToUpper(s))
|
||||
if s == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var multiplier int64 = 1
|
||||
switch {
|
||||
case strings.HasSuffix(s, "GB") || strings.HasSuffix(s, "G"):
|
||||
multiplier = 1024 * 1024 * 1024
|
||||
s = strings.TrimSuffix(strings.TrimSuffix(s, "GB"), "G")
|
||||
case strings.HasSuffix(s, "MB") || strings.HasSuffix(s, "M"):
|
||||
multiplier = 1024 * 1024
|
||||
s = strings.TrimSuffix(strings.TrimSuffix(s, "MB"), "M")
|
||||
case strings.HasSuffix(s, "KB") || strings.HasSuffix(s, "K"):
|
||||
multiplier = 1024
|
||||
s = strings.TrimSuffix(strings.TrimSuffix(s, "KB"), "K")
|
||||
case strings.HasSuffix(s, "B"):
|
||||
s = strings.TrimSuffix(s, "B")
|
||||
}
|
||||
|
||||
val, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil || val < 0 {
|
||||
return 0, fmt.Errorf("invalid bandwidth value: %q (use format like 1M, 500K, 1G)", s)
|
||||
}
|
||||
|
||||
result := val * multiplier
|
||||
if val > 0 && result/multiplier != val {
|
||||
return 0, fmt.Errorf("bandwidth value overflow: %q", s)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
59
internal/client/cli/http_test.go
Normal file
59
internal/client/cli/http_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseBandwidth(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want int64
|
||||
wantErr bool
|
||||
}{
|
||||
{"", 0, false},
|
||||
{"0", 0, false},
|
||||
{"1024", 1024, false},
|
||||
{"1K", 1024, false},
|
||||
{"1KB", 1024, false},
|
||||
{"1k", 1024, false},
|
||||
{"1M", 1024 * 1024, false},
|
||||
{"1MB", 1024 * 1024, false},
|
||||
{"1m", 1024 * 1024, false},
|
||||
{"10M", 10 * 1024 * 1024, false},
|
||||
{"1G", 1024 * 1024 * 1024, false},
|
||||
{"1GB", 1024 * 1024 * 1024, false},
|
||||
{"500K", 500 * 1024, false},
|
||||
{"100M", 100 * 1024 * 1024, false},
|
||||
{" 1M ", 1024 * 1024, false},
|
||||
{"1B", 1, false},
|
||||
{"100B", 100, false},
|
||||
{"invalid", 0, true},
|
||||
{"abc", 0, true},
|
||||
{"-1M", 0, true},
|
||||
{"-100", 0, true},
|
||||
{"1.5M", 0, true},
|
||||
{"M", 0, true},
|
||||
{"K", 0, true},
|
||||
{"9223372036854775807K", 0, true},
|
||||
{"9999999999999999999G", 0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got, err := parseBandwidth(tt.input)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("parseBandwidth(%q) = %d, want error", tt.input, got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("parseBandwidth(%q) unexpected error: %v", tt.input, err)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("parseBandwidth(%q) = %d, want %d", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,7 @@ Example:
|
||||
drip https 443 --auth secret Enable proxy authentication with password
|
||||
drip https 443 --auth-bearer sk-xxx Enable proxy authentication with bearer token
|
||||
drip https 443 --transport wss Use WebSocket over TLS (CDN-friendly)
|
||||
drip https 443 --bandwidth 1M Limit bandwidth to 1 MB/s
|
||||
|
||||
Configuration:
|
||||
First time: Run 'drip config init' to save server and token
|
||||
@@ -32,7 +33,13 @@ Configuration:
|
||||
Transport options:
|
||||
auto - Automatically select based on server address (default)
|
||||
tcp - Direct TLS 1.3 connection
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)
|
||||
|
||||
Bandwidth format:
|
||||
1K, 1KB - 1 kilobyte per second (1024 bytes/s)
|
||||
1M, 1MB - 1 megabyte per second (1048576 bytes/s)
|
||||
1G, 1GB - 1 gigabyte per second
|
||||
1024 - 1024 bytes per second (raw number)`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runHTTPS,
|
||||
SilenceUsage: true,
|
||||
@@ -48,6 +55,7 @@ func init() {
|
||||
httpsCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
|
||||
httpsCmd.Flags().StringVar(&authBearer, "auth-bearer", "", "Bearer token for proxy authentication")
|
||||
httpsCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
|
||||
httpsCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)")
|
||||
httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpsCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(httpsCmd)
|
||||
@@ -72,6 +80,11 @@ func runHTTPS(_ *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
bw, err := parseBandwidth(bandwidth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
@@ -85,6 +98,7 @@ func runHTTPS(_ *cobra.Command, args []string) error {
|
||||
AuthPass: authPass,
|
||||
AuthBearer: authBearer,
|
||||
Transport: parseTransport(transport),
|
||||
Bandwidth: bw,
|
||||
}
|
||||
|
||||
var daemon *DaemonInfo
|
||||
|
||||
@@ -313,6 +313,24 @@ func runServer(cmd *cobra.Command, _ []string) error {
|
||||
listener.SetAllowedTransports(cfg.AllowedTransports)
|
||||
listener.SetAllowedTunnelTypes(cfg.AllowedTunnelTypes)
|
||||
|
||||
bandwidth, err := parseBandwidth(cfg.Bandwidth)
|
||||
if err != nil {
|
||||
logger.Fatal("Invalid bandwidth configuration", zap.Error(err))
|
||||
}
|
||||
burstMultiplier := cfg.BurstMultiplier
|
||||
if burstMultiplier <= 0 {
|
||||
burstMultiplier = 2.0
|
||||
}
|
||||
listener.SetBandwidth(bandwidth)
|
||||
listener.SetBurstMultiplier(burstMultiplier)
|
||||
if bandwidth > 0 {
|
||||
logger.Info("Bandwidth limit configured",
|
||||
zap.String("bandwidth", cfg.Bandwidth),
|
||||
zap.Int64("bandwidth_bytes_sec", bandwidth),
|
||||
zap.Float64("burst_multiplier", burstMultiplier),
|
||||
)
|
||||
}
|
||||
|
||||
if err := listener.Start(); err != nil {
|
||||
logger.Fatal("Failed to start TCP listener", zap.Error(err))
|
||||
}
|
||||
|
||||
@@ -106,6 +106,12 @@ func runStart(_ *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("no tunnels to start")
|
||||
}
|
||||
|
||||
for _, t := range tunnelsToStart {
|
||||
if err := validateTunnelBandwidth(t); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Start tunnels
|
||||
if len(tunnelsToStart) == 1 {
|
||||
return startSingleTunnel(cfg, tunnelsToStart[0])
|
||||
@@ -127,7 +133,10 @@ func formatTunnelInfo(t *config.TunnelConfig) string {
|
||||
}
|
||||
|
||||
func startSingleTunnel(cfg *config.ClientConfig, t *config.TunnelConfig) error {
|
||||
connConfig := buildConnectorConfig(cfg, t)
|
||||
connConfig, err := buildConnectorConfig(cfg, t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Starting tunnel '%s' (%s %s:%d)\n", t.Name, t.Type, getAddress(t), t.Port)
|
||||
|
||||
@@ -164,7 +173,11 @@ func startMultipleTunnels(cfg *config.ClientConfig, tunnels []*config.TunnelConf
|
||||
go func(tunnel *config.TunnelConfig) {
|
||||
defer wg.Done()
|
||||
|
||||
connConfig := buildConnectorConfig(cfg, tunnel)
|
||||
connConfig, err := buildConnectorConfig(cfg, tunnel)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
fmt.Printf(" Starting %s (%s %s:%d)...\n", tunnel.Name, tunnel.Type, getAddress(tunnel), tunnel.Port)
|
||||
|
||||
client := tcp.NewTunnelClient(connConfig, logger)
|
||||
@@ -212,7 +225,12 @@ func startMultipleTunnels(cfg *config.ClientConfig, tunnels []*config.TunnelConf
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp.ConnectorConfig {
|
||||
func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) (*tcp.ConnectorConfig, error) {
|
||||
bw, err := parseBandwidth(t.Bandwidth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid bandwidth for tunnel '%s': %w", t.Name, err)
|
||||
}
|
||||
|
||||
tunnelType := protocol.TunnelTypeHTTP
|
||||
switch t.Type {
|
||||
case "https":
|
||||
@@ -242,7 +260,8 @@ func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp
|
||||
AuthPass: t.Auth,
|
||||
AuthBearer: t.AuthBearer,
|
||||
Transport: transport,
|
||||
}
|
||||
Bandwidth: bw,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getAddress(t *config.TunnelConfig) string {
|
||||
@@ -251,3 +270,11 @@ func getAddress(t *config.TunnelConfig) string {
|
||||
}
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
func validateTunnelBandwidth(t *config.TunnelConfig) error {
|
||||
_, err := parseBandwidth(t.Bandwidth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid bandwidth for tunnel '%s': %w", t.Name, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ Example:
|
||||
drip tcp 22 --allow-ip 10.0.0.1 Allow single IP
|
||||
drip tcp 22 --deny-ip 1.2.3.4 Block specific IP
|
||||
drip tcp 22 --transport wss Use WebSocket over TLS (CDN-friendly)
|
||||
drip tcp 22 --bandwidth 1M Limit bandwidth to 1 MB/s
|
||||
|
||||
Supported Services:
|
||||
- Databases: PostgreSQL (5432), MySQL (3306), Redis (6379), MongoDB (27017)
|
||||
@@ -54,6 +55,7 @@ func init() {
|
||||
tcpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
|
||||
tcpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
|
||||
tcpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
|
||||
tcpCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)")
|
||||
tcpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
tcpCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(tcpCmd)
|
||||
@@ -74,6 +76,11 @@ func runTCP(_ *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
bw, err := parseBandwidth(bandwidth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
@@ -85,6 +92,7 @@ func runTCP(_ *cobra.Command, args []string) error {
|
||||
AllowIPs: allowIPs,
|
||||
DenyIPs: denyIPs,
|
||||
Transport: parseTransport(transport),
|
||||
Bandwidth: bw,
|
||||
}
|
||||
|
||||
var daemon *DaemonInfo
|
||||
|
||||
@@ -30,6 +30,9 @@ func buildDaemonArgs(tunnelType string, args []string, subdomain string, localAd
|
||||
if authBearer != "" {
|
||||
daemonArgs = append(daemonArgs, "--auth-bearer", authBearer)
|
||||
}
|
||||
if bandwidth != "" {
|
||||
daemonArgs = append(daemonArgs, "--bandwidth", bandwidth)
|
||||
}
|
||||
if insecure {
|
||||
daemonArgs = append(daemonArgs, "--insecure")
|
||||
}
|
||||
|
||||
@@ -46,6 +46,9 @@ type ConnectorConfig struct {
|
||||
|
||||
// Transport protocol selection
|
||||
Transport TransportType
|
||||
|
||||
// Bandwidth limit (bytes/sec), 0 = unlimited
|
||||
Bandwidth int64
|
||||
}
|
||||
|
||||
type TunnelClient interface {
|
||||
|
||||
@@ -81,6 +81,9 @@ type PoolClient struct {
|
||||
|
||||
// Session scaler
|
||||
scaler *SessionScaler
|
||||
|
||||
// Bandwidth limit requested from server (bytes/sec), 0 = unlimited
|
||||
bandwidth int64
|
||||
}
|
||||
|
||||
// NewPoolClient creates a new pool client.
|
||||
@@ -178,6 +181,7 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
transport: transport,
|
||||
insecure: cfg.Insecure,
|
||||
dialer: NewConnectionDialer(serverAddr, tlsConfig, cfg.Token, transport, logger),
|
||||
bandwidth: cfg.Bandwidth,
|
||||
}
|
||||
|
||||
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
|
||||
@@ -229,6 +233,10 @@ func (c *PoolClient) Connect() error {
|
||||
}
|
||||
}
|
||||
|
||||
if c.bandwidth > 0 {
|
||||
req.Bandwidth = c.bandwidth
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
_ = primaryConn.Close()
|
||||
@@ -275,6 +283,10 @@ func (c *PoolClient) Connect() error {
|
||||
c.tunnelID = resp.TunnelID
|
||||
}
|
||||
|
||||
if resp.Bandwidth > 0 {
|
||||
c.bandwidth = resp.Bandwidth
|
||||
}
|
||||
|
||||
yamuxCfg := mux.NewClientConfig()
|
||||
|
||||
session, err := yamux.Server(primaryConn, yamuxCfg)
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/qos"
|
||||
)
|
||||
|
||||
// bufio.Reader pool to reduce allocations on hot path
|
||||
@@ -247,7 +248,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
tconn.IncActiveConnections()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
countingStream := netutil.NewCountingConn(stream,
|
||||
var limitedStream net.Conn = stream
|
||||
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
|
||||
if l, ok := limiter.(*qos.Limiter); ok {
|
||||
limitedStream = qos.NewLimitedConn(r.Context(), stream, l)
|
||||
}
|
||||
}
|
||||
|
||||
countingStream := netutil.NewCountingConn(limitedStream,
|
||||
tconn.AddBytesOut,
|
||||
tconn.AddBytesIn,
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"drip/internal/shared/httputil"
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/qos"
|
||||
"drip/internal/shared/wsutil"
|
||||
)
|
||||
|
||||
@@ -58,6 +59,13 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
|
||||
return
|
||||
}
|
||||
|
||||
var limitedStream net.Conn = stream
|
||||
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
|
||||
if l, ok := limiter.(*qos.Limiter); ok {
|
||||
limitedStream = qos.NewLimitedConn(context.Background(), stream, l)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
defer clientConn.Close()
|
||||
@@ -71,7 +79,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
|
||||
}
|
||||
}
|
||||
|
||||
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
|
||||
_ = netutil.PipeWithCallbacks(context.Background(), limitedStream, clientRW,
|
||||
func(n int64) { tconn.AddBytesOut(n) },
|
||||
func(n int64) { tconn.AddBytesIn(n) },
|
||||
)
|
||||
|
||||
93
internal/server/tcp/bandwidth_test.go
Normal file
93
internal/server/tcp/bandwidth_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEffectiveBandwidthSelection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverBW int64
|
||||
clientBW int64
|
||||
wantEffective int64
|
||||
}{
|
||||
{"server only", 1024 * 1024, 0, 1024 * 1024},
|
||||
{"client only", 0, 512 * 1024, 512 * 1024},
|
||||
{"both unlimited", 0, 0, 0},
|
||||
{"client lower than server", 10 * 1024 * 1024, 1 * 1024 * 1024, 1 * 1024 * 1024},
|
||||
{"client higher than server - server wins", 1 * 1024 * 1024, 10 * 1024 * 1024, 1 * 1024 * 1024},
|
||||
{"client equal to server", 5 * 1024 * 1024, 5 * 1024 * 1024, 5 * 1024 * 1024},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
effectiveBandwidth := tt.serverBW
|
||||
if tt.clientBW > 0 {
|
||||
if effectiveBandwidth == 0 || tt.clientBW < effectiveBandwidth {
|
||||
effectiveBandwidth = tt.clientBW
|
||||
}
|
||||
}
|
||||
|
||||
if effectiveBandwidth != tt.wantEffective {
|
||||
t.Errorf("effectiveBandwidth = %d, want %d", effectiveBandwidth, tt.wantEffective)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionSetBandwidthConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
wantBandwidth int64
|
||||
wantMultiplier float64
|
||||
}{
|
||||
{"1MB/s with 2x burst", 1024 * 1024, 2.0, 1024 * 1024, 2.0},
|
||||
{"default multiplier when 0", 1024 * 1024, 0, 1024 * 1024, 2.0},
|
||||
{"default multiplier when negative", 1024 * 1024, -1.0, 1024 * 1024, 2.0},
|
||||
{"unlimited bandwidth", 0, 2.5, 0, 2.5},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
conn := &Connection{}
|
||||
conn.SetBandwidthConfig(tt.bandwidth, tt.burstMultiplier)
|
||||
|
||||
if conn.bandwidth != tt.wantBandwidth {
|
||||
t.Errorf("bandwidth = %v, want %v", conn.bandwidth, tt.wantBandwidth)
|
||||
}
|
||||
if conn.burstMultiplier != tt.wantMultiplier {
|
||||
t.Errorf("burstMultiplier = %v, want %v", conn.burstMultiplier, tt.wantMultiplier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenerBandwidthConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
wantBandwidth int64
|
||||
wantMultiplier float64
|
||||
}{
|
||||
{"set bandwidth and multiplier", 1024 * 1024, 2.5, 1024 * 1024, 2.5},
|
||||
{"default multiplier", 1024 * 1024, 0, 1024 * 1024, 2.0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l := &Listener{}
|
||||
l.SetBandwidth(tt.bandwidth)
|
||||
l.SetBurstMultiplier(tt.burstMultiplier)
|
||||
|
||||
if l.bandwidth != tt.wantBandwidth {
|
||||
t.Errorf("bandwidth = %v, want %v", l.bandwidth, tt.wantBandwidth)
|
||||
}
|
||||
if l.burstMultiplier != tt.wantMultiplier {
|
||||
t.Errorf("burstMultiplier = %v, want %v", l.burstMultiplier, tt.wantMultiplier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/httputil"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/qos"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -69,6 +71,8 @@ type Connection struct {
|
||||
// Server capabilities
|
||||
allowedTunnelTypes []string
|
||||
allowedTransports []string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
}
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
@@ -231,11 +235,46 @@ func (c *Connection) Handle() error {
|
||||
)
|
||||
}
|
||||
|
||||
// Configure bandwidth limiting
|
||||
effectiveBandwidth := c.bandwidth
|
||||
if req.Bandwidth > 0 {
|
||||
if effectiveBandwidth == 0 || req.Bandwidth < effectiveBandwidth {
|
||||
effectiveBandwidth = req.Bandwidth
|
||||
}
|
||||
}
|
||||
if effectiveBandwidth > 0 {
|
||||
burstMultiplier := c.burstMultiplier
|
||||
if burstMultiplier <= 0 {
|
||||
burstMultiplier = 2.0
|
||||
}
|
||||
c.tunnelConn.SetBandwidthWithBurst(effectiveBandwidth, burstMultiplier)
|
||||
burst := limiterBurst(effectiveBandwidth, burstMultiplier)
|
||||
|
||||
limiter := qos.NewLimiter(qos.Config{
|
||||
Bandwidth: effectiveBandwidth,
|
||||
Burst: burst,
|
||||
})
|
||||
c.tunnelConn.SetLimiter(limiter)
|
||||
|
||||
source := "server"
|
||||
if req.Bandwidth > 0 && (c.bandwidth == 0 || req.Bandwidth < c.bandwidth) {
|
||||
source = "client"
|
||||
}
|
||||
c.logger.Info("Bandwidth limit configured",
|
||||
zap.String("subdomain", c.subdomain),
|
||||
zap.Int64("bandwidth_bytes_sec", effectiveBandwidth),
|
||||
zap.Float64("burst_multiplier", burstMultiplier),
|
||||
zap.Int("burst_bytes", burst),
|
||||
zap.String("source", source),
|
||||
)
|
||||
}
|
||||
|
||||
// Build and send registration response
|
||||
resp, err := regHandler.BuildRegistrationResponse(result)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build registration response: %w", err)
|
||||
}
|
||||
resp.Bandwidth = c.tunnelConn.GetBandwidth()
|
||||
|
||||
if err := regHandler.SendRegistrationResponse(c.conn, resp); err != nil {
|
||||
return fmt.Errorf("failed to send registration ack: %w", err)
|
||||
@@ -483,3 +522,36 @@ func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Connection) SetBandwidthConfig(bandwidth int64, burstMultiplier float64) {
|
||||
c.bandwidth = bandwidth
|
||||
if burstMultiplier <= 0 {
|
||||
burstMultiplier = 2.0
|
||||
}
|
||||
c.burstMultiplier = burstMultiplier
|
||||
}
|
||||
|
||||
func limiterBurst(bandwidth int64, burstMultiplier float64) int {
|
||||
if bandwidth <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
if burstMultiplier <= 0 || math.IsNaN(burstMultiplier) || math.IsInf(burstMultiplier, 0) {
|
||||
burstMultiplier = 2.0
|
||||
}
|
||||
|
||||
maxBurst := int64(^uint(0) >> 1)
|
||||
rawBurst := float64(bandwidth) * burstMultiplier
|
||||
if math.IsNaN(rawBurst) || rawBurst <= 0 {
|
||||
return 1
|
||||
}
|
||||
if rawBurst >= float64(maxBurst) {
|
||||
return int(maxBurst)
|
||||
}
|
||||
|
||||
burst := int(rawBurst)
|
||||
if burst <= 0 {
|
||||
return 1
|
||||
}
|
||||
return burst
|
||||
}
|
||||
|
||||
@@ -60,6 +60,8 @@ type Listener struct {
|
||||
// Server capabilities
|
||||
allowedTransports []string
|
||||
allowedTunnelTypes []string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
}
|
||||
|
||||
func NewListener(cfg ListenerConfig) *Listener {
|
||||
@@ -298,6 +300,7 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
})
|
||||
conn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
|
||||
conn.SetAllowedTransports(l.allowedTransports)
|
||||
conn.SetBandwidthConfig(l.bandwidth, l.burstMultiplier)
|
||||
|
||||
connID := netConn.RemoteAddr().String()
|
||||
l.connMu.Lock()
|
||||
@@ -420,6 +423,8 @@ func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
|
||||
HTTPListener: l.httpListener,
|
||||
})
|
||||
tcpConn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
|
||||
tcpConn.SetAllowedTransports(l.allowedTransports)
|
||||
tcpConn.SetBandwidthConfig(l.bandwidth, l.burstMultiplier)
|
||||
|
||||
l.connMu.Lock()
|
||||
l.connections[connID] = tcpConn
|
||||
@@ -471,6 +476,17 @@ func (l *Listener) SetAllowedTunnelTypes(types []string) {
|
||||
l.allowedTunnelTypes = types
|
||||
}
|
||||
|
||||
func (l *Listener) SetBandwidth(bandwidth int64) {
|
||||
l.bandwidth = bandwidth
|
||||
}
|
||||
|
||||
func (l *Listener) SetBurstMultiplier(multiplier float64) {
|
||||
if multiplier <= 0 {
|
||||
multiplier = 2.0
|
||||
}
|
||||
l.burstMultiplier = multiplier
|
||||
}
|
||||
|
||||
// IsTransportAllowed checks if a transport is allowed
|
||||
func (l *Listener) IsTransportAllowed(transport string) bool {
|
||||
if len(l.allowedTransports) == 0 {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/qos"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -34,6 +35,7 @@ type Proxy struct {
|
||||
cancel context.CancelFunc
|
||||
|
||||
checkIPAccess func(ip string) bool
|
||||
limiter interface{ IsLimited() bool }
|
||||
}
|
||||
|
||||
type trafficStats interface {
|
||||
@@ -73,6 +75,11 @@ func (p *Proxy) SetIPAccessCheck(check func(ip string) bool) {
|
||||
p.checkIPAccess = check
|
||||
}
|
||||
|
||||
// SetLimiter sets the bandwidth limiter for this proxy.
|
||||
func (p *Proxy) SetLimiter(limiter interface{ IsLimited() bool }) {
|
||||
p.limiter = limiter
|
||||
}
|
||||
|
||||
func (p *Proxy) Start() error {
|
||||
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
|
||||
|
||||
@@ -240,10 +247,17 @@ func (p *Proxy) handleConn(conn net.Conn) {
|
||||
|
||||
defer stream.Close()
|
||||
|
||||
var limitedStream net.Conn = stream
|
||||
if p.limiter != nil && p.limiter.IsLimited() {
|
||||
if l, ok := p.limiter.(*qos.Limiter); ok {
|
||||
limitedStream = qos.NewLimitedConn(p.ctx, stream, l)
|
||||
}
|
||||
}
|
||||
|
||||
_ = netutil.PipeWithCallbacksAndBufferSize(
|
||||
p.ctx,
|
||||
conn,
|
||||
stream,
|
||||
limitedStream,
|
||||
pool.SizeLarge,
|
||||
func(n int64) {
|
||||
if p.stats != nil {
|
||||
|
||||
@@ -52,6 +52,9 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
if c.tunnelConn != nil && c.tunnelConn.HasIPAccessControl() {
|
||||
c.proxy.SetIPAccessCheck(c.tunnelConn.IsIPAllowed)
|
||||
}
|
||||
if c.tunnelConn != nil {
|
||||
c.proxy.SetLimiter(c.tunnelConn.GetLimiter())
|
||||
}
|
||||
|
||||
// Update lifecycle manager with proxy
|
||||
if c.lifecycleManager != nil {
|
||||
|
||||
@@ -32,6 +32,10 @@ type Connection struct {
|
||||
|
||||
ipAccessChecker *netutil.IPAccessChecker
|
||||
proxyAuth *protocol.ProxyAuth
|
||||
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
limiter interface{ IsLimited() bool }
|
||||
}
|
||||
|
||||
func NewConnection(subdomain string, conn *websocket.Conn, logger *zap.Logger) *Connection {
|
||||
@@ -214,6 +218,31 @@ func (c *Connection) ValidateProxyAuth(password string) bool {
|
||||
return auth.Password == password
|
||||
}
|
||||
|
||||
func (c *Connection) SetBandwidthWithBurst(bandwidth int64, burstMultiplier float64) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.bandwidth = bandwidth
|
||||
c.burstMultiplier = burstMultiplier
|
||||
}
|
||||
|
||||
func (c *Connection) GetBandwidth() int64 {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.bandwidth
|
||||
}
|
||||
|
||||
func (c *Connection) SetLimiter(limiter interface{ IsLimited() bool }) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.limiter = limiter
|
||||
}
|
||||
|
||||
func (c *Connection) GetLimiter() interface{ IsLimited() bool } {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.limiter
|
||||
}
|
||||
|
||||
func (c *Connection) StartWritePump() {
|
||||
if c.Conn == nil {
|
||||
go func() {
|
||||
|
||||
59
internal/server/tunnel/connection_test.go
Normal file
59
internal/server/tunnel/connection_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"drip/internal/shared/qos"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConnectionBandwidthWithBurst(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
wantBandwidth int64
|
||||
}{
|
||||
{"1MB/s with 2x burst", 1024 * 1024, 2.0, 1024 * 1024},
|
||||
{"500KB/s with 3x burst", 500 * 1024, 3.0, 500 * 1024},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
conn := NewConnection("test-subdomain", nil, logger)
|
||||
conn.SetBandwidthWithBurst(tt.bandwidth, tt.burstMultiplier)
|
||||
|
||||
if conn.GetBandwidth() != tt.wantBandwidth {
|
||||
t.Errorf("GetBandwidth() = %v, want %v", conn.GetBandwidth(), tt.wantBandwidth)
|
||||
}
|
||||
|
||||
burst := int(float64(tt.bandwidth) * tt.burstMultiplier)
|
||||
limiter := qos.NewLimiter(qos.Config{Bandwidth: tt.bandwidth, Burst: burst})
|
||||
conn.SetLimiter(limiter)
|
||||
|
||||
got := conn.GetLimiter()
|
||||
if got == nil {
|
||||
t.Fatal("GetLimiter() should not be nil")
|
||||
}
|
||||
if !got.IsLimited() {
|
||||
t.Error("Limiter should be limited")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionBandwidthUnlimited(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test-subdomain", nil, logger)
|
||||
|
||||
if conn.GetBandwidth() != 0 {
|
||||
t.Errorf("Default bandwidth should be 0, got %v", conn.GetBandwidth())
|
||||
}
|
||||
|
||||
if conn.GetLimiter() != nil {
|
||||
t.Error("Default limiter should be nil")
|
||||
}
|
||||
}
|
||||
@@ -29,6 +29,7 @@ type RegisterRequest struct {
|
||||
PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"`
|
||||
IPAccess *IPAccessControl `json:"ip_access,omitempty"`
|
||||
ProxyAuth *ProxyAuth `json:"proxy_auth,omitempty"`
|
||||
Bandwidth int64 `json:"bandwidth,omitempty"`
|
||||
}
|
||||
|
||||
type RegisterResponse struct {
|
||||
@@ -39,6 +40,7 @@ type RegisterResponse struct {
|
||||
TunnelID string `json:"tunnel_id,omitempty"`
|
||||
SupportsDataConn bool `json:"supports_data_conn,omitempty"`
|
||||
RecommendedConns int `json:"recommended_conns,omitempty"`
|
||||
Bandwidth int64 `json:"bandwidth,omitempty"`
|
||||
}
|
||||
|
||||
type DataConnectRequest struct {
|
||||
|
||||
112
internal/shared/qos/conn.go
Normal file
112
internal/shared/qos/conn.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type LimitedConn struct {
|
||||
net.Conn
|
||||
limiter *Limiter
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewLimitedConn(ctx context.Context, conn net.Conn, limiter *Limiter) *LimitedConn {
|
||||
return &LimitedConn{
|
||||
Conn: conn,
|
||||
limiter: limiter,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LimitedConn) Read(b []byte) (n int, err error) {
|
||||
if c.limiter == nil || !c.limiter.IsLimited() {
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
burst := c.limiter.RateLimiter().Burst()
|
||||
if len(b) > burst {
|
||||
b = b[:burst]
|
||||
}
|
||||
|
||||
n, err = c.Conn.Read(b)
|
||||
if n > 0 {
|
||||
if waitErr := c.limiter.RateLimiter().WaitN(c.ctx, n); waitErr != nil {
|
||||
if err == nil {
|
||||
err = waitErr
|
||||
}
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *LimitedConn) Write(b []byte) (n int, err error) {
|
||||
if c.limiter == nil || !c.limiter.IsLimited() {
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
burst := c.limiter.RateLimiter().Burst()
|
||||
total := 0
|
||||
|
||||
for len(b) > 0 {
|
||||
chunk := min(len(b), burst)
|
||||
|
||||
if err := c.limiter.RateLimiter().WaitN(c.ctx, chunk); err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
nw, err := c.Conn.Write(b[:chunk])
|
||||
total += nw
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
b = b[chunk:]
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (c *LimitedConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
nr, er := r.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := c.Write(buf[:nr])
|
||||
n += int64(nw)
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er != io.EOF {
|
||||
err = er
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *LimitedConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
nr, er := c.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := w.Write(buf[:nr])
|
||||
n += int64(nw)
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er != io.EOF {
|
||||
err = er
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
254
internal/shared/qos/conn_test.go
Normal file
254
internal/shared/qos/conn_test.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type errorAfterConn struct {
|
||||
mockConn
|
||||
writeLimit int
|
||||
written int
|
||||
}
|
||||
|
||||
func (c *errorAfterConn) Write(b []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
remaining := c.writeLimit - c.written
|
||||
if remaining <= 0 {
|
||||
return 0, errors.New("write error")
|
||||
}
|
||||
if len(b) > remaining {
|
||||
b = b[:remaining]
|
||||
}
|
||||
c.writeBuf = append(c.writeBuf, b...)
|
||||
c.written += len(b)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func TestWriteLargerThanBurst(t *testing.T) {
|
||||
bandwidth := int64(10 * 1024)
|
||||
burst := 1024
|
||||
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
data := make([]byte, 5*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
n, err := lc.Write(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("Write returned %d, want %d", n, len(data))
|
||||
}
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
if !bytes.Equal(conn.writeBuf, data) {
|
||||
t.Error("Written data does not match input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteZeroLength(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
n, err := lc.Write(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Write(nil) failed: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("Write(nil) returned %d, want 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteContextCancelDuringChunking(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 100, Burst: 100})
|
||||
conn := newMockConn(nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
lc := NewLimitedConn(ctx, conn, limiter)
|
||||
|
||||
_, err := lc.Write(make([]byte, 100))
|
||||
if err != nil {
|
||||
t.Fatalf("First write failed: %v", err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
_, err = lc.Write(make([]byte, 200))
|
||||
if err == nil {
|
||||
t.Error("Write should fail after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCappedToBurst(t *testing.T) {
|
||||
burst := 512
|
||||
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
|
||||
|
||||
data := make([]byte, 4096)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
conn := newMockConn(data)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := lc.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
if n > burst {
|
||||
t.Errorf("Read returned %d bytes, should be capped at burst=%d", n, burst)
|
||||
}
|
||||
if !bytes.Equal(buf[:n], data[:n]) {
|
||||
t.Error("Read data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadEOF(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024})
|
||||
conn := newMockConn([]byte{})
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
buf := make([]byte, 100)
|
||||
_, err := lc.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("Expected io.EOF, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilLimiter(t *testing.T) {
|
||||
conn := newMockConn([]byte("hello"))
|
||||
lc := NewLimitedConn(context.Background(), conn, nil)
|
||||
|
||||
buf := make([]byte, 10)
|
||||
n, err := lc.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read with nil limiter failed: %v", err)
|
||||
}
|
||||
if n != 5 {
|
||||
t.Errorf("Read returned %d, want 5", n)
|
||||
}
|
||||
|
||||
n, err = lc.Write([]byte("world"))
|
||||
if err != nil {
|
||||
t.Fatalf("Write with nil limiter failed: %v", err)
|
||||
}
|
||||
if n != 5 {
|
||||
t.Errorf("Write returned %d, want 5", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentReadWrite(t *testing.T) {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
limiter := NewLimiter(Config{Bandwidth: 100 * 1024, Burst: 100 * 1024})
|
||||
lc := NewLimitedConn(context.Background(), serverConn, limiter)
|
||||
|
||||
dataSize := 50 * 1024
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
data := make([]byte, dataSize)
|
||||
for i := range data {
|
||||
data[i] = 0xAA
|
||||
}
|
||||
lc.Write(data)
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, dataSize)
|
||||
total := 0
|
||||
for total < dataSize {
|
||||
n, err := clientConn.Read(buf[total:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
total += n
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestReadFromBasic(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
src := bytes.NewReader(make([]byte, 50*1024))
|
||||
n, err := lc.ReadFrom(src)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrom failed: %v", err)
|
||||
}
|
||||
if n != 50*1024 {
|
||||
t.Errorf("ReadFrom transferred %d bytes, want %d", n, 50*1024)
|
||||
}
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
if len(conn.writeBuf) != 50*1024 {
|
||||
t.Errorf("Underlying conn received %d bytes, want %d", len(conn.writeBuf), 50*1024)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteToBasic(t *testing.T) {
|
||||
data := make([]byte, 50*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
conn := newMockConn(data)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
var buf bytes.Buffer
|
||||
n, err := lc.WriteTo(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTo failed: %v", err)
|
||||
}
|
||||
if n != int64(len(data)) {
|
||||
t.Errorf("WriteTo transferred %d bytes, want %d", n, len(data))
|
||||
}
|
||||
if !bytes.Equal(buf.Bytes(), data) {
|
||||
t.Error("WriteTo data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlimitedWrite(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 0})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
data := make([]byte, 1024*1024)
|
||||
start := time.Now()
|
||||
n, err := lc.Write(data)
|
||||
dur := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("Write returned %d, want %d", n, len(data))
|
||||
}
|
||||
if dur > 50*time.Millisecond {
|
||||
t.Errorf("Unlimited write took too long: %v", dur)
|
||||
}
|
||||
}
|
||||
34
internal/shared/qos/limiter.go
Normal file
34
internal/shared/qos/limiter.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Bandwidth int64
|
||||
Burst int
|
||||
}
|
||||
|
||||
type Limiter struct {
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func NewLimiter(cfg Config) *Limiter {
|
||||
l := &Limiter{}
|
||||
if cfg.Bandwidth > 0 {
|
||||
burst := cfg.Burst
|
||||
if burst <= 0 {
|
||||
burst = int(cfg.Bandwidth * 2)
|
||||
}
|
||||
l.limiter = rate.NewLimiter(rate.Limit(cfg.Bandwidth), burst)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Limiter) RateLimiter() *rate.Limiter {
|
||||
return l.limiter
|
||||
}
|
||||
|
||||
func (l *Limiter) IsLimited() bool {
|
||||
return l.limiter != nil
|
||||
}
|
||||
172
internal/shared/qos/limiter_test.go
Normal file
172
internal/shared/qos/limiter_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewLimiter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantLimit bool
|
||||
wantBurst int
|
||||
}{
|
||||
{
|
||||
name: "unlimited when bandwidth is 0",
|
||||
cfg: Config{Bandwidth: 0},
|
||||
wantLimit: false,
|
||||
},
|
||||
{
|
||||
name: "limited with default burst (2x)",
|
||||
cfg: Config{Bandwidth: 1024},
|
||||
wantLimit: true,
|
||||
wantBurst: 2048,
|
||||
},
|
||||
{
|
||||
name: "limited with custom burst",
|
||||
cfg: Config{Bandwidth: 1024, Burst: 4096},
|
||||
wantLimit: true,
|
||||
wantBurst: 4096,
|
||||
},
|
||||
{
|
||||
name: "1MB/s with 2x burst",
|
||||
cfg: Config{Bandwidth: 1024 * 1024},
|
||||
wantLimit: true,
|
||||
wantBurst: 2 * 1024 * 1024,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l := NewLimiter(tt.cfg)
|
||||
if l.IsLimited() != tt.wantLimit {
|
||||
t.Errorf("IsLimited() = %v, want %v", l.IsLimited(), tt.wantLimit)
|
||||
}
|
||||
if tt.wantLimit {
|
||||
if l.RateLimiter() == nil {
|
||||
t.Error("RateLimiter() should not be nil when limited")
|
||||
}
|
||||
if l.RateLimiter().Burst() != tt.wantBurst {
|
||||
t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiterBandwidthEnforcement(t *testing.T) {
|
||||
bandwidth := int64(10 * 1024)
|
||||
burst := int(bandwidth * 2)
|
||||
|
||||
l := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
|
||||
if !l.IsLimited() {
|
||||
t.Fatal("Limiter should be limited")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
start := time.Now()
|
||||
err := l.RateLimiter().WaitN(ctx, burst)
|
||||
if err != nil {
|
||||
t.Fatalf("WaitN failed: %v", err)
|
||||
}
|
||||
burstDuration := time.Since(start)
|
||||
if burstDuration > 100*time.Millisecond {
|
||||
t.Errorf("Burst should be instant, took %v", burstDuration)
|
||||
}
|
||||
|
||||
start = time.Now()
|
||||
err = l.RateLimiter().WaitN(ctx, int(bandwidth))
|
||||
if err != nil {
|
||||
t.Fatalf("WaitN failed: %v", err)
|
||||
}
|
||||
limitedDuration := time.Since(start)
|
||||
|
||||
if limitedDuration < 800*time.Millisecond {
|
||||
t.Errorf("Rate limiting not working, took only %v for 1 second worth of data", limitedDuration)
|
||||
}
|
||||
if limitedDuration > 1500*time.Millisecond {
|
||||
t.Errorf("Rate limiting too slow, took %v for 1 second worth of data", limitedDuration)
|
||||
}
|
||||
}
|
||||
|
||||
type mockConn struct {
|
||||
readBuf []byte
|
||||
readPos int
|
||||
writeBuf []byte
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockConn(data []byte) *mockConn {
|
||||
return &mockConn{readBuf: data}
|
||||
}
|
||||
|
||||
func (c *mockConn) Read(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.readPos >= len(c.readBuf) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n = copy(b, c.readBuf[c.readPos:])
|
||||
c.readPos += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Write(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.writeBuf = append(c.writeBuf, b...)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Close() error { return nil }
|
||||
func (c *mockConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *mockConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *mockConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (c *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func TestBurstMultiplier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
multiplier float64
|
||||
wantBurst int
|
||||
}{
|
||||
{
|
||||
name: "2x multiplier",
|
||||
bandwidth: 1024 * 1024,
|
||||
multiplier: 2.0,
|
||||
wantBurst: 2 * 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "1x multiplier (no extra burst)",
|
||||
bandwidth: 1024 * 1024,
|
||||
multiplier: 1.0,
|
||||
wantBurst: 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "3x multiplier",
|
||||
bandwidth: 500 * 1024,
|
||||
multiplier: 3.0,
|
||||
wantBurst: 3 * 500 * 1024,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
burst := int(float64(tt.bandwidth) * tt.multiplier)
|
||||
l := NewLimiter(Config{Bandwidth: tt.bandwidth, Burst: burst})
|
||||
|
||||
if l.RateLimiter().Burst() != tt.wantBurst {
|
||||
t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ type TunnelConfig struct {
|
||||
DenyIPs []string `yaml:"deny_ips,omitempty"` // Denied IPs/CIDRs
|
||||
Auth string `yaml:"auth,omitempty"` // Proxy authentication password (http/https only)
|
||||
AuthBearer string `yaml:"auth_bearer,omitempty"` // Proxy authentication bearer token (http/https only)
|
||||
Bandwidth string `yaml:"bandwidth,omitempty"` // Bandwidth limit (e.g., 1M, 500K, 1G)
|
||||
}
|
||||
|
||||
// Validate checks if the tunnel configuration is valid
|
||||
|
||||
@@ -41,6 +41,10 @@ type ServerConfig struct {
|
||||
|
||||
// Allowed tunnel types: "http", "https", "tcp" (default: all)
|
||||
AllowedTunnelTypes []string `yaml:"tunnel_types"`
|
||||
|
||||
// Bandwidth limiting
|
||||
Bandwidth string `yaml:"bandwidth,omitempty"`
|
||||
BurstMultiplier float64 `yaml:"burst_multiplier,omitempty"`
|
||||
}
|
||||
|
||||
// Validate checks if the server configuration is valid
|
||||
|
||||
63
pkg/config/config_test.go
Normal file
63
pkg/config/config_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestServerConfigBandwidth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
yaml string
|
||||
wantBandwidth string
|
||||
wantMultiplier float64
|
||||
}{
|
||||
{
|
||||
name: "bandwidth 1M with 2.5x burst",
|
||||
yaml: `
|
||||
port: 8443
|
||||
domain: example.com
|
||||
tcp_port_min: 10000
|
||||
tcp_port_max: 20000
|
||||
bandwidth: 1M
|
||||
burst_multiplier: 2.5
|
||||
`,
|
||||
wantBandwidth: "1M",
|
||||
wantMultiplier: 2.5,
|
||||
},
|
||||
{
|
||||
name: "no bandwidth limit",
|
||||
yaml: `
|
||||
port: 8443
|
||||
domain: example.com
|
||||
tcp_port_min: 10000
|
||||
tcp_port_max: 20000
|
||||
`,
|
||||
wantBandwidth: "",
|
||||
wantMultiplier: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
if err := os.WriteFile(configPath, []byte(tt.yaml), 0600); err != nil {
|
||||
t.Fatalf("Failed to write config file: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadServerConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadServerConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Bandwidth != tt.wantBandwidth {
|
||||
t.Errorf("Bandwidth = %q, want %q", cfg.Bandwidth, tt.wantBandwidth)
|
||||
}
|
||||
if cfg.BurstMultiplier != tt.wantMultiplier {
|
||||
t.Errorf("BurstMultiplier = %v, want %v", cfg.BurstMultiplier, tt.wantMultiplier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user