Files
drip/internal/server/tcp/connection.go
Gouryella aead68bb62 feat: Add HTTP streaming, compression support, and Docker deployment
enhancements

  - Add adaptive HTTP response handling with automatic streaming for large
  responses (>1MB)
  - Implement zero-copy streaming using buffer pools for better performance
  - Add compression module for reduced bandwidth usage
  - Add GitHub Container Registry workflow for automated Docker builds
  - Add production-optimized Dockerfile and docker-compose configuration
  - Simplify background mode with -d flag and improved daemon management
  - Update documentation with new command syntax and deployment guides
  - Clean up unused code and improve error handling
  - Fix lipgloss style usage (remove unnecessary .Copy() calls)
2025-12-05 22:09:07 +08:00

676 lines
18 KiB
Go

package tcp
import (
"bufio"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
json "github.com/goccy/go-json"
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// Connection represents a client TCP connection
type Connection struct {
conn net.Conn
authToken string
manager *tunnel.Manager
logger *zap.Logger
subdomain string
port int
domain string
publicPort int
portAlloc *PortAllocator
tunnelConn *tunnel.Connection
proxy *TunnelProxy
stopCh chan struct{}
once sync.Once
lastHeartbeat time.Time
mu sync.RWMutex
frameWriter *protocol.FrameWriter
httpHandler http.Handler
responseChans HTTPResponseHandler
tunnelType protocol.TunnelType // Track tunnel type
}
// HTTPResponseHandler interface for response channel operations
type HTTPResponseHandler interface {
CreateResponseChan(requestID string) chan *protocol.HTTPResponse
GetResponseChan(requestID string) <-chan *protocol.HTTPResponse
CleanupResponseChan(requestID string)
SendResponse(requestID string, resp *protocol.HTTPResponse)
// Streaming response methods
SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error
SendStreamingChunk(requestID string, chunk []byte, isLast bool) error
}
// NewConnection creates a new connection handler
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Connection {
return &Connection{
conn: conn,
authToken: authToken,
manager: manager,
logger: logger,
portAlloc: portAlloc,
domain: domain,
publicPort: publicPort,
httpHandler: httpHandler,
responseChans: responseChans,
stopCh: make(chan struct{}),
lastHeartbeat: time.Now(),
}
}
// Handle handles the connection lifecycle
func (c *Connection) Handle() error {
// Register connection for adaptive load tracking
protocol.RegisterConnection()
// Ensure cleanup of control connection, proxy, port, and registry on exit.
defer c.Close()
// Set initial read timeout for protocol detection
c.conn.SetReadDeadline(time.Now().Add(30 * time.Second))
// Use buffered reader to support peeking
reader := bufio.NewReader(c.conn)
// Peek first 8 bytes to detect protocol
peek, err := reader.Peek(8)
if err != nil {
return fmt.Errorf("failed to peek connection: %w", err)
}
peekStr := string(peek)
httpMethods := []string{"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC"}
isHTTP := false
for _, method := range httpMethods {
if strings.HasPrefix(peekStr, method) {
isHTTP = true
break
}
}
if isHTTP {
c.logger.Info("Detected HTTP request on TCP port, handling as HTTP")
return c.handleHTTPRequest(reader)
}
// Continue with drip protocol
// Wait for registration frame
frame, err := protocol.ReadFrame(reader)
if err != nil {
return fmt.Errorf("failed to read registration frame: %w", err)
}
defer frame.Release() // Return pool buffer when done
if frame.Type != protocol.FrameTypeRegister {
return fmt.Errorf("expected register frame, got %s", frame.Type)
}
var req protocol.RegisterRequest
if err := json.Unmarshal(frame.Payload, &req); err != nil {
return fmt.Errorf("failed to parse registration request: %w", err)
}
c.tunnelType = req.TunnelType
if c.authToken != "" && req.Token != c.authToken {
c.sendError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed")
}
// Allocate TCP port only for TCP tunnels
if req.TunnelType == protocol.TunnelTypeTCP {
if c.portAlloc == nil {
return fmt.Errorf("port allocator not configured")
}
port, err := c.portAlloc.Allocate()
if err != nil {
c.sendError("port_allocation_failed", err.Error())
return fmt.Errorf("failed to allocate port: %w", err)
}
c.port = port
// For TCP tunnels, prefer deterministic subdomain tied to port when not provided by client.
if req.CustomSubdomain == "" {
req.CustomSubdomain = fmt.Sprintf("tcp-%d", port)
}
}
subdomain, err := c.manager.Register(nil, req.CustomSubdomain)
if err != nil {
c.sendError("registration_failed", err.Error())
c.portAlloc.Release(c.port)
c.port = 0
return fmt.Errorf("tunnel registration failed: %w", err)
}
c.subdomain = subdomain
tunnelConn, ok := c.manager.Get(subdomain)
if !ok {
return fmt.Errorf("failed to get registered tunnel")
}
c.tunnelConn = tunnelConn
// Store TCP connection reference and metadata for HTTP proxy routing
c.tunnelConn.Conn = nil // We're using TCP, not WebSocket
c.tunnelConn.SetTransport(c, req.TunnelType)
c.tunnelConn.SetTunnelType(req.TunnelType)
c.tunnelType = req.TunnelType
c.logger.Info("Tunnel registered",
zap.String("subdomain", subdomain),
zap.String("tunnel_type", string(req.TunnelType)),
zap.Int("local_port", req.LocalPort),
zap.Int("remote_port", c.port),
)
// Send registration acknowledgment
// Generate appropriate URL based on tunnel type
var tunnelURL string
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
// HTTP/HTTPS tunnels use HTTPS with subdomain
// Use publicPort for URL generation (configured via --public-port flag)
if c.publicPort == 443 {
tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.domain)
} else {
tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.domain, c.publicPort)
}
} else {
// TCP tunnels use tcp:// with port
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
}
resp := protocol.RegisterResponse{
Subdomain: subdomain,
Port: c.port,
URL: tunnelURL,
Message: "Tunnel registered successfully",
}
respData, _ := json.Marshal(resp)
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
// Send registration ack (sync write before frameWriter is created)
err = protocol.WriteFrame(c.conn, ackFrame)
if err != nil {
return fmt.Errorf("failed to send registration ack: %w", err)
}
// Create frame writer for async writes
c.frameWriter = protocol.NewFrameWriter(c.conn)
c.conn.SetReadDeadline(time.Time{})
// Start TCP proxy only for TCP tunnels
if req.TunnelType == protocol.TunnelTypeTCP {
c.proxy = NewTunnelProxy(c.port, subdomain, c.conn, c.logger)
if err := c.proxy.Start(); err != nil {
return fmt.Errorf("failed to start TCP proxy: %w", err)
}
}
go c.heartbeatChecker()
// Handle frames (pass reader for consistent buffering)
return c.handleFrames(reader)
}
// handleHTTPRequest handles HTTP requests that arrive on the TCP port
func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
// If no HTTP handler is configured, return error
if c.httpHandler == nil {
c.logger.Warn("HTTP request received but no HTTP handler configured")
response := "HTTP/1.1 503 Service Unavailable\r\n" +
"Content-Type: text/plain\r\n" +
"Content-Length: 47\r\n" +
"\r\n" +
"HTTP handler not configured for this TCP port\r\n"
c.conn.Write([]byte(response))
return fmt.Errorf("HTTP handler not configured")
}
// Clear read deadline for HTTP processing
c.conn.SetReadDeadline(time.Time{})
// Handle multiple HTTP requests on the same connection (HTTP/1.1 keep-alive)
for {
// Set a read deadline for each request to avoid hanging forever
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
// Parse HTTP request
req, err := http.ReadRequest(reader)
if err != nil {
// EOF or timeout is normal when client closes connection or no more requests
if err == io.EOF || err == io.ErrUnexpectedEOF {
c.logger.Debug("Client closed HTTP connection")
return nil
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
c.logger.Debug("HTTP keep-alive timeout")
return nil
}
// Connection reset by peer is normal - client closed connection abruptly
errStr := err.Error()
if errors.Is(err, net.ErrClosed) || strings.Contains(errStr, "use of closed network connection") {
c.logger.Debug("HTTP connection closed during read", zap.Error(err))
return nil
}
if strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") {
c.logger.Debug("Client disconnected abruptly", zap.Error(err))
return nil
}
// Check if it looks like garbage data (not a valid HTTP request)
if strings.Contains(errStr, "malformed HTTP") {
c.logger.Warn("Received malformed HTTP request, possibly due to pipelined requests or protocol error",
zap.Error(err),
zap.String("error_snippet", errStr[:min(len(errStr), 100)]),
)
// Close connection on malformed request to prevent further errors
return nil
}
c.logger.Error("Failed to parse HTTP request", zap.Error(err))
return fmt.Errorf("failed to parse HTTP request: %w", err)
}
c.logger.Info("Processing HTTP request on TCP port",
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
zap.String("host", req.Host),
)
// Create a response writer that writes directly to the connection
respWriter := &httpResponseWriter{
conn: c.conn,
header: make(http.Header),
}
// Handle the request - this blocks until response is complete
c.httpHandler.ServeHTTP(respWriter, req)
// Ensure response is flushed to client
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
// Force flush TCP buffers
tcpConn.SetNoDelay(true)
tcpConn.SetNoDelay(false)
}
c.logger.Debug("HTTP request processing completed",
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
)
// Check if we should close the connection
// Close if: Connection: close header, or HTTP/1.0 without Connection: keep-alive
shouldClose := false
if req.Close {
shouldClose = true
} else if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
// HTTP/1.0 defaults to close unless keep-alive is explicitly requested
if req.Header.Get("Connection") != "keep-alive" {
shouldClose = true
}
}
// Also check if response indicated connection should close
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
shouldClose = true
}
if shouldClose {
c.logger.Debug("Closing connection as requested by client or server")
return nil
}
// Continue to next request on the same connection
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
// handleFrames handles incoming frames
func (c *Connection) handleFrames(reader *bufio.Reader) error {
for {
select {
case <-c.stopCh:
return nil
default:
}
// Read frame with timeout
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(reader)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
c.logger.Warn("Read timeout, connection may be dead")
return fmt.Errorf("read timeout")
}
// EOF is normal when client closes connection gracefully
if err.Error() == "failed to read frame header: EOF" || err.Error() == "EOF" {
c.logger.Info("Client disconnected")
return nil
}
// Check if connection was closed (during shutdown)
select {
case <-c.stopCh:
// Connection was closed intentionally, don't log as error
c.logger.Debug("Connection closed during shutdown")
return nil
default:
return fmt.Errorf("failed to read frame: %w", err)
}
}
// Handle frame based on type
switch frame.Type {
case protocol.FrameTypeHeartbeat:
c.handleHeartbeat()
frame.Release()
case protocol.FrameTypeData:
// Data frame from client (response to forwarded request)
c.handleDataFrame(frame)
frame.Release() // Release after processing
case protocol.FrameTypeClose:
frame.Release()
c.logger.Info("Client requested close")
return nil
default:
frame.Release()
c.logger.Warn("Unexpected frame type",
zap.String("type", frame.Type.String()),
)
}
}
}
// handleHeartbeat handles heartbeat frame
func (c *Connection) handleHeartbeat() {
c.mu.Lock()
c.lastHeartbeat = time.Now()
c.mu.Unlock()
// Send heartbeat ack
ackFrame := protocol.NewFrame(protocol.FrameTypeHeartbeatAck, nil)
err := c.frameWriter.WriteFrame(ackFrame)
if err != nil {
c.logger.Error("Failed to send heartbeat ack", zap.Error(err))
}
}
// handleDataFrame handles data frame (response from client)
func (c *Connection) handleDataFrame(frame *protocol.Frame) {
// Decode payload (auto-detects protocol version)
header, data, err := protocol.DecodeDataPayload(frame.Payload)
if err != nil {
c.logger.Error("Failed to decode data payload",
zap.Error(err),
)
return
}
c.logger.Debug("Received data frame",
zap.String("stream_id", header.StreamID),
zap.String("type", header.Type.String()),
zap.Int("data_size", len(data)),
)
switch header.Type {
case protocol.DataTypeResponse:
// TCP tunnel response, forward to proxy
if c.proxy != nil {
if err := c.proxy.HandleResponse(header.StreamID, data); err != nil {
c.logger.Error("Failed to handle response",
zap.String("stream_id", header.StreamID),
zap.Error(err),
)
}
}
case protocol.DataTypeHTTPResponse:
if c.responseChans == nil {
c.logger.Warn("No response channel handler for HTTP response",
zap.String("stream_id", header.StreamID),
)
return
}
// Decode HTTP response (auto-detects JSON vs msgpack)
httpResp, err := protocol.DecodeHTTPResponse(data)
if err != nil {
c.logger.Error("Failed to decode HTTP response",
zap.String("stream_id", header.StreamID),
zap.Error(err),
)
return
}
// Route by request ID when provided to keep request/response aligned.
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
c.responseChans.SendResponse(reqID, httpResp)
case protocol.DataTypeHTTPHead:
// Streaming HTTP response headers
if c.responseChans == nil {
c.logger.Warn("No response handler for streaming HTTP head",
zap.String("stream_id", header.StreamID),
)
return
}
httpHead, err := protocol.DecodeHTTPResponseHead(data)
if err != nil {
c.logger.Error("Failed to decode HTTP response head",
zap.String("stream_id", header.StreamID),
zap.Error(err),
)
return
}
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
if err := c.responseChans.SendStreamingHead(reqID, httpHead); err != nil {
c.logger.Error("Failed to send streaming head",
zap.String("request_id", reqID),
zap.Error(err),
)
}
case protocol.DataTypeHTTPBodyChunk:
// Streaming HTTP response body chunk
if c.responseChans == nil {
c.logger.Warn("No response handler for streaming HTTP chunk",
zap.String("stream_id", header.StreamID),
)
return
}
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
if err := c.responseChans.SendStreamingChunk(reqID, data, header.IsLast); err != nil {
c.logger.Error("Failed to send streaming chunk",
zap.String("request_id", reqID),
zap.Error(err),
)
}
case protocol.DataTypeClose:
// Client is closing the stream
if c.proxy != nil {
c.proxy.CloseStream(header.StreamID)
}
default:
c.logger.Warn("Unknown data frame type",
zap.String("type", header.Type.String()),
zap.String("stream_id", header.StreamID),
)
}
}
// heartbeatChecker checks for heartbeat timeout
func (c *Connection) heartbeatChecker() {
ticker := time.NewTicker(constants.HeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-c.stopCh:
return
case <-ticker.C:
c.mu.RLock()
lastHB := c.lastHeartbeat
c.mu.RUnlock()
if time.Since(lastHB) > constants.HeartbeatTimeout {
c.logger.Warn("Heartbeat timeout",
zap.String("subdomain", c.subdomain),
zap.Duration("last_heartbeat", time.Since(lastHB)),
)
c.Close()
return
}
}
}
}
// SendFrame sends a frame to the client
func (c *Connection) SendFrame(frame *protocol.Frame) error {
if c.frameWriter == nil {
return protocol.WriteFrame(c.conn, frame)
}
if err := c.frameWriter.WriteFrame(frame); err != nil {
return err
}
// Flush immediately to ensure the frame is sent without batching delay
c.frameWriter.Flush()
return nil
}
// sendError sends an error frame to the client
func (c *Connection) sendError(code, message string) {
errMsg := protocol.ErrorMessage{
Code: code,
Message: message,
}
data, _ := json.Marshal(errMsg)
errFrame := protocol.NewFrame(protocol.FrameTypeError, data)
if c.frameWriter == nil {
// Fallback if frameWriter not initialized (early errors)
protocol.WriteFrame(c.conn, errFrame)
} else {
c.frameWriter.WriteFrame(errFrame)
}
}
// Close closes the connection
func (c *Connection) Close() {
c.once.Do(func() {
// Unregister connection from adaptive load tracking
protocol.UnregisterConnection()
close(c.stopCh)
if c.frameWriter != nil {
c.frameWriter.Close()
}
if c.proxy != nil {
c.proxy.Stop()
}
c.conn.Close()
if c.port > 0 && c.portAlloc != nil {
c.portAlloc.Release(c.port)
}
if c.subdomain != "" {
c.manager.Unregister(c.subdomain)
}
c.logger.Info("Connection closed",
zap.String("subdomain", c.subdomain),
)
})
}
// GetSubdomain returns the assigned subdomain
func (c *Connection) GetSubdomain() string {
return c.subdomain
}
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
type httpResponseWriter struct {
conn net.Conn
header http.Header
statusCode int
headerWritten bool
}
func (w *httpResponseWriter) Header() http.Header {
return w.header
}
func (w *httpResponseWriter) WriteHeader(statusCode int) {
if w.headerWritten {
return
}
w.statusCode = statusCode
w.headerWritten = true
// Write status line
statusText := http.StatusText(statusCode)
if statusText == "" {
statusText = "Unknown"
}
fmt.Fprintf(w.conn, "HTTP/1.1 %d %s\r\n", statusCode, statusText)
// Write headers
for key, values := range w.header {
for _, value := range values {
fmt.Fprintf(w.conn, "%s: %s\r\n", key, value)
}
}
// Write empty line to end headers
fmt.Fprintf(w.conn, "\r\n")
}
func (w *httpResponseWriter) Write(data []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(http.StatusOK)
}
return w.conn.Write(data)
}