mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
feat(cli): Add bandwidth limit function support
Added bandwidth limiting functionality, allowing users to limit the bandwidth of tunnel connections via the --bandwidth parameter. Supported formats include: 1K/1KB (kilobytes), 1M/1MB (megabytes), 1G/1GB (gigabytes) or Raw number (bytes).
This commit is contained in:
1
go.mod
1
go.mod
@@ -42,5 +42,6 @@ require (
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
)
|
||||
|
||||
2
go.sum
2
go.sum
@@ -95,6 +95,8 @@ golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
@@ -151,6 +151,9 @@ func runConfigShow(_ *cobra.Command, _ []string) error {
|
||||
if t.Transport != "" {
|
||||
fmt.Printf(" transport=%s", t.Transport)
|
||||
}
|
||||
if t.Bandwidth != "" {
|
||||
fmt.Printf(" bandwidth=%s", t.Bandwidth)
|
||||
}
|
||||
if len(t.AllowIPs) > 0 {
|
||||
fmt.Printf(" allow=%s", strings.Join(t.AllowIPs, ","))
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ var (
|
||||
denyIPs []string
|
||||
authPass string
|
||||
transport string
|
||||
bandwidth string
|
||||
)
|
||||
|
||||
var httpCmd = &cobra.Command{
|
||||
@@ -35,6 +36,7 @@ Example:
|
||||
drip http 3000 --deny-ip 1.2.3.4 Block specific IP
|
||||
drip http 3000 --auth secret Enable proxy authentication with password
|
||||
drip http 3000 --transport wss Use WebSocket over TLS (CDN-friendly)
|
||||
drip http 3000 --bandwidth 1M Limit bandwidth to 1 MB/s
|
||||
|
||||
Configuration:
|
||||
First time: Run 'drip config init' to save server and token
|
||||
@@ -43,7 +45,13 @@ Configuration:
|
||||
Transport options:
|
||||
auto - Automatically select based on server address (default)
|
||||
tcp - Direct TLS 1.3 connection
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)
|
||||
|
||||
Bandwidth format:
|
||||
1K, 1KB - 1 kilobyte per second (1024 bytes/s)
|
||||
1M, 1MB - 1 megabyte per second (1048576 bytes/s)
|
||||
1G, 1GB - 1 gigabyte per second
|
||||
1024 - 1024 bytes per second (raw number)`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runHTTP,
|
||||
}
|
||||
@@ -56,6 +64,7 @@ func init() {
|
||||
httpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
|
||||
httpCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
|
||||
httpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
|
||||
httpCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)")
|
||||
httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(httpCmd)
|
||||
@@ -76,6 +85,11 @@ func runHTTP(_ *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
bw, err := parseBandwidth(bandwidth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
@@ -88,6 +102,7 @@ func runHTTP(_ *cobra.Command, args []string) error {
|
||||
DenyIPs: denyIPs,
|
||||
AuthPass: authPass,
|
||||
Transport: parseTransport(transport),
|
||||
Bandwidth: bw,
|
||||
}
|
||||
|
||||
var daemon *DaemonInfo
|
||||
@@ -108,3 +123,36 @@ func parseTransport(s string) tcp.TransportType {
|
||||
return tcp.TransportAuto
|
||||
}
|
||||
}
|
||||
|
||||
func parseBandwidth(s string) (int64, error) {
|
||||
if s == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
s = strings.TrimSpace(strings.ToUpper(s))
|
||||
if s == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var multiplier int64 = 1
|
||||
switch {
|
||||
case strings.HasSuffix(s, "GB") || strings.HasSuffix(s, "G"):
|
||||
multiplier = 1024 * 1024 * 1024
|
||||
s = strings.TrimSuffix(strings.TrimSuffix(s, "GB"), "G")
|
||||
case strings.HasSuffix(s, "MB") || strings.HasSuffix(s, "M"):
|
||||
multiplier = 1024 * 1024
|
||||
s = strings.TrimSuffix(strings.TrimSuffix(s, "MB"), "M")
|
||||
case strings.HasSuffix(s, "KB") || strings.HasSuffix(s, "K"):
|
||||
multiplier = 1024
|
||||
s = strings.TrimSuffix(strings.TrimSuffix(s, "KB"), "K")
|
||||
case strings.HasSuffix(s, "B"):
|
||||
s = strings.TrimSuffix(s, "B")
|
||||
}
|
||||
|
||||
val, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil || val < 0 {
|
||||
return 0, fmt.Errorf("invalid bandwidth value: %q (use format like 1M, 500K, 1G)", s)
|
||||
}
|
||||
|
||||
return val * multiplier, nil
|
||||
}
|
||||
|
||||
62
internal/client/cli/http_test.go
Normal file
62
internal/client/cli/http_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseBandwidth(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want int64
|
||||
wantErr bool
|
||||
}{
|
||||
// Valid cases
|
||||
{"", 0, false},
|
||||
{"0", 0, false},
|
||||
{"1024", 1024, false},
|
||||
{"1K", 1024, false},
|
||||
{"1KB", 1024, false},
|
||||
{"1k", 1024, false},
|
||||
{"1M", 1024 * 1024, false},
|
||||
{"1MB", 1024 * 1024, false},
|
||||
{"1m", 1024 * 1024, false},
|
||||
{"10M", 10 * 1024 * 1024, false},
|
||||
{"1G", 1024 * 1024 * 1024, false},
|
||||
{"1GB", 1024 * 1024 * 1024, false},
|
||||
{"500K", 500 * 1024, false},
|
||||
{"100M", 100 * 1024 * 1024, false},
|
||||
{" 1M ", 1024 * 1024, false},
|
||||
{"1B", 1, false},
|
||||
{"100B", 100, false},
|
||||
|
||||
// Error cases
|
||||
{"invalid", 0, true},
|
||||
{"abc", 0, true},
|
||||
{"-1M", 0, true},
|
||||
{"-100", 0, true},
|
||||
{"1.5M", 0, true},
|
||||
{"M", 0, true},
|
||||
{"K", 0, true},
|
||||
{"1T", 0, true},
|
||||
{"1KM", 0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got, err := parseBandwidth(tt.input)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("parseBandwidth(%q) = %d, want error", tt.input, got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("parseBandwidth(%q) unexpected error: %v", tt.input, err)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("parseBandwidth(%q) = %d, want %d", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,7 @@ Example:
|
||||
drip https 443 --deny-ip 1.2.3.4 Block specific IP
|
||||
drip https 443 --auth secret Enable proxy authentication with password
|
||||
drip https 443 --transport wss Use WebSocket over TLS (CDN-friendly)
|
||||
drip https 443 --bandwidth 1M Limit bandwidth to 1 MB/s
|
||||
|
||||
Configuration:
|
||||
First time: Run 'drip config init' to save server and token
|
||||
@@ -31,7 +32,13 @@ Configuration:
|
||||
Transport options:
|
||||
auto - Automatically select based on server address (default)
|
||||
tcp - Direct TLS 1.3 connection
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)`,
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)
|
||||
|
||||
Bandwidth format:
|
||||
1K, 1KB - 1 kilobyte per second (1024 bytes/s)
|
||||
1M, 1MB - 1 megabyte per second (1048576 bytes/s)
|
||||
1G, 1GB - 1 gigabyte per second
|
||||
1024 - 1024 bytes per second (raw number)`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runHTTPS,
|
||||
}
|
||||
@@ -44,6 +51,7 @@ func init() {
|
||||
httpsCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
|
||||
httpsCmd.Flags().StringVar(&authPass, "auth", "", "Password for proxy authentication")
|
||||
httpsCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
|
||||
httpsCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)")
|
||||
httpsCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpsCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(httpsCmd)
|
||||
@@ -64,6 +72,11 @@ func runHTTPS(_ *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
bw, err := parseBandwidth(bandwidth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
@@ -76,6 +89,7 @@ func runHTTPS(_ *cobra.Command, args []string) error {
|
||||
DenyIPs: denyIPs,
|
||||
AuthPass: authPass,
|
||||
Transport: parseTransport(transport),
|
||||
Bandwidth: bw,
|
||||
}
|
||||
|
||||
var daemon *DaemonInfo
|
||||
|
||||
@@ -22,21 +22,21 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
serverPort int
|
||||
serverPublicPort int
|
||||
serverDomain string
|
||||
serverTunnelDomain string
|
||||
serverAuthToken string
|
||||
serverMetricsToken string
|
||||
serverDebug bool
|
||||
serverTCPPortMin int
|
||||
serverTCPPortMax int
|
||||
serverTLSCert string
|
||||
serverTLSKey string
|
||||
serverPprofPort int
|
||||
serverTransports string
|
||||
serverTunnelTypes string
|
||||
serverConfigFile string
|
||||
serverPort int
|
||||
serverPublicPort int
|
||||
serverDomain string
|
||||
serverTunnelDomain string
|
||||
serverAuthToken string
|
||||
serverMetricsToken string
|
||||
serverDebug bool
|
||||
serverTCPPortMin int
|
||||
serverTCPPortMax int
|
||||
serverTLSCert string
|
||||
serverTLSKey string
|
||||
serverPprofPort int
|
||||
serverTransports string
|
||||
serverTunnelTypes string
|
||||
serverConfigFile string
|
||||
)
|
||||
|
||||
var serverCmd = &cobra.Command{
|
||||
@@ -293,6 +293,24 @@ func runServer(cmd *cobra.Command, _ []string) error {
|
||||
listener.SetAllowedTransports(cfg.AllowedTransports)
|
||||
listener.SetAllowedTunnelTypes(cfg.AllowedTunnelTypes)
|
||||
|
||||
bandwidth, err := parseBandwidth(cfg.Bandwidth)
|
||||
if err != nil {
|
||||
logger.Fatal("Invalid bandwidth configuration", zap.Error(err))
|
||||
}
|
||||
burstMultiplier := cfg.BurstMultiplier
|
||||
if burstMultiplier <= 0 {
|
||||
burstMultiplier = 2.0
|
||||
}
|
||||
listener.SetBandwidth(bandwidth)
|
||||
listener.SetBurstMultiplier(burstMultiplier)
|
||||
if bandwidth > 0 {
|
||||
logger.Info("Bandwidth limit configured",
|
||||
zap.String("bandwidth", cfg.Bandwidth),
|
||||
zap.Int64("bandwidth_bytes_sec", bandwidth),
|
||||
zap.Float64("burst_multiplier", burstMultiplier),
|
||||
)
|
||||
}
|
||||
|
||||
if err := listener.Start(); err != nil {
|
||||
logger.Fatal("Failed to start TCP listener", zap.Error(err))
|
||||
}
|
||||
@@ -326,7 +344,6 @@ func runServer(cmd *cobra.Command, _ []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getEnvInt returns the environment variable value as int, or defaultVal if not set
|
||||
func getEnvInt(key string, defaultVal int) int {
|
||||
if val := os.Getenv(key); val != "" {
|
||||
if i, err := strconv.Atoi(val); err == nil {
|
||||
@@ -336,7 +353,6 @@ func getEnvInt(key string, defaultVal int) int {
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
// getEnvString returns the environment variable value, or defaultVal if not set
|
||||
func getEnvString(key string, defaultVal string) string {
|
||||
if val := os.Getenv(key); val != "" {
|
||||
return val
|
||||
@@ -344,7 +360,6 @@ func getEnvString(key string, defaultVal string) string {
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
// parseCommaSeparated splits a comma-separated string into a slice
|
||||
func parseCommaSeparated(s string) []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
|
||||
@@ -125,7 +125,10 @@ func formatTunnelInfo(t *config.TunnelConfig) string {
|
||||
}
|
||||
|
||||
func startSingleTunnel(cfg *config.ClientConfig, t *config.TunnelConfig) error {
|
||||
connConfig := buildConnectorConfig(cfg, t)
|
||||
connConfig, err := buildConnectorConfig(cfg, t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Starting tunnel '%s' (%s %s:%d)\n", t.Name, t.Type, getAddress(t), t.Port)
|
||||
|
||||
@@ -162,7 +165,11 @@ func startMultipleTunnels(cfg *config.ClientConfig, tunnels []*config.TunnelConf
|
||||
go func(tunnel *config.TunnelConfig) {
|
||||
defer wg.Done()
|
||||
|
||||
connConfig := buildConnectorConfig(cfg, tunnel)
|
||||
connConfig, err := buildConnectorConfig(cfg, tunnel)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("tunnel '%s': %w", tunnel.Name, err)
|
||||
return
|
||||
}
|
||||
fmt.Printf(" Starting %s (%s %s:%d)...\n", tunnel.Name, tunnel.Type, getAddress(tunnel), tunnel.Port)
|
||||
|
||||
client := tcp.NewTunnelClient(connConfig, logger)
|
||||
@@ -210,7 +217,7 @@ func startMultipleTunnels(cfg *config.ClientConfig, tunnels []*config.TunnelConf
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp.ConnectorConfig {
|
||||
func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) (*tcp.ConnectorConfig, error) {
|
||||
tunnelType := protocol.TunnelTypeHTTP
|
||||
switch t.Type {
|
||||
case "https":
|
||||
@@ -219,12 +226,11 @@ func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp
|
||||
tunnelType = protocol.TunnelTypeTCP
|
||||
}
|
||||
|
||||
transport := tcp.TransportAuto
|
||||
switch strings.ToLower(t.Transport) {
|
||||
case "tcp", "tls":
|
||||
transport = tcp.TransportTCP
|
||||
case "wss":
|
||||
transport = tcp.TransportWebSocket
|
||||
transport := parseTransport(t.Transport)
|
||||
|
||||
bw, err := parseBandwidth(t.Bandwidth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tcp.ConnectorConfig{
|
||||
@@ -239,7 +245,8 @@ func buildConnectorConfig(cfg *config.ClientConfig, t *config.TunnelConfig) *tcp
|
||||
DenyIPs: t.DenyIPs,
|
||||
AuthPass: t.Auth,
|
||||
Transport: transport,
|
||||
}
|
||||
Bandwidth: bw,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getAddress(t *config.TunnelConfig) string {
|
||||
|
||||
@@ -24,6 +24,7 @@ Example:
|
||||
drip tcp 22 --allow-ip 10.0.0.1 Allow single IP
|
||||
drip tcp 22 --deny-ip 1.2.3.4 Block specific IP
|
||||
drip tcp 22 --transport wss Use WebSocket over TLS (CDN-friendly)
|
||||
drip tcp 22 --bandwidth 1M Limit bandwidth to 1 MB/s
|
||||
|
||||
Supported Services:
|
||||
- Databases: PostgreSQL (5432), MySQL (3306), Redis (6379), MongoDB (27017)
|
||||
@@ -39,6 +40,12 @@ Transport options:
|
||||
tcp - Direct TLS 1.3 connection
|
||||
wss - WebSocket over TLS (works through CDN like Cloudflare)
|
||||
|
||||
Bandwidth format:
|
||||
1K, 1KB - 1 kilobyte per second (1024 bytes/s)
|
||||
1M, 1MB - 1 megabyte per second (1048576 bytes/s)
|
||||
1G, 1GB - 1 gigabyte per second
|
||||
1024 - 1024 bytes per second (raw number)
|
||||
|
||||
Note: TCP tunnels require dynamic port allocation on the server.
|
||||
When using CDN (--transport wss), the server must still expose the allocated port directly.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
@@ -52,6 +59,7 @@ func init() {
|
||||
tcpCmd.Flags().StringSliceVar(&allowIPs, "allow-ip", nil, "Allow only these IPs or CIDR ranges (e.g., 192.168.1.1,10.0.0.0/8)")
|
||||
tcpCmd.Flags().StringSliceVar(&denyIPs, "deny-ip", nil, "Deny these IPs or CIDR ranges (e.g., 1.2.3.4,192.168.1.0/24)")
|
||||
tcpCmd.Flags().StringVar(&transport, "transport", "auto", "Transport protocol: auto, tcp, wss (WebSocket over TLS)")
|
||||
tcpCmd.Flags().StringVar(&bandwidth, "bandwidth", "", "Bandwidth limit (e.g., 1M, 500K, 1G)")
|
||||
tcpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
tcpCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(tcpCmd)
|
||||
@@ -72,6 +80,11 @@ func runTCP(_ *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
bw, err := parseBandwidth(bandwidth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
@@ -83,6 +96,7 @@ func runTCP(_ *cobra.Command, args []string) error {
|
||||
AllowIPs: allowIPs,
|
||||
DenyIPs: denyIPs,
|
||||
Transport: parseTransport(transport),
|
||||
Bandwidth: bw,
|
||||
}
|
||||
|
||||
var daemon *DaemonInfo
|
||||
|
||||
@@ -45,6 +45,9 @@ type ConnectorConfig struct {
|
||||
|
||||
// Transport protocol selection
|
||||
Transport TransportType
|
||||
|
||||
// Bandwidth limiting (bytes/sec), 0 = unlimited
|
||||
Bandwidth int64
|
||||
}
|
||||
|
||||
type TunnelClient interface {
|
||||
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/hashicorp/yamux"
|
||||
"go.uber.org/zap"
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
"drip/pkg/config"
|
||||
)
|
||||
|
||||
// PoolClient manages a pool of yamux sessions for tunnel connections.
|
||||
type PoolClient struct {
|
||||
serverAddr string
|
||||
tlsConfig *tls.Config
|
||||
@@ -76,11 +75,12 @@ type PoolClient struct {
|
||||
// Transport protocol selection
|
||||
transport TransportType
|
||||
insecure bool
|
||||
|
||||
// Bandwidth limit requested from server (bytes/sec), 0 = unlimited
|
||||
bandwidth int64
|
||||
}
|
||||
|
||||
// NewPoolClient creates a new pool client.
|
||||
func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
// Parse server address to get host for TLS config
|
||||
serverAddr := cfg.ServerAddr
|
||||
host := serverAddr
|
||||
|
||||
@@ -96,7 +96,6 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
}
|
||||
}
|
||||
|
||||
// Extract hostname without port for TLS
|
||||
hostOnly, _, _ := net.SplitHostPort(host)
|
||||
if hostOnly == "" {
|
||||
hostOnly = host
|
||||
@@ -140,7 +139,6 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
}
|
||||
initialSessions = min(max(initialSessions, minSessions), maxSessions)
|
||||
|
||||
// Determine transport type
|
||||
transport := cfg.Transport
|
||||
if transport == "" {
|
||||
transport = TransportAuto
|
||||
@@ -171,6 +169,7 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
authPass: cfg.AuthPass,
|
||||
transport: transport,
|
||||
insecure: cfg.Insecure,
|
||||
bandwidth: cfg.Bandwidth,
|
||||
}
|
||||
|
||||
if tunnelType == protocol.TunnelTypeHTTP || tunnelType == protocol.TunnelTypeHTTPS {
|
||||
@@ -181,7 +180,6 @@ func NewPoolClient(cfg *ConnectorConfig, logger *zap.Logger) *PoolClient {
|
||||
return c
|
||||
}
|
||||
|
||||
// Connect establishes the primary connection and starts background workers.
|
||||
func (c *PoolClient) Connect() error {
|
||||
primaryConn, err := c.dial()
|
||||
if err != nil {
|
||||
@@ -215,6 +213,10 @@ func (c *PoolClient) Connect() error {
|
||||
}
|
||||
}
|
||||
|
||||
if c.bandwidth > 0 {
|
||||
req.Bandwidth = c.bandwidth
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
_ = primaryConn.Close()
|
||||
@@ -261,6 +263,10 @@ func (c *PoolClient) Connect() error {
|
||||
c.tunnelID = resp.TunnelID
|
||||
}
|
||||
|
||||
if resp.Bandwidth > 0 {
|
||||
c.bandwidth = resp.Bandwidth
|
||||
}
|
||||
|
||||
yamuxCfg := mux.NewClientConfig()
|
||||
|
||||
session, err := yamux.Server(primaryConn, yamuxCfg)
|
||||
@@ -335,13 +341,11 @@ func (c *PoolClient) dialTLS() (net.Conn, error) {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// serverCapabilities holds the discovered server capabilities
|
||||
type serverCapabilities struct {
|
||||
Transports []string `json:"transports"`
|
||||
Preferred string `json:"preferred"`
|
||||
}
|
||||
|
||||
// dial selects the appropriate transport and establishes a connection
|
||||
func (c *PoolClient) dial() (net.Conn, error) {
|
||||
switch c.transport {
|
||||
case TransportWebSocket:
|
||||
@@ -377,7 +381,6 @@ func (c *PoolClient) dial() (net.Conn, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// discoverServerCapabilities queries the server for its capabilities
|
||||
func (c *PoolClient) discoverServerCapabilities() *serverCapabilities {
|
||||
host, port, err := net.SplitHostPort(c.serverAddr)
|
||||
if err != nil {
|
||||
@@ -420,9 +423,7 @@ func (c *PoolClient) discoverServerCapabilities() *serverCapabilities {
|
||||
return &caps
|
||||
}
|
||||
|
||||
// dialWebSocket establishes a WebSocket connection to the server over TLS
|
||||
func (c *PoolClient) dialWebSocket() (net.Conn, error) {
|
||||
// Build WebSocket URL
|
||||
host, port, err := net.SplitHostPort(c.serverAddr)
|
||||
if err != nil {
|
||||
// No port specified, use default
|
||||
@@ -574,7 +575,6 @@ func (c *PoolClient) pingLoop(h *sessionHandle) {
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the client and all sessions.
|
||||
func (c *PoolClient) Close() error {
|
||||
var closeErr error
|
||||
|
||||
@@ -623,12 +623,12 @@ func (c *PoolClient) Close() error {
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (c *PoolClient) Wait() { <-c.doneCh }
|
||||
func (c *PoolClient) GetURL() string { return c.assignedURL }
|
||||
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
|
||||
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
|
||||
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
|
||||
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
|
||||
func (c *PoolClient) Wait() { <-c.doneCh }
|
||||
func (c *PoolClient) GetURL() string { return c.assignedURL }
|
||||
func (c *PoolClient) GetSubdomain() string { return c.subdomain }
|
||||
func (c *PoolClient) GetLatency() time.Duration { return time.Duration(c.latencyNanos.Load()) }
|
||||
func (c *PoolClient) GetStats() *stats.TrafficStats { return c.stats }
|
||||
func (c *PoolClient) IsClosed() bool { return c.closed.Load() }
|
||||
|
||||
func (c *PoolClient) SetLatencyCallback(cb LatencyCallback) {
|
||||
if cb == nil {
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/qos"
|
||||
"drip/internal/shared/wsutil"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
@@ -171,24 +172,22 @@ func (h *Handler) SetAllowedTunnelTypes(types []string) {
|
||||
|
||||
// IsTransportAllowed checks if a transport is allowed
|
||||
func (h *Handler) IsTransportAllowed(transport string) bool {
|
||||
if len(h.allowedTransports) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, t := range h.allowedTransports {
|
||||
if strings.EqualFold(t, transport) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return containsFold(h.allowedTransports, transport)
|
||||
}
|
||||
|
||||
// IsTunnelTypeAllowed checks if a tunnel type is allowed
|
||||
func (h *Handler) IsTunnelTypeAllowed(tunnelType string) bool {
|
||||
if len(h.allowedTunnelTypes) == 0 {
|
||||
return containsFold(h.allowedTunnelTypes, tunnelType)
|
||||
}
|
||||
|
||||
// containsFold returns true if the slice is empty (allow all) or contains the
|
||||
// value in a case-insensitive comparison.
|
||||
func containsFold(allowed []string, value string) bool {
|
||||
if len(allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, t := range h.allowedTunnelTypes {
|
||||
if strings.EqualFold(t, tunnelType) {
|
||||
for _, a := range allowed {
|
||||
if strings.EqualFold(a, value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -299,7 +298,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
tconn.IncActiveConnections()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
countingStream := netutil.NewCountingConn(stream,
|
||||
var limitedStream net.Conn = stream
|
||||
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
|
||||
limitedStream = qos.NewLimitedConn(r.Context(), stream, limiter)
|
||||
}
|
||||
|
||||
countingStream := netutil.NewCountingConn(limitedStream,
|
||||
tconn.AddBytesOut,
|
||||
tconn.AddBytesIn,
|
||||
)
|
||||
@@ -428,6 +432,11 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
|
||||
return
|
||||
}
|
||||
|
||||
var limitedStream net.Conn = stream
|
||||
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
|
||||
limitedStream = qos.NewLimitedConn(context.Background(), stream, limiter)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
defer clientConn.Close()
|
||||
@@ -441,7 +450,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
|
||||
}
|
||||
}
|
||||
|
||||
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
|
||||
_ = netutil.PipeWithCallbacks(context.Background(), limitedStream, clientRW,
|
||||
func(n int64) { tconn.AddBytesOut(n) },
|
||||
func(n int64) { tconn.AddBytesIn(n) },
|
||||
)
|
||||
|
||||
168
internal/server/tcp/bandwidth_test.go
Normal file
168
internal/server/tcp/bandwidth_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEffectiveBandwidthSelection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverBW int64
|
||||
clientBW int64
|
||||
wantEffective int64
|
||||
}{
|
||||
{
|
||||
name: "server only",
|
||||
serverBW: 1024 * 1024,
|
||||
clientBW: 0,
|
||||
wantEffective: 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "client only",
|
||||
serverBW: 0,
|
||||
clientBW: 512 * 1024,
|
||||
wantEffective: 512 * 1024,
|
||||
},
|
||||
{
|
||||
name: "both unlimited",
|
||||
serverBW: 0,
|
||||
clientBW: 0,
|
||||
wantEffective: 0,
|
||||
},
|
||||
{
|
||||
name: "client lower than server",
|
||||
serverBW: 10 * 1024 * 1024,
|
||||
clientBW: 1 * 1024 * 1024,
|
||||
wantEffective: 1 * 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "client higher than server - server wins",
|
||||
serverBW: 1 * 1024 * 1024,
|
||||
clientBW: 10 * 1024 * 1024,
|
||||
wantEffective: 1 * 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "client equal to server",
|
||||
serverBW: 5 * 1024 * 1024,
|
||||
clientBW: 5 * 1024 * 1024,
|
||||
wantEffective: 5 * 1024 * 1024,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
effectiveBandwidth := tt.serverBW
|
||||
if tt.clientBW > 0 {
|
||||
if effectiveBandwidth == 0 || tt.clientBW < effectiveBandwidth {
|
||||
effectiveBandwidth = tt.clientBW
|
||||
}
|
||||
}
|
||||
|
||||
if effectiveBandwidth != tt.wantEffective {
|
||||
t.Errorf("effectiveBandwidth = %d, want %d", effectiveBandwidth, tt.wantEffective)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionSetBandwidthConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
wantBandwidth int64
|
||||
wantMultiplier float64
|
||||
}{
|
||||
{
|
||||
name: "1MB/s with 2x burst",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 2.0,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantMultiplier: 2.0,
|
||||
},
|
||||
{
|
||||
name: "1MB/s with 2.5x burst",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 2.5,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantMultiplier: 2.5,
|
||||
},
|
||||
{
|
||||
name: "default multiplier when 0",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 0,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantMultiplier: 2.0,
|
||||
},
|
||||
{
|
||||
name: "default multiplier when negative",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: -1.0,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantMultiplier: 2.0,
|
||||
},
|
||||
{
|
||||
name: "unlimited bandwidth",
|
||||
bandwidth: 0,
|
||||
burstMultiplier: 2.5,
|
||||
wantBandwidth: 0,
|
||||
wantMultiplier: 2.5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
conn := &Connection{}
|
||||
conn.SetBandwidthConfig(tt.bandwidth, tt.burstMultiplier)
|
||||
|
||||
if conn.bandwidth != tt.wantBandwidth {
|
||||
t.Errorf("bandwidth = %v, want %v", conn.bandwidth, tt.wantBandwidth)
|
||||
}
|
||||
|
||||
if conn.burstMultiplier != tt.wantMultiplier {
|
||||
t.Errorf("burstMultiplier = %v, want %v", conn.burstMultiplier, tt.wantMultiplier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenerBandwidthConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
wantBandwidth int64
|
||||
wantMultiplier float64
|
||||
}{
|
||||
{
|
||||
name: "set bandwidth and multiplier",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 2.5,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantMultiplier: 2.5,
|
||||
},
|
||||
{
|
||||
name: "default multiplier",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 0,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantMultiplier: 2.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l := &Listener{}
|
||||
l.SetBandwidth(tt.bandwidth)
|
||||
l.SetBurstMultiplier(tt.burstMultiplier)
|
||||
|
||||
if l.bandwidth != tt.wantBandwidth {
|
||||
t.Errorf("bandwidth = %v, want %v", l.bandwidth, tt.wantBandwidth)
|
||||
}
|
||||
|
||||
if l.burstMultiplier != tt.wantMultiplier {
|
||||
t.Errorf("burstMultiplier = %v, want %v", l.burstMultiplier, tt.wantMultiplier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -59,12 +59,12 @@ type Connection struct {
|
||||
httpListener *connQueueListener
|
||||
handedOff bool
|
||||
|
||||
// Server capabilities
|
||||
allowedTunnelTypes []string
|
||||
allowedTransports []string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
}
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
c := &Connection{
|
||||
@@ -99,22 +99,11 @@ func (c *Connection) Handle() error {
|
||||
return fmt.Errorf("failed to peek connection: %w", err)
|
||||
}
|
||||
|
||||
peekStr := string(peek)
|
||||
httpMethods := []string{"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC"}
|
||||
isHTTP := false
|
||||
for _, method := range httpMethods {
|
||||
if strings.HasPrefix(peekStr, method) {
|
||||
isHTTP = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isHTTP {
|
||||
if isHTTPMethod(string(peek)) {
|
||||
c.logger.Info("Detected HTTP request on TCP port, handling as HTTP")
|
||||
return c.handleHTTPRequest(reader)
|
||||
}
|
||||
|
||||
// Check if TCP transport is allowed (only for Drip protocol connections, not HTTP)
|
||||
if !c.isTransportAllowed("tcp") {
|
||||
c.logger.Warn("TCP transport not allowed, rejecting Drip protocol connection")
|
||||
return fmt.Errorf("TCP transport not allowed")
|
||||
@@ -142,7 +131,6 @@ func (c *Connection) Handle() error {
|
||||
|
||||
c.tunnelType = req.TunnelType
|
||||
|
||||
// Check if tunnel type is allowed
|
||||
if !c.isTunnelTypeAllowed(string(req.TunnelType)) {
|
||||
c.sendError("tunnel_type_not_allowed", fmt.Sprintf("Tunnel type '%s' is not allowed on this server", req.TunnelType))
|
||||
return fmt.Errorf("tunnel type not allowed: %s", req.TunnelType)
|
||||
@@ -197,7 +185,6 @@ func (c *Connection) Handle() error {
|
||||
|
||||
c.tunnelConn.Conn = nil
|
||||
c.tunnelConn.SetTunnelType(req.TunnelType)
|
||||
c.tunnelType = req.TunnelType
|
||||
|
||||
if req.IPAccess != nil && (len(req.IPAccess.AllowIPs) > 0 || len(req.IPAccess.DenyIPs) > 0) {
|
||||
c.tunnelConn.SetIPAccessControl(req.IPAccess.AllowIPs, req.IPAccess.DenyIPs)
|
||||
@@ -215,6 +202,31 @@ func (c *Connection) Handle() error {
|
||||
)
|
||||
}
|
||||
|
||||
effectiveBandwidth := c.bandwidth
|
||||
if req.Bandwidth > 0 {
|
||||
if effectiveBandwidth == 0 || req.Bandwidth < effectiveBandwidth {
|
||||
effectiveBandwidth = req.Bandwidth
|
||||
}
|
||||
}
|
||||
if effectiveBandwidth > 0 {
|
||||
burstMultiplier := c.burstMultiplier
|
||||
if burstMultiplier <= 0 {
|
||||
burstMultiplier = 2.0
|
||||
}
|
||||
c.tunnelConn.SetBandwidthWithBurst(effectiveBandwidth, burstMultiplier)
|
||||
|
||||
source := "server"
|
||||
if req.Bandwidth > 0 && (c.bandwidth == 0 || req.Bandwidth < c.bandwidth) {
|
||||
source = "client"
|
||||
}
|
||||
c.logger.Info("Bandwidth limit configured",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.Int64("bandwidth_bytes_sec", effectiveBandwidth),
|
||||
zap.Float64("burst_multiplier", burstMultiplier),
|
||||
zap.String("source", source),
|
||||
)
|
||||
}
|
||||
|
||||
c.logger.Info("Tunnel registered",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.String("tunnel_type", string(req.TunnelType)),
|
||||
@@ -258,6 +270,7 @@ func (c *Connection) Handle() error {
|
||||
TunnelID: tunnelID,
|
||||
SupportsDataConn: supportsDataConn,
|
||||
RecommendedConns: recommendedConns,
|
||||
Bandwidth: c.tunnelConn.GetBandwidth(),
|
||||
}
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
@@ -389,7 +402,6 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
|
||||
zap.String("host", req.Host),
|
||||
)
|
||||
|
||||
// Get writer from pool to reduce GC pressure
|
||||
pooledWriter := bufioWriterPool.Get().(*bufio.Writer)
|
||||
pooledWriter.Reset(c.conn)
|
||||
|
||||
@@ -405,30 +417,17 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
|
||||
c.logger.Debug("Failed to flush HTTP response", zap.Error(err))
|
||||
}
|
||||
|
||||
// Return writer to pool
|
||||
pooledWriter.Reset(nil) // Clear reference to connection
|
||||
pooledWriter.Reset(nil)
|
||||
bufioWriterPool.Put(pooledWriter)
|
||||
|
||||
// Keep TCP_NODELAY enabled for low latency HTTP responses
|
||||
// (removed the toggle that was disabling it)
|
||||
|
||||
c.logger.Debug("HTTP request processing completed",
|
||||
zap.String("method", req.Method),
|
||||
zap.String("url", req.URL.String()),
|
||||
)
|
||||
|
||||
shouldClose := false
|
||||
if req.Close {
|
||||
shouldClose = true
|
||||
} else if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
|
||||
if req.Header.Get("Connection") != "keep-alive" {
|
||||
shouldClose = true
|
||||
}
|
||||
}
|
||||
|
||||
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
|
||||
shouldClose = true
|
||||
}
|
||||
shouldClose := req.Close ||
|
||||
(req.ProtoMajor == 1 && req.ProtoMinor == 0 && req.Header.Get("Connection") != "keep-alive") ||
|
||||
(respWriter.headerWritten && respWriter.header.Get("Connection") == "close")
|
||||
|
||||
if shouldClose {
|
||||
c.logger.Debug("Closing connection as requested by client or server")
|
||||
@@ -636,7 +635,7 @@ func (w *httpResponseWriter) WriteHeader(statusCode int) {
|
||||
}
|
||||
|
||||
w.writer.WriteString("HTTP/1.1 ")
|
||||
w.writer.WriteString(fmt.Sprintf("%d", statusCode))
|
||||
w.writer.WriteString(strconv.Itoa(statusCode))
|
||||
w.writer.WriteByte(' ')
|
||||
w.writer.WriteString(statusText)
|
||||
w.writer.WriteString("\r\n")
|
||||
@@ -755,6 +754,14 @@ func isTimeoutError(err error) bool {
|
||||
return strings.Contains(err.Error(), "i/o timeout")
|
||||
}
|
||||
|
||||
func isHTTPMethod(peek string) bool {
|
||||
switch peek {
|
||||
case "GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Connection) sendDataConnectError(code, message string) {
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: false,
|
||||
@@ -769,38 +776,40 @@ func (c *Connection) sendDataConnectError(code, message string) {
|
||||
_ = protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
|
||||
// SetAllowedTunnelTypes sets the allowed tunnel types for this connection
|
||||
func (c *Connection) SetAllowedTunnelTypes(types []string) {
|
||||
c.allowedTunnelTypes = types
|
||||
}
|
||||
|
||||
// SetAllowedTransports sets the allowed transports for this connection
|
||||
func (c *Connection) SetAllowedTransports(transports []string) {
|
||||
c.allowedTransports = transports
|
||||
}
|
||||
|
||||
// isTransportAllowed checks if a transport is allowed
|
||||
func (c *Connection) isTransportAllowed(transport string) bool {
|
||||
if len(c.allowedTransports) == 0 {
|
||||
return containsFold(c.allowedTransports, transport)
|
||||
}
|
||||
|
||||
func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool {
|
||||
return containsFold(c.allowedTunnelTypes, tunnelType)
|
||||
}
|
||||
|
||||
// containsFold returns true if the slice is empty (allow all) or contains the
|
||||
// value case-insensitively.
|
||||
func containsFold(allowed []string, value string) bool {
|
||||
if len(allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, t := range c.allowedTransports {
|
||||
if strings.EqualFold(t, transport) {
|
||||
for _, a := range allowed {
|
||||
if strings.EqualFold(a, value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTunnelTypeAllowed checks if a tunnel type is allowed
|
||||
func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool {
|
||||
if len(c.allowedTunnelTypes) == 0 {
|
||||
return true // Allow all by default
|
||||
func (c *Connection) SetBandwidthConfig(bandwidth int64, burstMultiplier float64) {
|
||||
c.bandwidth = bandwidth
|
||||
if burstMultiplier <= 0 {
|
||||
burstMultiplier = 2.0
|
||||
}
|
||||
for _, t := range c.allowedTunnelTypes {
|
||||
if strings.EqualFold(t, tunnelType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
c.burstMultiplier = burstMultiplier
|
||||
}
|
||||
|
||||
@@ -43,9 +43,10 @@ type Listener struct {
|
||||
httpServer *http.Server
|
||||
httpListener *connQueueListener
|
||||
|
||||
// Server capabilities
|
||||
allowedTransports []string
|
||||
allowedTunnelTypes []string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
}
|
||||
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, tunnelDomain string, publicPort int, httpHandler http.Handler) *Listener {
|
||||
@@ -63,7 +64,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
panicMetrics := recovery.NewPanicMetrics(logger, nil)
|
||||
recoverer := recovery.NewRecoverer(logger, panicMetrics)
|
||||
|
||||
// Initialize worker pool metrics
|
||||
metrics.WorkerPoolSize.Set(float64(workers))
|
||||
|
||||
l := &Listener{
|
||||
@@ -85,7 +85,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
groupManager: NewConnectionGroupManager(logger),
|
||||
}
|
||||
|
||||
// Set up WebSocket connection handler if httpHandler supports it
|
||||
if h, ok := httpHandler.(*proxy.Handler); ok {
|
||||
h.SetWSConnectionHandler(l)
|
||||
h.SetPublicPort(publicPort)
|
||||
@@ -97,7 +96,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
func (l *Listener) Start() error {
|
||||
var err error
|
||||
|
||||
// Support both TLS and plain TCP modes
|
||||
if l.tlsConfig != nil {
|
||||
l.listener, err = tls.Listen("tcp", l.address, l.tlsConfig)
|
||||
if err != nil {
|
||||
@@ -269,57 +267,13 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
)
|
||||
}
|
||||
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
|
||||
conn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
|
||||
conn.SetAllowedTransports(l.allowedTransports)
|
||||
conn := l.newConfiguredConnection(netConn)
|
||||
|
||||
connID := netConn.RemoteAddr().String()
|
||||
l.connMu.Lock()
|
||||
l.connections[connID] = conn
|
||||
l.connMu.Unlock()
|
||||
l.trackConnection(connID, conn)
|
||||
defer l.untrackConnection(connID, conn, netConn)
|
||||
|
||||
// Update connection metrics
|
||||
metrics.TotalConnections.Inc()
|
||||
metrics.ActiveConnections.Inc()
|
||||
|
||||
defer func() {
|
||||
l.connMu.Lock()
|
||||
delete(l.connections, connID)
|
||||
l.connMu.Unlock()
|
||||
|
||||
metrics.ActiveConnections.Dec()
|
||||
|
||||
if !conn.IsHandedOff() {
|
||||
netConn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := conn.Handle(); err != nil {
|
||||
errStr := err.Error()
|
||||
|
||||
if strings.Contains(errStr, "EOF") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(errStr, "payload too large") ||
|
||||
strings.Contains(errStr, "failed to read registration frame") ||
|
||||
strings.Contains(errStr, "expected register frame") ||
|
||||
strings.Contains(errStr, "failed to parse registration request") ||
|
||||
strings.Contains(errStr, "failed to parse HTTP request") {
|
||||
l.logger.Warn("Protocol validation failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
l.logger.Error("Connection handling failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
l.logConnectionError(conn.Handle(), connID, "Connection")
|
||||
}
|
||||
|
||||
func (l *Listener) Stop() error {
|
||||
@@ -372,7 +326,6 @@ func (l *Listener) GetActiveConnections() int {
|
||||
return len(l.connections)
|
||||
}
|
||||
|
||||
// HandleWSConnection implements proxy.WSConnectionHandler for WebSocket tunnel connections
|
||||
func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
|
||||
l.wg.Add(1)
|
||||
defer l.wg.Done()
|
||||
@@ -386,77 +339,103 @@ func (l *Listener) HandleWSConnection(conn net.Conn, remoteAddr string) {
|
||||
zap.String("remote_addr", connID),
|
||||
)
|
||||
|
||||
// Create connection handler (no TLS verification needed - already done by HTTP server)
|
||||
tcpConn := NewConnection(conn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
|
||||
tcpConn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
|
||||
tcpConn := l.newConfiguredConnection(conn)
|
||||
|
||||
l.connMu.Lock()
|
||||
l.connections[connID] = tcpConn
|
||||
l.connMu.Unlock()
|
||||
l.trackConnection(connID, tcpConn)
|
||||
defer l.untrackConnection(connID, tcpConn, conn)
|
||||
|
||||
metrics.TotalConnections.Inc()
|
||||
metrics.ActiveConnections.Inc()
|
||||
|
||||
defer func() {
|
||||
l.connMu.Lock()
|
||||
delete(l.connections, connID)
|
||||
l.connMu.Unlock()
|
||||
|
||||
metrics.ActiveConnections.Dec()
|
||||
|
||||
if !tcpConn.IsHandedOff() {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := tcpConn.Handle(); err != nil {
|
||||
errStr := err.Error()
|
||||
|
||||
if strings.Contains(errStr, "EOF") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") ||
|
||||
strings.Contains(errStr, "websocket: close") {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(errStr, "payload too large") ||
|
||||
strings.Contains(errStr, "failed to read registration frame") ||
|
||||
strings.Contains(errStr, "expected register frame") ||
|
||||
strings.Contains(errStr, "failed to parse registration request") ||
|
||||
strings.Contains(errStr, "tunnel type not allowed") {
|
||||
l.logger.Warn("WebSocket tunnel protocol validation failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
l.logger.Error("WebSocket tunnel connection handling failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
l.logConnectionError(tcpConn.Handle(), connID, "WebSocket tunnel")
|
||||
}
|
||||
|
||||
// SetAllowedTransports sets the allowed transport protocols
|
||||
func (l *Listener) SetAllowedTransports(transports []string) {
|
||||
l.allowedTransports = transports
|
||||
}
|
||||
|
||||
// SetAllowedTunnelTypes sets the allowed tunnel types
|
||||
func (l *Listener) SetAllowedTunnelTypes(types []string) {
|
||||
l.allowedTunnelTypes = types
|
||||
}
|
||||
|
||||
// IsTransportAllowed checks if a transport is allowed
|
||||
func (l *Listener) IsTransportAllowed(transport string) bool {
|
||||
if len(l.allowedTransports) == 0 {
|
||||
return true
|
||||
return containsFold(l.allowedTransports, transport)
|
||||
}
|
||||
|
||||
func (l *Listener) SetBurstMultiplier(multiplier float64) {
|
||||
if multiplier <= 0 {
|
||||
multiplier = 2.0
|
||||
}
|
||||
for _, t := range l.allowedTransports {
|
||||
if strings.EqualFold(t, transport) {
|
||||
return true
|
||||
l.burstMultiplier = multiplier
|
||||
}
|
||||
|
||||
func (l *Listener) SetBandwidth(bandwidth int64) {
|
||||
l.bandwidth = bandwidth
|
||||
}
|
||||
|
||||
func (l *Listener) newConfiguredConnection(conn net.Conn) *Connection {
|
||||
c := NewConnection(conn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
|
||||
c.SetAllowedTunnelTypes(l.allowedTunnelTypes)
|
||||
c.SetAllowedTransports(l.allowedTransports)
|
||||
c.SetBandwidthConfig(l.bandwidth, l.burstMultiplier)
|
||||
return c
|
||||
}
|
||||
|
||||
func (l *Listener) trackConnection(connID string, conn *Connection) {
|
||||
l.connMu.Lock()
|
||||
l.connections[connID] = conn
|
||||
l.connMu.Unlock()
|
||||
|
||||
metrics.TotalConnections.Inc()
|
||||
metrics.ActiveConnections.Inc()
|
||||
}
|
||||
|
||||
func (l *Listener) untrackConnection(connID string, conn *Connection, netConn net.Conn) {
|
||||
l.connMu.Lock()
|
||||
delete(l.connections, connID)
|
||||
l.connMu.Unlock()
|
||||
|
||||
metrics.ActiveConnections.Dec()
|
||||
|
||||
if !conn.IsHandedOff() {
|
||||
netConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// logConnectionError classifies and logs a connection handling error.
|
||||
// Transient network errors are silently ignored, protocol errors are warned,
|
||||
// and everything else is logged as an error.
|
||||
func (l *Listener) logConnectionError(err error, connID, label string) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
|
||||
// Transient / expected disconnects — ignore silently
|
||||
for _, substr := range []string{
|
||||
"EOF", "connection reset by peer", "broken pipe",
|
||||
"connection refused", "websocket: close",
|
||||
} {
|
||||
if strings.Contains(errStr, substr) {
|
||||
return
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
// Protocol-level validation failures — warn
|
||||
for _, substr := range []string{
|
||||
"payload too large", "failed to read registration frame",
|
||||
"expected register frame", "failed to parse registration request",
|
||||
"failed to parse HTTP request", "tunnel type not allowed",
|
||||
} {
|
||||
if strings.Contains(errStr, substr) {
|
||||
l.logger.Warn(label+" protocol validation failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
l.logger.Error(label+" handling failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/qos"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -34,6 +35,7 @@ type Proxy struct {
|
||||
cancel context.CancelFunc
|
||||
|
||||
checkIPAccess func(ip string) bool
|
||||
limiter *qos.Limiter
|
||||
}
|
||||
|
||||
type trafficStats interface {
|
||||
@@ -49,12 +51,6 @@ func NewProxy(ctx context.Context, port int, subdomain string, openStream func()
|
||||
}
|
||||
cctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
const maxConcurrentConnections = 10000
|
||||
var sem chan struct{}
|
||||
if maxConcurrentConnections > 0 {
|
||||
sem = make(chan struct{}, maxConcurrentConnections)
|
||||
}
|
||||
|
||||
return &Proxy{
|
||||
port: port,
|
||||
subdomain: subdomain,
|
||||
@@ -62,17 +58,20 @@ func NewProxy(ctx context.Context, port int, subdomain string, openStream func()
|
||||
stopCh: make(chan struct{}),
|
||||
openStream: openStream,
|
||||
stats: stats,
|
||||
sem: sem,
|
||||
sem: make(chan struct{}, 10000),
|
||||
ctx: cctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// SetIPAccessCheck sets the IP access control check function.
|
||||
func (p *Proxy) SetIPAccessCheck(check func(ip string) bool) {
|
||||
p.checkIPAccess = check
|
||||
}
|
||||
|
||||
func (p *Proxy) SetLimiter(limiter *qos.Limiter) {
|
||||
p.limiter = limiter
|
||||
}
|
||||
|
||||
func (p *Proxy) Start() error {
|
||||
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
|
||||
|
||||
@@ -174,13 +173,11 @@ func (p *Proxy) handleConn(conn net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
if p.sem != nil {
|
||||
select {
|
||||
case p.sem <- struct{}{}:
|
||||
defer func() { <-p.sem }()
|
||||
default:
|
||||
return
|
||||
}
|
||||
select {
|
||||
case p.sem <- struct{}{}:
|
||||
defer func() { <-p.sem }()
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
if p.stats != nil {
|
||||
@@ -243,7 +240,7 @@ func (p *Proxy) handleConn(conn net.Conn) {
|
||||
_ = netutil.PipeWithCallbacksAndBufferSize(
|
||||
p.ctx,
|
||||
conn,
|
||||
stream,
|
||||
qos.NewLimitedConn(p.ctx, stream, p.limiter),
|
||||
pool.SizeLarge,
|
||||
func(n int64) {
|
||||
if p.stats != nil {
|
||||
|
||||
@@ -19,19 +19,17 @@ func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
// initMuxSession creates a yamux session over the buffered connection and
|
||||
// returns the openStream function (possibly group-aware).
|
||||
func (c *Connection) initMuxSession(reader *bufio.Reader) (func() (net.Conn, error), *yamux.Session, error) {
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
// Use optimized mux config for server
|
||||
cfg := mux.NewServerConfig()
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
session, err := yamux.Client(bc, mux.NewServerConfig())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
@@ -43,10 +41,22 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
}
|
||||
}
|
||||
|
||||
return openStream, session, nil
|
||||
}
|
||||
|
||||
func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
openStream, session, err := c.initMuxSession(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.proxy = NewProxy(c.ctx, c.port, c.subdomain, openStream, c.tunnelConn, c.logger)
|
||||
if c.tunnelConn != nil && c.tunnelConn.HasIPAccessControl() {
|
||||
c.proxy.SetIPAccessCheck(c.tunnelConn.IsIPAllowed)
|
||||
}
|
||||
if c.tunnelConn != nil {
|
||||
c.proxy.SetLimiter(c.tunnelConn.GetLimiter())
|
||||
}
|
||||
|
||||
if err := c.proxy.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start tcp proxy: %w", err)
|
||||
@@ -61,27 +71,9 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
}
|
||||
|
||||
func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
// Use optimized mux config for server
|
||||
cfg := mux.NewServerConfig()
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
openStream, session, err := c.initMuxSession(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
openStream := session.Open
|
||||
if c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
group.AddSession("primary", session)
|
||||
openStream = group.OpenStream
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if c.tunnelConn != nil {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"drip/internal/server/metrics"
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/qos"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -32,6 +33,9 @@ type Connection struct {
|
||||
|
||||
ipAccessChecker *netutil.IPAccessChecker
|
||||
proxyAuth *protocol.ProxyAuth
|
||||
|
||||
bandwidth int64
|
||||
limiter *qos.Limiter
|
||||
}
|
||||
|
||||
func NewConnection(subdomain string, conn *websocket.Conn, logger *zap.Logger) *Connection {
|
||||
@@ -214,6 +218,34 @@ func (c *Connection) ValidateProxyAuth(password string) bool {
|
||||
return auth.Password == password
|
||||
}
|
||||
|
||||
func (c *Connection) SetBandwidth(bandwidth int64) {
|
||||
c.SetBandwidthWithBurst(bandwidth, 2.0)
|
||||
}
|
||||
|
||||
func (c *Connection) SetBandwidthWithBurst(bandwidth int64, burstMultiplier float64) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.bandwidth = bandwidth
|
||||
if bandwidth > 0 {
|
||||
burst := int(float64(bandwidth) * burstMultiplier)
|
||||
c.limiter = qos.NewLimiter(qos.Config{Bandwidth: bandwidth, Burst: burst})
|
||||
} else {
|
||||
c.limiter = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) GetBandwidth() int64 {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.bandwidth
|
||||
}
|
||||
|
||||
func (c *Connection) GetLimiter() *qos.Limiter {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.limiter
|
||||
}
|
||||
|
||||
func (c *Connection) StartWritePump() {
|
||||
if c.Conn == nil {
|
||||
go func() {
|
||||
|
||||
117
internal/server/tunnel/connection_test.go
Normal file
117
internal/server/tunnel/connection_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConnectionBandwidthWithBurst(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
burstMultiplier float64
|
||||
wantBandwidth int64
|
||||
wantBurst int
|
||||
}{
|
||||
{
|
||||
name: "1MB/s with 2x burst",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 2.0,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantBurst: 2 * 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "1MB/s with 2.5x burst",
|
||||
bandwidth: 1024 * 1024,
|
||||
burstMultiplier: 2.5,
|
||||
wantBandwidth: 1024 * 1024,
|
||||
wantBurst: int(float64(1024*1024) * 2.5),
|
||||
},
|
||||
{
|
||||
name: "500KB/s with 3x burst",
|
||||
bandwidth: 500 * 1024,
|
||||
burstMultiplier: 3.0,
|
||||
wantBandwidth: 500 * 1024,
|
||||
wantBurst: 3 * 500 * 1024,
|
||||
},
|
||||
{
|
||||
name: "10MB/s with 1.5x burst",
|
||||
bandwidth: 10 * 1024 * 1024,
|
||||
burstMultiplier: 1.5,
|
||||
wantBandwidth: 10 * 1024 * 1024,
|
||||
wantBurst: int(float64(10*1024*1024) * 1.5),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
conn := NewConnection("test-subdomain", nil, logger)
|
||||
|
||||
conn.SetBandwidthWithBurst(tt.bandwidth, tt.burstMultiplier)
|
||||
|
||||
if conn.GetBandwidth() != tt.wantBandwidth {
|
||||
t.Errorf("GetBandwidth() = %v, want %v", conn.GetBandwidth(), tt.wantBandwidth)
|
||||
}
|
||||
|
||||
limiter := conn.GetLimiter()
|
||||
if limiter == nil {
|
||||
t.Fatal("GetLimiter() should not be nil")
|
||||
}
|
||||
|
||||
if !limiter.IsLimited() {
|
||||
t.Error("Limiter should be limited")
|
||||
}
|
||||
|
||||
if limiter.RateLimiter().Burst() != tt.wantBurst {
|
||||
t.Errorf("Burst() = %v, want %v", limiter.RateLimiter().Burst(), tt.wantBurst)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionBandwidthUnlimited(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test-subdomain", nil, logger)
|
||||
|
||||
if conn.GetBandwidth() != 0 {
|
||||
t.Errorf("Default bandwidth should be 0, got %v", conn.GetBandwidth())
|
||||
}
|
||||
|
||||
if conn.GetLimiter() != nil {
|
||||
t.Error("Default limiter should be nil")
|
||||
}
|
||||
|
||||
conn.SetBandwidth(0)
|
||||
if conn.GetLimiter() != nil {
|
||||
t.Error("Limiter should be nil when bandwidth is 0")
|
||||
}
|
||||
|
||||
conn.SetBandwidthWithBurst(0, 2.0)
|
||||
if conn.GetLimiter() != nil {
|
||||
t.Error("Limiter should be nil when bandwidth is 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionSetBandwidth(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test-subdomain", nil, logger)
|
||||
|
||||
conn.SetBandwidth(1024 * 1024)
|
||||
|
||||
if conn.GetBandwidth() != 1024*1024 {
|
||||
t.Errorf("GetBandwidth() = %v, want %v", conn.GetBandwidth(), 1024*1024)
|
||||
}
|
||||
|
||||
limiter := conn.GetLimiter()
|
||||
if limiter == nil {
|
||||
t.Fatal("GetLimiter() should not be nil")
|
||||
}
|
||||
|
||||
expectedBurst := 2 * 1024 * 1024
|
||||
if limiter.RateLimiter().Burst() != expectedBurst {
|
||||
t.Errorf("Burst() = %v, want %v", limiter.RateLimiter().Burst(), expectedBurst)
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,7 @@ type RegisterRequest struct {
|
||||
PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"`
|
||||
IPAccess *IPAccessControl `json:"ip_access,omitempty"`
|
||||
ProxyAuth *ProxyAuth `json:"proxy_auth,omitempty"`
|
||||
Bandwidth int64 `json:"bandwidth,omitempty"` // Bandwidth limit (bytes/sec), 0 = unlimited
|
||||
}
|
||||
|
||||
type RegisterResponse struct {
|
||||
@@ -37,6 +38,7 @@ type RegisterResponse struct {
|
||||
TunnelID string `json:"tunnel_id,omitempty"`
|
||||
SupportsDataConn bool `json:"supports_data_conn,omitempty"`
|
||||
RecommendedConns int `json:"recommended_conns,omitempty"`
|
||||
Bandwidth int64 `json:"bandwidth,omitempty"` // Applied bandwidth limit (bytes/sec)
|
||||
}
|
||||
|
||||
type DataConnectRequest struct {
|
||||
|
||||
112
internal/shared/qos/conn.go
Normal file
112
internal/shared/qos/conn.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type LimitedConn struct {
|
||||
net.Conn
|
||||
limiter *Limiter
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewLimitedConn(ctx context.Context, conn net.Conn, limiter *Limiter) *LimitedConn {
|
||||
return &LimitedConn{
|
||||
Conn: conn,
|
||||
limiter: limiter,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LimitedConn) Read(b []byte) (n int, err error) {
|
||||
if c.limiter == nil || !c.limiter.IsLimited() {
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
burst := c.limiter.RateLimiter().Burst()
|
||||
if len(b) > burst {
|
||||
b = b[:burst]
|
||||
}
|
||||
|
||||
n, err = c.Conn.Read(b)
|
||||
if n > 0 {
|
||||
if waitErr := c.limiter.RateLimiter().WaitN(c.ctx, n); waitErr != nil {
|
||||
if err == nil {
|
||||
err = waitErr
|
||||
}
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *LimitedConn) Write(b []byte) (n int, err error) {
|
||||
if c.limiter == nil || !c.limiter.IsLimited() {
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
burst := c.limiter.RateLimiter().Burst()
|
||||
total := 0
|
||||
|
||||
for len(b) > 0 {
|
||||
chunk := min(len(b), burst)
|
||||
|
||||
if err := c.limiter.RateLimiter().WaitN(c.ctx, chunk); err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
nw, err := c.Conn.Write(b[:chunk])
|
||||
total += nw
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
b = b[chunk:]
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (c *LimitedConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
nr, er := r.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := c.Write(buf[:nr])
|
||||
n += int64(nw)
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er != io.EOF {
|
||||
err = er
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *LimitedConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
nr, er := c.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := w.Write(buf[:nr])
|
||||
n += int64(nw)
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er != io.EOF {
|
||||
err = er
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
484
internal/shared/qos/conn_test.go
Normal file
484
internal/shared/qos/conn_test.go
Normal file
@@ -0,0 +1,484 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type errorAfterConn struct {
|
||||
mockConn
|
||||
writeLimit int
|
||||
written int
|
||||
}
|
||||
|
||||
func (c *errorAfterConn) Write(b []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
remaining := c.writeLimit - c.written
|
||||
if remaining <= 0 {
|
||||
return 0, errors.New("write error")
|
||||
}
|
||||
if len(b) > remaining {
|
||||
b = b[:remaining]
|
||||
}
|
||||
c.writeBuf = append(c.writeBuf, b...)
|
||||
c.written += len(b)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func TestWriteLargerThanBurst(t *testing.T) {
|
||||
// 10KB/s, burst=1KB — write 5KB should be chunked into 5 pieces
|
||||
bandwidth := int64(10 * 1024)
|
||||
burst := 1024
|
||||
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
data := make([]byte, 5*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
n, err := lc.Write(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("Write returned %d, want %d", n, len(data))
|
||||
}
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
if !bytes.Equal(conn.writeBuf, data) {
|
||||
t.Error("Written data does not match input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteExactBurstSize(t *testing.T) {
|
||||
burst := 2048
|
||||
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
|
||||
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
data := make([]byte, burst)
|
||||
start := time.Now()
|
||||
n, err := lc.Write(data)
|
||||
dur := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
if n != burst {
|
||||
t.Errorf("Write returned %d, want %d", n, burst)
|
||||
}
|
||||
// Exact burst should be instant
|
||||
if dur > 100*time.Millisecond {
|
||||
t.Errorf("Exact burst write should be instant, took %v", dur)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteZeroLength(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
n, err := lc.Write(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Write(nil) failed: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("Write(nil) returned %d, want 0", n)
|
||||
}
|
||||
|
||||
n, err = lc.Write([]byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("Write([]) failed: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("Write([]) returned %d, want 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteContextCancelDuringChunking(t *testing.T) {
|
||||
// Very slow rate, small burst so second chunk must wait
|
||||
limiter := NewLimiter(Config{Bandwidth: 100, Burst: 100})
|
||||
conn := newMockConn(nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
lc := NewLimitedConn(ctx, conn, limiter)
|
||||
|
||||
// Use up burst
|
||||
_, err := lc.Write(make([]byte, 100))
|
||||
if err != nil {
|
||||
t.Fatalf("First write failed: %v", err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
// This write needs more tokens but context is cancelled
|
||||
_, err = lc.Write(make([]byte, 200))
|
||||
if err == nil {
|
||||
t.Error("Write should fail after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWritePartialErrorMidChunk(t *testing.T) {
|
||||
// Underlying conn fails after 500 bytes
|
||||
conn := &errorAfterConn{writeLimit: 500}
|
||||
limiter := NewLimiter(Config{Bandwidth: 100000, Burst: 1024})
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
data := make([]byte, 2048)
|
||||
n, err := lc.Write(data)
|
||||
if err == nil {
|
||||
t.Error("Expected write error")
|
||||
}
|
||||
if n < 500 {
|
||||
t.Errorf("Expected at least 500 bytes written, got %d", n)
|
||||
}
|
||||
if n > 1024 {
|
||||
t.Errorf("Expected at most 1024 bytes written (one chunk), got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCappedToBurst(t *testing.T) {
|
||||
burst := 512
|
||||
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
|
||||
|
||||
// Provide 4KB of data
|
||||
data := make([]byte, 4096)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
conn := newMockConn(data)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
// Request 4KB read, should get at most burst (512) bytes
|
||||
buf := make([]byte, 4096)
|
||||
n, err := lc.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
if n > burst {
|
||||
t.Errorf("Read returned %d bytes, should be capped at burst=%d", n, burst)
|
||||
}
|
||||
if !bytes.Equal(buf[:n], data[:n]) {
|
||||
t.Error("Read data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadSmallerThanBurst(t *testing.T) {
|
||||
burst := 4096
|
||||
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
|
||||
|
||||
data := make([]byte, 100)
|
||||
conn := newMockConn(data)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
buf := make([]byte, 100)
|
||||
n, err := lc.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
if n != 100 {
|
||||
t.Errorf("Read returned %d, want 100", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadEOF(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024})
|
||||
conn := newMockConn([]byte{}) // empty
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
buf := make([]byte, 100)
|
||||
_, err := lc.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("Expected io.EOF, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadContextCancel(t *testing.T) {
|
||||
// Slow rate, small burst
|
||||
limiter := NewLimiter(Config{Bandwidth: 100, Burst: 100})
|
||||
data := make([]byte, 200)
|
||||
conn := newMockConn(data)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
lc := NewLimitedConn(ctx, conn, limiter)
|
||||
|
||||
// First read uses burst
|
||||
buf := make([]byte, 100)
|
||||
_, err := lc.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("First read failed: %v", err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
// Second read should fail on WaitN
|
||||
_, err = lc.Read(buf)
|
||||
if err == nil {
|
||||
t.Error("Read should fail after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFromInterface(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
var _ io.ReaderFrom = lc
|
||||
}
|
||||
|
||||
func TestWriteToInterface(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
data := make([]byte, 100)
|
||||
conn := newMockConn(data)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
var _ io.WriterTo = lc
|
||||
}
|
||||
|
||||
func TestReadFromBasic(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
src := bytes.NewReader(make([]byte, 50*1024))
|
||||
n, err := lc.ReadFrom(src)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrom failed: %v", err)
|
||||
}
|
||||
if n != 50*1024 {
|
||||
t.Errorf("ReadFrom transferred %d bytes, want %d", n, 50*1024)
|
||||
}
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
if len(conn.writeBuf) != 50*1024 {
|
||||
t.Errorf("Underlying conn received %d bytes, want %d", len(conn.writeBuf), 50*1024)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFromReaderError(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
errReader := &failingReader{failAfter: 100}
|
||||
n, err := lc.ReadFrom(errReader)
|
||||
if err == nil {
|
||||
t.Error("Expected error from ReadFrom")
|
||||
}
|
||||
if n != 100 {
|
||||
t.Errorf("ReadFrom transferred %d bytes before error, want 100", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteToBasic(t *testing.T) {
|
||||
data := make([]byte, 50*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
conn := newMockConn(data)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
var buf bytes.Buffer
|
||||
n, err := lc.WriteTo(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTo failed: %v", err)
|
||||
}
|
||||
if n != int64(len(data)) {
|
||||
t.Errorf("WriteTo transferred %d bytes, want %d", n, len(data))
|
||||
}
|
||||
if !bytes.Equal(buf.Bytes(), data) {
|
||||
t.Error("WriteTo data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteToWriterError(t *testing.T) {
|
||||
data := make([]byte, 1024)
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
conn := newMockConn(data)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
fw := &failingWriter{failAfter: 100}
|
||||
_, err := lc.WriteTo(fw)
|
||||
if err == nil {
|
||||
t.Error("Expected error from WriteTo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFromRateLimited(t *testing.T) {
|
||||
// 10KB/s, 10KB burst — 20KB transfer should take ~1s
|
||||
limiter := NewLimiter(Config{Bandwidth: 10 * 1024, Burst: 10 * 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
src := bytes.NewReader(make([]byte, 20*1024))
|
||||
start := time.Now()
|
||||
n, err := lc.ReadFrom(src)
|
||||
dur := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrom failed: %v", err)
|
||||
}
|
||||
if n != 20*1024 {
|
||||
t.Errorf("ReadFrom transferred %d, want %d", n, 20*1024)
|
||||
}
|
||||
if dur < 800*time.Millisecond {
|
||||
t.Errorf("ReadFrom too fast: %v (expected ~1s for 20KB at 10KB/s with 10KB burst)", dur)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIoCopyUsesReadFrom verifies io.Copy goes through our ReadFrom,
|
||||
// not the underlying conn's optimized path.
|
||||
func TestIoCopyUsesReadFrom(t *testing.T) {
|
||||
// Use a small burst so we can detect if rate limiting is applied
|
||||
limiter := NewLimiter(Config{Bandwidth: 10 * 1024, Burst: 10 * 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
src := bytes.NewReader(make([]byte, 20*1024))
|
||||
start := time.Now()
|
||||
n, err := io.Copy(lc, src)
|
||||
dur := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("io.Copy failed: %v", err)
|
||||
}
|
||||
if n != 20*1024 {
|
||||
t.Errorf("io.Copy transferred %d, want %d", n, 20*1024)
|
||||
}
|
||||
// If io.Copy bypassed our ReadFrom, it would be instant
|
||||
if dur < 800*time.Millisecond {
|
||||
t.Errorf("io.Copy too fast (%v), rate limiting may be bypassed", dur)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlimitedWrite(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 0})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
data := make([]byte, 1024*1024) // 1MB
|
||||
start := time.Now()
|
||||
n, err := lc.Write(data)
|
||||
dur := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("Write returned %d, want %d", n, len(data))
|
||||
}
|
||||
if dur > 50*time.Millisecond {
|
||||
t.Errorf("Unlimited write took too long: %v", dur)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilLimiter(t *testing.T) {
|
||||
conn := newMockConn([]byte("hello"))
|
||||
lc := NewLimitedConn(context.Background(), conn, nil)
|
||||
|
||||
buf := make([]byte, 10)
|
||||
n, err := lc.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read with nil limiter failed: %v", err)
|
||||
}
|
||||
if n != 5 {
|
||||
t.Errorf("Read returned %d, want 5", n)
|
||||
}
|
||||
|
||||
n, err = lc.Write([]byte("world"))
|
||||
if err != nil {
|
||||
t.Fatalf("Write with nil limiter failed: %v", err)
|
||||
}
|
||||
if n != 5 {
|
||||
t.Errorf("Write returned %d, want 5", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentReadWrite(t *testing.T) {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
limiter := NewLimiter(Config{Bandwidth: 100 * 1024, Burst: 100 * 1024})
|
||||
lc := NewLimitedConn(context.Background(), serverConn, limiter)
|
||||
|
||||
dataSize := 50 * 1024
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Writer
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
data := make([]byte, dataSize)
|
||||
for i := range data {
|
||||
data[i] = 0xAA
|
||||
}
|
||||
lc.Write(data)
|
||||
}()
|
||||
|
||||
// Reader on the other end
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, dataSize)
|
||||
total := 0
|
||||
for total < dataSize {
|
||||
n, err := clientConn.Read(buf[total:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
total += n
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type failingReader struct {
|
||||
failAfter int
|
||||
read int
|
||||
}
|
||||
|
||||
func (r *failingReader) Read(b []byte) (int, error) {
|
||||
remaining := r.failAfter - r.read
|
||||
if remaining <= 0 {
|
||||
return 0, errors.New("reader error")
|
||||
}
|
||||
n := len(b)
|
||||
if n > remaining {
|
||||
n = remaining
|
||||
}
|
||||
r.read += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
type failingWriter struct {
|
||||
failAfter int
|
||||
written int
|
||||
}
|
||||
|
||||
func (w *failingWriter) Write(b []byte) (int, error) {
|
||||
remaining := w.failAfter - w.written
|
||||
if remaining <= 0 {
|
||||
return 0, errors.New("writer error")
|
||||
}
|
||||
n := len(b)
|
||||
if n > remaining {
|
||||
w.written += remaining
|
||||
return remaining, errors.New("writer error")
|
||||
}
|
||||
w.written += n
|
||||
return n, nil
|
||||
}
|
||||
269
internal/shared/qos/integration_test.go
Normal file
269
internal/shared/qos/integration_test.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEndToEndBandwidthLimiting(t *testing.T) {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
bandwidth := int64(100 * 1024)
|
||||
burstMultiplier := 2.0
|
||||
burst := int(float64(bandwidth) * burstMultiplier)
|
||||
|
||||
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
ctx := context.Background()
|
||||
limitedServerConn := NewLimitedConn(ctx, serverConn, limiter)
|
||||
|
||||
dataSize := 500 * 1024
|
||||
testData := make([]byte, dataSize)
|
||||
for i := range testData {
|
||||
testData[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var writeErr, readErr error
|
||||
var writeDuration time.Duration
|
||||
receivedData := make([]byte, dataSize)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
start := time.Now()
|
||||
chunkSize := 32 * 1024
|
||||
for i := 0; i < dataSize; i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > dataSize {
|
||||
end = dataSize
|
||||
}
|
||||
_, err := limitedServerConn.Write(testData[i:end])
|
||||
if err != nil {
|
||||
writeErr = err
|
||||
return
|
||||
}
|
||||
}
|
||||
writeDuration = time.Since(start)
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
totalRead := 0
|
||||
for totalRead < dataSize {
|
||||
n, err := clientConn.Read(receivedData[totalRead:])
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
readErr = err
|
||||
}
|
||||
return
|
||||
}
|
||||
totalRead += n
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if writeErr != nil {
|
||||
t.Fatalf("Write error: %v", writeErr)
|
||||
}
|
||||
if readErr != nil {
|
||||
t.Fatalf("Read error: %v", readErr)
|
||||
}
|
||||
|
||||
for i := 0; i < dataSize; i++ {
|
||||
if receivedData[i] != testData[i] {
|
||||
t.Fatalf("Data mismatch at byte %d: got %d, want %d", i, receivedData[i], testData[i])
|
||||
}
|
||||
}
|
||||
|
||||
expectedMinDuration := 2500 * time.Millisecond
|
||||
expectedMaxDuration := 4000 * time.Millisecond
|
||||
|
||||
if writeDuration < expectedMinDuration {
|
||||
t.Errorf("Transfer too fast: %v (expected >= %v)", writeDuration, expectedMinDuration)
|
||||
}
|
||||
if writeDuration > expectedMaxDuration {
|
||||
t.Errorf("Transfer too slow: %v (expected <= %v)", writeDuration, expectedMaxDuration)
|
||||
}
|
||||
|
||||
t.Logf("Transferred %d bytes in %v (rate: %.2f KB/s)",
|
||||
dataSize, writeDuration, float64(dataSize)/writeDuration.Seconds()/1024)
|
||||
}
|
||||
|
||||
func TestBidirectionalBandwidthLimiting(t *testing.T) {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
bandwidth := int64(50 * 1024)
|
||||
burst := int(bandwidth * 2)
|
||||
|
||||
serverLimiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
clientLimiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
|
||||
ctx := context.Background()
|
||||
limitedServerConn := NewLimitedConn(ctx, serverConn, serverLimiter)
|
||||
limitedClientConn := NewLimitedConn(ctx, clientConn, clientLimiter)
|
||||
|
||||
dataSize := 200 * 1024
|
||||
serverData := make([]byte, dataSize)
|
||||
clientData := make([]byte, dataSize)
|
||||
for i := range serverData {
|
||||
serverData[i] = byte(i % 256)
|
||||
clientData[i] = byte((i + 128) % 256)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
receivedByClient := make([]byte, dataSize)
|
||||
receivedByServer := make([]byte, dataSize)
|
||||
|
||||
// Server writes to client
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
chunkSize := 16 * 1024
|
||||
for i := 0; i < dataSize; i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > dataSize {
|
||||
end = dataSize
|
||||
}
|
||||
limitedServerConn.Write(serverData[i:end])
|
||||
}
|
||||
}()
|
||||
|
||||
// Client writes to server
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
chunkSize := 16 * 1024
|
||||
for i := 0; i < dataSize; i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > dataSize {
|
||||
end = dataSize
|
||||
}
|
||||
limitedClientConn.Write(clientData[i:end])
|
||||
}
|
||||
}()
|
||||
|
||||
// Client reads from server
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
totalRead := 0
|
||||
for totalRead < dataSize {
|
||||
n, err := limitedClientConn.Read(receivedByClient[totalRead:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
totalRead += n
|
||||
}
|
||||
}()
|
||||
|
||||
// Server reads from client
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
totalRead := 0
|
||||
for totalRead < dataSize {
|
||||
n, err := limitedServerConn.Read(receivedByServer[totalRead:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
totalRead += n
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
for i := 0; i < dataSize; i++ {
|
||||
if receivedByClient[i] != serverData[i] {
|
||||
t.Fatalf("Client received wrong data at byte %d", i)
|
||||
}
|
||||
if receivedByServer[i] != clientData[i] {
|
||||
t.Fatalf("Server received wrong data at byte %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("Bidirectional transfer completed successfully")
|
||||
}
|
||||
|
||||
func TestBurstBehavior(t *testing.T) {
|
||||
bandwidth := int64(10 * 1024)
|
||||
burst := 50 * 1024
|
||||
|
||||
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
ctx := context.Background()
|
||||
|
||||
start := time.Now()
|
||||
err := limiter.RateLimiter().WaitN(ctx, burst)
|
||||
if err != nil {
|
||||
t.Fatalf("WaitN failed: %v", err)
|
||||
}
|
||||
burstDuration := time.Since(start)
|
||||
|
||||
if burstDuration > 100*time.Millisecond {
|
||||
t.Errorf("Burst should be instant, took %v", burstDuration)
|
||||
}
|
||||
|
||||
start = time.Now()
|
||||
err = limiter.RateLimiter().WaitN(ctx, 10*1024)
|
||||
if err != nil {
|
||||
t.Fatalf("WaitN failed: %v", err)
|
||||
}
|
||||
limitedDuration := time.Since(start)
|
||||
|
||||
if limitedDuration < 900*time.Millisecond || limitedDuration > 1200*time.Millisecond {
|
||||
t.Errorf("Rate limiting not working correctly, took %v (expected ~1s)", limitedDuration)
|
||||
}
|
||||
|
||||
t.Logf("Burst: %v, Rate-limited: %v", burstDuration, limitedDuration)
|
||||
}
|
||||
|
||||
func TestMultipleBurstMultipliers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
multiplier float64
|
||||
}{
|
||||
{"1x burst", 10 * 1024, 1.0},
|
||||
{"1.5x burst", 10 * 1024, 1.5},
|
||||
{"2x burst", 10 * 1024, 2.0},
|
||||
{"2.5x burst", 10 * 1024, 2.5},
|
||||
{"3x burst", 10 * 1024, 3.0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
burst := int(float64(tt.bandwidth) * tt.multiplier)
|
||||
limiter := NewLimiter(Config{Bandwidth: tt.bandwidth, Burst: burst})
|
||||
|
||||
if !limiter.IsLimited() {
|
||||
t.Error("Limiter should be limited")
|
||||
}
|
||||
|
||||
actualBurst := limiter.RateLimiter().Burst()
|
||||
if actualBurst != burst {
|
||||
t.Errorf("Burst = %d, want %d", actualBurst, burst)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
start := time.Now()
|
||||
err := limiter.RateLimiter().WaitN(ctx, burst)
|
||||
if err != nil {
|
||||
t.Fatalf("WaitN failed: %v", err)
|
||||
}
|
||||
duration := time.Since(start)
|
||||
|
||||
if duration > 50*time.Millisecond {
|
||||
t.Errorf("Burst should be instant, took %v", duration)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
34
internal/shared/qos/limiter.go
Normal file
34
internal/shared/qos/limiter.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Bandwidth int64
|
||||
Burst int
|
||||
}
|
||||
|
||||
type Limiter struct {
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func NewLimiter(cfg Config) *Limiter {
|
||||
l := &Limiter{}
|
||||
if cfg.Bandwidth > 0 {
|
||||
burst := cfg.Burst
|
||||
if burst <= 0 {
|
||||
burst = int(cfg.Bandwidth * 2)
|
||||
}
|
||||
l.limiter = rate.NewLimiter(rate.Limit(cfg.Bandwidth), burst)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Limiter) RateLimiter() *rate.Limiter {
|
||||
return l.limiter
|
||||
}
|
||||
|
||||
func (l *Limiter) IsLimited() bool {
|
||||
return l.limiter != nil
|
||||
}
|
||||
313
internal/shared/qos/limiter_test.go
Normal file
313
internal/shared/qos/limiter_test.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewLimiter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantLimit bool
|
||||
wantBurst int
|
||||
}{
|
||||
{
|
||||
name: "unlimited when bandwidth is 0",
|
||||
cfg: Config{Bandwidth: 0},
|
||||
wantLimit: false,
|
||||
},
|
||||
{
|
||||
name: "limited with default burst (2x)",
|
||||
cfg: Config{Bandwidth: 1024},
|
||||
wantLimit: true,
|
||||
wantBurst: 2048, // 2x bandwidth
|
||||
},
|
||||
{
|
||||
name: "limited with custom burst",
|
||||
cfg: Config{Bandwidth: 1024, Burst: 4096},
|
||||
wantLimit: true,
|
||||
wantBurst: 4096,
|
||||
},
|
||||
{
|
||||
name: "1MB/s with 2x burst",
|
||||
cfg: Config{Bandwidth: 1024 * 1024},
|
||||
wantLimit: true,
|
||||
wantBurst: 2 * 1024 * 1024,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l := NewLimiter(tt.cfg)
|
||||
if l.IsLimited() != tt.wantLimit {
|
||||
t.Errorf("IsLimited() = %v, want %v", l.IsLimited(), tt.wantLimit)
|
||||
}
|
||||
if tt.wantLimit {
|
||||
if l.RateLimiter() == nil {
|
||||
t.Error("RateLimiter() should not be nil when limited")
|
||||
}
|
||||
if l.RateLimiter().Burst() != tt.wantBurst {
|
||||
t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiterBandwidthEnforcement(t *testing.T) {
|
||||
bandwidth := int64(10 * 1024)
|
||||
burst := int(bandwidth * 2)
|
||||
|
||||
l := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
|
||||
if !l.IsLimited() {
|
||||
t.Fatal("Limiter should be limited")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
start := time.Now()
|
||||
err := l.RateLimiter().WaitN(ctx, burst)
|
||||
if err != nil {
|
||||
t.Fatalf("WaitN failed: %v", err)
|
||||
}
|
||||
burstDuration := time.Since(start)
|
||||
if burstDuration > 100*time.Millisecond {
|
||||
t.Errorf("Burst should be instant, took %v", burstDuration)
|
||||
}
|
||||
|
||||
start = time.Now()
|
||||
err = l.RateLimiter().WaitN(ctx, int(bandwidth)) // Request 1 second worth
|
||||
if err != nil {
|
||||
t.Fatalf("WaitN failed: %v", err)
|
||||
}
|
||||
limitedDuration := time.Since(start)
|
||||
|
||||
if limitedDuration < 800*time.Millisecond {
|
||||
t.Errorf("Rate limiting not working, took only %v for 1 second worth of data", limitedDuration)
|
||||
}
|
||||
if limitedDuration > 1500*time.Millisecond {
|
||||
t.Errorf("Rate limiting too slow, took %v for 1 second worth of data", limitedDuration)
|
||||
}
|
||||
}
|
||||
|
||||
type mockConn struct {
|
||||
readBuf []byte
|
||||
readPos int
|
||||
writeBuf []byte
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockConn(data []byte) *mockConn {
|
||||
return &mockConn{readBuf: data}
|
||||
}
|
||||
|
||||
func (c *mockConn) Read(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.readPos >= len(c.readBuf) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n = copy(b, c.readBuf[c.readPos:])
|
||||
c.readPos += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Write(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.writeBuf = append(c.writeBuf, b...)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Close() error { return nil }
|
||||
func (c *mockConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *mockConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *mockConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (c *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func TestLimitedConnRead(t *testing.T) {
|
||||
dataSize := 20 * 1024
|
||||
testData := make([]byte, dataSize)
|
||||
for i := range testData {
|
||||
testData[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
// 10KB/s limit, 20KB burst
|
||||
bandwidth := int64(10 * 1024)
|
||||
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: int(bandwidth * 2)})
|
||||
|
||||
conn := newMockConn(testData)
|
||||
ctx := context.Background()
|
||||
limitedConn := NewLimitedConn(ctx, conn, limiter)
|
||||
|
||||
buf := make([]byte, dataSize)
|
||||
start := time.Now()
|
||||
|
||||
totalRead := 0
|
||||
for totalRead < dataSize {
|
||||
n, err := limitedConn.Read(buf[totalRead:])
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
totalRead += n
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
if totalRead != dataSize {
|
||||
t.Errorf("Read %d bytes, want %d", totalRead, dataSize)
|
||||
}
|
||||
|
||||
t.Logf("Read %d bytes in %v", totalRead, duration)
|
||||
}
|
||||
|
||||
func TestLimitedConnWrite(t *testing.T) {
|
||||
bandwidth := int64(10 * 1024)
|
||||
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: int(bandwidth * 2)})
|
||||
|
||||
conn := newMockConn(nil)
|
||||
ctx := context.Background()
|
||||
limitedConn := NewLimitedConn(ctx, conn, limiter)
|
||||
|
||||
dataSize := 30 * 1024
|
||||
testData := make([]byte, dataSize)
|
||||
for i := range testData {
|
||||
testData[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
chunkSize := 10 * 1024
|
||||
for i := 0; i < dataSize; i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > dataSize {
|
||||
end = dataSize
|
||||
}
|
||||
n, err := limitedConn.Write(testData[i:end])
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
if n != end-i {
|
||||
t.Errorf("Write returned %d, want %d", n, end-i)
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
// 30KB data, 10KB/s rate, 20KB burst → ~1s for remaining 10KB
|
||||
if duration < 800*time.Millisecond {
|
||||
t.Errorf("Write too fast, took %v for 30KB with 10KB/s limit and 20KB burst", duration)
|
||||
}
|
||||
|
||||
t.Logf("Wrote %d bytes in %v", dataSize, duration)
|
||||
}
|
||||
|
||||
func TestLimitedConnUnlimited(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 0})
|
||||
|
||||
testData := make([]byte, 100*1024)
|
||||
conn := newMockConn(testData)
|
||||
ctx := context.Background()
|
||||
limitedConn := NewLimitedConn(ctx, conn, limiter)
|
||||
|
||||
buf := make([]byte, len(testData))
|
||||
start := time.Now()
|
||||
|
||||
totalRead := 0
|
||||
for totalRead < len(testData) {
|
||||
n, err := limitedConn.Read(buf[totalRead:])
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
totalRead += n
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
if totalRead != len(testData) {
|
||||
t.Errorf("Read %d bytes, want %d", totalRead, len(testData))
|
||||
}
|
||||
|
||||
if duration > 100*time.Millisecond {
|
||||
t.Errorf("Unlimited read took too long: %v", duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitedConnContextCancellation(t *testing.T) {
|
||||
bandwidth := int64(100)
|
||||
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: 100})
|
||||
|
||||
conn := newMockConn(nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
limitedConn := NewLimitedConn(ctx, conn, limiter)
|
||||
|
||||
_, err := limitedConn.Write(make([]byte, 100))
|
||||
if err != nil {
|
||||
t.Fatalf("First write failed: %v", err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
_, err = limitedConn.Write(make([]byte, 1000))
|
||||
if err == nil {
|
||||
t.Error("Write should fail after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBurstMultiplier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
multiplier float64
|
||||
wantBurst int
|
||||
}{
|
||||
{
|
||||
name: "2x multiplier",
|
||||
bandwidth: 1024 * 1024, // 1MB/s
|
||||
multiplier: 2.0,
|
||||
wantBurst: 2 * 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "2.5x multiplier",
|
||||
bandwidth: 1024 * 1024,
|
||||
multiplier: 2.5,
|
||||
wantBurst: int(float64(1024*1024) * 2.5),
|
||||
},
|
||||
{
|
||||
name: "1x multiplier (no extra burst)",
|
||||
bandwidth: 1024 * 1024,
|
||||
multiplier: 1.0,
|
||||
wantBurst: 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "3x multiplier",
|
||||
bandwidth: 500 * 1024, // 500KB/s
|
||||
multiplier: 3.0,
|
||||
wantBurst: 3 * 500 * 1024,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
burst := int(float64(tt.bandwidth) * tt.multiplier)
|
||||
l := NewLimiter(Config{Bandwidth: tt.bandwidth, Burst: burst})
|
||||
|
||||
if l.RateLimiter().Burst() != tt.wantBurst {
|
||||
t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TunnelConfig holds configuration for a predefined tunnel
|
||||
type TunnelConfig struct {
|
||||
Name string `yaml:"name"` // Tunnel name (required, unique identifier)
|
||||
Type string `yaml:"type"` // Tunnel type: http, https, tcp (required)
|
||||
@@ -21,9 +20,9 @@ type TunnelConfig struct {
|
||||
AllowIPs []string `yaml:"allow_ips,omitempty"` // Allowed IPs/CIDRs
|
||||
DenyIPs []string `yaml:"deny_ips,omitempty"` // Denied IPs/CIDRs
|
||||
Auth string `yaml:"auth,omitempty"` // Proxy authentication password (http/https only)
|
||||
Bandwidth string `yaml:"bandwidth,omitempty"` // Bandwidth limit (e.g., 1M, 500K, 1G)
|
||||
}
|
||||
|
||||
// Validate checks if the tunnel configuration is valid
|
||||
func (t *TunnelConfig) Validate() error {
|
||||
if t.Name == "" {
|
||||
return fmt.Errorf("tunnel name is required")
|
||||
@@ -47,7 +46,6 @@ func (t *TunnelConfig) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClientConfig represents the client configuration
|
||||
type ClientConfig struct {
|
||||
Server string `yaml:"server"` // Server address (e.g., tunnel.example.com:443)
|
||||
Token string `yaml:"token"` // Authentication token
|
||||
@@ -55,7 +53,6 @@ type ClientConfig struct {
|
||||
Tunnels []*TunnelConfig `yaml:"tunnels,omitempty"` // Predefined tunnels
|
||||
}
|
||||
|
||||
// Validate checks if the client configuration is valid
|
||||
func (c *ClientConfig) Validate() error {
|
||||
if c.Server == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
@@ -92,7 +89,6 @@ func (c *ClientConfig) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTunnel returns a tunnel by name
|
||||
func (c *ClientConfig) GetTunnel(name string) *TunnelConfig {
|
||||
for _, t := range c.Tunnels {
|
||||
if t.Name == name {
|
||||
@@ -102,7 +98,6 @@ func (c *ClientConfig) GetTunnel(name string) *TunnelConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTunnelNames returns all tunnel names
|
||||
func (c *ClientConfig) GetTunnelNames() []string {
|
||||
names := make([]string, len(c.Tunnels))
|
||||
for i, t := range c.Tunnels {
|
||||
@@ -111,7 +106,6 @@ func (c *ClientConfig) GetTunnelNames() []string {
|
||||
return names
|
||||
}
|
||||
|
||||
// DefaultClientConfig returns the default configuration path
|
||||
func DefaultClientConfigPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
@@ -120,7 +114,6 @@ func DefaultClientConfigPath() string {
|
||||
return filepath.Join(home, ".drip", "config.yaml")
|
||||
}
|
||||
|
||||
// LoadClientConfig loads configuration from file
|
||||
func LoadClientConfig(path string) (*ClientConfig, error) {
|
||||
if path == "" {
|
||||
path = DefaultClientConfigPath()
|
||||
@@ -146,7 +139,6 @@ func LoadClientConfig(path string) (*ClientConfig, error) {
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// SaveClientConfig saves configuration to file
|
||||
func SaveClientConfig(config *ClientConfig, path string) error {
|
||||
if path == "" {
|
||||
path = DefaultClientConfigPath()
|
||||
@@ -162,7 +154,6 @@ func SaveClientConfig(config *ClientConfig, path string) error {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Write to file with secure permissions
|
||||
if err := os.WriteFile(path, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write config file: %w", err)
|
||||
}
|
||||
@@ -170,7 +161,6 @@ func SaveClientConfig(config *ClientConfig, path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigExists checks if config file exists
|
||||
func ConfigExists(path string) bool {
|
||||
if path == "" {
|
||||
path = DefaultClientConfigPath()
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ServerConfig holds the server configuration
|
||||
type ServerConfig struct {
|
||||
Port int `yaml:"port"`
|
||||
PublicPort int `yaml:"public_port"` // Port to display in URLs (for reverse proxy scenarios)
|
||||
@@ -37,13 +36,12 @@ type ServerConfig struct {
|
||||
PprofPort int `yaml:"pprof_port"`
|
||||
|
||||
// Allowed transports: "tcp", "wss", or "tcp,wss" (default: "tcp,wss")
|
||||
AllowedTransports []string `yaml:"transports"`
|
||||
|
||||
// Allowed tunnel types: "http", "https", "tcp" (default: all)
|
||||
AllowedTransports []string `yaml:"transports"`
|
||||
AllowedTunnelTypes []string `yaml:"tunnel_types"`
|
||||
Bandwidth string `yaml:"bandwidth,omitempty"`
|
||||
BurstMultiplier float64 `yaml:"burst_multiplier,omitempty"`
|
||||
}
|
||||
|
||||
// Validate checks if the server configuration is valid
|
||||
func (c *ServerConfig) Validate() error {
|
||||
// Validate port
|
||||
if c.Port < 1 || c.Port > 65535 {
|
||||
@@ -92,7 +90,6 @@ func (c *ServerConfig) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadTLSConfig loads TLS configuration
|
||||
func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
|
||||
if !c.TLSEnabled {
|
||||
return nil, nil
|
||||
@@ -131,7 +128,6 @@ func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// GetClientTLSConfig returns TLS config for client connections
|
||||
func GetClientTLSConfig(serverName string) *tls.Config {
|
||||
return &tls.Config{
|
||||
ServerName: serverName,
|
||||
@@ -147,8 +143,6 @@ func GetClientTLSConfig(serverName string) *tls.Config {
|
||||
}
|
||||
}
|
||||
|
||||
// GetClientTLSConfigInsecure returns TLS config for client with InsecureSkipVerify
|
||||
// WARNING: Only use for testing!
|
||||
func GetClientTLSConfigInsecure() *tls.Config {
|
||||
return &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
@@ -164,7 +158,6 @@ func GetClientTLSConfigInsecure() *tls.Config {
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultServerConfigPath returns the default server configuration path
|
||||
func DefaultServerConfigPath() string {
|
||||
// Check /etc/drip/config.yaml first (system-wide)
|
||||
systemPath := "/etc/drip/config.yaml"
|
||||
@@ -180,7 +173,6 @@ func DefaultServerConfigPath() string {
|
||||
return filepath.Join(home, ".drip", "server.yaml")
|
||||
}
|
||||
|
||||
// LoadServerConfig loads server configuration from file
|
||||
func LoadServerConfig(path string) (*ServerConfig, error) {
|
||||
if path == "" {
|
||||
path = DefaultServerConfigPath()
|
||||
@@ -202,7 +194,6 @@ func LoadServerConfig(path string) (*ServerConfig, error) {
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// SaveServerConfig saves server configuration to file
|
||||
func SaveServerConfig(config *ServerConfig, path string) error {
|
||||
if path == "" {
|
||||
path = DefaultServerConfigPath()
|
||||
@@ -225,7 +216,6 @@ func SaveServerConfig(config *ServerConfig, path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServerConfigExists checks if server config file exists
|
||||
func ServerConfigExists(path string) bool {
|
||||
if path == "" {
|
||||
path = DefaultServerConfigPath()
|
||||
|
||||
158
pkg/config/config_test.go
Normal file
158
pkg/config/config_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestServerConfigBandwidth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
yaml string
|
||||
wantBandwidth string
|
||||
wantMultiplier float64
|
||||
}{
|
||||
{
|
||||
name: "bandwidth 1M with 2.5x burst",
|
||||
yaml: `
|
||||
port: 8443
|
||||
domain: example.com
|
||||
tcp_port_min: 10000
|
||||
tcp_port_max: 20000
|
||||
bandwidth: 1M
|
||||
burst_multiplier: 2.5
|
||||
`,
|
||||
wantBandwidth: "1M",
|
||||
wantMultiplier: 2.5,
|
||||
},
|
||||
{
|
||||
name: "bandwidth 10M with default burst",
|
||||
yaml: `
|
||||
port: 8443
|
||||
domain: example.com
|
||||
tcp_port_min: 10000
|
||||
tcp_port_max: 20000
|
||||
bandwidth: 10M
|
||||
`,
|
||||
wantBandwidth: "10M",
|
||||
wantMultiplier: 0, // not set, will use default 2.0 in code
|
||||
},
|
||||
{
|
||||
name: "no bandwidth limit",
|
||||
yaml: `
|
||||
port: 8443
|
||||
domain: example.com
|
||||
tcp_port_min: 10000
|
||||
tcp_port_max: 20000
|
||||
`,
|
||||
wantBandwidth: "",
|
||||
wantMultiplier: 0,
|
||||
},
|
||||
{
|
||||
name: "bandwidth 500K with 3x burst",
|
||||
yaml: `
|
||||
port: 8443
|
||||
domain: example.com
|
||||
tcp_port_min: 10000
|
||||
tcp_port_max: 20000
|
||||
bandwidth: 500K
|
||||
burst_multiplier: 3.0
|
||||
`,
|
||||
wantBandwidth: "500K",
|
||||
wantMultiplier: 3.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
if err := os.WriteFile(configPath, []byte(tt.yaml), 0600); err != nil {
|
||||
t.Fatalf("Failed to write config file: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadServerConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadServerConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Bandwidth != tt.wantBandwidth {
|
||||
t.Errorf("Bandwidth = %q, want %q", cfg.Bandwidth, tt.wantBandwidth)
|
||||
}
|
||||
|
||||
if cfg.BurstMultiplier != tt.wantMultiplier {
|
||||
t.Errorf("BurstMultiplier = %v, want %v", cfg.BurstMultiplier, tt.wantMultiplier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBandwidth(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want int64
|
||||
}{
|
||||
{"", 0},
|
||||
{"0", 0},
|
||||
{"1024", 1024},
|
||||
{"1K", 1024},
|
||||
{"1KB", 1024},
|
||||
{"1k", 1024},
|
||||
{"1M", 1024 * 1024},
|
||||
{"1MB", 1024 * 1024},
|
||||
{"1m", 1024 * 1024},
|
||||
{"10M", 10 * 1024 * 1024},
|
||||
{"1G", 1024 * 1024 * 1024},
|
||||
{"1GB", 1024 * 1024 * 1024},
|
||||
{"500K", 500 * 1024},
|
||||
{"100M", 100 * 1024 * 1024},
|
||||
{" 1M ", 1024 * 1024}, // with spaces
|
||||
{"invalid", 0},
|
||||
{"-1M", 0}, // negative
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := parseBandwidthString(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("parseBandwidthString(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func parseBandwidthString(s string) int64 {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
s = strings.TrimSpace(strings.ToUpper(s))
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
var multiplier int64 = 1
|
||||
switch {
|
||||
case strings.HasSuffix(s, "GB") || strings.HasSuffix(s, "G"):
|
||||
multiplier = 1024 * 1024 * 1024
|
||||
s = strings.TrimSuffix(strings.TrimSuffix(s, "GB"), "G")
|
||||
case strings.HasSuffix(s, "MB") || strings.HasSuffix(s, "M"):
|
||||
multiplier = 1024 * 1024
|
||||
s = strings.TrimSuffix(strings.TrimSuffix(s, "MB"), "M")
|
||||
case strings.HasSuffix(s, "KB") || strings.HasSuffix(s, "K"):
|
||||
multiplier = 1024
|
||||
s = strings.TrimSuffix(strings.TrimSuffix(s, "KB"), "K")
|
||||
case strings.HasSuffix(s, "B"):
|
||||
s = strings.TrimSuffix(s, "B")
|
||||
}
|
||||
|
||||
val, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil || val < 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return val * multiplier
|
||||
}
|
||||
Reference in New Issue
Block a user