mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
Merge pull request #6 from Gouryella/perf/improve-performance
perf: Improve performance and stability with buffer pooling, context management, and security enhancements
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
<p align="center">
|
||||
<img src="images/logo.png" alt="Drip Logo" width="200" />
|
||||
<img src="assets/logo.png" alt="Drip Logo" width="200" />
|
||||
</p>
|
||||
|
||||
<h1 align="center">Drip</h1>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<p align="center">
|
||||
<img src="images/logo.png" alt="Drip Logo" width="200" />
|
||||
<img src="assets/logo.png" alt="Drip Logo" width="200" />
|
||||
</p>
|
||||
|
||||
<h1 align="center">Drip</h1>
|
||||
|
||||
|
Before Width: | Height: | Size: 1.4 MiB After Width: | Height: | Size: 1.4 MiB |
@@ -3,6 +3,8 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
|
||||
"drip/internal/client/cli"
|
||||
)
|
||||
@@ -14,6 +16,9 @@ var (
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Performance optimizations
|
||||
setupPerformanceOptimizations()
|
||||
|
||||
cli.SetVersion(Version, GitCommit, BuildTime)
|
||||
|
||||
if err := cli.Execute(); err != nil {
|
||||
@@ -21,3 +26,19 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// setupPerformanceOptimizations configures runtime settings for optimal performance
|
||||
func setupPerformanceOptimizations() {
|
||||
// Set GOMAXPROCS to use all available CPU cores
|
||||
numCPU := runtime.NumCPU()
|
||||
runtime.GOMAXPROCS(numCPU)
|
||||
|
||||
// Reduce GC frequency for high-throughput scenarios
|
||||
// Default is 100, setting to 200 reduces GC overhead at cost of more memory
|
||||
// This is beneficial since we now use buffer pools (less garbage)
|
||||
debug.SetGCPercent(200)
|
||||
|
||||
// Set memory limit to prevent OOM (adjust based on your server)
|
||||
// This is a soft limit - Go will try to stay under this
|
||||
debug.SetMemoryLimit(8 * 1024 * 1024 * 1024) // 8GB limit
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ services:
|
||||
|
||||
volumes:
|
||||
- ./certs:/app/data/certs:ro
|
||||
- drip-data:/app/data
|
||||
- ./data:/app/data
|
||||
|
||||
environment:
|
||||
TZ: ${TZ:-UTC}
|
||||
@@ -63,10 +63,6 @@ services:
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
|
||||
volumes:
|
||||
drip-data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
drip-net:
|
||||
driver: bridge
|
||||
|
||||
@@ -117,9 +117,15 @@ func runConfigShow(cmd *cobra.Command, args []string) error {
|
||||
|
||||
var displayToken string
|
||||
if cfg.Token != "" {
|
||||
if len(cfg.Token) > 10 {
|
||||
displayToken = cfg.Token[:3] + "***" + cfg.Token[len(cfg.Token)-3:]
|
||||
tokenLen := len(cfg.Token)
|
||||
if tokenLen <= 3 {
|
||||
// For very short tokens, just show asterisks
|
||||
displayToken = "***"
|
||||
} else if tokenLen > 10 {
|
||||
// For long tokens, show first 3 and last 3 characters
|
||||
displayToken = cfg.Token[:3] + "***" + cfg.Token[tokenLen-3:]
|
||||
} else {
|
||||
// For medium tokens (4-10 chars), show first 3 characters
|
||||
displayToken = cfg.Token[:3] + "***"
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -14,6 +14,7 @@ var stopCmd = &cobra.Command{
|
||||
|
||||
Examples:
|
||||
drip stop http 3000 Stop HTTP tunnel on port 3000
|
||||
drip stop https 8443 Stop HTTPS tunnel on port 8443
|
||||
drip stop tcp 5432 Stop TCP tunnel on port 5432
|
||||
drip stop all Stop all running tunnels
|
||||
|
||||
@@ -37,8 +38,8 @@ func runStop(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
tunnelType := args[0]
|
||||
if tunnelType != "http" && tunnelType != "tcp" {
|
||||
return fmt.Errorf("invalid tunnel type: %s (must be 'http' or 'tcp')", tunnelType)
|
||||
if tunnelType != "http" && tunnelType != "https" && tunnelType != "tcp" {
|
||||
return fmt.Errorf("invalid tunnel type: %s (must be 'http', 'https', or 'tcp')", tunnelType)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(args[1])
|
||||
|
||||
@@ -43,6 +43,10 @@ type Connector struct {
|
||||
handlerWg sync.WaitGroup // Tracks active data frame handlers
|
||||
closed bool
|
||||
closedMu sync.RWMutex
|
||||
|
||||
// Worker pool for handling data frames
|
||||
dataFrameQueue chan *protocol.Frame
|
||||
workerCount int
|
||||
}
|
||||
|
||||
// ConnectorConfig holds connector configuration
|
||||
@@ -71,16 +75,21 @@ func NewConnector(cfg *ConnectorConfig, logger *zap.Logger) *Connector {
|
||||
localHost = "127.0.0.1"
|
||||
}
|
||||
|
||||
numCPU := pool.NumCPU()
|
||||
workerCount := max(numCPU+numCPU/2, 4)
|
||||
|
||||
return &Connector{
|
||||
serverAddr: cfg.ServerAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
token: cfg.Token,
|
||||
tunnelType: cfg.TunnelType,
|
||||
localHost: localHost,
|
||||
localPort: cfg.LocalPort,
|
||||
subdomain: cfg.Subdomain,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
serverAddr: cfg.ServerAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
token: cfg.Token,
|
||||
tunnelType: cfg.TunnelType,
|
||||
localHost: localHost,
|
||||
localPort: cfg.LocalPort,
|
||||
subdomain: cfg.Subdomain,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
dataFrameQueue: make(chan *protocol.Frame, workerCount*100),
|
||||
workerCount: workerCount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,6 +144,11 @@ func (c *Connector) Connect() error {
|
||||
|
||||
c.frameWriter.EnableHeartbeat(constants.HeartbeatInterval, c.createHeartbeatFrame)
|
||||
|
||||
for i := 0; i < c.workerCount; i++ {
|
||||
c.handlerWg.Add(1)
|
||||
go c.dataFrameWorker(i)
|
||||
}
|
||||
|
||||
go c.frameHandler.WarmupConnectionPool(3)
|
||||
go c.handleFrames()
|
||||
|
||||
@@ -200,6 +214,29 @@ func (c *Connector) register() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connector) dataFrameWorker(workerID int) {
|
||||
defer c.handlerWg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case frame, ok := <-c.dataFrameQueue:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.frameHandler.HandleDataFrame(frame); err != nil {
|
||||
c.logger.Error("Failed to handle data frame",
|
||||
zap.Int("worker_id", workerID),
|
||||
zap.Error(err))
|
||||
}
|
||||
frame.Release()
|
||||
|
||||
case <-c.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleFrames handles incoming frames from server
|
||||
func (c *Connector) handleFrames() {
|
||||
defer c.Close()
|
||||
@@ -246,14 +283,15 @@ func (c *Connector) handleFrames() {
|
||||
frame.Release()
|
||||
|
||||
case protocol.FrameTypeData:
|
||||
c.handlerWg.Add(1)
|
||||
go func(f *protocol.Frame) {
|
||||
defer c.handlerWg.Done()
|
||||
defer f.Release()
|
||||
if err := c.frameHandler.HandleDataFrame(f); err != nil {
|
||||
c.logger.Error("Failed to handle data frame", zap.Error(err))
|
||||
}
|
||||
}(frame)
|
||||
select {
|
||||
case c.dataFrameQueue <- frame:
|
||||
case <-c.stopCh:
|
||||
frame.Release()
|
||||
return
|
||||
default:
|
||||
c.logger.Warn("Data frame queue full, dropping frame")
|
||||
frame.Release()
|
||||
}
|
||||
|
||||
case protocol.FrameTypeClose:
|
||||
frame.Release()
|
||||
@@ -280,7 +318,6 @@ func (c *Connector) handleFrames() {
|
||||
}
|
||||
}
|
||||
|
||||
// createHeartbeatFrame creates a heartbeat frame to be sent by the write loop.
|
||||
func (c *Connector) createHeartbeatFrame() *protocol.Frame {
|
||||
c.closedMu.RLock()
|
||||
if c.closed {
|
||||
@@ -293,7 +330,6 @@ func (c *Connector) createHeartbeatFrame() *protocol.Frame {
|
||||
c.heartbeatSentAt = time.Now()
|
||||
c.heartbeatMu.Unlock()
|
||||
|
||||
c.logger.Debug("Heartbeat sent")
|
||||
return protocol.NewFrame(protocol.FrameTypeHeartbeat, nil)
|
||||
}
|
||||
|
||||
@@ -306,7 +342,6 @@ func (c *Connector) SendFrame(frame *protocol.Frame) error {
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
|
||||
// Close closes the connection
|
||||
func (c *Connector) Close() error {
|
||||
c.once.Do(func() {
|
||||
c.closedMu.Lock()
|
||||
@@ -314,9 +349,8 @@ func (c *Connector) Close() error {
|
||||
c.closedMu.Unlock()
|
||||
|
||||
close(c.stopCh)
|
||||
close(c.dataFrameQueue)
|
||||
|
||||
// Wait for active handlers with timeout
|
||||
c.logger.Debug("Waiting for active handlers to complete")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.handlerWg.Wait()
|
||||
@@ -325,7 +359,6 @@ func (c *Connector) Close() error {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
c.logger.Debug("All handlers completed")
|
||||
case <-time.After(3 * time.Second):
|
||||
c.logger.Warn("Force closing: some handlers are still active")
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -89,19 +90,21 @@ func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost
|
||||
httpClient: &http.Client{
|
||||
// No overall timeout - streaming responses can take arbitrary time
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 1000,
|
||||
MaxIdleConnsPerHost: 500,
|
||||
MaxConnsPerHost: 0,
|
||||
IdleConnTimeout: 180 * time.Second,
|
||||
DisableCompression: true,
|
||||
DisableKeepAlives: false,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
MaxIdleConns: 2000, // Increased from 1000 for better connection reuse
|
||||
MaxIdleConnsPerHost: 1000, // Increased from 500 for high concurrency
|
||||
MaxConnsPerHost: 0, // Unlimited connections per host
|
||||
IdleConnTimeout: 180 * time.Second, // Keep connections alive for reuse
|
||||
DisableCompression: true, // Disable compression for better CPU efficiency
|
||||
DisableKeepAlives: false, // Enable keep-alive for connection reuse
|
||||
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s for faster failure detection
|
||||
TLSClientConfig: tlsConfig,
|
||||
ResponseHeaderTimeout: 30 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 15 * time.Second, // Reduced from 30s for faster timeout
|
||||
ExpectContinueTimeout: 500 * time.Millisecond, // Reduced from 1s for better responsiveness
|
||||
WriteBufferSize: 32 * 1024, // 32KB write buffer
|
||||
ReadBufferSize: 32 * 1024, // 32KB read buffer
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Timeout: 3 * time.Second, // Reduced from 5s for faster connection attempts
|
||||
KeepAlive: 30 * time.Second, // Keep TCP keepalive
|
||||
}).DialContext,
|
||||
},
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
@@ -224,9 +227,6 @@ func (h *FrameHandler) handleLocalResponse(stream *Stream) {
|
||||
}
|
||||
|
||||
n, err := stream.LocalConn.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
if h.isClosedCheck != nil && h.isClosedCheck() {
|
||||
@@ -262,6 +262,10 @@ func (h *FrameHandler) handleLocalResponse(stream *Stream) {
|
||||
|
||||
h.stats.AddBytesOut(int64(len(payload)))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -352,7 +356,9 @@ func (h *FrameHandler) adaptiveHTTPResponse(streamID, requestID string, resp *ht
|
||||
|
||||
// Buffer for initial read
|
||||
buffer := make([]byte, 0, threshold)
|
||||
tempBuf := make([]byte, 32*1024) // 32KB read chunks
|
||||
tempBufPtr := h.bufferPool.Get(pool.SizeMedium)
|
||||
defer h.bufferPool.Put(tempBufPtr)
|
||||
tempBuf := (*tempBufPtr)[:pool.SizeMedium]
|
||||
|
||||
var totalRead int64
|
||||
var hitThreshold bool
|
||||
@@ -516,7 +522,12 @@ func (h *FrameHandler) adaptiveHTTPResponse(streamID, requestID string, resp *ht
|
||||
break
|
||||
}
|
||||
if readErr != nil {
|
||||
if errors.Is(readErr, context.Canceled) || errors.Is(readErr, context.DeadlineExceeded) || errors.Is(readErr, http.ErrBodyReadAfterClose) || errors.Is(readErr, net.ErrClosed) {
|
||||
// Check for expected errors that indicate connection/body closure
|
||||
if errors.Is(readErr, context.Canceled) ||
|
||||
errors.Is(readErr, context.DeadlineExceeded) ||
|
||||
errors.Is(readErr, http.ErrBodyReadAfterClose) ||
|
||||
errors.Is(readErr, net.ErrClosed) ||
|
||||
strings.Contains(readErr.Error(), "read on closed response body") {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read response body: %w", readErr)
|
||||
@@ -652,7 +663,12 @@ func (h *FrameHandler) streamHTTPResponse(streamID, requestID string, resp *http
|
||||
break
|
||||
}
|
||||
if readErr != nil {
|
||||
if errors.Is(readErr, context.Canceled) || errors.Is(readErr, context.DeadlineExceeded) || errors.Is(readErr, http.ErrBodyReadAfterClose) || errors.Is(readErr, net.ErrClosed) {
|
||||
// Check for expected errors that indicate connection/body closure
|
||||
if errors.Is(readErr, context.Canceled) ||
|
||||
errors.Is(readErr, context.DeadlineExceeded) ||
|
||||
errors.Is(readErr, http.ErrBodyReadAfterClose) ||
|
||||
errors.Is(readErr, net.ErrClosed) ||
|
||||
strings.Contains(readErr.Error(), "read on closed response body") {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read response body: %w", readErr)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
@@ -25,6 +27,7 @@ type Handler struct {
|
||||
domain string
|
||||
authToken string
|
||||
headerPool *pool.HeaderPool
|
||||
bufferPool *pool.AdaptiveBufferPool
|
||||
}
|
||||
|
||||
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, responses *ResponseHandler, domain string, authToken string) *Handler {
|
||||
@@ -35,10 +38,22 @@ func NewHandler(manager *tunnel.Manager, logger *zap.Logger, responses *Response
|
||||
domain: domain,
|
||||
authToken: authToken,
|
||||
headerPool: pool.NewHeaderPool(),
|
||||
bufferPool: pool.NewAdaptiveBufferPool(),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Always handle /health and /stats directly, regardless of subdomain
|
||||
if r.URL.Path == "/health" {
|
||||
h.serveHealth(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/stats" {
|
||||
h.serveStats(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
subdomain := h.extractSubdomain(r.Host)
|
||||
|
||||
if subdomain == "" {
|
||||
@@ -77,8 +92,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
|
||||
const streamingThreshold int64 = 1 * 1024 * 1024
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
var cancelTransport func()
|
||||
if transport != nil {
|
||||
h.responses.RegisterCancelFunc(requestID, func() {
|
||||
cancelOnce := sync.Once{}
|
||||
cancelFunc := func() {
|
||||
header := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
@@ -98,13 +117,26 @@ func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request,
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
cancelTransport = func() {
|
||||
cancelOnce.Do(cancelFunc)
|
||||
}
|
||||
|
||||
h.responses.RegisterCancelFunc(requestID, cancelTransport)
|
||||
defer h.responses.CleanupCancelFunc(requestID)
|
||||
}
|
||||
|
||||
buffer := make([]byte, 0, streamingThreshold)
|
||||
tempBuf := make([]byte, 32*1024)
|
||||
largeBufferPtr := h.bufferPool.GetLarge()
|
||||
tempBufPtr := h.bufferPool.GetMedium()
|
||||
|
||||
defer func() {
|
||||
h.bufferPool.PutLarge(largeBufferPtr)
|
||||
h.bufferPool.PutMedium(tempBufPtr)
|
||||
}()
|
||||
|
||||
buffer := (*largeBufferPtr)[:0]
|
||||
tempBuf := (*tempBufPtr)[:pool.MediumBufferSize]
|
||||
|
||||
var totalRead int64
|
||||
var hitThreshold bool
|
||||
@@ -117,7 +149,7 @@ func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
if err == io.EOF {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
|
||||
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -134,14 +166,14 @@ func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request,
|
||||
|
||||
if !hitThreshold {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
|
||||
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
return
|
||||
}
|
||||
|
||||
h.streamLargeRequest(w, r, transport, requestID, subdomain, buffer)
|
||||
h.streamLargeRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
}
|
||||
|
||||
func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, body []byte) {
|
||||
func (h *Handler) sendBufferedRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), body []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
@@ -199,6 +231,15 @@ func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, tr
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("HTTP request context cancelled",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
@@ -208,7 +249,7 @@ func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, tr
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, bufferedData []byte) {
|
||||
func (h *Handler) streamLargeRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), bufferedData []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
@@ -302,8 +343,23 @@ func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, tra
|
||||
}
|
||||
}
|
||||
|
||||
buffer := make([]byte, 32*1024)
|
||||
streamBufPtr := h.bufferPool.GetMedium()
|
||||
defer h.bufferPool.PutMedium(streamBufPtr)
|
||||
buffer := (*streamBufPtr)[:pool.MediumBufferSize]
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("Streaming request cancelled via context",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
n, readErr := r.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
isLast := readErr == io.EOF
|
||||
@@ -399,6 +455,15 @@ func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, tra
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("Streaming HTTP request context cancelled",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Streaming request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
@@ -421,12 +486,12 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
|
||||
|
||||
// Skip hop-by-hop headers completely using canonical key comparison
|
||||
if canonicalKey == "Connection" ||
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -511,16 +576,6 @@ func (h *Handler) extractSubdomain(host string) string {
|
||||
}
|
||||
|
||||
func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/health" {
|
||||
h.serveHealth(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/stats" {
|
||||
h.serveStats(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
html := `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
@@ -560,8 +615,15 @@ func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(health)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(health)
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -594,6 +656,13 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
data, err := json.Marshal(stats)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(stats)
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -162,12 +163,12 @@ func (h *ResponseHandler) SendStreamingHead(requestID string, head *protocol.HTT
|
||||
|
||||
// Skip ALL hop-by-hop headers
|
||||
if canonicalKey == "Connection" ||
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -349,7 +350,7 @@ func (h *ResponseHandler) cleanupLoop() {
|
||||
|
||||
func (h *ResponseHandler) cleanupExpiredChannels() {
|
||||
now := time.Now()
|
||||
timeout := 30 * time.Second
|
||||
timeout := 5 * time.Minute
|
||||
streamingTimeout := 5 * time.Minute
|
||||
|
||||
h.mu.Lock()
|
||||
|
||||
@@ -2,6 +2,7 @@ package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -41,6 +42,8 @@ type Connection struct {
|
||||
httpHandler http.Handler
|
||||
responseChans HTTPResponseHandler
|
||||
tunnelType protocol.TunnelType // Track tunnel type
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// HTTPResponseHandler interface for response channel operations
|
||||
@@ -56,6 +59,7 @@ type HTTPResponseHandler interface {
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Connection {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Connection{
|
||||
conn: conn,
|
||||
authToken: authToken,
|
||||
@@ -68,6 +72,8 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
|
||||
responseChans: responseChans,
|
||||
stopCh: make(chan struct{}),
|
||||
lastHeartbeat: time.Now(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,12 +217,15 @@ func (c *Connection) Handle() error {
|
||||
return fmt.Errorf("failed to send registration ack: %w", err)
|
||||
}
|
||||
|
||||
// Create frame writer for async writes
|
||||
c.frameWriter = protocol.NewFrameWriter(c.conn)
|
||||
|
||||
c.frameWriter.SetWriteErrorHandler(func(err error) {
|
||||
c.logger.Error("Write error detected, closing connection", zap.Error(err))
|
||||
c.Close()
|
||||
})
|
||||
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Start TCP proxy only for TCP tunnels
|
||||
if req.TunnelType == protocol.TunnelTypeTCP {
|
||||
c.proxy = NewTunnelProxy(c.port, subdomain, c.conn, c.logger)
|
||||
if err := c.proxy.Start(); err != nil {
|
||||
@@ -226,13 +235,10 @@ func (c *Connection) Handle() error {
|
||||
|
||||
go c.heartbeatChecker()
|
||||
|
||||
// Handle frames (pass reader for consistent buffering)
|
||||
return c.handleFrames(reader)
|
||||
}
|
||||
|
||||
// handleHTTPRequest handles HTTP requests that arrive on the TCP port
|
||||
func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
// If no HTTP handler is configured, return error
|
||||
if c.httpHandler == nil {
|
||||
c.logger.Warn("HTTP request received but no HTTP handler configured")
|
||||
response := "HTTP/1.1 503 Service Unavailable\r\n" +
|
||||
@@ -289,24 +295,29 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
return fmt.Errorf("failed to parse HTTP request: %w", err)
|
||||
}
|
||||
|
||||
if c.ctx != nil {
|
||||
req = req.WithContext(c.ctx)
|
||||
}
|
||||
|
||||
c.logger.Info("Processing HTTP request on TCP port",
|
||||
zap.String("method", req.Method),
|
||||
zap.String("url", req.URL.String()),
|
||||
zap.String("host", req.Host),
|
||||
)
|
||||
|
||||
// Create a response writer that writes directly to the connection
|
||||
respWriter := &httpResponseWriter{
|
||||
conn: c.conn,
|
||||
writer: bufio.NewWriterSize(c.conn, 4096),
|
||||
header: make(http.Header),
|
||||
}
|
||||
|
||||
// Handle the request - this blocks until response is complete
|
||||
c.httpHandler.ServeHTTP(respWriter, req)
|
||||
|
||||
// Ensure response is flushed to client
|
||||
if err := respWriter.writer.Flush(); err != nil {
|
||||
c.logger.Debug("Failed to flush HTTP response", zap.Error(err))
|
||||
}
|
||||
|
||||
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
|
||||
// Force flush TCP buffers
|
||||
tcpConn.SetNoDelay(true)
|
||||
tcpConn.SetNoDelay(false)
|
||||
}
|
||||
@@ -316,19 +327,15 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
zap.String("url", req.URL.String()),
|
||||
)
|
||||
|
||||
// Check if we should close the connection
|
||||
// Close if: Connection: close header, or HTTP/1.0 without Connection: keep-alive
|
||||
shouldClose := false
|
||||
if req.Close {
|
||||
shouldClose = true
|
||||
} else if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
|
||||
// HTTP/1.0 defaults to close unless keep-alive is explicitly requested
|
||||
if req.Header.Get("Connection") != "keep-alive" {
|
||||
shouldClose = true
|
||||
}
|
||||
}
|
||||
|
||||
// Also check if response indicated connection should close
|
||||
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
|
||||
shouldClose = true
|
||||
}
|
||||
@@ -563,20 +570,13 @@ func (c *Connection) heartbeatChecker() {
|
||||
}
|
||||
}
|
||||
|
||||
// SendFrame sends a frame to the client
|
||||
func (c *Connection) SendFrame(frame *protocol.Frame) error {
|
||||
if c.frameWriter == nil {
|
||||
return protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
if err := c.frameWriter.WriteFrame(frame); err != nil {
|
||||
return err
|
||||
}
|
||||
// Flush immediately to ensure the frame is sent without batching delay
|
||||
c.frameWriter.Flush()
|
||||
return nil
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
|
||||
// sendError sends an error frame to the client
|
||||
func (c *Connection) sendError(code, message string) {
|
||||
errMsg := protocol.ErrorMessage{
|
||||
Code: code,
|
||||
@@ -586,22 +586,24 @@ func (c *Connection) sendError(code, message string) {
|
||||
errFrame := protocol.NewFrame(protocol.FrameTypeError, data)
|
||||
|
||||
if c.frameWriter == nil {
|
||||
// Fallback if frameWriter not initialized (early errors)
|
||||
protocol.WriteFrame(c.conn, errFrame)
|
||||
} else {
|
||||
c.frameWriter.WriteFrame(errFrame)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the connection
|
||||
func (c *Connection) Close() {
|
||||
c.once.Do(func() {
|
||||
// Unregister connection from adaptive load tracking
|
||||
protocol.UnregisterConnection()
|
||||
|
||||
close(c.stopCh)
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
if c.frameWriter != nil {
|
||||
c.frameWriter.Flush()
|
||||
c.frameWriter.Close()
|
||||
}
|
||||
|
||||
@@ -633,6 +635,7 @@ func (c *Connection) GetSubdomain() string {
|
||||
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
|
||||
type httpResponseWriter struct {
|
||||
conn net.Conn
|
||||
writer *bufio.Writer // Buffered writer for efficient I/O
|
||||
header http.Header
|
||||
statusCode int
|
||||
headerWritten bool
|
||||
@@ -649,27 +652,32 @@ func (w *httpResponseWriter) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
w.headerWritten = true
|
||||
|
||||
// Write status line
|
||||
statusText := http.StatusText(statusCode)
|
||||
if statusText == "" {
|
||||
statusText = "Unknown"
|
||||
}
|
||||
fmt.Fprintf(w.conn, "HTTP/1.1 %d %s\r\n", statusCode, statusText)
|
||||
|
||||
// Write headers
|
||||
w.writer.WriteString("HTTP/1.1 ")
|
||||
w.writer.WriteString(fmt.Sprintf("%d", statusCode))
|
||||
w.writer.WriteByte(' ')
|
||||
w.writer.WriteString(statusText)
|
||||
w.writer.WriteString("\r\n")
|
||||
|
||||
for key, values := range w.header {
|
||||
for _, value := range values {
|
||||
fmt.Fprintf(w.conn, "%s: %s\r\n", key, value)
|
||||
w.writer.WriteString(key)
|
||||
w.writer.WriteString(": ")
|
||||
w.writer.WriteString(value)
|
||||
w.writer.WriteString("\r\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Write empty line to end headers
|
||||
fmt.Fprintf(w.conn, "\r\n")
|
||||
w.writer.WriteString("\r\n")
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) Write(data []byte) (int, error) {
|
||||
if !w.headerWritten {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return w.conn.Write(data)
|
||||
return w.writer.Write(data)
|
||||
}
|
||||
|
||||
@@ -34,11 +34,17 @@ type Listener struct {
|
||||
workerPool *pool.WorkerPool // Worker pool for connection handling
|
||||
}
|
||||
|
||||
// NewListener creates a new TCP listener
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Listener {
|
||||
// Create worker pool with 50 workers and queue size of 1000
|
||||
// This reduces goroutine creation overhead for connection handling
|
||||
workerPool := pool.NewWorkerPool(50, 1000)
|
||||
numCPU := pool.NumCPU()
|
||||
workers := numCPU * 5
|
||||
queueSize := workers * 20
|
||||
workerPool := pool.NewWorkerPool(workers, queueSize)
|
||||
|
||||
logger.Info("Worker pool configured",
|
||||
zap.Int("cpu_cores", numCPU),
|
||||
zap.Int("workers", workers),
|
||||
zap.Int("queue_size", queueSize),
|
||||
)
|
||||
|
||||
return &Listener{
|
||||
address: address,
|
||||
@@ -107,14 +113,11 @@ func (l *Listener) acceptLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle connection using worker pool instead of creating new goroutine
|
||||
// This reduces goroutine creation overhead and improves performance
|
||||
l.wg.Add(1)
|
||||
submitted := l.workerPool.Submit(func() {
|
||||
l.handleConnection(conn)
|
||||
})
|
||||
|
||||
// If pool is full or closed, fall back to direct goroutine
|
||||
if !submitted {
|
||||
go l.handleConnection(conn)
|
||||
}
|
||||
@@ -132,9 +135,16 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
// Set read deadline before handshake to prevent slow handshake attacks
|
||||
if err := tlsConn.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||
l.logger.Warn("Failed to set read deadline",
|
||||
zap.String("remote_addr", netConn.RemoteAddr().String()),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
// TLS handshake failures are common (HTTP clients, scanners, etc.)
|
||||
// Log as WARN instead of ERROR
|
||||
l.logger.Warn("TLS handshake failed",
|
||||
zap.String("remote_addr", netConn.RemoteAddr().String()),
|
||||
zap.Error(err),
|
||||
@@ -142,6 +152,23 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear the read deadline after successful handshake
|
||||
if err := tlsConn.SetReadDeadline(time.Time{}); err != nil {
|
||||
l.logger.Warn("Failed to clear read deadline",
|
||||
zap.String("remote_addr", netConn.RemoteAddr().String()),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if tcpConn, ok := tlsConn.NetConn().(*net.TCPConn); ok {
|
||||
tcpConn.SetNoDelay(true)
|
||||
tcpConn.SetKeepAlive(true)
|
||||
tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
tcpConn.SetReadBuffer(256 * 1024)
|
||||
tcpConn.SetWriteBuffer(256 * 1024)
|
||||
}
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
l.logger.Info("New connection",
|
||||
zap.String("remote_addr", netConn.RemoteAddr().String()),
|
||||
|
||||
73
internal/shared/pool/adaptive_buffer_pool.go
Normal file
73
internal/shared/pool/adaptive_buffer_pool.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// AdaptiveBufferPool manages reusable buffers of different sizes
|
||||
// This eliminates the massive memory allocation overhead seen in profiling
|
||||
type AdaptiveBufferPool struct {
|
||||
// Large buffers for streaming threshold (1MB)
|
||||
largePool *sync.Pool
|
||||
|
||||
// Medium buffers for temporary reads (32KB)
|
||||
mediumPool *sync.Pool
|
||||
}
|
||||
|
||||
const (
|
||||
// LargeBufferSize is 1MB for streaming threshold
|
||||
LargeBufferSize = 1 * 1024 * 1024
|
||||
|
||||
// MediumBufferSize is 32KB for temporary reads
|
||||
MediumBufferSize = 32 * 1024
|
||||
)
|
||||
|
||||
// NewAdaptiveBufferPool creates a new adaptive buffer pool
|
||||
func NewAdaptiveBufferPool() *AdaptiveBufferPool {
|
||||
return &AdaptiveBufferPool{
|
||||
largePool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, LargeBufferSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
mediumPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, MediumBufferSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetLarge returns a large buffer (1MB) from the pool
|
||||
// The returned buffer should be returned via PutLarge when done
|
||||
func (p *AdaptiveBufferPool) GetLarge() *[]byte {
|
||||
return p.largePool.Get().(*[]byte)
|
||||
}
|
||||
|
||||
// PutLarge returns a large buffer to the pool for reuse
|
||||
func (p *AdaptiveBufferPool) PutLarge(buf *[]byte) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
// Reset to full capacity to allow reuse
|
||||
*buf = (*buf)[:cap(*buf)]
|
||||
p.largePool.Put(buf)
|
||||
}
|
||||
|
||||
// GetMedium returns a medium buffer (32KB) from the pool
|
||||
// The returned buffer should be returned via PutMedium when done
|
||||
func (p *AdaptiveBufferPool) GetMedium() *[]byte {
|
||||
return p.mediumPool.Get().(*[]byte)
|
||||
}
|
||||
|
||||
// PutMedium returns a medium buffer to the pool for reuse
|
||||
func (p *AdaptiveBufferPool) PutMedium(buf *[]byte) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
// Reset to full capacity to allow reuse
|
||||
*buf = (*buf)[:cap(*buf)]
|
||||
p.mediumPool.Put(buf)
|
||||
}
|
||||
@@ -1,9 +1,15 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// NumCPU returns the number of logical CPUs available
|
||||
func NumCPU() int {
|
||||
return runtime.NumCPU()
|
||||
}
|
||||
|
||||
// WorkerPool is a fixed-size goroutine pool for handling tasks
|
||||
type WorkerPool struct {
|
||||
workers int
|
||||
|
||||
@@ -22,12 +22,21 @@ type FrameWriter struct {
|
||||
heartbeatCallback func() *Frame
|
||||
heartbeatEnabled bool
|
||||
heartbeatControl chan struct{}
|
||||
|
||||
// Error handling
|
||||
writeErr error
|
||||
errOnce sync.Once
|
||||
onWriteError func(error) // Callback for write errors
|
||||
|
||||
// Adaptive flushing
|
||||
adaptiveFlush bool // Enable adaptive flush based on queue depth
|
||||
lowConcurrencyThreshold int // Queue depth threshold for immediate flush
|
||||
}
|
||||
|
||||
func NewFrameWriter(conn io.Writer) *FrameWriter {
|
||||
// Larger queue size for better burst handling across all load scenarios
|
||||
// With adaptive buffer pool, memory pressure is well controlled
|
||||
return NewFrameWriterWithConfig(conn, 128, 2*time.Millisecond, 2048)
|
||||
w := NewFrameWriterWithConfig(conn, 256, 2*time.Millisecond, 4096)
|
||||
w.EnableAdaptiveFlush(16)
|
||||
return w
|
||||
}
|
||||
|
||||
func NewFrameWriterWithConfig(conn io.Writer, maxBatch int, maxBatchWait time.Duration, queueSize int) *FrameWriter {
|
||||
@@ -77,7 +86,10 @@ func (w *FrameWriter) writeLoop() {
|
||||
w.mu.Lock()
|
||||
w.batch = append(w.batch, frame)
|
||||
|
||||
if len(w.batch) >= w.maxBatch {
|
||||
shouldFlushNow := len(w.batch) >= w.maxBatch ||
|
||||
(w.adaptiveFlush && len(w.queue) <= w.lowConcurrencyThreshold)
|
||||
|
||||
if shouldFlushNow {
|
||||
w.flushBatchLocked()
|
||||
}
|
||||
w.mu.Unlock()
|
||||
@@ -127,7 +139,15 @@ func (w *FrameWriter) flushBatchLocked() {
|
||||
}
|
||||
|
||||
for _, frame := range w.batch {
|
||||
_ = WriteFrame(w.conn, frame)
|
||||
if err := WriteFrame(w.conn, frame); err != nil {
|
||||
w.errOnce.Do(func() {
|
||||
w.writeErr = err
|
||||
if w.onWriteError != nil {
|
||||
go w.onWriteError(err)
|
||||
}
|
||||
w.closed = true
|
||||
})
|
||||
}
|
||||
frame.Release()
|
||||
}
|
||||
|
||||
@@ -138,6 +158,9 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
if w.writeErr != nil {
|
||||
return w.writeErr
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
}
|
||||
w.mu.Unlock()
|
||||
@@ -146,6 +169,12 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
case w.queue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
w.mu.Lock()
|
||||
err := w.writeErr
|
||||
w.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("writer closed")
|
||||
}
|
||||
}
|
||||
@@ -177,7 +206,6 @@ func (w *FrameWriter) Flush() {
|
||||
return
|
||||
}
|
||||
|
||||
// First, drain the queue into batch
|
||||
for {
|
||||
select {
|
||||
case frame, ok := <-w.queue:
|
||||
@@ -190,7 +218,6 @@ func (w *FrameWriter) Flush() {
|
||||
}
|
||||
}
|
||||
done:
|
||||
// Then flush the batch
|
||||
w.flushBatchLocked()
|
||||
w.mu.Unlock()
|
||||
}
|
||||
@@ -218,3 +245,22 @@ func (w *FrameWriter) DisableHeartbeat() {
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (w *FrameWriter) SetWriteErrorHandler(handler func(error)) {
|
||||
w.mu.Lock()
|
||||
w.onWriteError = handler
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
func (w *FrameWriter) EnableAdaptiveFlush(lowConcurrencyThreshold int) {
|
||||
w.mu.Lock()
|
||||
w.adaptiveFlush = true
|
||||
w.lowConcurrencyThreshold = lowConcurrencyThreshold
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
func (w *FrameWriter) DisableAdaptiveFlush() {
|
||||
w.mu.Lock()
|
||||
w.adaptiveFlush = false
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
@@ -55,9 +55,10 @@ func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
|
||||
|
||||
// Force TLS 1.3 only
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
PreferServerCipherSuites: true, // Prefer server cipher suites (ignored in TLS 1.3 but set for consistency)
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_AES_128_GCM_SHA256,
|
||||
tls.TLS_AES_256_GCM_SHA384,
|
||||
@@ -71,9 +72,11 @@ func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
|
||||
// GetClientTLSConfig returns TLS config for client connections
|
||||
func GetClientTLSConfig(serverName string) *tls.Config {
|
||||
return &tls.Config{
|
||||
ServerName: serverName,
|
||||
MinVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
ServerName: serverName,
|
||||
MinVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
ClientSessionCache: tls.NewLRUClientSessionCache(0), // Enable session resumption (0 = default size)
|
||||
PreferServerCipherSuites: true, // Prefer server cipher suites (ignored in TLS 1.3 but set for consistency)
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_AES_128_GCM_SHA256,
|
||||
tls.TLS_AES_256_GCM_SHA384,
|
||||
@@ -86,9 +89,11 @@ func GetClientTLSConfig(serverName string) *tls.Config {
|
||||
// WARNING: Only use for testing!
|
||||
func GetClientTLSConfigInsecure() *tls.Config {
|
||||
return &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
MinVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
InsecureSkipVerify: true,
|
||||
MinVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
ClientSessionCache: tls.NewLRUClientSessionCache(0), // Enable session resumption (0 = default size)
|
||||
PreferServerCipherSuites: true, // Prefer server cipher suites (ignored in TLS 1.3 but set for consistency)
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_AES_128_GCM_SHA256,
|
||||
tls.TLS_AES_256_GCM_SHA384,
|
||||
|
||||
434
scripts/test/profile-test.sh
Executable file
434
scripts/test/profile-test.sh
Executable file
@@ -0,0 +1,434 @@
|
||||
#!/bin/bash
|
||||
# Drip Performance Profiling Script
|
||||
# Runs performance test while collecting CPU and memory profiles
|
||||
|
||||
set -e
|
||||
|
||||
# Color definitions
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Configuration
|
||||
RESULTS_DIR="benchmark-results"
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
LOG_DIR="/tmp/drip-profile-${TIMESTAMP}"
|
||||
PROFILE_DIR="${RESULTS_DIR}/profiles-${TIMESTAMP}"
|
||||
|
||||
# Port configuration
|
||||
HTTP_TEST_PORT=3000
|
||||
DRIP_SERVER_PORT=8443
|
||||
PPROF_PORT=6060
|
||||
|
||||
# PID file
|
||||
PIDS_FILE="${LOG_DIR}/pids.txt"
|
||||
|
||||
# Create directories
|
||||
mkdir -p "$RESULTS_DIR"
|
||||
mkdir -p "$LOG_DIR"
|
||||
mkdir -p "$PROFILE_DIR"
|
||||
|
||||
# ============================================
|
||||
# Helper functions
|
||||
# ============================================
|
||||
|
||||
log_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1" >&2
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1" >&2
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1" >&2
|
||||
}
|
||||
|
||||
log_step() {
|
||||
echo -e "\n${BLUE}==>${NC} $1\n" >&2
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
log_step "Cleaning up..."
|
||||
|
||||
if [ -f "$PIDS_FILE" ]; then
|
||||
while read -r pid; do
|
||||
if ps -p "$pid" > /dev/null 2>&1; then
|
||||
kill "$pid" 2>/dev/null || true
|
||||
fi
|
||||
done < "$PIDS_FILE"
|
||||
rm -f "$PIDS_FILE"
|
||||
fi
|
||||
|
||||
pkill -f "python.*${HTTP_TEST_PORT}" 2>/dev/null || true
|
||||
pkill -f "drip server.*${DRIP_SERVER_PORT}" 2>/dev/null || true
|
||||
pkill -f "drip http ${HTTP_TEST_PORT}" 2>/dev/null || true
|
||||
|
||||
log_info "Cleanup completed"
|
||||
}
|
||||
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
# Wait for port to be available
|
||||
wait_for_port() {
|
||||
local port=$1
|
||||
local max_wait=${2:-30}
|
||||
local waited=0
|
||||
|
||||
while ! nc -z localhost "$port" 2>/dev/null; do
|
||||
if [ "$waited" -ge "$max_wait" ]; then
|
||||
return 1
|
||||
fi
|
||||
sleep 1
|
||||
waited=$((waited + 1))
|
||||
done
|
||||
return 0
|
||||
}
|
||||
|
||||
# Generate test certificate
|
||||
generate_test_certs() {
|
||||
log_step "Generating test TLS certificate..."
|
||||
|
||||
local cert_dir="${LOG_DIR}/certs"
|
||||
mkdir -p "$cert_dir"
|
||||
|
||||
openssl ecparam -name prime256v1 -genkey -noout \
|
||||
-out "${cert_dir}/server.key" >/dev/null 2>&1
|
||||
|
||||
openssl req -new -x509 \
|
||||
-key "${cert_dir}/server.key" \
|
||||
-out "${cert_dir}/server.crt" \
|
||||
-days 1 \
|
||||
-subj "/C=US/ST=Test/L=Test/O=Test/CN=localhost" \
|
||||
>/dev/null 2>&1
|
||||
|
||||
log_info "✓ Test certificate generated"
|
||||
echo "${cert_dir}/server.crt ${cert_dir}/server.key"
|
||||
}
|
||||
|
||||
# Start HTTP test server
|
||||
start_http_server() {
|
||||
log_step "Starting HTTP test server..."
|
||||
|
||||
cat > "${LOG_DIR}/test-server.py" << 'EOF'
|
||||
import http.server
|
||||
import socketserver
|
||||
import json
|
||||
from datetime import datetime
|
||||
import sys
|
||||
|
||||
PORT = int(sys.argv[1]) if len(sys.argv) > 1 else 3000
|
||||
|
||||
class TestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
response = {
|
||||
"status": "ok",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": "Test server response"
|
||||
}
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(response).encode())
|
||||
|
||||
def log_message(self, format, *args):
|
||||
pass
|
||||
|
||||
with socketserver.TCPServer(("", PORT), TestHandler) as httpd:
|
||||
print(f"Server started on port {PORT}", flush=True)
|
||||
httpd.serve_forever()
|
||||
EOF
|
||||
|
||||
python3 "${LOG_DIR}/test-server.py" "$HTTP_TEST_PORT" \
|
||||
> "${LOG_DIR}/http-server.log" 2>&1 &
|
||||
local pid=$!
|
||||
echo "$pid" >> "$PIDS_FILE"
|
||||
|
||||
if wait_for_port "$HTTP_TEST_PORT" 10; then
|
||||
log_info "✓ HTTP test server started (PID: $pid)"
|
||||
else
|
||||
log_error "HTTP test server failed to start"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Start Drip server with pprof
|
||||
start_drip_server() {
|
||||
log_step "Starting Drip server with pprof on port $PPROF_PORT..."
|
||||
|
||||
local cert_path=$1
|
||||
local key_path=$2
|
||||
|
||||
./bin/drip server \
|
||||
--port "$DRIP_SERVER_PORT" \
|
||||
--domain localhost \
|
||||
--tls-cert "$cert_path" \
|
||||
--tls-key "$key_path" \
|
||||
--pprof "$PPROF_PORT" \
|
||||
> "${LOG_DIR}/drip-server.log" 2>&1 &
|
||||
local pid=$!
|
||||
echo "$pid" >> "$PIDS_FILE"
|
||||
|
||||
if wait_for_port "$DRIP_SERVER_PORT" 10; then
|
||||
log_info "✓ Drip server started (PID: $pid)"
|
||||
else
|
||||
log_error "Drip server failed to start"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Wait for pprof to be available
|
||||
if wait_for_port "$PPROF_PORT" 10; then
|
||||
log_info "✓ pprof endpoint available at http://localhost:$PPROF_PORT/debug/pprof"
|
||||
else
|
||||
log_warn "pprof endpoint not available"
|
||||
fi
|
||||
}
|
||||
|
||||
# Start Drip client
|
||||
start_drip_client() {
|
||||
log_step "Starting Drip client..."
|
||||
|
||||
./bin/drip http "$HTTP_TEST_PORT" \
|
||||
--server "localhost:${DRIP_SERVER_PORT}" \
|
||||
--insecure \
|
||||
> "${LOG_DIR}/drip-client.log" 2>&1 &
|
||||
local pid=$!
|
||||
echo "$pid" >> "$PIDS_FILE"
|
||||
|
||||
sleep 3
|
||||
|
||||
local tunnel_url=""
|
||||
local max_attempts=10
|
||||
local attempt=0
|
||||
|
||||
while [ "$attempt" -lt "$max_attempts" ]; do
|
||||
tunnel_url=$(grep -oE 'https://[a-zA-Z0-9.-]+:[0-9]+' "${LOG_DIR}/drip-client.log" 2>/dev/null | head -1)
|
||||
if [ -n "$tunnel_url" ]; then
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
attempt=$((attempt + 1))
|
||||
done
|
||||
|
||||
if [ -z "$tunnel_url" ]; then
|
||||
log_error "Cannot get tunnel URL"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_info "✓ Drip client started (PID: $pid)"
|
||||
log_info "✓ Tunnel URL: $tunnel_url"
|
||||
|
||||
echo "$tunnel_url"
|
||||
}
|
||||
|
||||
# Collect CPU profile
|
||||
collect_cpu_profile() {
|
||||
local duration=$1
|
||||
local profile_file="${PROFILE_DIR}/cpu.prof"
|
||||
|
||||
log_step "Collecting CPU profile for ${duration}s..."
|
||||
|
||||
curl -s "http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=${duration}" \
|
||||
-o "$profile_file"
|
||||
|
||||
if [ -f "$profile_file" ] && [ -s "$profile_file" ]; then
|
||||
log_info "✓ CPU profile saved to $profile_file"
|
||||
return 0
|
||||
else
|
||||
log_error "Failed to collect CPU profile"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Collect heap profile
|
||||
collect_heap_profile() {
|
||||
local profile_file="${PROFILE_DIR}/heap.prof"
|
||||
|
||||
log_step "Collecting heap profile..."
|
||||
|
||||
curl -s "http://localhost:${PPROF_PORT}/debug/pprof/heap" \
|
||||
-o "$profile_file"
|
||||
|
||||
if [ -f "$profile_file" ] && [ -s "$profile_file" ]; then
|
||||
log_info "✓ Heap profile saved to $profile_file"
|
||||
return 0
|
||||
else
|
||||
log_error "Failed to collect heap profile"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Collect goroutine profile
|
||||
collect_goroutine_profile() {
|
||||
local profile_file="${PROFILE_DIR}/goroutine.txt"
|
||||
|
||||
log_step "Collecting goroutine profile..."
|
||||
|
||||
curl -s "http://localhost:${PPROF_PORT}/debug/pprof/goroutine?debug=2" \
|
||||
-o "$profile_file"
|
||||
|
||||
if [ -f "$profile_file" ] && [ -s "$profile_file" ]; then
|
||||
log_info "✓ Goroutine profile saved to $profile_file"
|
||||
return 0
|
||||
else
|
||||
log_error "Failed to collect goroutine profile"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Run performance test with profiling
|
||||
run_profiling_test() {
|
||||
local url=$1
|
||||
|
||||
log_step "Starting profiling test..."
|
||||
|
||||
# Warm up
|
||||
log_info "Warming up (5s)..."
|
||||
for _ in {1..5}; do
|
||||
curl -sk "$url" > /dev/null 2>&1 || true
|
||||
sleep 1
|
||||
done
|
||||
|
||||
# Start CPU profiling in background
|
||||
log_info "Starting CPU profile collection (45s)..."
|
||||
collect_cpu_profile 45 &
|
||||
local profile_pid=$!
|
||||
|
||||
# Wait a moment for profiler to start
|
||||
sleep 2
|
||||
|
||||
# Run load test during profiling
|
||||
log_info "Running load test (30s, 100 connections)..."
|
||||
wrk -t 8 -c 100 -d 30s --latency "$url" \
|
||||
> "${RESULTS_DIR}/profile-test-${TIMESTAMP}.txt" 2>&1
|
||||
|
||||
# Wait for profiling to complete
|
||||
wait $profile_pid
|
||||
|
||||
# Collect heap profile after test
|
||||
collect_heap_profile
|
||||
|
||||
# Collect goroutine profile
|
||||
collect_goroutine_profile
|
||||
|
||||
log_info "✓ Profiling test completed"
|
||||
}
|
||||
|
||||
# Analyze CPU profile
|
||||
analyze_cpu_profile() {
|
||||
local profile_file="${PROFILE_DIR}/cpu.prof"
|
||||
|
||||
if [ ! -f "$profile_file" ]; then
|
||||
log_error "CPU profile not found"
|
||||
return 1
|
||||
fi
|
||||
|
||||
log_step "Analyzing CPU profile..."
|
||||
|
||||
# Generate text report
|
||||
log_info "Generating top functions report..."
|
||||
go tool pprof -text -nodecount=20 "$profile_file" \
|
||||
> "${PROFILE_DIR}/cpu-top20.txt" 2>/dev/null || true
|
||||
|
||||
# Generate list report
|
||||
log_info "Generating function list..."
|
||||
go tool pprof -list=. "$profile_file" \
|
||||
> "${PROFILE_DIR}/cpu-list.txt" 2>/dev/null || true
|
||||
|
||||
log_info "✓ Analysis complete"
|
||||
log_info " - Top functions: ${PROFILE_DIR}/cpu-top20.txt"
|
||||
log_info " - Detailed list: ${PROFILE_DIR}/cpu-list.txt"
|
||||
}
|
||||
|
||||
# Show summary
|
||||
show_summary() {
|
||||
log_step "Profiling Results Summary"
|
||||
|
||||
echo ""
|
||||
echo "========================================="
|
||||
echo " CPU Profile - Top 20 Functions"
|
||||
echo "========================================="
|
||||
if [ -f "${PROFILE_DIR}/cpu-top20.txt" ]; then
|
||||
head -30 "${PROFILE_DIR}/cpu-top20.txt"
|
||||
fi
|
||||
echo ""
|
||||
echo "========================================="
|
||||
echo " Performance Test Results"
|
||||
echo "========================================="
|
||||
if [ -f "${RESULTS_DIR}/profile-test-${TIMESTAMP}.txt" ]; then
|
||||
grep "Requests/sec:" "${RESULTS_DIR}/profile-test-${TIMESTAMP}.txt"
|
||||
grep "Transfer/sec:" "${RESULTS_DIR}/profile-test-${TIMESTAMP}.txt"
|
||||
echo ""
|
||||
grep "50%" "${RESULTS_DIR}/profile-test-${TIMESTAMP}.txt"
|
||||
grep "99%" "${RESULTS_DIR}/profile-test-${TIMESTAMP}.txt"
|
||||
fi
|
||||
echo "========================================="
|
||||
echo ""
|
||||
|
||||
log_info "Profile files location: $PROFILE_DIR"
|
||||
log_info ""
|
||||
log_info "To analyze interactively, run:"
|
||||
log_info " go tool pprof ${PROFILE_DIR}/cpu.prof"
|
||||
log_info ""
|
||||
log_info "Web UI:"
|
||||
log_info " go tool pprof -http=:8080 ${PROFILE_DIR}/cpu.prof"
|
||||
}
|
||||
|
||||
# ============================================
|
||||
# Main flow
|
||||
# ============================================
|
||||
|
||||
main() {
|
||||
clear
|
||||
echo "========================================="
|
||||
echo " Drip Performance Profiling Test"
|
||||
echo "========================================="
|
||||
echo ""
|
||||
|
||||
# Check dependencies
|
||||
if ! command -v wrk &> /dev/null; then
|
||||
log_error "wrk is required (brew install wrk)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "./bin/drip" ]; then
|
||||
log_error "Cannot find drip executable. Run: make build"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Generate certificate
|
||||
CERT_PATHS=$(generate_test_certs)
|
||||
CERT_FILE=$(echo "$CERT_PATHS" | awk '{print $1}')
|
||||
KEY_FILE=$(echo "$CERT_PATHS" | awk '{print $2}')
|
||||
|
||||
# Start services
|
||||
start_http_server
|
||||
start_drip_server "$CERT_FILE" "$KEY_FILE"
|
||||
TUNNEL_URL=$(start_drip_client)
|
||||
|
||||
# Verify connectivity
|
||||
log_info "Verifying connectivity..."
|
||||
if ! curl -sk --max-time 5 "$TUNNEL_URL" > /dev/null 2>&1; then
|
||||
log_error "Tunnel not accessible"
|
||||
exit 1
|
||||
fi
|
||||
log_info "✓ Tunnel connectivity OK"
|
||||
|
||||
# Run profiling test
|
||||
run_profiling_test "$TUNNEL_URL"
|
||||
|
||||
# Analyze profiles
|
||||
analyze_cpu_profile
|
||||
|
||||
# Show summary
|
||||
show_summary
|
||||
|
||||
log_step "Profiling completed!"
|
||||
}
|
||||
|
||||
# Run main
|
||||
main
|
||||
Reference in New Issue
Block a user