mirror of
https://github.com/Gouryella/drip.git
synced 2026-04-28 13:20:36 +00:00
Added the `--short` flag to the `version` command for printing version information without styles. In this mode, only the version, Git commit hash, and build time in plain text format will be output, facilitating script parsing. Optimized Windows process detection logic to improve runtime accuracy. Removed redundant comments and simplified signal checking methods, making the code clearer and easier to maintain. refactor(protocol): Replaced string matching of data frame types with enumeration types. Unified the representation of data frame types in the protocol, using the `DataType` enumeration to improve performance and readability. Introduced a pooled buffer mechanism to improve memory efficiency in high-load scenarios. refactor(ui): Adjusted style definitions, removing hard-coded color values. Removed fixed color settings from some lipgloss styles, providing flexibility for future theme customization. ``` docs(install): Improved the version extraction function in the installation script. Added the `get_version_from_binary` function to enhance version identification capabilities, prioritizing plain mode output, ensuring accurate version number acquisition for the drip client or server across different terminal environments. perf(tcp): Improved TCP processing performance and connection management capabilities. Adjusted HTTP client transmission parameter configuration, increasing the maximum number of idle connections to accommodate higher concurrent requests. Improved error handling logic, adding special checks for common cases such as closing network connections to avoid log pollution. chore(writer): Expanded the FrameWriter queue length to improve batch write stability. Increased the FrameWriter queue size from 1024 to 2048, and released pooled resources after flushing, better handling sudden traffic spikes and reducing memory usage fluctuations.
590 lines
16 KiB
Go
590 lines
16 KiB
Go
package tcp
|
|
|
|
import (
|
|
"bufio"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
json "github.com/goccy/go-json"
|
|
|
|
"drip/internal/server/tunnel"
|
|
"drip/internal/shared/constants"
|
|
"drip/internal/shared/protocol"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// Connection represents a client TCP connection
|
|
type Connection struct {
|
|
conn net.Conn
|
|
authToken string
|
|
manager *tunnel.Manager
|
|
logger *zap.Logger
|
|
subdomain string
|
|
port int
|
|
domain string
|
|
publicPort int
|
|
portAlloc *PortAllocator
|
|
tunnelConn *tunnel.Connection
|
|
proxy *TunnelProxy
|
|
stopCh chan struct{}
|
|
once sync.Once
|
|
lastHeartbeat time.Time
|
|
mu sync.RWMutex
|
|
frameWriter *protocol.FrameWriter
|
|
httpHandler http.Handler
|
|
responseChans HTTPResponseHandler
|
|
tunnelType protocol.TunnelType // Track tunnel type
|
|
}
|
|
|
|
// HTTPResponseHandler interface for response channel operations
|
|
type HTTPResponseHandler interface {
|
|
CreateResponseChan(requestID string) chan *protocol.HTTPResponse
|
|
GetResponseChan(requestID string) <-chan *protocol.HTTPResponse
|
|
CleanupResponseChan(requestID string)
|
|
SendResponse(requestID string, resp *protocol.HTTPResponse)
|
|
}
|
|
|
|
// NewConnection creates a new connection handler
|
|
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Connection {
|
|
return &Connection{
|
|
conn: conn,
|
|
authToken: authToken,
|
|
manager: manager,
|
|
logger: logger,
|
|
portAlloc: portAlloc,
|
|
domain: domain,
|
|
publicPort: publicPort,
|
|
httpHandler: httpHandler,
|
|
responseChans: responseChans,
|
|
stopCh: make(chan struct{}),
|
|
lastHeartbeat: time.Now(),
|
|
}
|
|
}
|
|
|
|
// Handle handles the connection lifecycle
|
|
func (c *Connection) Handle() error {
|
|
// Register connection for adaptive load tracking
|
|
protocol.RegisterConnection()
|
|
|
|
// Ensure cleanup of control connection, proxy, port, and registry on exit.
|
|
defer c.Close()
|
|
|
|
// Set initial read timeout for protocol detection
|
|
c.conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
|
|
|
// Use buffered reader to support peeking
|
|
reader := bufio.NewReader(c.conn)
|
|
|
|
// Peek first 8 bytes to detect protocol
|
|
peek, err := reader.Peek(8)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to peek connection: %w", err)
|
|
}
|
|
|
|
peekStr := string(peek)
|
|
httpMethods := []string{"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC"}
|
|
isHTTP := false
|
|
for _, method := range httpMethods {
|
|
if strings.HasPrefix(peekStr, method) {
|
|
isHTTP = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if isHTTP {
|
|
c.logger.Info("Detected HTTP request on TCP port, handling as HTTP")
|
|
return c.handleHTTPRequest(reader)
|
|
}
|
|
|
|
// Continue with drip protocol
|
|
// Wait for registration frame
|
|
frame, err := protocol.ReadFrame(reader)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read registration frame: %w", err)
|
|
}
|
|
defer frame.Release() // Return pool buffer when done
|
|
|
|
if frame.Type != protocol.FrameTypeRegister {
|
|
return fmt.Errorf("expected register frame, got %s", frame.Type)
|
|
}
|
|
|
|
var req protocol.RegisterRequest
|
|
if err := json.Unmarshal(frame.Payload, &req); err != nil {
|
|
return fmt.Errorf("failed to parse registration request: %w", err)
|
|
}
|
|
|
|
c.tunnelType = req.TunnelType
|
|
|
|
if c.authToken != "" && req.Token != c.authToken {
|
|
c.sendError("authentication_failed", "Invalid authentication token")
|
|
return fmt.Errorf("authentication failed")
|
|
}
|
|
|
|
// Allocate TCP port only for TCP tunnels
|
|
if req.TunnelType == protocol.TunnelTypeTCP {
|
|
if c.portAlloc == nil {
|
|
return fmt.Errorf("port allocator not configured")
|
|
}
|
|
|
|
port, err := c.portAlloc.Allocate()
|
|
if err != nil {
|
|
c.sendError("port_allocation_failed", err.Error())
|
|
return fmt.Errorf("failed to allocate port: %w", err)
|
|
}
|
|
c.port = port
|
|
|
|
// For TCP tunnels, prefer deterministic subdomain tied to port when not provided by client.
|
|
if req.CustomSubdomain == "" {
|
|
req.CustomSubdomain = fmt.Sprintf("tcp-%d", port)
|
|
}
|
|
}
|
|
|
|
subdomain, err := c.manager.Register(nil, req.CustomSubdomain)
|
|
if err != nil {
|
|
c.sendError("registration_failed", err.Error())
|
|
c.portAlloc.Release(c.port)
|
|
c.port = 0
|
|
return fmt.Errorf("tunnel registration failed: %w", err)
|
|
}
|
|
|
|
c.subdomain = subdomain
|
|
|
|
tunnelConn, ok := c.manager.Get(subdomain)
|
|
if !ok {
|
|
return fmt.Errorf("failed to get registered tunnel")
|
|
}
|
|
c.tunnelConn = tunnelConn
|
|
|
|
// Store TCP connection reference and metadata for HTTP proxy routing
|
|
c.tunnelConn.Conn = nil // We're using TCP, not WebSocket
|
|
c.tunnelConn.SetTransport(c, req.TunnelType)
|
|
c.tunnelConn.SetTunnelType(req.TunnelType)
|
|
c.tunnelType = req.TunnelType
|
|
|
|
c.logger.Info("Tunnel registered",
|
|
zap.String("subdomain", subdomain),
|
|
zap.String("tunnel_type", string(req.TunnelType)),
|
|
zap.Int("local_port", req.LocalPort),
|
|
zap.Int("remote_port", c.port),
|
|
)
|
|
|
|
// Send registration acknowledgment
|
|
// Generate appropriate URL based on tunnel type
|
|
var tunnelURL string
|
|
|
|
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
|
|
// HTTP/HTTPS tunnels use HTTPS with subdomain
|
|
// Use publicPort for URL generation (configured via --public-port flag)
|
|
if c.publicPort == 443 {
|
|
tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.domain)
|
|
} else {
|
|
tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.domain, c.publicPort)
|
|
}
|
|
} else {
|
|
// TCP tunnels use tcp:// with port
|
|
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
|
|
}
|
|
|
|
resp := protocol.RegisterResponse{
|
|
Subdomain: subdomain,
|
|
Port: c.port,
|
|
URL: tunnelURL,
|
|
Message: "Tunnel registered successfully",
|
|
}
|
|
|
|
respData, _ := json.Marshal(resp)
|
|
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
|
|
|
|
// Send registration ack (sync write before frameWriter is created)
|
|
err = protocol.WriteFrame(c.conn, ackFrame)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to send registration ack: %w", err)
|
|
}
|
|
|
|
// Create frame writer for async writes
|
|
c.frameWriter = protocol.NewFrameWriter(c.conn)
|
|
|
|
c.conn.SetReadDeadline(time.Time{})
|
|
|
|
// Start TCP proxy only for TCP tunnels
|
|
if req.TunnelType == protocol.TunnelTypeTCP {
|
|
c.proxy = NewTunnelProxy(c.port, subdomain, c.conn, c.logger)
|
|
if err := c.proxy.Start(); err != nil {
|
|
return fmt.Errorf("failed to start TCP proxy: %w", err)
|
|
}
|
|
}
|
|
|
|
go c.heartbeatChecker()
|
|
|
|
// Handle frames (pass reader for consistent buffering)
|
|
return c.handleFrames(reader)
|
|
}
|
|
|
|
// handleHTTPRequest handles HTTP requests that arrive on the TCP port
|
|
func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
|
// If no HTTP handler is configured, return error
|
|
if c.httpHandler == nil {
|
|
c.logger.Warn("HTTP request received but no HTTP handler configured")
|
|
response := "HTTP/1.1 503 Service Unavailable\r\n" +
|
|
"Content-Type: text/plain\r\n" +
|
|
"Content-Length: 47\r\n" +
|
|
"\r\n" +
|
|
"HTTP handler not configured for this TCP port\r\n"
|
|
c.conn.Write([]byte(response))
|
|
return fmt.Errorf("HTTP handler not configured")
|
|
}
|
|
|
|
// Clear read deadline for HTTP processing
|
|
c.conn.SetReadDeadline(time.Time{})
|
|
|
|
// Handle multiple HTTP requests on the same connection (HTTP/1.1 keep-alive)
|
|
for {
|
|
// Set a read deadline for each request to avoid hanging forever
|
|
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
|
|
// Parse HTTP request
|
|
req, err := http.ReadRequest(reader)
|
|
if err != nil {
|
|
// EOF or timeout is normal when client closes connection or no more requests
|
|
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
|
c.logger.Debug("Client closed HTTP connection")
|
|
return nil
|
|
}
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
c.logger.Debug("HTTP keep-alive timeout")
|
|
return nil
|
|
}
|
|
// Connection reset by peer is normal - client closed connection abruptly
|
|
errStr := err.Error()
|
|
if errors.Is(err, net.ErrClosed) || strings.Contains(errStr, "use of closed network connection") {
|
|
c.logger.Debug("HTTP connection closed during read", zap.Error(err))
|
|
return nil
|
|
}
|
|
if strings.Contains(errStr, "connection reset by peer") ||
|
|
strings.Contains(errStr, "broken pipe") ||
|
|
strings.Contains(errStr, "connection refused") {
|
|
c.logger.Debug("Client disconnected abruptly", zap.Error(err))
|
|
return nil
|
|
}
|
|
c.logger.Error("Failed to parse HTTP request", zap.Error(err))
|
|
return fmt.Errorf("failed to parse HTTP request: %w", err)
|
|
}
|
|
|
|
c.logger.Info("Processing HTTP request on TCP port",
|
|
zap.String("method", req.Method),
|
|
zap.String("url", req.URL.String()),
|
|
zap.String("host", req.Host),
|
|
)
|
|
|
|
// Create a response writer that writes directly to the connection
|
|
respWriter := &httpResponseWriter{
|
|
conn: c.conn,
|
|
header: make(http.Header),
|
|
}
|
|
|
|
// Handle the request
|
|
c.httpHandler.ServeHTTP(respWriter, req)
|
|
|
|
// Check if we should close the connection
|
|
// Close if: Connection: close header, or HTTP/1.0 without Connection: keep-alive
|
|
shouldClose := false
|
|
if req.Close {
|
|
shouldClose = true
|
|
} else if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
|
|
// HTTP/1.0 defaults to close unless keep-alive is explicitly requested
|
|
if req.Header.Get("Connection") != "keep-alive" {
|
|
shouldClose = true
|
|
}
|
|
}
|
|
|
|
if shouldClose {
|
|
c.logger.Debug("Closing connection as requested by client")
|
|
return nil
|
|
}
|
|
|
|
// Continue to next request on the same connection
|
|
}
|
|
}
|
|
|
|
// handleFrames handles incoming frames
|
|
func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
|
for {
|
|
select {
|
|
case <-c.stopCh:
|
|
return nil
|
|
default:
|
|
}
|
|
|
|
// Read frame with timeout
|
|
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
|
frame, err := protocol.ReadFrame(reader)
|
|
if err != nil {
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
c.logger.Warn("Read timeout, connection may be dead")
|
|
return fmt.Errorf("read timeout")
|
|
}
|
|
// EOF is normal when client closes connection gracefully
|
|
if err.Error() == "failed to read frame header: EOF" || err.Error() == "EOF" {
|
|
c.logger.Info("Client disconnected")
|
|
return nil
|
|
}
|
|
// Check if connection was closed (during shutdown)
|
|
select {
|
|
case <-c.stopCh:
|
|
// Connection was closed intentionally, don't log as error
|
|
c.logger.Debug("Connection closed during shutdown")
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("failed to read frame: %w", err)
|
|
}
|
|
}
|
|
|
|
// Handle frame based on type
|
|
switch frame.Type {
|
|
case protocol.FrameTypeHeartbeat:
|
|
c.handleHeartbeat()
|
|
frame.Release()
|
|
|
|
case protocol.FrameTypeData:
|
|
// Data frame from client (response to forwarded request)
|
|
c.handleDataFrame(frame)
|
|
frame.Release() // Release after processing
|
|
|
|
case protocol.FrameTypeClose:
|
|
frame.Release()
|
|
c.logger.Info("Client requested close")
|
|
return nil
|
|
|
|
default:
|
|
frame.Release()
|
|
c.logger.Warn("Unexpected frame type",
|
|
zap.String("type", frame.Type.String()),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleHeartbeat handles heartbeat frame
|
|
func (c *Connection) handleHeartbeat() {
|
|
c.mu.Lock()
|
|
c.lastHeartbeat = time.Now()
|
|
c.mu.Unlock()
|
|
|
|
// Send heartbeat ack
|
|
ackFrame := protocol.NewFrame(protocol.FrameTypeHeartbeatAck, nil)
|
|
|
|
err := c.frameWriter.WriteFrame(ackFrame)
|
|
if err != nil {
|
|
c.logger.Error("Failed to send heartbeat ack", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
// handleDataFrame handles data frame (response from client)
|
|
func (c *Connection) handleDataFrame(frame *protocol.Frame) {
|
|
// Decode payload (auto-detects protocol version)
|
|
header, data, err := protocol.DecodeDataPayload(frame.Payload)
|
|
if err != nil {
|
|
c.logger.Error("Failed to decode data payload",
|
|
zap.Error(err),
|
|
)
|
|
return
|
|
}
|
|
|
|
c.logger.Debug("Received data frame",
|
|
zap.String("stream_id", header.StreamID),
|
|
zap.String("type", header.Type.String()),
|
|
zap.Int("data_size", len(data)),
|
|
)
|
|
|
|
switch header.Type {
|
|
case protocol.DataTypeResponse:
|
|
// TCP tunnel response, forward to proxy
|
|
if c.proxy != nil {
|
|
if err := c.proxy.HandleResponse(header.StreamID, data); err != nil {
|
|
c.logger.Error("Failed to handle response",
|
|
zap.String("stream_id", header.StreamID),
|
|
zap.Error(err),
|
|
)
|
|
}
|
|
}
|
|
case protocol.DataTypeHTTPResponse:
|
|
if c.responseChans == nil {
|
|
c.logger.Warn("No response channel handler for HTTP response",
|
|
zap.String("stream_id", header.StreamID),
|
|
)
|
|
return
|
|
}
|
|
|
|
// Decode HTTP response (auto-detects JSON vs msgpack)
|
|
httpResp, err := protocol.DecodeHTTPResponse(data)
|
|
if err != nil {
|
|
c.logger.Error("Failed to decode HTTP response",
|
|
zap.String("stream_id", header.StreamID),
|
|
zap.Error(err),
|
|
)
|
|
return
|
|
}
|
|
|
|
// Route by request ID when provided to keep request/response aligned.
|
|
reqID := header.RequestID
|
|
if reqID == "" {
|
|
reqID = header.StreamID
|
|
}
|
|
|
|
c.responseChans.SendResponse(reqID, httpResp)
|
|
|
|
c.logger.Debug("Routed HTTP response to channel",
|
|
zap.String("request_id", reqID),
|
|
)
|
|
case protocol.DataTypeClose:
|
|
// Client is closing the stream
|
|
if c.proxy != nil {
|
|
c.proxy.CloseStream(header.StreamID)
|
|
}
|
|
default:
|
|
c.logger.Warn("Unknown data frame type",
|
|
zap.String("type", header.Type.String()),
|
|
zap.String("stream_id", header.StreamID),
|
|
)
|
|
}
|
|
}
|
|
|
|
// heartbeatChecker checks for heartbeat timeout
|
|
func (c *Connection) heartbeatChecker() {
|
|
ticker := time.NewTicker(constants.HeartbeatInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-c.stopCh:
|
|
return
|
|
case <-ticker.C:
|
|
c.mu.RLock()
|
|
lastHB := c.lastHeartbeat
|
|
c.mu.RUnlock()
|
|
|
|
if time.Since(lastHB) > constants.HeartbeatTimeout {
|
|
c.logger.Warn("Heartbeat timeout",
|
|
zap.String("subdomain", c.subdomain),
|
|
zap.Duration("last_heartbeat", time.Since(lastHB)),
|
|
)
|
|
c.Close()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// SendFrame sends a frame to the client
|
|
func (c *Connection) SendFrame(frame *protocol.Frame) error {
|
|
if c.frameWriter == nil {
|
|
return protocol.WriteFrame(c.conn, frame)
|
|
}
|
|
return c.frameWriter.WriteFrame(frame)
|
|
}
|
|
|
|
// sendError sends an error frame to the client
|
|
func (c *Connection) sendError(code, message string) {
|
|
errMsg := protocol.ErrorMessage{
|
|
Code: code,
|
|
Message: message,
|
|
}
|
|
data, _ := json.Marshal(errMsg)
|
|
errFrame := protocol.NewFrame(protocol.FrameTypeError, data)
|
|
|
|
if c.frameWriter == nil {
|
|
// Fallback if frameWriter not initialized (early errors)
|
|
protocol.WriteFrame(c.conn, errFrame)
|
|
} else {
|
|
c.frameWriter.WriteFrame(errFrame)
|
|
}
|
|
}
|
|
|
|
// Close closes the connection
|
|
func (c *Connection) Close() {
|
|
c.once.Do(func() {
|
|
// Unregister connection from adaptive load tracking
|
|
protocol.UnregisterConnection()
|
|
|
|
close(c.stopCh)
|
|
|
|
if c.frameWriter != nil {
|
|
c.frameWriter.Close()
|
|
}
|
|
|
|
if c.proxy != nil {
|
|
c.proxy.Stop()
|
|
}
|
|
|
|
c.conn.Close()
|
|
|
|
if c.port > 0 && c.portAlloc != nil {
|
|
c.portAlloc.Release(c.port)
|
|
}
|
|
|
|
if c.subdomain != "" {
|
|
c.manager.Unregister(c.subdomain)
|
|
}
|
|
|
|
c.logger.Info("Connection closed",
|
|
zap.String("subdomain", c.subdomain),
|
|
)
|
|
})
|
|
}
|
|
|
|
// GetSubdomain returns the assigned subdomain
|
|
func (c *Connection) GetSubdomain() string {
|
|
return c.subdomain
|
|
}
|
|
|
|
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
|
|
type httpResponseWriter struct {
|
|
conn net.Conn
|
|
header http.Header
|
|
statusCode int
|
|
headerWritten bool
|
|
}
|
|
|
|
func (w *httpResponseWriter) Header() http.Header {
|
|
return w.header
|
|
}
|
|
|
|
func (w *httpResponseWriter) WriteHeader(statusCode int) {
|
|
if w.headerWritten {
|
|
return
|
|
}
|
|
w.statusCode = statusCode
|
|
w.headerWritten = true
|
|
|
|
// Write status line
|
|
statusText := http.StatusText(statusCode)
|
|
if statusText == "" {
|
|
statusText = "Unknown"
|
|
}
|
|
fmt.Fprintf(w.conn, "HTTP/1.1 %d %s\r\n", statusCode, statusText)
|
|
|
|
// Write headers
|
|
for key, values := range w.header {
|
|
for _, value := range values {
|
|
fmt.Fprintf(w.conn, "%s: %s\r\n", key, value)
|
|
}
|
|
}
|
|
|
|
// Write empty line to end headers
|
|
fmt.Fprintf(w.conn, "\r\n")
|
|
}
|
|
|
|
func (w *httpResponseWriter) Write(data []byte) (int, error) {
|
|
if !w.headerWritten {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
return w.conn.Write(data)
|
|
}
|