feat(cli): Add bandwidth limit function support

Added bandwidth limiting functionality, allowing users to limit the bandwidth of tunnel connections via the --bandwidth parameter.
Supported formats include: 1K/1KB (kilobytes), 1M/1MB (megabytes), 1G/1GB (gigabytes) or
Raw number (bytes).
This commit is contained in:
Gouryella
2026-02-14 14:20:21 +08:00
parent 3872bd9326
commit f90df37d7c
28 changed files with 2115 additions and 291 deletions

View File

@@ -27,6 +27,7 @@ type RegisterRequest struct {
PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"`
IPAccess *IPAccessControl `json:"ip_access,omitempty"`
ProxyAuth *ProxyAuth `json:"proxy_auth,omitempty"`
Bandwidth int64 `json:"bandwidth,omitempty"` // Bandwidth limit (bytes/sec), 0 = unlimited
}
type RegisterResponse struct {
@@ -37,6 +38,7 @@ type RegisterResponse struct {
TunnelID string `json:"tunnel_id,omitempty"`
SupportsDataConn bool `json:"supports_data_conn,omitempty"`
RecommendedConns int `json:"recommended_conns,omitempty"`
Bandwidth int64 `json:"bandwidth,omitempty"` // Applied bandwidth limit (bytes/sec)
}
type DataConnectRequest struct {

112
internal/shared/qos/conn.go Normal file
View File

@@ -0,0 +1,112 @@
package qos
import (
"context"
"io"
"net"
)
type LimitedConn struct {
net.Conn
limiter *Limiter
ctx context.Context
}
func NewLimitedConn(ctx context.Context, conn net.Conn, limiter *Limiter) *LimitedConn {
return &LimitedConn{
Conn: conn,
limiter: limiter,
ctx: ctx,
}
}
func (c *LimitedConn) Read(b []byte) (n int, err error) {
if c.limiter == nil || !c.limiter.IsLimited() {
return c.Conn.Read(b)
}
burst := c.limiter.RateLimiter().Burst()
if len(b) > burst {
b = b[:burst]
}
n, err = c.Conn.Read(b)
if n > 0 {
if waitErr := c.limiter.RateLimiter().WaitN(c.ctx, n); waitErr != nil {
if err == nil {
err = waitErr
}
}
}
return n, err
}
func (c *LimitedConn) Write(b []byte) (n int, err error) {
if c.limiter == nil || !c.limiter.IsLimited() {
return c.Conn.Write(b)
}
burst := c.limiter.RateLimiter().Burst()
total := 0
for len(b) > 0 {
chunk := min(len(b), burst)
if err := c.limiter.RateLimiter().WaitN(c.ctx, chunk); err != nil {
return total, err
}
nw, err := c.Conn.Write(b[:chunk])
total += nw
if err != nil {
return total, err
}
b = b[chunk:]
}
return total, nil
}
func (c *LimitedConn) ReadFrom(r io.Reader) (n int64, err error) {
buf := make([]byte, 32*1024)
for {
nr, er := r.Read(buf)
if nr > 0 {
nw, ew := c.Write(buf[:nr])
n += int64(nw)
if ew != nil {
err = ew
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return n, err
}
func (c *LimitedConn) WriteTo(w io.Writer) (n int64, err error) {
buf := make([]byte, 32*1024)
for {
nr, er := c.Read(buf)
if nr > 0 {
nw, ew := w.Write(buf[:nr])
n += int64(nw)
if ew != nil {
err = ew
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return n, err
}

View File

@@ -0,0 +1,484 @@
package qos
import (
"bytes"
"context"
"errors"
"io"
"net"
"sync"
"testing"
"time"
)
type errorAfterConn struct {
mockConn
writeLimit int
written int
}
func (c *errorAfterConn) Write(b []byte) (int, error) {
c.mu.Lock()
defer c.mu.Unlock()
remaining := c.writeLimit - c.written
if remaining <= 0 {
return 0, errors.New("write error")
}
if len(b) > remaining {
b = b[:remaining]
}
c.writeBuf = append(c.writeBuf, b...)
c.written += len(b)
return len(b), nil
}
func TestWriteLargerThanBurst(t *testing.T) {
// 10KB/s, burst=1KB — write 5KB should be chunked into 5 pieces
bandwidth := int64(10 * 1024)
burst := 1024
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
data := make([]byte, 5*1024)
for i := range data {
data[i] = byte(i % 256)
}
n, err := lc.Write(data)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != len(data) {
t.Errorf("Write returned %d, want %d", n, len(data))
}
conn.mu.Lock()
defer conn.mu.Unlock()
if !bytes.Equal(conn.writeBuf, data) {
t.Error("Written data does not match input")
}
}
func TestWriteExactBurstSize(t *testing.T) {
burst := 2048
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
data := make([]byte, burst)
start := time.Now()
n, err := lc.Write(data)
dur := time.Since(start)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != burst {
t.Errorf("Write returned %d, want %d", n, burst)
}
// Exact burst should be instant
if dur > 100*time.Millisecond {
t.Errorf("Exact burst write should be instant, took %v", dur)
}
}
func TestWriteZeroLength(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
n, err := lc.Write(nil)
if err != nil {
t.Fatalf("Write(nil) failed: %v", err)
}
if n != 0 {
t.Errorf("Write(nil) returned %d, want 0", n)
}
n, err = lc.Write([]byte{})
if err != nil {
t.Fatalf("Write([]) failed: %v", err)
}
if n != 0 {
t.Errorf("Write([]) returned %d, want 0", n)
}
}
func TestWriteContextCancelDuringChunking(t *testing.T) {
// Very slow rate, small burst so second chunk must wait
limiter := NewLimiter(Config{Bandwidth: 100, Burst: 100})
conn := newMockConn(nil)
ctx, cancel := context.WithCancel(context.Background())
lc := NewLimitedConn(ctx, conn, limiter)
// Use up burst
_, err := lc.Write(make([]byte, 100))
if err != nil {
t.Fatalf("First write failed: %v", err)
}
cancel()
// This write needs more tokens but context is cancelled
_, err = lc.Write(make([]byte, 200))
if err == nil {
t.Error("Write should fail after context cancellation")
}
}
func TestWritePartialErrorMidChunk(t *testing.T) {
// Underlying conn fails after 500 bytes
conn := &errorAfterConn{writeLimit: 500}
limiter := NewLimiter(Config{Bandwidth: 100000, Burst: 1024})
lc := NewLimitedConn(context.Background(), conn, limiter)
data := make([]byte, 2048)
n, err := lc.Write(data)
if err == nil {
t.Error("Expected write error")
}
if n < 500 {
t.Errorf("Expected at least 500 bytes written, got %d", n)
}
if n > 1024 {
t.Errorf("Expected at most 1024 bytes written (one chunk), got %d", n)
}
}
func TestReadCappedToBurst(t *testing.T) {
burst := 512
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
// Provide 4KB of data
data := make([]byte, 4096)
for i := range data {
data[i] = byte(i % 256)
}
conn := newMockConn(data)
lc := NewLimitedConn(context.Background(), conn, limiter)
// Request 4KB read, should get at most burst (512) bytes
buf := make([]byte, 4096)
n, err := lc.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n > burst {
t.Errorf("Read returned %d bytes, should be capped at burst=%d", n, burst)
}
if !bytes.Equal(buf[:n], data[:n]) {
t.Error("Read data mismatch")
}
}
func TestReadSmallerThanBurst(t *testing.T) {
burst := 4096
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
data := make([]byte, 100)
conn := newMockConn(data)
lc := NewLimitedConn(context.Background(), conn, limiter)
buf := make([]byte, 100)
n, err := lc.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n != 100 {
t.Errorf("Read returned %d, want 100", n)
}
}
func TestReadEOF(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024})
conn := newMockConn([]byte{}) // empty
lc := NewLimitedConn(context.Background(), conn, limiter)
buf := make([]byte, 100)
_, err := lc.Read(buf)
if err != io.EOF {
t.Errorf("Expected io.EOF, got %v", err)
}
}
func TestReadContextCancel(t *testing.T) {
// Slow rate, small burst
limiter := NewLimiter(Config{Bandwidth: 100, Burst: 100})
data := make([]byte, 200)
conn := newMockConn(data)
ctx, cancel := context.WithCancel(context.Background())
lc := NewLimitedConn(ctx, conn, limiter)
// First read uses burst
buf := make([]byte, 100)
_, err := lc.Read(buf)
if err != nil {
t.Fatalf("First read failed: %v", err)
}
cancel()
// Second read should fail on WaitN
_, err = lc.Read(buf)
if err == nil {
t.Error("Read should fail after context cancellation")
}
}
func TestReadFromInterface(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
var _ io.ReaderFrom = lc
}
func TestWriteToInterface(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
data := make([]byte, 100)
conn := newMockConn(data)
lc := NewLimitedConn(context.Background(), conn, limiter)
var _ io.WriterTo = lc
}
func TestReadFromBasic(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
src := bytes.NewReader(make([]byte, 50*1024))
n, err := lc.ReadFrom(src)
if err != nil {
t.Fatalf("ReadFrom failed: %v", err)
}
if n != 50*1024 {
t.Errorf("ReadFrom transferred %d bytes, want %d", n, 50*1024)
}
conn.mu.Lock()
defer conn.mu.Unlock()
if len(conn.writeBuf) != 50*1024 {
t.Errorf("Underlying conn received %d bytes, want %d", len(conn.writeBuf), 50*1024)
}
}
func TestReadFromReaderError(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
errReader := &failingReader{failAfter: 100}
n, err := lc.ReadFrom(errReader)
if err == nil {
t.Error("Expected error from ReadFrom")
}
if n != 100 {
t.Errorf("ReadFrom transferred %d bytes before error, want 100", n)
}
}
func TestWriteToBasic(t *testing.T) {
data := make([]byte, 50*1024)
for i := range data {
data[i] = byte(i % 256)
}
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
conn := newMockConn(data)
lc := NewLimitedConn(context.Background(), conn, limiter)
var buf bytes.Buffer
n, err := lc.WriteTo(&buf)
if err != nil {
t.Fatalf("WriteTo failed: %v", err)
}
if n != int64(len(data)) {
t.Errorf("WriteTo transferred %d bytes, want %d", n, len(data))
}
if !bytes.Equal(buf.Bytes(), data) {
t.Error("WriteTo data mismatch")
}
}
func TestWriteToWriterError(t *testing.T) {
data := make([]byte, 1024)
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
conn := newMockConn(data)
lc := NewLimitedConn(context.Background(), conn, limiter)
fw := &failingWriter{failAfter: 100}
_, err := lc.WriteTo(fw)
if err == nil {
t.Error("Expected error from WriteTo")
}
}
func TestReadFromRateLimited(t *testing.T) {
// 10KB/s, 10KB burst — 20KB transfer should take ~1s
limiter := NewLimiter(Config{Bandwidth: 10 * 1024, Burst: 10 * 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
src := bytes.NewReader(make([]byte, 20*1024))
start := time.Now()
n, err := lc.ReadFrom(src)
dur := time.Since(start)
if err != nil {
t.Fatalf("ReadFrom failed: %v", err)
}
if n != 20*1024 {
t.Errorf("ReadFrom transferred %d, want %d", n, 20*1024)
}
if dur < 800*time.Millisecond {
t.Errorf("ReadFrom too fast: %v (expected ~1s for 20KB at 10KB/s with 10KB burst)", dur)
}
}
// TestIoCopyUsesReadFrom verifies io.Copy goes through our ReadFrom,
// not the underlying conn's optimized path.
func TestIoCopyUsesReadFrom(t *testing.T) {
// Use a small burst so we can detect if rate limiting is applied
limiter := NewLimiter(Config{Bandwidth: 10 * 1024, Burst: 10 * 1024})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
src := bytes.NewReader(make([]byte, 20*1024))
start := time.Now()
n, err := io.Copy(lc, src)
dur := time.Since(start)
if err != nil {
t.Fatalf("io.Copy failed: %v", err)
}
if n != 20*1024 {
t.Errorf("io.Copy transferred %d, want %d", n, 20*1024)
}
// If io.Copy bypassed our ReadFrom, it would be instant
if dur < 800*time.Millisecond {
t.Errorf("io.Copy too fast (%v), rate limiting may be bypassed", dur)
}
}
func TestUnlimitedWrite(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 0})
conn := newMockConn(nil)
lc := NewLimitedConn(context.Background(), conn, limiter)
data := make([]byte, 1024*1024) // 1MB
start := time.Now()
n, err := lc.Write(data)
dur := time.Since(start)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != len(data) {
t.Errorf("Write returned %d, want %d", n, len(data))
}
if dur > 50*time.Millisecond {
t.Errorf("Unlimited write took too long: %v", dur)
}
}
func TestNilLimiter(t *testing.T) {
conn := newMockConn([]byte("hello"))
lc := NewLimitedConn(context.Background(), conn, nil)
buf := make([]byte, 10)
n, err := lc.Read(buf)
if err != nil {
t.Fatalf("Read with nil limiter failed: %v", err)
}
if n != 5 {
t.Errorf("Read returned %d, want 5", n)
}
n, err = lc.Write([]byte("world"))
if err != nil {
t.Fatalf("Write with nil limiter failed: %v", err)
}
if n != 5 {
t.Errorf("Write returned %d, want 5", n)
}
}
func TestConcurrentReadWrite(t *testing.T) {
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
limiter := NewLimiter(Config{Bandwidth: 100 * 1024, Burst: 100 * 1024})
lc := NewLimitedConn(context.Background(), serverConn, limiter)
dataSize := 50 * 1024
var wg sync.WaitGroup
// Writer
wg.Add(1)
go func() {
defer wg.Done()
data := make([]byte, dataSize)
for i := range data {
data[i] = 0xAA
}
lc.Write(data)
}()
// Reader on the other end
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, dataSize)
total := 0
for total < dataSize {
n, err := clientConn.Read(buf[total:])
if err != nil {
return
}
total += n
}
}()
wg.Wait()
}
type failingReader struct {
failAfter int
read int
}
func (r *failingReader) Read(b []byte) (int, error) {
remaining := r.failAfter - r.read
if remaining <= 0 {
return 0, errors.New("reader error")
}
n := len(b)
if n > remaining {
n = remaining
}
r.read += n
return n, nil
}
type failingWriter struct {
failAfter int
written int
}
func (w *failingWriter) Write(b []byte) (int, error) {
remaining := w.failAfter - w.written
if remaining <= 0 {
return 0, errors.New("writer error")
}
n := len(b)
if n > remaining {
w.written += remaining
return remaining, errors.New("writer error")
}
w.written += n
return n, nil
}

View File

@@ -0,0 +1,269 @@
package qos
import (
"context"
"io"
"net"
"sync"
"testing"
"time"
)
func TestEndToEndBandwidthLimiting(t *testing.T) {
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
bandwidth := int64(100 * 1024)
burstMultiplier := 2.0
burst := int(float64(bandwidth) * burstMultiplier)
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
ctx := context.Background()
limitedServerConn := NewLimitedConn(ctx, serverConn, limiter)
dataSize := 500 * 1024
testData := make([]byte, dataSize)
for i := range testData {
testData[i] = byte(i % 256)
}
var wg sync.WaitGroup
var writeErr, readErr error
var writeDuration time.Duration
receivedData := make([]byte, dataSize)
wg.Add(1)
go func() {
defer wg.Done()
start := time.Now()
chunkSize := 32 * 1024
for i := 0; i < dataSize; i += chunkSize {
end := i + chunkSize
if end > dataSize {
end = dataSize
}
_, err := limitedServerConn.Write(testData[i:end])
if err != nil {
writeErr = err
return
}
}
writeDuration = time.Since(start)
}()
wg.Add(1)
go func() {
defer wg.Done()
totalRead := 0
for totalRead < dataSize {
n, err := clientConn.Read(receivedData[totalRead:])
if err != nil {
if err != io.EOF {
readErr = err
}
return
}
totalRead += n
}
}()
wg.Wait()
if writeErr != nil {
t.Fatalf("Write error: %v", writeErr)
}
if readErr != nil {
t.Fatalf("Read error: %v", readErr)
}
for i := 0; i < dataSize; i++ {
if receivedData[i] != testData[i] {
t.Fatalf("Data mismatch at byte %d: got %d, want %d", i, receivedData[i], testData[i])
}
}
expectedMinDuration := 2500 * time.Millisecond
expectedMaxDuration := 4000 * time.Millisecond
if writeDuration < expectedMinDuration {
t.Errorf("Transfer too fast: %v (expected >= %v)", writeDuration, expectedMinDuration)
}
if writeDuration > expectedMaxDuration {
t.Errorf("Transfer too slow: %v (expected <= %v)", writeDuration, expectedMaxDuration)
}
t.Logf("Transferred %d bytes in %v (rate: %.2f KB/s)",
dataSize, writeDuration, float64(dataSize)/writeDuration.Seconds()/1024)
}
func TestBidirectionalBandwidthLimiting(t *testing.T) {
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
bandwidth := int64(50 * 1024)
burst := int(bandwidth * 2)
serverLimiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
clientLimiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
ctx := context.Background()
limitedServerConn := NewLimitedConn(ctx, serverConn, serverLimiter)
limitedClientConn := NewLimitedConn(ctx, clientConn, clientLimiter)
dataSize := 200 * 1024
serverData := make([]byte, dataSize)
clientData := make([]byte, dataSize)
for i := range serverData {
serverData[i] = byte(i % 256)
clientData[i] = byte((i + 128) % 256)
}
var wg sync.WaitGroup
receivedByClient := make([]byte, dataSize)
receivedByServer := make([]byte, dataSize)
// Server writes to client
wg.Add(1)
go func() {
defer wg.Done()
chunkSize := 16 * 1024
for i := 0; i < dataSize; i += chunkSize {
end := i + chunkSize
if end > dataSize {
end = dataSize
}
limitedServerConn.Write(serverData[i:end])
}
}()
// Client writes to server
wg.Add(1)
go func() {
defer wg.Done()
chunkSize := 16 * 1024
for i := 0; i < dataSize; i += chunkSize {
end := i + chunkSize
if end > dataSize {
end = dataSize
}
limitedClientConn.Write(clientData[i:end])
}
}()
// Client reads from server
wg.Add(1)
go func() {
defer wg.Done()
totalRead := 0
for totalRead < dataSize {
n, err := limitedClientConn.Read(receivedByClient[totalRead:])
if err != nil {
return
}
totalRead += n
}
}()
// Server reads from client
wg.Add(1)
go func() {
defer wg.Done()
totalRead := 0
for totalRead < dataSize {
n, err := limitedServerConn.Read(receivedByServer[totalRead:])
if err != nil {
return
}
totalRead += n
}
}()
wg.Wait()
for i := 0; i < dataSize; i++ {
if receivedByClient[i] != serverData[i] {
t.Fatalf("Client received wrong data at byte %d", i)
}
if receivedByServer[i] != clientData[i] {
t.Fatalf("Server received wrong data at byte %d", i)
}
}
t.Log("Bidirectional transfer completed successfully")
}
func TestBurstBehavior(t *testing.T) {
bandwidth := int64(10 * 1024)
burst := 50 * 1024
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
ctx := context.Background()
start := time.Now()
err := limiter.RateLimiter().WaitN(ctx, burst)
if err != nil {
t.Fatalf("WaitN failed: %v", err)
}
burstDuration := time.Since(start)
if burstDuration > 100*time.Millisecond {
t.Errorf("Burst should be instant, took %v", burstDuration)
}
start = time.Now()
err = limiter.RateLimiter().WaitN(ctx, 10*1024)
if err != nil {
t.Fatalf("WaitN failed: %v", err)
}
limitedDuration := time.Since(start)
if limitedDuration < 900*time.Millisecond || limitedDuration > 1200*time.Millisecond {
t.Errorf("Rate limiting not working correctly, took %v (expected ~1s)", limitedDuration)
}
t.Logf("Burst: %v, Rate-limited: %v", burstDuration, limitedDuration)
}
func TestMultipleBurstMultipliers(t *testing.T) {
tests := []struct {
name string
bandwidth int64
multiplier float64
}{
{"1x burst", 10 * 1024, 1.0},
{"1.5x burst", 10 * 1024, 1.5},
{"2x burst", 10 * 1024, 2.0},
{"2.5x burst", 10 * 1024, 2.5},
{"3x burst", 10 * 1024, 3.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
burst := int(float64(tt.bandwidth) * tt.multiplier)
limiter := NewLimiter(Config{Bandwidth: tt.bandwidth, Burst: burst})
if !limiter.IsLimited() {
t.Error("Limiter should be limited")
}
actualBurst := limiter.RateLimiter().Burst()
if actualBurst != burst {
t.Errorf("Burst = %d, want %d", actualBurst, burst)
}
ctx := context.Background()
start := time.Now()
err := limiter.RateLimiter().WaitN(ctx, burst)
if err != nil {
t.Fatalf("WaitN failed: %v", err)
}
duration := time.Since(start)
if duration > 50*time.Millisecond {
t.Errorf("Burst should be instant, took %v", duration)
}
})
}
}

View File

@@ -0,0 +1,34 @@
package qos
import (
"golang.org/x/time/rate"
)
type Config struct {
Bandwidth int64
Burst int
}
type Limiter struct {
limiter *rate.Limiter
}
func NewLimiter(cfg Config) *Limiter {
l := &Limiter{}
if cfg.Bandwidth > 0 {
burst := cfg.Burst
if burst <= 0 {
burst = int(cfg.Bandwidth * 2)
}
l.limiter = rate.NewLimiter(rate.Limit(cfg.Bandwidth), burst)
}
return l
}
func (l *Limiter) RateLimiter() *rate.Limiter {
return l.limiter
}
func (l *Limiter) IsLimited() bool {
return l.limiter != nil
}

View File

@@ -0,0 +1,313 @@
package qos
import (
"context"
"io"
"net"
"sync"
"testing"
"time"
)
func TestNewLimiter(t *testing.T) {
tests := []struct {
name string
cfg Config
wantLimit bool
wantBurst int
}{
{
name: "unlimited when bandwidth is 0",
cfg: Config{Bandwidth: 0},
wantLimit: false,
},
{
name: "limited with default burst (2x)",
cfg: Config{Bandwidth: 1024},
wantLimit: true,
wantBurst: 2048, // 2x bandwidth
},
{
name: "limited with custom burst",
cfg: Config{Bandwidth: 1024, Burst: 4096},
wantLimit: true,
wantBurst: 4096,
},
{
name: "1MB/s with 2x burst",
cfg: Config{Bandwidth: 1024 * 1024},
wantLimit: true,
wantBurst: 2 * 1024 * 1024,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := NewLimiter(tt.cfg)
if l.IsLimited() != tt.wantLimit {
t.Errorf("IsLimited() = %v, want %v", l.IsLimited(), tt.wantLimit)
}
if tt.wantLimit {
if l.RateLimiter() == nil {
t.Error("RateLimiter() should not be nil when limited")
}
if l.RateLimiter().Burst() != tt.wantBurst {
t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst)
}
}
})
}
}
func TestLimiterBandwidthEnforcement(t *testing.T) {
bandwidth := int64(10 * 1024)
burst := int(bandwidth * 2)
l := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
if !l.IsLimited() {
t.Fatal("Limiter should be limited")
}
ctx := context.Background()
start := time.Now()
err := l.RateLimiter().WaitN(ctx, burst)
if err != nil {
t.Fatalf("WaitN failed: %v", err)
}
burstDuration := time.Since(start)
if burstDuration > 100*time.Millisecond {
t.Errorf("Burst should be instant, took %v", burstDuration)
}
start = time.Now()
err = l.RateLimiter().WaitN(ctx, int(bandwidth)) // Request 1 second worth
if err != nil {
t.Fatalf("WaitN failed: %v", err)
}
limitedDuration := time.Since(start)
if limitedDuration < 800*time.Millisecond {
t.Errorf("Rate limiting not working, took only %v for 1 second worth of data", limitedDuration)
}
if limitedDuration > 1500*time.Millisecond {
t.Errorf("Rate limiting too slow, took %v for 1 second worth of data", limitedDuration)
}
}
type mockConn struct {
readBuf []byte
readPos int
writeBuf []byte
mu sync.Mutex
}
func newMockConn(data []byte) *mockConn {
return &mockConn{readBuf: data}
}
func (c *mockConn) Read(b []byte) (n int, err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.readPos >= len(c.readBuf) {
return 0, io.EOF
}
n = copy(b, c.readBuf[c.readPos:])
c.readPos += n
return n, nil
}
func (c *mockConn) Write(b []byte) (n int, err error) {
c.mu.Lock()
defer c.mu.Unlock()
c.writeBuf = append(c.writeBuf, b...)
return len(b), nil
}
func (c *mockConn) Close() error { return nil }
func (c *mockConn) LocalAddr() net.Addr { return nil }
func (c *mockConn) RemoteAddr() net.Addr { return nil }
func (c *mockConn) SetDeadline(t time.Time) error { return nil }
func (c *mockConn) SetReadDeadline(t time.Time) error { return nil }
func (c *mockConn) SetWriteDeadline(t time.Time) error { return nil }
func TestLimitedConnRead(t *testing.T) {
dataSize := 20 * 1024
testData := make([]byte, dataSize)
for i := range testData {
testData[i] = byte(i % 256)
}
// 10KB/s limit, 20KB burst
bandwidth := int64(10 * 1024)
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: int(bandwidth * 2)})
conn := newMockConn(testData)
ctx := context.Background()
limitedConn := NewLimitedConn(ctx, conn, limiter)
buf := make([]byte, dataSize)
start := time.Now()
totalRead := 0
for totalRead < dataSize {
n, err := limitedConn.Read(buf[totalRead:])
if err != nil && err != io.EOF {
t.Fatalf("Read failed: %v", err)
}
totalRead += n
if err == io.EOF {
break
}
}
duration := time.Since(start)
if totalRead != dataSize {
t.Errorf("Read %d bytes, want %d", totalRead, dataSize)
}
t.Logf("Read %d bytes in %v", totalRead, duration)
}
func TestLimitedConnWrite(t *testing.T) {
bandwidth := int64(10 * 1024)
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: int(bandwidth * 2)})
conn := newMockConn(nil)
ctx := context.Background()
limitedConn := NewLimitedConn(ctx, conn, limiter)
dataSize := 30 * 1024
testData := make([]byte, dataSize)
for i := range testData {
testData[i] = byte(i % 256)
}
start := time.Now()
chunkSize := 10 * 1024
for i := 0; i < dataSize; i += chunkSize {
end := i + chunkSize
if end > dataSize {
end = dataSize
}
n, err := limitedConn.Write(testData[i:end])
if err != nil {
t.Fatalf("Write failed: %v", err)
}
if n != end-i {
t.Errorf("Write returned %d, want %d", n, end-i)
}
}
duration := time.Since(start)
// 30KB data, 10KB/s rate, 20KB burst → ~1s for remaining 10KB
if duration < 800*time.Millisecond {
t.Errorf("Write too fast, took %v for 30KB with 10KB/s limit and 20KB burst", duration)
}
t.Logf("Wrote %d bytes in %v", dataSize, duration)
}
func TestLimitedConnUnlimited(t *testing.T) {
limiter := NewLimiter(Config{Bandwidth: 0})
testData := make([]byte, 100*1024)
conn := newMockConn(testData)
ctx := context.Background()
limitedConn := NewLimitedConn(ctx, conn, limiter)
buf := make([]byte, len(testData))
start := time.Now()
totalRead := 0
for totalRead < len(testData) {
n, err := limitedConn.Read(buf[totalRead:])
if err != nil && err != io.EOF {
t.Fatalf("Read failed: %v", err)
}
totalRead += n
if err == io.EOF {
break
}
}
duration := time.Since(start)
if totalRead != len(testData) {
t.Errorf("Read %d bytes, want %d", totalRead, len(testData))
}
if duration > 100*time.Millisecond {
t.Errorf("Unlimited read took too long: %v", duration)
}
}
func TestLimitedConnContextCancellation(t *testing.T) {
bandwidth := int64(100)
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: 100})
conn := newMockConn(nil)
ctx, cancel := context.WithCancel(context.Background())
limitedConn := NewLimitedConn(ctx, conn, limiter)
_, err := limitedConn.Write(make([]byte, 100))
if err != nil {
t.Fatalf("First write failed: %v", err)
}
cancel()
_, err = limitedConn.Write(make([]byte, 1000))
if err == nil {
t.Error("Write should fail after context cancellation")
}
}
func TestBurstMultiplier(t *testing.T) {
tests := []struct {
name string
bandwidth int64
multiplier float64
wantBurst int
}{
{
name: "2x multiplier",
bandwidth: 1024 * 1024, // 1MB/s
multiplier: 2.0,
wantBurst: 2 * 1024 * 1024,
},
{
name: "2.5x multiplier",
bandwidth: 1024 * 1024,
multiplier: 2.5,
wantBurst: int(float64(1024*1024) * 2.5),
},
{
name: "1x multiplier (no extra burst)",
bandwidth: 1024 * 1024,
multiplier: 1.0,
wantBurst: 1024 * 1024,
},
{
name: "3x multiplier",
bandwidth: 500 * 1024, // 500KB/s
multiplier: 3.0,
wantBurst: 3 * 500 * 1024,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
burst := int(float64(tt.bandwidth) * tt.multiplier)
l := NewLimiter(Config{Bandwidth: tt.bandwidth, Burst: burst})
if l.RateLimiter().Burst() != tt.wantBurst {
t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst)
}
})
}
}