mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-26 22:31:35 +00:00
feat(cli): Add bandwidth limit function support
Added bandwidth limiting functionality, allowing users to limit the bandwidth of tunnel connections via the --bandwidth parameter. Supported formats include: 1K/1KB (kilobytes), 1M/1MB (megabytes), 1G/1GB (gigabytes) or Raw number (bytes).
This commit is contained in:
@@ -27,6 +27,7 @@ type RegisterRequest struct {
|
||||
PoolCapabilities *PoolCapabilities `json:"pool_capabilities,omitempty"`
|
||||
IPAccess *IPAccessControl `json:"ip_access,omitempty"`
|
||||
ProxyAuth *ProxyAuth `json:"proxy_auth,omitempty"`
|
||||
Bandwidth int64 `json:"bandwidth,omitempty"` // Bandwidth limit (bytes/sec), 0 = unlimited
|
||||
}
|
||||
|
||||
type RegisterResponse struct {
|
||||
@@ -37,6 +38,7 @@ type RegisterResponse struct {
|
||||
TunnelID string `json:"tunnel_id,omitempty"`
|
||||
SupportsDataConn bool `json:"supports_data_conn,omitempty"`
|
||||
RecommendedConns int `json:"recommended_conns,omitempty"`
|
||||
Bandwidth int64 `json:"bandwidth,omitempty"` // Applied bandwidth limit (bytes/sec)
|
||||
}
|
||||
|
||||
type DataConnectRequest struct {
|
||||
|
||||
112
internal/shared/qos/conn.go
Normal file
112
internal/shared/qos/conn.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type LimitedConn struct {
|
||||
net.Conn
|
||||
limiter *Limiter
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewLimitedConn(ctx context.Context, conn net.Conn, limiter *Limiter) *LimitedConn {
|
||||
return &LimitedConn{
|
||||
Conn: conn,
|
||||
limiter: limiter,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LimitedConn) Read(b []byte) (n int, err error) {
|
||||
if c.limiter == nil || !c.limiter.IsLimited() {
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
burst := c.limiter.RateLimiter().Burst()
|
||||
if len(b) > burst {
|
||||
b = b[:burst]
|
||||
}
|
||||
|
||||
n, err = c.Conn.Read(b)
|
||||
if n > 0 {
|
||||
if waitErr := c.limiter.RateLimiter().WaitN(c.ctx, n); waitErr != nil {
|
||||
if err == nil {
|
||||
err = waitErr
|
||||
}
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *LimitedConn) Write(b []byte) (n int, err error) {
|
||||
if c.limiter == nil || !c.limiter.IsLimited() {
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
burst := c.limiter.RateLimiter().Burst()
|
||||
total := 0
|
||||
|
||||
for len(b) > 0 {
|
||||
chunk := min(len(b), burst)
|
||||
|
||||
if err := c.limiter.RateLimiter().WaitN(c.ctx, chunk); err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
nw, err := c.Conn.Write(b[:chunk])
|
||||
total += nw
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
b = b[chunk:]
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (c *LimitedConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
nr, er := r.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := c.Write(buf[:nr])
|
||||
n += int64(nw)
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er != io.EOF {
|
||||
err = er
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *LimitedConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
nr, er := c.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := w.Write(buf[:nr])
|
||||
n += int64(nw)
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er != io.EOF {
|
||||
err = er
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
484
internal/shared/qos/conn_test.go
Normal file
484
internal/shared/qos/conn_test.go
Normal file
@@ -0,0 +1,484 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type errorAfterConn struct {
|
||||
mockConn
|
||||
writeLimit int
|
||||
written int
|
||||
}
|
||||
|
||||
func (c *errorAfterConn) Write(b []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
remaining := c.writeLimit - c.written
|
||||
if remaining <= 0 {
|
||||
return 0, errors.New("write error")
|
||||
}
|
||||
if len(b) > remaining {
|
||||
b = b[:remaining]
|
||||
}
|
||||
c.writeBuf = append(c.writeBuf, b...)
|
||||
c.written += len(b)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func TestWriteLargerThanBurst(t *testing.T) {
|
||||
// 10KB/s, burst=1KB — write 5KB should be chunked into 5 pieces
|
||||
bandwidth := int64(10 * 1024)
|
||||
burst := 1024
|
||||
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
data := make([]byte, 5*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
n, err := lc.Write(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("Write returned %d, want %d", n, len(data))
|
||||
}
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
if !bytes.Equal(conn.writeBuf, data) {
|
||||
t.Error("Written data does not match input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteExactBurstSize(t *testing.T) {
|
||||
burst := 2048
|
||||
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
|
||||
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
data := make([]byte, burst)
|
||||
start := time.Now()
|
||||
n, err := lc.Write(data)
|
||||
dur := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
if n != burst {
|
||||
t.Errorf("Write returned %d, want %d", n, burst)
|
||||
}
|
||||
// Exact burst should be instant
|
||||
if dur > 100*time.Millisecond {
|
||||
t.Errorf("Exact burst write should be instant, took %v", dur)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteZeroLength(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
n, err := lc.Write(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Write(nil) failed: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("Write(nil) returned %d, want 0", n)
|
||||
}
|
||||
|
||||
n, err = lc.Write([]byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("Write([]) failed: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("Write([]) returned %d, want 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteContextCancelDuringChunking(t *testing.T) {
|
||||
// Very slow rate, small burst so second chunk must wait
|
||||
limiter := NewLimiter(Config{Bandwidth: 100, Burst: 100})
|
||||
conn := newMockConn(nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
lc := NewLimitedConn(ctx, conn, limiter)
|
||||
|
||||
// Use up burst
|
||||
_, err := lc.Write(make([]byte, 100))
|
||||
if err != nil {
|
||||
t.Fatalf("First write failed: %v", err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
// This write needs more tokens but context is cancelled
|
||||
_, err = lc.Write(make([]byte, 200))
|
||||
if err == nil {
|
||||
t.Error("Write should fail after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWritePartialErrorMidChunk(t *testing.T) {
|
||||
// Underlying conn fails after 500 bytes
|
||||
conn := &errorAfterConn{writeLimit: 500}
|
||||
limiter := NewLimiter(Config{Bandwidth: 100000, Burst: 1024})
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
data := make([]byte, 2048)
|
||||
n, err := lc.Write(data)
|
||||
if err == nil {
|
||||
t.Error("Expected write error")
|
||||
}
|
||||
if n < 500 {
|
||||
t.Errorf("Expected at least 500 bytes written, got %d", n)
|
||||
}
|
||||
if n > 1024 {
|
||||
t.Errorf("Expected at most 1024 bytes written (one chunk), got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCappedToBurst(t *testing.T) {
|
||||
burst := 512
|
||||
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
|
||||
|
||||
// Provide 4KB of data
|
||||
data := make([]byte, 4096)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
conn := newMockConn(data)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
// Request 4KB read, should get at most burst (512) bytes
|
||||
buf := make([]byte, 4096)
|
||||
n, err := lc.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
if n > burst {
|
||||
t.Errorf("Read returned %d bytes, should be capped at burst=%d", n, burst)
|
||||
}
|
||||
if !bytes.Equal(buf[:n], data[:n]) {
|
||||
t.Error("Read data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadSmallerThanBurst(t *testing.T) {
|
||||
burst := 4096
|
||||
limiter := NewLimiter(Config{Bandwidth: 10240, Burst: burst})
|
||||
|
||||
data := make([]byte, 100)
|
||||
conn := newMockConn(data)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
buf := make([]byte, 100)
|
||||
n, err := lc.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
if n != 100 {
|
||||
t.Errorf("Read returned %d, want 100", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadEOF(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024, Burst: 1024})
|
||||
conn := newMockConn([]byte{}) // empty
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
buf := make([]byte, 100)
|
||||
_, err := lc.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("Expected io.EOF, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadContextCancel(t *testing.T) {
|
||||
// Slow rate, small burst
|
||||
limiter := NewLimiter(Config{Bandwidth: 100, Burst: 100})
|
||||
data := make([]byte, 200)
|
||||
conn := newMockConn(data)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
lc := NewLimitedConn(ctx, conn, limiter)
|
||||
|
||||
// First read uses burst
|
||||
buf := make([]byte, 100)
|
||||
_, err := lc.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("First read failed: %v", err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
// Second read should fail on WaitN
|
||||
_, err = lc.Read(buf)
|
||||
if err == nil {
|
||||
t.Error("Read should fail after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFromInterface(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
var _ io.ReaderFrom = lc
|
||||
}
|
||||
|
||||
func TestWriteToInterface(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
data := make([]byte, 100)
|
||||
conn := newMockConn(data)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
var _ io.WriterTo = lc
|
||||
}
|
||||
|
||||
func TestReadFromBasic(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
src := bytes.NewReader(make([]byte, 50*1024))
|
||||
n, err := lc.ReadFrom(src)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrom failed: %v", err)
|
||||
}
|
||||
if n != 50*1024 {
|
||||
t.Errorf("ReadFrom transferred %d bytes, want %d", n, 50*1024)
|
||||
}
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
if len(conn.writeBuf) != 50*1024 {
|
||||
t.Errorf("Underlying conn received %d bytes, want %d", len(conn.writeBuf), 50*1024)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFromReaderError(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
errReader := &failingReader{failAfter: 100}
|
||||
n, err := lc.ReadFrom(errReader)
|
||||
if err == nil {
|
||||
t.Error("Expected error from ReadFrom")
|
||||
}
|
||||
if n != 100 {
|
||||
t.Errorf("ReadFrom transferred %d bytes before error, want 100", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteToBasic(t *testing.T) {
|
||||
data := make([]byte, 50*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
conn := newMockConn(data)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
var buf bytes.Buffer
|
||||
n, err := lc.WriteTo(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTo failed: %v", err)
|
||||
}
|
||||
if n != int64(len(data)) {
|
||||
t.Errorf("WriteTo transferred %d bytes, want %d", n, len(data))
|
||||
}
|
||||
if !bytes.Equal(buf.Bytes(), data) {
|
||||
t.Error("WriteTo data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteToWriterError(t *testing.T) {
|
||||
data := make([]byte, 1024)
|
||||
limiter := NewLimiter(Config{Bandwidth: 1024 * 1024, Burst: 1024 * 1024})
|
||||
conn := newMockConn(data)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
fw := &failingWriter{failAfter: 100}
|
||||
_, err := lc.WriteTo(fw)
|
||||
if err == nil {
|
||||
t.Error("Expected error from WriteTo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFromRateLimited(t *testing.T) {
|
||||
// 10KB/s, 10KB burst — 20KB transfer should take ~1s
|
||||
limiter := NewLimiter(Config{Bandwidth: 10 * 1024, Burst: 10 * 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
src := bytes.NewReader(make([]byte, 20*1024))
|
||||
start := time.Now()
|
||||
n, err := lc.ReadFrom(src)
|
||||
dur := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFrom failed: %v", err)
|
||||
}
|
||||
if n != 20*1024 {
|
||||
t.Errorf("ReadFrom transferred %d, want %d", n, 20*1024)
|
||||
}
|
||||
if dur < 800*time.Millisecond {
|
||||
t.Errorf("ReadFrom too fast: %v (expected ~1s for 20KB at 10KB/s with 10KB burst)", dur)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIoCopyUsesReadFrom verifies io.Copy goes through our ReadFrom,
|
||||
// not the underlying conn's optimized path.
|
||||
func TestIoCopyUsesReadFrom(t *testing.T) {
|
||||
// Use a small burst so we can detect if rate limiting is applied
|
||||
limiter := NewLimiter(Config{Bandwidth: 10 * 1024, Burst: 10 * 1024})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
src := bytes.NewReader(make([]byte, 20*1024))
|
||||
start := time.Now()
|
||||
n, err := io.Copy(lc, src)
|
||||
dur := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("io.Copy failed: %v", err)
|
||||
}
|
||||
if n != 20*1024 {
|
||||
t.Errorf("io.Copy transferred %d, want %d", n, 20*1024)
|
||||
}
|
||||
// If io.Copy bypassed our ReadFrom, it would be instant
|
||||
if dur < 800*time.Millisecond {
|
||||
t.Errorf("io.Copy too fast (%v), rate limiting may be bypassed", dur)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlimitedWrite(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 0})
|
||||
conn := newMockConn(nil)
|
||||
lc := NewLimitedConn(context.Background(), conn, limiter)
|
||||
|
||||
data := make([]byte, 1024*1024) // 1MB
|
||||
start := time.Now()
|
||||
n, err := lc.Write(data)
|
||||
dur := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("Write returned %d, want %d", n, len(data))
|
||||
}
|
||||
if dur > 50*time.Millisecond {
|
||||
t.Errorf("Unlimited write took too long: %v", dur)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilLimiter(t *testing.T) {
|
||||
conn := newMockConn([]byte("hello"))
|
||||
lc := NewLimitedConn(context.Background(), conn, nil)
|
||||
|
||||
buf := make([]byte, 10)
|
||||
n, err := lc.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read with nil limiter failed: %v", err)
|
||||
}
|
||||
if n != 5 {
|
||||
t.Errorf("Read returned %d, want 5", n)
|
||||
}
|
||||
|
||||
n, err = lc.Write([]byte("world"))
|
||||
if err != nil {
|
||||
t.Fatalf("Write with nil limiter failed: %v", err)
|
||||
}
|
||||
if n != 5 {
|
||||
t.Errorf("Write returned %d, want 5", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentReadWrite(t *testing.T) {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
limiter := NewLimiter(Config{Bandwidth: 100 * 1024, Burst: 100 * 1024})
|
||||
lc := NewLimitedConn(context.Background(), serverConn, limiter)
|
||||
|
||||
dataSize := 50 * 1024
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Writer
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
data := make([]byte, dataSize)
|
||||
for i := range data {
|
||||
data[i] = 0xAA
|
||||
}
|
||||
lc.Write(data)
|
||||
}()
|
||||
|
||||
// Reader on the other end
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, dataSize)
|
||||
total := 0
|
||||
for total < dataSize {
|
||||
n, err := clientConn.Read(buf[total:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
total += n
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type failingReader struct {
|
||||
failAfter int
|
||||
read int
|
||||
}
|
||||
|
||||
func (r *failingReader) Read(b []byte) (int, error) {
|
||||
remaining := r.failAfter - r.read
|
||||
if remaining <= 0 {
|
||||
return 0, errors.New("reader error")
|
||||
}
|
||||
n := len(b)
|
||||
if n > remaining {
|
||||
n = remaining
|
||||
}
|
||||
r.read += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
type failingWriter struct {
|
||||
failAfter int
|
||||
written int
|
||||
}
|
||||
|
||||
func (w *failingWriter) Write(b []byte) (int, error) {
|
||||
remaining := w.failAfter - w.written
|
||||
if remaining <= 0 {
|
||||
return 0, errors.New("writer error")
|
||||
}
|
||||
n := len(b)
|
||||
if n > remaining {
|
||||
w.written += remaining
|
||||
return remaining, errors.New("writer error")
|
||||
}
|
||||
w.written += n
|
||||
return n, nil
|
||||
}
|
||||
269
internal/shared/qos/integration_test.go
Normal file
269
internal/shared/qos/integration_test.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEndToEndBandwidthLimiting(t *testing.T) {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
bandwidth := int64(100 * 1024)
|
||||
burstMultiplier := 2.0
|
||||
burst := int(float64(bandwidth) * burstMultiplier)
|
||||
|
||||
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
ctx := context.Background()
|
||||
limitedServerConn := NewLimitedConn(ctx, serverConn, limiter)
|
||||
|
||||
dataSize := 500 * 1024
|
||||
testData := make([]byte, dataSize)
|
||||
for i := range testData {
|
||||
testData[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var writeErr, readErr error
|
||||
var writeDuration time.Duration
|
||||
receivedData := make([]byte, dataSize)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
start := time.Now()
|
||||
chunkSize := 32 * 1024
|
||||
for i := 0; i < dataSize; i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > dataSize {
|
||||
end = dataSize
|
||||
}
|
||||
_, err := limitedServerConn.Write(testData[i:end])
|
||||
if err != nil {
|
||||
writeErr = err
|
||||
return
|
||||
}
|
||||
}
|
||||
writeDuration = time.Since(start)
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
totalRead := 0
|
||||
for totalRead < dataSize {
|
||||
n, err := clientConn.Read(receivedData[totalRead:])
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
readErr = err
|
||||
}
|
||||
return
|
||||
}
|
||||
totalRead += n
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if writeErr != nil {
|
||||
t.Fatalf("Write error: %v", writeErr)
|
||||
}
|
||||
if readErr != nil {
|
||||
t.Fatalf("Read error: %v", readErr)
|
||||
}
|
||||
|
||||
for i := 0; i < dataSize; i++ {
|
||||
if receivedData[i] != testData[i] {
|
||||
t.Fatalf("Data mismatch at byte %d: got %d, want %d", i, receivedData[i], testData[i])
|
||||
}
|
||||
}
|
||||
|
||||
expectedMinDuration := 2500 * time.Millisecond
|
||||
expectedMaxDuration := 4000 * time.Millisecond
|
||||
|
||||
if writeDuration < expectedMinDuration {
|
||||
t.Errorf("Transfer too fast: %v (expected >= %v)", writeDuration, expectedMinDuration)
|
||||
}
|
||||
if writeDuration > expectedMaxDuration {
|
||||
t.Errorf("Transfer too slow: %v (expected <= %v)", writeDuration, expectedMaxDuration)
|
||||
}
|
||||
|
||||
t.Logf("Transferred %d bytes in %v (rate: %.2f KB/s)",
|
||||
dataSize, writeDuration, float64(dataSize)/writeDuration.Seconds()/1024)
|
||||
}
|
||||
|
||||
func TestBidirectionalBandwidthLimiting(t *testing.T) {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
bandwidth := int64(50 * 1024)
|
||||
burst := int(bandwidth * 2)
|
||||
|
||||
serverLimiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
clientLimiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
|
||||
ctx := context.Background()
|
||||
limitedServerConn := NewLimitedConn(ctx, serverConn, serverLimiter)
|
||||
limitedClientConn := NewLimitedConn(ctx, clientConn, clientLimiter)
|
||||
|
||||
dataSize := 200 * 1024
|
||||
serverData := make([]byte, dataSize)
|
||||
clientData := make([]byte, dataSize)
|
||||
for i := range serverData {
|
||||
serverData[i] = byte(i % 256)
|
||||
clientData[i] = byte((i + 128) % 256)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
receivedByClient := make([]byte, dataSize)
|
||||
receivedByServer := make([]byte, dataSize)
|
||||
|
||||
// Server writes to client
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
chunkSize := 16 * 1024
|
||||
for i := 0; i < dataSize; i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > dataSize {
|
||||
end = dataSize
|
||||
}
|
||||
limitedServerConn.Write(serverData[i:end])
|
||||
}
|
||||
}()
|
||||
|
||||
// Client writes to server
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
chunkSize := 16 * 1024
|
||||
for i := 0; i < dataSize; i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > dataSize {
|
||||
end = dataSize
|
||||
}
|
||||
limitedClientConn.Write(clientData[i:end])
|
||||
}
|
||||
}()
|
||||
|
||||
// Client reads from server
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
totalRead := 0
|
||||
for totalRead < dataSize {
|
||||
n, err := limitedClientConn.Read(receivedByClient[totalRead:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
totalRead += n
|
||||
}
|
||||
}()
|
||||
|
||||
// Server reads from client
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
totalRead := 0
|
||||
for totalRead < dataSize {
|
||||
n, err := limitedServerConn.Read(receivedByServer[totalRead:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
totalRead += n
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
for i := 0; i < dataSize; i++ {
|
||||
if receivedByClient[i] != serverData[i] {
|
||||
t.Fatalf("Client received wrong data at byte %d", i)
|
||||
}
|
||||
if receivedByServer[i] != clientData[i] {
|
||||
t.Fatalf("Server received wrong data at byte %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("Bidirectional transfer completed successfully")
|
||||
}
|
||||
|
||||
func TestBurstBehavior(t *testing.T) {
|
||||
bandwidth := int64(10 * 1024)
|
||||
burst := 50 * 1024
|
||||
|
||||
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
ctx := context.Background()
|
||||
|
||||
start := time.Now()
|
||||
err := limiter.RateLimiter().WaitN(ctx, burst)
|
||||
if err != nil {
|
||||
t.Fatalf("WaitN failed: %v", err)
|
||||
}
|
||||
burstDuration := time.Since(start)
|
||||
|
||||
if burstDuration > 100*time.Millisecond {
|
||||
t.Errorf("Burst should be instant, took %v", burstDuration)
|
||||
}
|
||||
|
||||
start = time.Now()
|
||||
err = limiter.RateLimiter().WaitN(ctx, 10*1024)
|
||||
if err != nil {
|
||||
t.Fatalf("WaitN failed: %v", err)
|
||||
}
|
||||
limitedDuration := time.Since(start)
|
||||
|
||||
if limitedDuration < 900*time.Millisecond || limitedDuration > 1200*time.Millisecond {
|
||||
t.Errorf("Rate limiting not working correctly, took %v (expected ~1s)", limitedDuration)
|
||||
}
|
||||
|
||||
t.Logf("Burst: %v, Rate-limited: %v", burstDuration, limitedDuration)
|
||||
}
|
||||
|
||||
func TestMultipleBurstMultipliers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
multiplier float64
|
||||
}{
|
||||
{"1x burst", 10 * 1024, 1.0},
|
||||
{"1.5x burst", 10 * 1024, 1.5},
|
||||
{"2x burst", 10 * 1024, 2.0},
|
||||
{"2.5x burst", 10 * 1024, 2.5},
|
||||
{"3x burst", 10 * 1024, 3.0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
burst := int(float64(tt.bandwidth) * tt.multiplier)
|
||||
limiter := NewLimiter(Config{Bandwidth: tt.bandwidth, Burst: burst})
|
||||
|
||||
if !limiter.IsLimited() {
|
||||
t.Error("Limiter should be limited")
|
||||
}
|
||||
|
||||
actualBurst := limiter.RateLimiter().Burst()
|
||||
if actualBurst != burst {
|
||||
t.Errorf("Burst = %d, want %d", actualBurst, burst)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
start := time.Now()
|
||||
err := limiter.RateLimiter().WaitN(ctx, burst)
|
||||
if err != nil {
|
||||
t.Fatalf("WaitN failed: %v", err)
|
||||
}
|
||||
duration := time.Since(start)
|
||||
|
||||
if duration > 50*time.Millisecond {
|
||||
t.Errorf("Burst should be instant, took %v", duration)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
34
internal/shared/qos/limiter.go
Normal file
34
internal/shared/qos/limiter.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Bandwidth int64
|
||||
Burst int
|
||||
}
|
||||
|
||||
type Limiter struct {
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func NewLimiter(cfg Config) *Limiter {
|
||||
l := &Limiter{}
|
||||
if cfg.Bandwidth > 0 {
|
||||
burst := cfg.Burst
|
||||
if burst <= 0 {
|
||||
burst = int(cfg.Bandwidth * 2)
|
||||
}
|
||||
l.limiter = rate.NewLimiter(rate.Limit(cfg.Bandwidth), burst)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Limiter) RateLimiter() *rate.Limiter {
|
||||
return l.limiter
|
||||
}
|
||||
|
||||
func (l *Limiter) IsLimited() bool {
|
||||
return l.limiter != nil
|
||||
}
|
||||
313
internal/shared/qos/limiter_test.go
Normal file
313
internal/shared/qos/limiter_test.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package qos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewLimiter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantLimit bool
|
||||
wantBurst int
|
||||
}{
|
||||
{
|
||||
name: "unlimited when bandwidth is 0",
|
||||
cfg: Config{Bandwidth: 0},
|
||||
wantLimit: false,
|
||||
},
|
||||
{
|
||||
name: "limited with default burst (2x)",
|
||||
cfg: Config{Bandwidth: 1024},
|
||||
wantLimit: true,
|
||||
wantBurst: 2048, // 2x bandwidth
|
||||
},
|
||||
{
|
||||
name: "limited with custom burst",
|
||||
cfg: Config{Bandwidth: 1024, Burst: 4096},
|
||||
wantLimit: true,
|
||||
wantBurst: 4096,
|
||||
},
|
||||
{
|
||||
name: "1MB/s with 2x burst",
|
||||
cfg: Config{Bandwidth: 1024 * 1024},
|
||||
wantLimit: true,
|
||||
wantBurst: 2 * 1024 * 1024,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l := NewLimiter(tt.cfg)
|
||||
if l.IsLimited() != tt.wantLimit {
|
||||
t.Errorf("IsLimited() = %v, want %v", l.IsLimited(), tt.wantLimit)
|
||||
}
|
||||
if tt.wantLimit {
|
||||
if l.RateLimiter() == nil {
|
||||
t.Error("RateLimiter() should not be nil when limited")
|
||||
}
|
||||
if l.RateLimiter().Burst() != tt.wantBurst {
|
||||
t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiterBandwidthEnforcement(t *testing.T) {
|
||||
bandwidth := int64(10 * 1024)
|
||||
burst := int(bandwidth * 2)
|
||||
|
||||
l := NewLimiter(Config{Bandwidth: bandwidth, Burst: burst})
|
||||
|
||||
if !l.IsLimited() {
|
||||
t.Fatal("Limiter should be limited")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
start := time.Now()
|
||||
err := l.RateLimiter().WaitN(ctx, burst)
|
||||
if err != nil {
|
||||
t.Fatalf("WaitN failed: %v", err)
|
||||
}
|
||||
burstDuration := time.Since(start)
|
||||
if burstDuration > 100*time.Millisecond {
|
||||
t.Errorf("Burst should be instant, took %v", burstDuration)
|
||||
}
|
||||
|
||||
start = time.Now()
|
||||
err = l.RateLimiter().WaitN(ctx, int(bandwidth)) // Request 1 second worth
|
||||
if err != nil {
|
||||
t.Fatalf("WaitN failed: %v", err)
|
||||
}
|
||||
limitedDuration := time.Since(start)
|
||||
|
||||
if limitedDuration < 800*time.Millisecond {
|
||||
t.Errorf("Rate limiting not working, took only %v for 1 second worth of data", limitedDuration)
|
||||
}
|
||||
if limitedDuration > 1500*time.Millisecond {
|
||||
t.Errorf("Rate limiting too slow, took %v for 1 second worth of data", limitedDuration)
|
||||
}
|
||||
}
|
||||
|
||||
type mockConn struct {
|
||||
readBuf []byte
|
||||
readPos int
|
||||
writeBuf []byte
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockConn(data []byte) *mockConn {
|
||||
return &mockConn{readBuf: data}
|
||||
}
|
||||
|
||||
func (c *mockConn) Read(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.readPos >= len(c.readBuf) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n = copy(b, c.readBuf[c.readPos:])
|
||||
c.readPos += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Write(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.writeBuf = append(c.writeBuf, b...)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Close() error { return nil }
|
||||
func (c *mockConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *mockConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *mockConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (c *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func TestLimitedConnRead(t *testing.T) {
|
||||
dataSize := 20 * 1024
|
||||
testData := make([]byte, dataSize)
|
||||
for i := range testData {
|
||||
testData[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
// 10KB/s limit, 20KB burst
|
||||
bandwidth := int64(10 * 1024)
|
||||
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: int(bandwidth * 2)})
|
||||
|
||||
conn := newMockConn(testData)
|
||||
ctx := context.Background()
|
||||
limitedConn := NewLimitedConn(ctx, conn, limiter)
|
||||
|
||||
buf := make([]byte, dataSize)
|
||||
start := time.Now()
|
||||
|
||||
totalRead := 0
|
||||
for totalRead < dataSize {
|
||||
n, err := limitedConn.Read(buf[totalRead:])
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
totalRead += n
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
if totalRead != dataSize {
|
||||
t.Errorf("Read %d bytes, want %d", totalRead, dataSize)
|
||||
}
|
||||
|
||||
t.Logf("Read %d bytes in %v", totalRead, duration)
|
||||
}
|
||||
|
||||
func TestLimitedConnWrite(t *testing.T) {
|
||||
bandwidth := int64(10 * 1024)
|
||||
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: int(bandwidth * 2)})
|
||||
|
||||
conn := newMockConn(nil)
|
||||
ctx := context.Background()
|
||||
limitedConn := NewLimitedConn(ctx, conn, limiter)
|
||||
|
||||
dataSize := 30 * 1024
|
||||
testData := make([]byte, dataSize)
|
||||
for i := range testData {
|
||||
testData[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
chunkSize := 10 * 1024
|
||||
for i := 0; i < dataSize; i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > dataSize {
|
||||
end = dataSize
|
||||
}
|
||||
n, err := limitedConn.Write(testData[i:end])
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
if n != end-i {
|
||||
t.Errorf("Write returned %d, want %d", n, end-i)
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
// 30KB data, 10KB/s rate, 20KB burst → ~1s for remaining 10KB
|
||||
if duration < 800*time.Millisecond {
|
||||
t.Errorf("Write too fast, took %v for 30KB with 10KB/s limit and 20KB burst", duration)
|
||||
}
|
||||
|
||||
t.Logf("Wrote %d bytes in %v", dataSize, duration)
|
||||
}
|
||||
|
||||
func TestLimitedConnUnlimited(t *testing.T) {
|
||||
limiter := NewLimiter(Config{Bandwidth: 0})
|
||||
|
||||
testData := make([]byte, 100*1024)
|
||||
conn := newMockConn(testData)
|
||||
ctx := context.Background()
|
||||
limitedConn := NewLimitedConn(ctx, conn, limiter)
|
||||
|
||||
buf := make([]byte, len(testData))
|
||||
start := time.Now()
|
||||
|
||||
totalRead := 0
|
||||
for totalRead < len(testData) {
|
||||
n, err := limitedConn.Read(buf[totalRead:])
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
totalRead += n
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
if totalRead != len(testData) {
|
||||
t.Errorf("Read %d bytes, want %d", totalRead, len(testData))
|
||||
}
|
||||
|
||||
if duration > 100*time.Millisecond {
|
||||
t.Errorf("Unlimited read took too long: %v", duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitedConnContextCancellation(t *testing.T) {
|
||||
bandwidth := int64(100)
|
||||
limiter := NewLimiter(Config{Bandwidth: bandwidth, Burst: 100})
|
||||
|
||||
conn := newMockConn(nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
limitedConn := NewLimitedConn(ctx, conn, limiter)
|
||||
|
||||
_, err := limitedConn.Write(make([]byte, 100))
|
||||
if err != nil {
|
||||
t.Fatalf("First write failed: %v", err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
_, err = limitedConn.Write(make([]byte, 1000))
|
||||
if err == nil {
|
||||
t.Error("Write should fail after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBurstMultiplier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bandwidth int64
|
||||
multiplier float64
|
||||
wantBurst int
|
||||
}{
|
||||
{
|
||||
name: "2x multiplier",
|
||||
bandwidth: 1024 * 1024, // 1MB/s
|
||||
multiplier: 2.0,
|
||||
wantBurst: 2 * 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "2.5x multiplier",
|
||||
bandwidth: 1024 * 1024,
|
||||
multiplier: 2.5,
|
||||
wantBurst: int(float64(1024*1024) * 2.5),
|
||||
},
|
||||
{
|
||||
name: "1x multiplier (no extra burst)",
|
||||
bandwidth: 1024 * 1024,
|
||||
multiplier: 1.0,
|
||||
wantBurst: 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "3x multiplier",
|
||||
bandwidth: 500 * 1024, // 500KB/s
|
||||
multiplier: 3.0,
|
||||
wantBurst: 3 * 500 * 1024,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
burst := int(float64(tt.bandwidth) * tt.multiplier)
|
||||
l := NewLimiter(Config{Bandwidth: tt.bandwidth, Burst: burst})
|
||||
|
||||
if l.RateLimiter().Burst() != tt.wantBurst {
|
||||
t.Errorf("Burst() = %v, want %v", l.RateLimiter().Burst(), tt.wantBurst)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user