Merge pull request #23 from Gouryella/feat/qos-bandwidth-limiting-v2

feat(client): Add bandwidth limit function support
This commit is contained in:
Gouryella
2026-02-15 03:04:28 +08:00
committed by GitHub
30 changed files with 1185 additions and 9 deletions

View File

@@ -35,6 +35,22 @@
## Recent Changes
### 2025-02-14
- **Bandwidth Limiting (QoS)** - Per-tunnel bandwidth control with token bucket algorithm, server enforces `min(client, server)` as effective limit
- **Transport Protocol Control** - Support independent configuration for service domain and tunnel domain
```bash
# Client: limit to 1MB/s
drip http 3000 --bandwidth 1M
```
```yaml
# Server: global limit (config.yaml)
bandwidth: 10M
burst_multiplier: 2.5
```
### 2025-01-29
- **Bearer Token Authentication** - Added bearer token authentication support for tunnel access control

View File

@@ -35,6 +35,22 @@
## 最近更新
### 2025-02-14
- **带宽限速 (QoS)** - 支持按隧道粒度进行带宽控制,使用令牌桶算法,服务端按 `min(client, server)` 作为实际生效限速
- **传输协议控制** - 支持服务域名与隧道域名的独立配置
```bash
# Client: limit to 1MB/s
drip http 3000 --bandwidth 1M
```
```yaml
# Server: global limit (config.yaml)
bandwidth: 10M
burst_multiplier: 2.5
```
### 2025-01-29
- **Bearer Token 认证** - 新增 Bearer Token 认证支持,用于隧道访问控制

1
go.mod
View File

@@ -42,5 +42,6 @@ require (
go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/text v0.33.0 // indirect
golang.org/x/time v0.14.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
)

2
go.sum
View File

@@ -95,6 +95,8 @@ golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -21,6 +21,7 @@ var (
authPass string
authBearer string
transport string
bandwidth string
)
var httpCmd = &cobra.Command{
@@ -37,6 +38,7 @@ Example:
drip http 3000 --auth secret Enable proxy authentication with password
drip http 3000 --auth-bearer sk-xxx Enable proxy authentication with bearer token
drip http 3000 --transport wss Use WebSocket over TLS (CDN-friendly)
drip http 3000 --bandwidth 1M Limit bandwidth to 1 MB/s
Configuration:
First time: Run 'drip config init' to save server and token
@@ -45,7 +47,13 @@ Configuration:
Transport options:
auto - Automatically select based on server address (default)
tcp - Direct TLS 1.3 connection
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
wss - WebSocket over TLS (works through CDN like Cloudflare)
Bandwidth format:
1K, 1KB - 1 kilobyte per second (1024 bytes/s)
1M, 1MB - 1 megabyte per second (1048576 bytes/s)
1G, 1GB - 1 gigabyte per second
1024 - 1024 bytes per second (raw number)`,
Args: cobra.ExactArgs(1),
RunE: runHTTP,
SilenceUsage: true,
@@ -61,6 +69,7 @@ func init() {
httpCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
httpCmd.Flags().StringVar(&authBearer, "auth-bearer", "", "Bearer token for proxy authentication")
httpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
httpCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)")
httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(httpCmd)
@@ -85,6 +94,11 @@ func runHTTP(_ *cobra.Command, args []string) error {
return err
}
bw, err := parseBandwidth(bandwidth)
if err != nil {
return err
}
connConfig := &tcp.ConnectorConfig{
ServerAddr: serverAddr,
Token: token,
@@ -98,6 +112,7 @@ func runHTTP(_ *cobra.Command, args []string) error {
AuthPass: authPass,
AuthBearer: authBearer,
Transport: parseTransport(transport),
Bandwidth: bw,
}
var daemon *DaemonInfo
@@ -118,3 +133,41 @@ func parseTransport(s string) tcp.TransportType {
return tcp.TransportAuto
}
}
func parseBandwidth(s string) (int64, error) {
if s == "" {
return 0, nil
}
s = strings.TrimSpace(strings.ToUpper(s))
if s == "" {
return 0, nil
}
var multiplier int64 = 1
switch {
case strings.HasSuffix(s, "GB") || strings.HasSuffix(s, "G"):
multiplier = 1024 * 1024 * 1024
s = strings.TrimSuffix(strings.TrimSuffix(s, "GB"), "G")
case strings.HasSuffix(s, "MB") || strings.HasSuffix(s, "M"):
multiplier = 1024 * 1024
s = strings.TrimSuffix(strings.TrimSuffix(s, "MB"), "M")
case strings.HasSuffix(s, "KB") || strings.HasSuffix(s, "K"):
multiplier = 1024
s = strings.TrimSuffix(strings.TrimSuffix(s, "KB"), "K")
case strings.HasSuffix(s, "B"):
s = strings.TrimSuffix(s, "B")
}
val, err := strconv.ParseInt(s, 10, 64)
if err != nil || val < 0 {
return 0, fmt.Errorf("invalid bandwidth value: %q (use format like 1M, 500K, 1G)", s)
}
result := val * multiplier
if val > 0 && result/multiplier != val {
return 0, fmt.Errorf("bandwidth value overflow: %q", s)
}
return result, nil
}

View File

@@ -0,0 +1,59 @@
package cli
import (
"testing"
)
func TestParseBandwidth(t *testing.T) {
tests := []struct {
input string
want int64
wantErr bool
}{
{"", 0, false},
{"0", 0, false},
{"1024", 1024, false},
{"1K", 1024, false},
{"1KB", 1024, false},
{"1k", 1024, false},
{"1M", 1024 * 1024, false},
{"1MB", 1024 * 1024, false},
{"1m", 1024 * 1024, false},
{"10M", 10 * 1024 * 1024, false},
{"1G", 1024 * 1024 * 1024, false},
{"1GB", 1024 * 1024 * 1024, false},
{"500K", 500 * 1024, false},
{"100M", 100 * 1024 * 1024, false},
{" 1M ", 1024 * 1024, false},
{"1B", 1, false},
{"100B", 100, false},
{"invalid", 0, true},
{"abc", 0, true},
{"-1M", 0, true},
{"-100", 0, true},
{"1.5M", 0, true},
{"M", 0, true},
{"K", 0, true},
{"9223372036854775807K", 0, true},
{"9999999999999999999G", 0, true},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got, err := parseBandwidth(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("parseBandwidth(%q) = %d, want error", tt.input, got)
}
return
}
if err != nil {
t.Errorf("parseBandwidth(%q) unexpected error: %v", tt.input, err)
return
}
if got != tt.want {
t.Errorf("parseBandwidth(%q) = %d, want %d", tt.input, got, tt.want)
}
})
}
}

View File

@@ -24,6 +24,7 @@ Example:
drip https 443 --auth secret Enable proxy authentication with password
drip https 443 --auth-bearer sk-xxx Enable proxy authentication with bearer token
drip https 443 --transport wss Use WebSocket over TLS (CDN-friendly)
drip https 443 --bandwidth 1M Limit bandwidth to 1 MB/s
Configuration:
First time: Run 'drip config init' to save server and token
@@ -32,7 +33,13 @@ Configuration:
Transport options:
auto - Automatically select based on server address (default)
tcp - Direct TLS 1.3 connection
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
wss - WebSocket over TLS (works through CDN like Cloudflare)
Bandwidth format:
1K, 1KB - 1 kilobyte per second (1024 bytes/s)
1M, 1MB - 1 megabyte per second (1048576 bytes/s)
1G, 1GB - 1 gigabyte per second
1024 - 1024 bytes per second (raw number)`,
Args: cobra.ExactArgs(1),
RunE: runHTTPS,
SilenceUsage: true,
@@ -48,6 +55,7 @@ func init() {
httpsCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
httpsCmd.Flags().StringVar(&authBearer, "auth-bearer", "", "Bearer token for proxy authentication")
httpsCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
httpsCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)")
httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpsCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(httpsCmd)
@@ -72,6 +80,11 @@ func runHTTPS(_ *cobra.Command, args []string) error {
return err
}
bw, err := parseBandwidth(bandwidth)
if err != nil {
return err
}
connConfig := &tcp.ConnectorConfig{
ServerAddr: serverAddr,
Token: token,
@@ -85,6 +98,7 @@ func runHTTPS(_ *cobra.Command, args []string) error {
AuthPass: authPass,
AuthBearer: authBearer,
Transport: parseTransport(transport),
Bandwidth: bw,
}
var daemon *DaemonInfo

View File

@@ -313,6 +313,24 @@ func runServer(cmd *cobra.Command, _ []string) error {
listener.SetAllowedTransports(cfg.AllowedTransports)
listener.SetAllowedTunnelTypes(cfg.AllowedTunnelTypes)
bandwidth, err := parseBandwidth(cfg.Bandwidth)
if err != nil {
logger.Fatal("Invalid bandwidth configuration", zap.Error(err))
}
burstMultiplier := cfg.BurstMultiplier
if burstMultiplier <= 0 {
burstMultiplier = 2.0
}
listener.SetBandwidth(bandwidth)
listener.SetBurstMultiplier(burstMultiplier)
if bandwidth > 0 {
logger.Info("Bandwidth limit configured",
zap.String("bandwidth", cfg.Bandwidth),
zap.Int64("bandwidth_bytes_sec", bandwidth),
zap.Float64("burst_multiplier", burstMultiplier),
)
}
if err := listener.Start(); err != nil {
logger.Fatal("Failed to start TCP listener", zap.Error(err))
}

View File

@@ -106,6 +106,12 @@ func runStart(_ *cobra.Command, args []string) error {
return fmt.Errorf("no tunnels to start")
}
for _, t := range tunnelsToStart {
if err := validateTunnelBandwidth(t); err != nil {
return err
}
}
// Start tunnels
if len(tunnelsToStart) == 1 {
return startSingleTunnel(cfg, tunnelsToStart[0])
@@ -127,7 +133,10 @@ func formatTunnelInfo(t *config.TunnelConfig) string {
}
func startSingleTunnel(cfg *config.ClientConfig, t *config.TunnelConfig) error {
connConfig := buildConnectorConfig(cfg, t)
connConfig, err := buildConnectorConfig(cfg, t)
if err != nil {
return err
}
fmt.Printf("Starting tunnel '%s' (%s %s:%d)\n", t.Name, t.Type, getAddress(t), t.Port)
@@ -164,7 +173,11 @@ func startMultipleTunnels(cfg *config.ClientConfig, tunnels []*config.TunnelConf
go func(tunnel *config.TunnelConfig) {
defer wg.Done()
connConfig := buildConnectorConfig(cfg, tunnel)
connConfig, err := buildConnectorConfig(cfg, tunnel)
if err != nil {
errChan <- err
return
}
fmt.Printf(" Starting %s (%s %s:%d)...\n", tunnel.Name, tunnel.Type, getAddress(tunnel), tunnel.Port)
client := tcp.NewTunnelClient(connConfig, logger)
@@ -212,7 +225,12 @@ func startMultipleTunnels(cfg *config.ClientConfig, tunnels []*config.TunnelConf
return nil
}
func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp.ConnectorConfig {
func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) (*tcp.ConnectorConfig, error) {
bw, err := parseBandwidth(t.Bandwidth)
if err != nil {
return nil, fmt.Errorf("invalid bandwidth for tunnel '%s': %w", t.Name, err)
}
tunnelType := protocol.TunnelTypeHTTP
switch t.Type {
case "https":
@@ -242,7 +260,8 @@ func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp
AuthPass: t.Auth,
AuthBearer: t.AuthBearer,
Transport: transport,
}
Bandwidth: bw,
}, nil
}
func getAddress(t *config.TunnelConfig) string {
@@ -251,3 +270,11 @@ func getAddress(t *config.TunnelConfig) string {
}
return "127.0.0.1"
}
func validateTunnelBandwidth(t *config.TunnelConfig) error {
_, err := parseBandwidth(t.Bandwidth)
if err != nil {
return fmt.Errorf("invalid bandwidth for tunnel '%s': %w", t.Name, err)
}
return nil
}

View File

@@ -24,6 +24,7 @@ Example:
drip tcp 22 --allow-ip 10.0.0.1 Allow single IP
drip tcp 22 --deny-ip 1.2.3.4 Block specific IP
drip tcp 22 --transport wss Use WebSocket over TLS (CDN-friendly)
drip tcp 22 --bandwidth 1M Limit bandwidth to 1 MB/s
Supported Services:
- Databases: PostgreSQL (5432), MySQL (3306), Redis (6379), MongoDB (27017)
@@ -54,6 +55,7 @@ func init() {
tcpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
tcpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
tcpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
tcpCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)")
tcpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
tcpCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(tcpCmd)
@@ -74,6 +76,11 @@ func runTCP(_ *cobra.Command, args []string) error {
return err
}
bw, err := parseBandwidth(bandwidth)
if err != nil {
return err
}
connConfig := &tcp.ConnectorConfig{
ServerAddr: serverAddr,
Token: token,
@@ -85,6 +92,7 @@ func runTCP(_ *cobra.Command, args []string) error {
AllowIPs: allowIPs,
DenyIPs: denyIPs,
Transport: parseTransport(transport),
Bandwidth: bw,
}
var daemon *DaemonInfo

View File

@@ -30,6 +30,9 @@ func buildDaemonArgs(tunnelType string, args []string, subdomain string, localAd
if authBearer != "" {
daemonArgs = append(daemonArgs, "--auth-bearer", authBearer)
}
if bandwidth != "" {
daemonArgs = append(daemonArgs, "--bandwidth", bandwidth)
}
if insecure {
daemonArgs = append(daemonArgs, "--insecure")
}

View File

@@ -46,6 +46,9 @@ type ConnectorConfig struct {
// Transport protocol selection
Transport TransportType
// Bandwidth limit (bytes/sec), 0 = unlimited
Bandwidth int64
}
type TunnelClient interface {

View File

@@ -81,6 +81,9 @@ type PoolClient struct {
// Session scaler
scaler *SessionScaler
// Bandwidth limit requested from server (bytes/sec), 0 = unlimited
bandwidth int64
}
// NewPoolClient creates a new pool client.
@@ -178,6 +181,7 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
transport: transport,
insecure: cfg.Insecure,
dialer: NewConnectionDialer(serverAddr, tlsConfig, cfg.Token, transport, logger),
bandwidth: cfg.Bandwidth,
}
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
@@ -229,6 +233,10 @@ func (c *PoolClient) Connect() error {
}
}
if c.bandwidth > 0 {
req.Bandwidth = c.bandwidth
}
payload, err := json.Marshal(req)
if err != nil {
_ = primaryConn.Close()
@@ -275,6 +283,10 @@ func (c *PoolClient) Connect() error {
c.tunnelID = resp.TunnelID
}
if resp.Bandwidth > 0 {
c.bandwidth = resp.Bandwidth
}
yamuxCfg := mux.NewClientConfig()
session, err := yamux.Server(primaryConn, yamuxCfg)

View File

@@ -20,6 +20,7 @@ import (
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/qos"
)
// bufio.Reader pool to reduce allocations on hot path
@@ -247,7 +248,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
tconn.IncActiveConnections()
defer tconn.DecActiveConnections()
countingStream := netutil.NewCountingConn(stream,
var limitedStream net.Conn = stream
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
if l, ok := limiter.(*qos.Limiter); ok {
limitedStream = qos.NewLimitedConn(r.Context(), stream, l)
}
}
countingStream := netutil.NewCountingConn(limitedStream,
tconn.AddBytesOut,
tconn.AddBytesIn,
)

View File

@@ -14,6 +14,7 @@ import (
"drip/internal/shared/httputil"
"drip/internal/shared/netutil"
"drip/internal/shared/protocol"
"drip/internal/shared/qos"
"drip/internal/shared/wsutil"
)
@@ -58,6 +59,13 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
return
}
var limitedStream net.Conn = stream
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
if l, ok := limiter.(*qos.Limiter); ok {
limitedStream = qos.NewLimitedConn(context.Background(), stream, l)
}
}
go func() {
defer stream.Close()
defer clientConn.Close()
@@ -71,7 +79,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
}
}
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
_ = netutil.PipeWithCallbacks(context.Background(), limitedStream, clientRW,
func(n int64) { tconn.AddBytesOut(n) },
func(n int64) { tconn.AddBytesIn(n) },
)

View File

@@ -0,0 +1,93 @@
package tcp
import (
"testing"
)
func TestEffectiveBandwidthSelection(t *testing.T) {
tests := []struct {
name string
serverBW int64
clientBW int64
wantEffective int64
}{
{"server only", 1024 * 1024, 0, 1024 * 1024},
{"client only", 0, 512 * 1024, 512 * 1024},
{"both unlimited", 0, 0, 0},
{"client lower than server", 10 * 1024 * 1024, 1 * 1024 * 1024, 1 * 1024 * 1024},
{"client higher than server - server wins", 1 * 1024 * 1024, 10 * 1024 * 1024, 1 * 1024 * 1024},
{"client equal to server", 5 * 1024 * 1024, 5 * 1024 * 1024, 5 * 1024 * 1024},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
effectiveBandwidth := tt.serverBW
if tt.clientBW > 0 {
if effectiveBandwidth == 0 || tt.clientBW < effectiveBandwidth {
effectiveBandwidth = tt.clientBW
}
}
if effectiveBandwidth != tt.wantEffective {
t.Errorf("effectiveBandwidth = %d, want %d", effectiveBandwidth, tt.wantEffective)
}
})
}
}
func TestConnectionSetBandwidthConfig(t *testing.T) {
tests := []struct {
name string
bandwidth int64
burstMultiplier float64
wantBandwidth int64
wantMultiplier float64
}{
{"1MB/s with 2x burst", 1024 * 1024, 2.0, 1024 * 1024, 2.0},
{"default multiplier when 0", 1024 * 1024, 0, 1024 * 1024, 2.0},
{"default multiplier when negative", 1024 * 1024, -1.0, 1024 * 1024, 2.0},
{"unlimited bandwidth", 0, 2.5, 0, 2.5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conn := &Connection{}
conn.SetBandwidthConfig(tt.bandwidth, tt.burstMultiplier)
if conn.bandwidth != tt.wantBandwidth {
t.Errorf("bandwidth = %v, want %v", conn.bandwidth, tt.wantBandwidth)
}
if conn.burstMultiplier != tt.wantMultiplier {
t.Errorf("burstMultiplier = %v, want %v", conn.burstMultiplier, tt.wantMultiplier)
}
})
}
}
func TestListenerBandwidthConfig(t *testing.T) {
tests := []struct {
name string
bandwidth int64
burstMultiplier float64
wantBandwidth int64
wantMultiplier float64
}{
{"set bandwidth and multiplier", 1024 * 1024, 2.5, 1024 * 1024, 2.5},
{"default multiplier", 1024 * 1024, 0, 1024 * 1024, 2.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := &Listener{}
l.SetBandwidth(tt.bandwidth)
l.SetBurstMultiplier(tt.burstMultiplier)
if l.bandwidth != tt.wantBandwidth {
t.Errorf("bandwidth = %v, want %v", l.bandwidth, tt.wantBandwidth)
}
if l.burstMultiplier != tt.wantMultiplier {
t.Errorf("burstMultiplier = %v, want %v", l.burstMultiplier, tt.wantMultiplier)
}
})
}
}

View File

@@ -5,6 +5,7 @@ import (
"context"
"errors"
"fmt"
"math"
"net"
"net/http"
"strconv"
@@ -19,6 +20,7 @@ import (
"drip/internal/shared/constants"
"drip/internal/shared/httputil"
"drip/internal/shared/protocol"
"drip/internal/shared/qos"
"go.uber.org/zap"
)
@@ -69,6 +71,8 @@ type Connection struct {
// Server capabilities
allowedTunnelTypes []string
allowedTransports []string
bandwidth int64
burstMultiplier float64
}
// NewConnection creates a new connection handler
@@ -231,11 +235,46 @@ func (c *Connection) Handle() error {
)
}
// Configure bandwidth limiting
effectiveBandwidth := c.bandwidth
if req.Bandwidth > 0 {
if effectiveBandwidth == 0 || req.Bandwidth < effectiveBandwidth {
effectiveBandwidth = req.Bandwidth
}
}
if effectiveBandwidth > 0 {
burstMultiplier := c.burstMultiplier
if burstMultiplier <= 0 {
burstMultiplier = 2.0
}
c.tunnelConn.SetBandwidthWithBurst(effectiveBandwidth, burstMultiplier)
burst := limiterBurst(effectiveBandwidth, burstMultiplier)
limiter := qos.NewLimiter(qos.Config{
Bandwidth: effectiveBandwidth,
Burst: burst,
})
c.tunnelConn.SetLimiter(limiter)
source := "server"
if req.Bandwidth > 0 && (c.bandwidth == 0 || req.Bandwidth < c.bandwidth) {
source = "client"
}
c.logger.Info("Bandwidth limit configured",
zap.String("subdomain", c.subdomain),
zap.Int64("bandwidth_bytes_sec", effectiveBandwidth),
zap.Float64("burst_multiplier", burstMultiplier),
zap.Int("burst_bytes", burst),
zap.String("source", source),
)
}
// Build and send registration response
resp, err := regHandler.BuildRegistrationResponse(result)
if err != nil {
return fmt.Errorf("failed to build registration response: %w", err)
}
resp.Bandwidth = c.tunnelConn.GetBandwidth()
if err := regHandler.SendRegistrationResponse(c.conn, resp); err != nil {
return fmt.Errorf("failed to send registration ack: %w", err)
@@ -483,3 +522,36 @@ func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool {
}
return false
}
func (c *Connection) SetBandwidthConfig(bandwidth int64, burstMultiplier float64) {
c.bandwidth = bandwidth
if burstMultiplier <= 0 {
burstMultiplier = 2.0
}
c.burstMultiplier = burstMultiplier
}
func limiterBurst(bandwidth int64, burstMultiplier float64) int {
if bandwidth <= 0 {
return 0
}
if burstMultiplier <= 0 || math.IsNaN(burstMultiplier) || math.IsInf(burstMultiplier, 0) {
burstMultiplier = 2.0
}
maxBurst := int64(^uint(0) >> 1)
rawBurst := float64(bandwidth) * burstMultiplier
if math.IsNaN(rawBurst) || rawBurst <= 0 {
return 1
}
if rawBurst >= float64(maxBurst) {
return int(maxBurst)
}
burst := int(rawBurst)
if burst <= 0 {
return 1
}
return burst
}

View File

@@ -60,6 +60,8 @@ type Listener struct {
// Server capabilities
allowedTransports []string
allowedTunnelTypes []string
bandwidth int64
burstMultiplier float64
}
func NewListener(cfg ListenerConfig) *Listener {
@@ -298,6 +300,7 @@ func (l *Listener) handleConnection(netConn net.Conn) {
})
conn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
conn.SetAllowedTransports(l.allowedTransports)
conn.SetBandwidthConfig(l.bandwidth, l.burstMultiplier)
connID := netConn.RemoteAddr().String()
l.connMu.Lock()
@@ -420,6 +423,8 @@ func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
HTTPListener: l.httpListener,
})
tcpConn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
tcpConn.SetAllowedTransports(l.allowedTransports)
tcpConn.SetBandwidthConfig(l.bandwidth, l.burstMultiplier)
l.connMu.Lock()
l.connections[connID] = tcpConn
@@ -471,6 +476,17 @@ func (l *Listener) SetAllowedTunnelTypes(types []string) {
l.allowedTunnelTypes = types
}
func (l *Listener) SetBandwidth(bandwidth int64) {
l.bandwidth = bandwidth
}
func (l *Listener) SetBurstMultiplier(multiplier float64) {
if multiplier <= 0 {
multiplier = 2.0
}
l.burstMultiplier = multiplier
}
// IsTransportAllowed checks if a transport is allowed
func (l *Listener) IsTransportAllowed(transport string) bool {
if len(l.allowedTransports) == 0 {

View File

@@ -10,6 +10,7 @@ import (
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/qos"
"go.uber.org/zap"
)
@@ -34,6 +35,7 @@ type Proxy struct {
cancel context.CancelFunc
checkIPAccess func(ip string) bool
limiter interface{ IsLimited() bool }
}
type trafficStats interface {
@@ -73,6 +75,11 @@ func (p *Proxy) SetIPAccessCheck(check func(ip string) bool) {
p.checkIPAccess = check
}
// SetLimiter sets the bandwidth limiter for this proxy.
func (p *Proxy) SetLimiter(limiter interface{ IsLimited() bool }) {
p.limiter = limiter
}
func (p *Proxy) Start() error {
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
@@ -240,10 +247,17 @@ func (p *Proxy) handleConn(conn net.Conn) {
defer stream.Close()
var limitedStream net.Conn = stream
if p.limiter != nil && p.limiter.IsLimited() {
if l, ok := p.limiter.(*qos.Limiter); ok {
limitedStream = qos.NewLimitedConn(p.ctx, stream, l)
}
}
_ = netutil.PipeWithCallbacksAndBufferSize(
p.ctx,
conn,
stream,
limitedStream,
pool.SizeLarge,
func(n int64) {
if p.stats != nil {

View File

@@ -52,6 +52,9 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
if c.tunnelConn != nil && c.tunnelConn.HasIPAccessControl() {
c.proxy.SetIPAccessCheck(c.tunnelConn.IsIPAllowed)
}
if c.tunnelConn != nil {
c.proxy.SetLimiter(c.tunnelConn.GetLimiter())
}
// Update lifecycle manager with proxy
if c.lifecycleManager != nil {

View File

@@ -32,6 +32,10 @@ type Connection struct {
ipAccessChecker *netutil.IPAccessChecker
proxyAuth *protocol.ProxyAuth
bandwidth int64
burstMultiplier float64
limiter interface{ IsLimited() bool }
}
func NewConnection(subdomain string, conn *websocket.Conn, logger *zap.Logger) *Connection {
@@ -214,6 +218,31 @@ func (c *Connection) ValidateProxyAuth(password string) bool {
return auth.Password == password
}
func (c *Connection) SetBandwidthWithBurst(bandwidth int64, burstMultiplier float64) {
c.mu.Lock()
defer c.mu.Unlock()
c.bandwidth = bandwidth
c.burstMultiplier = burstMultiplier
}
func (c *Connection) GetBandwidth() int64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.bandwidth
}
func (c *Connection) SetLimiter(limiter interface{ IsLimited() bool }) {
c.mu.Lock()
defer c.mu.Unlock()
c.limiter = limiter
}
func (c *Connection) GetLimiter() interface{ IsLimited() bool } {
c.mu.RLock()
defer c.mu.RUnlock()
return c.limiter
}
func (c *Connection) StartWritePump() {
if c.Conn == nil {
go func() {

View File

@@ -0,0 +1,59 @@
package tunnel
import (
"testing"
"drip/internal/shared/qos"
"go.uber.org/zap"
)
func TestConnectionBandwidthWithBurst(t *testing.T) {
logger := zap.NewNop()
tests := []struct {
name string
bandwidth int64
burstMultiplier float64
wantBandwidth int64
}{
{"1MB/s with 2x burst", 1024 * 1024, 2.0, 1024 * 1024},
{"500KB/s with 3x burst", 500 * 1024, 3.0, 500 * 1024},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conn := NewConnection("test-subdomain", nil, logger)
conn.SetBandwidthWithBurst(tt.bandwidth, tt.burstMultiplier)
if conn.GetBandwidth() != tt.wantBandwidth {
t.Errorf("GetBandwidth() = %v, want %v", conn.GetBandwidth(), tt.wantBandwidth)
}
burst := int(float64(tt.bandwidth) * tt.burstMultiplier)
limiter := qos.NewLimiter(qos.Config{Bandwidth: tt.bandwidth, Burst: burst})
conn.SetLimiter(limiter)
got := conn.GetLimiter()
if got == nil {
t.Fatal("GetLimiter() should not be nil")
}
if !got.IsLimited() {
t.Error("Limiter should be limited")
}
})
}
}
func TestConnectionBandwidthUnlimited(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test-subdomain", nil, logger)
if conn.GetBandwidth() != 0 {
t.Errorf("Default bandwidth should be 0, got %v", conn.GetBandwidth())
}
if conn.GetLimiter() != nil {
t.Error("Default limiter should be nil")
}
}

View File

@@ -29,6 +29,7 @@ type RegisterRequest struct {
PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"`
IPAccess *IPAccessControl `json:"ip_access,omitempty"`
ProxyAuth *ProxyAuth `json:"proxy_auth,omitempty"`
Bandwidth int64 `json:"bandwidth,omitempty"`
}
type RegisterResponse struct {
@@ -39,6 +40,7 @@ type RegisterResponse struct {
TunnelID string `json:"tunnel_id,omitempty"`
SupportsDataConn bool `json:"supports_data_conn,omitempty"`
RecommendedConns int `json:"recommended_conns,omitempty"`
Bandwidth int64 `json:"bandwidth,omitempty"`
}
type DataConnectRequest struct {

112
internal/shared/qos/conn.go Normal file
View File

@@ -0,0 +1,112 @@
package qos
import (
"context"
"io"
"net"
)
type LimitedConn struct {
net.Conn
limiter *Limiter
ctx context.Context
}
func NewLimitedConn(ctx context.Context, conn net.Conn, limiter *Limiter) *LimitedConn {
return &LimitedConn{
Conn: conn,
limiter: limiter,
ctx: ctx,
}
}
func (c *LimitedConn) Read(b []byte) (n int, err error) {
if c.limiter == nil || !c.limiter.IsLimited() {
return c.Conn.Read(b)
}
burst := c.limiter.RateLimiter().Burst()
if len(b) > burst {
b = b[:burst]
}
n, err = c.Conn.Read(b)
if n > 0 {
if waitErr := c.limiter.RateLimiter().WaitN(c.ctx, n); waitErr != nil {
if err == nil {
err = waitErr
}
}
}
return n, err
}
func (c *LimitedConn) Write(b []byte) (n int, err error) {
if c.limiter == nil || !c.limiter.IsLimited() {
return c.Conn.Write(b)
}
burst := c.limiter.RateLimiter().Burst()
total := 0
for len(b) > 0 {
chunk := min(len(b), burst)
if err := c.limiter.RateLimiter().WaitN(c.ctx, chunk); err != nil {
return total, err
}
nw, err := c.Conn.Write(b[:chunk])
total += nw
if err != nil {
return total, err
}
b = b[chunk:]
}
return total, nil
}
func (c *LimitedConn) ReadFrom(r io.Reader) (n int64, err error) {
buf := make([]byte, 32*1024)
for {
nr, er := r.Read(buf)
if nr > 0 {
nw, ew := c.Write(buf[:nr])
n += int64(nw)
if ew != nil {
err = ew
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return n, err
}
func (c *LimitedConn) WriteTo(w io.Writer) (n int64, err error) {
buf := make([]byte, 32*1024)
for {
nr, er := c.Read(buf)
if nr > 0 {
nw, ew := w.Write(buf[:nr])
n += int64(nw)
if ew != nil {
err = ew
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return n, err
}

View File

@@ -0,0 +1,254 @@
package qos
import (
"bytes"
"context"
"errors"
"io"
"net"
"sync"
"testing"
"time"
)
type errorAfterConn struct {
mockConn
writeLimit int
written int
}
func (c *errorAfterConn) Write(b []byte) (int, error) {
c.mu.Lock()
defer c.mu.Unlock()
remaining := c.writeLimit - c.written
if remaining <= 0 {
return 0, errors.New("write error")
}
if len(b) > remaining {
b = b[:remaining]
}
c.writeBuf = append(c.writeBuf, b...)
c.written += len(b)
return len(b), nil
}
func TestWriteLargerThanBurst(t *testing.T) {
bandwidth := int64(10 * 1024)
burst := 1024
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
data := make([]byte, 5*1024)
for i := range data {
data[i] = byte(i % 256)
}
n, err := lc.Write(data)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != len(data) {
t.Errorf("Write returned %d, want %d", n, len(data))
}
conn.mu.Lock()
defer conn.mu.Unlock()
if !bytes.Equal(conn.writeBuf, data) {
t.Error("Written data does not match input")
}
}
func TestWriteZeroLength(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
n, err := lc.Write(nil)
if err != nil {
t.Fatalf("Write(nil) failed: %v", err)
}
if n != 0 {
t.Errorf("Write(nil) returned %d, want 0", n)
}
}
func TestWriteContextCancelDuringChunking(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 100, Burst: 100})
conn := newMockConn(nil)
ctx, cancel := context.WithCancel(context.Background())
lc := NewLimitedConn(ctx, conn, limiter)
_, err := lc.Write(make([]byte, 100))
if err != nil {
t.Fatalf("First write failed: %v", err)
}
cancel()
_, err = lc.Write(make([]byte, 200))
if err == nil {
t.Error("Write should fail after context cancellation")
}
}
func TestReadCappedToBurst(t *testing.T) {
burst := 512
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
data := make([]byte, 4096)
for i := range data {
data[i] = byte(i % 256)
}
conn := newMockConn(data)
lc := NewLimitedConn(context.Background(), conn, limiter)
buf := make([]byte, 4096)
n, err := lc.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n > burst {
t.Errorf("Read returned %d bytes, should be capped at burst=%d", n, burst)
}
if !bytes.Equal(buf[:n], data[:n]) {
t.Error("Read data mismatch")
}
}
func TestReadEOF(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024})
conn := newMockConn([]byte{})
lc := NewLimitedConn(context.Background(), conn, limiter)
buf := make([]byte, 100)
_, err := lc.Read(buf)
if err != io.EOF {
t.Errorf("Expected io.EOF, got %v", err)
}
}
func TestNilLimiter(t *testing.T) {
conn := newMockConn([]byte("hello"))
lc := NewLimitedConn(context.Background(), conn, nil)
buf := make([]byte, 10)
n, err := lc.Read(buf)
if err != nil {
t.Fatalf("Read with nil limiter failed: %v", err)
}
if n != 5 {
t.Errorf("Read returned %d, want 5", n)
}
n, err = lc.Write([]byte("world"))
if err != nil {
t.Fatalf("Write with nil limiter failed: %v", err)
}
if n != 5 {
t.Errorf("Write returned %d, want 5", n)
}
}
func TestConcurrentReadWrite(t *testing.T) {
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
limiter := NewLimiter(Config{Bandwidth: 100 * 1024, Burst: 100 * 1024})
lc := NewLimitedConn(context.Background(), serverConn, limiter)
dataSize := 50 * 1024
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
data := make([]byte, dataSize)
for i := range data {
data[i] = 0xAA
}
lc.Write(data)
}()
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, dataSize)
total := 0
for total < dataSize {
n, err := clientConn.Read(buf[total:])
if err != nil {
return
}
total += n
}
}()
wg.Wait()
}
func TestReadFromBasic(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
src := bytes.NewReader(make([]byte, 50*1024))
n, err := lc.ReadFrom(src)
if err != nil {
t.Fatalf("ReadFrom failed: %v", err)
}
if n != 50*1024 {
t.Errorf("ReadFrom transferred %d bytes, want %d", n, 50*1024)
}
conn.mu.Lock()
defer conn.mu.Unlock()
if len(conn.writeBuf) != 50*1024 {
t.Errorf("Underlying conn received %d bytes, want %d", len(conn.writeBuf), 50*1024)
}
}
func TestWriteToBasic(t *testing.T) {
data := make([]byte, 50*1024)
for i := range data {
data[i] = byte(i % 256)
}
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
conn := newMockConn(data)
lc := NewLimitedConn(context.Background(), conn, limiter)
var buf bytes.Buffer
n, err := lc.WriteTo(&buf)
if err != nil {
t.Fatalf("WriteTo failed: %v", err)
}
if n != int64(len(data)) {
t.Errorf("WriteTo transferred %d bytes, want %d", n, len(data))
}
if !bytes.Equal(buf.Bytes(), data) {
t.Error("WriteTo data mismatch")
}
}
func TestUnlimitedWrite(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 0})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
data := make([]byte, 1024*1024)
start := time.Now()
n, err := lc.Write(data)
dur := time.Since(start)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != len(data) {
t.Errorf("Write returned %d, want %d", n, len(data))
}
if dur > 50*time.Millisecond {
t.Errorf("Unlimited write took too long: %v", dur)
}
}

View File

@@ -0,0 +1,34 @@
package qos
import (
"golang.org/x/time/rate"
)
type Config struct {
Bandwidth int64
Burst int
}
type Limiter struct {
limiter *rate.Limiter
}
func NewLimiter(cfg Config) *Limiter {
l := &Limiter{}
if cfg.Bandwidth > 0 {
burst := cfg.Burst
if burst <= 0 {
burst = int(cfg.Bandwidth * 2)
}
l.limiter = rate.NewLimiter(rate.Limit(cfg.Bandwidth), burst)
}
return l
}
func (l *Limiter) RateLimiter() *rate.Limiter {
return l.limiter
}
func (l *Limiter) IsLimited() bool {
return l.limiter != nil
}

View File

@@ -0,0 +1,172 @@
package qos
import (
"context"
"io"
"net"
"sync"
"testing"
"time"
)
func TestNewLimiter(t *testing.T) {
tests := []struct {
name string
cfg Config
wantLimit bool
wantBurst int
}{
{
name: "unlimited when bandwidth is 0",
cfg: Config{Bandwidth: 0},
wantLimit: false,
},
{
name: "limited with default burst (2x)",
cfg: Config{Bandwidth: 1024},
wantLimit: true,
wantBurst: 2048,
},
{
name: "limited with custom burst",
cfg: Config{Bandwidth: 1024, Burst: 4096},
wantLimit: true,
wantBurst: 4096,
},
{
name: "1MB/s with 2x burst",
cfg: Config{Bandwidth: 1024 * 1024},
wantLimit: true,
wantBurst: 2 * 1024 * 1024,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := NewLimiter(tt.cfg)
if l.IsLimited() != tt.wantLimit {
t.Errorf("IsLimited() = %v, want %v", l.IsLimited(), tt.wantLimit)
}
if tt.wantLimit {
if l.RateLimiter() == nil {
t.Error("RateLimiter() should not be nil when limited")
}
if l.RateLimiter().Burst() != tt.wantBurst {
t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst)
}
}
})
}
}
func TestLimiterBandwidthEnforcement(t *testing.T) {
bandwidth := int64(10 * 1024)
burst := int(bandwidth * 2)
l := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
if !l.IsLimited() {
t.Fatal("Limiter should be limited")
}
ctx := context.Background()
start := time.Now()
err := l.RateLimiter().WaitN(ctx, burst)
if err != nil {
t.Fatalf("WaitN failed: %v", err)
}
burstDuration := time.Since(start)
if burstDuration > 100*time.Millisecond {
t.Errorf("Burst should be instant, took %v", burstDuration)
}
start = time.Now()
err = l.RateLimiter().WaitN(ctx, int(bandwidth))
if err != nil {
t.Fatalf("WaitN failed: %v", err)
}
limitedDuration := time.Since(start)
if limitedDuration < 800*time.Millisecond {
t.Errorf("Rate limiting not working, took only %v for 1 second worth of data", limitedDuration)
}
if limitedDuration > 1500*time.Millisecond {
t.Errorf("Rate limiting too slow, took %v for 1 second worth of data", limitedDuration)
}
}
type mockConn struct {
readBuf []byte
readPos int
writeBuf []byte
mu sync.Mutex
}
func newMockConn(data []byte) *mockConn {
return &mockConn{readBuf: data}
}
func (c *mockConn) Read(b []byte) (n int, err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.readPos >= len(c.readBuf) {
return 0, io.EOF
}
n = copy(b, c.readBuf[c.readPos:])
c.readPos += n
return n, nil
}
func (c *mockConn) Write(b []byte) (n int, err error) {
c.mu.Lock()
defer c.mu.Unlock()
c.writeBuf = append(c.writeBuf, b...)
return len(b), nil
}
func (c *mockConn) Close() error { return nil }
func (c *mockConn) LocalAddr() net.Addr { return nil }
func (c *mockConn) RemoteAddr() net.Addr { return nil }
func (c *mockConn) SetDeadline(t time.Time) error { return nil }
func (c *mockConn) SetReadDeadline(t time.Time) error { return nil }
func (c *mockConn) SetWriteDeadline(t time.Time) error { return nil }
func TestBurstMultiplier(t *testing.T) {
tests := []struct {
name string
bandwidth int64
multiplier float64
wantBurst int
}{
{
name: "2x multiplier",
bandwidth: 1024 * 1024,
multiplier: 2.0,
wantBurst: 2 * 1024 * 1024,
},
{
name: "1x multiplier (no extra burst)",
bandwidth: 1024 * 1024,
multiplier: 1.0,
wantBurst: 1024 * 1024,
},
{
name: "3x multiplier",
bandwidth: 500 * 1024,
multiplier: 3.0,
wantBurst: 3 * 500 * 1024,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
burst := int(float64(tt.bandwidth) * tt.multiplier)
l := NewLimiter(Config{Bandwidth: tt.bandwidth, Burst: burst})
if l.RateLimiter().Burst() != tt.wantBurst {
t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst)
}
})
}
}

View File

@@ -22,6 +22,7 @@ type TunnelConfig struct {
DenyIPs []string `yaml:"deny_ips,omitempty"` // Denied IPs/CIDRs
Auth string `yaml:"auth,omitempty"` // Proxy authentication password (http/https only)
AuthBearer string `yaml:"auth_bearer,omitempty"` // Proxy authentication bearer token (http/https only)
Bandwidth string `yaml:"bandwidth,omitempty"` // Bandwidth limit (e.g., 1M, 500K, 1G)
}
// Validate checks if the tunnel configuration is valid

View File

@@ -41,6 +41,10 @@ type ServerConfig struct {
// Allowed tunnel types: "http", "https", "tcp" (default: all)
AllowedTunnelTypes []string `yaml:"tunnel_types"`
// Bandwidth limiting
Bandwidth string `yaml:"bandwidth,omitempty"`
BurstMultiplier float64 `yaml:"burst_multiplier,omitempty"`
}
// Validate checks if the server configuration is valid

63
pkg/config/config_test.go Normal file
View File

@@ -0,0 +1,63 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestServerConfigBandwidth(t *testing.T) {
tests := []struct {
name string
yaml string
wantBandwidth string
wantMultiplier float64
}{
{
name: "bandwidth 1M with 2.5x burst",
yaml: `
port: 8443
domain: example.com
tcp_port_min: 10000
tcp_port_max: 20000
bandwidth: 1M
burst_multiplier: 2.5
`,
wantBandwidth: "1M",
wantMultiplier: 2.5,
},
{
name: "no bandwidth limit",
yaml: `
port: 8443
domain: example.com
tcp_port_min: 10000
tcp_port_max: 20000
`,
wantBandwidth: "",
wantMultiplier: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte(tt.yaml), 0600); err != nil {
t.Fatalf("Failed to write config file: %v", err)
}
cfg, err := LoadServerConfig(configPath)
if err != nil {
t.Fatalf("LoadServerConfig failed: %v", err)
}
if cfg.Bandwidth != tt.wantBandwidth {
t.Errorf("Bandwidth = %q, want %q", cfg.Bandwidth, tt.wantBandwidth)
}
if cfg.BurstMultiplier != tt.wantMultiplier {
t.Errorf("BurstMultiplier = %v, want %v", cfg.BurstMultiplier, tt.wantMultiplier)
}
})
}
}