mirror of
https://github.com/Gouryella/drip.git
synced 2026-03-04 04:43:47 +00:00
feat(tunnel): switch to yamux stream proxying and connection pooling
- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server - Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly - Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/hashicorp/yamux"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/constants"
|
||||
@@ -33,36 +34,27 @@ type Connection struct {
|
||||
publicPort int
|
||||
portAlloc *PortAllocator
|
||||
tunnelConn *tunnel.Connection
|
||||
proxy *TunnelProxy
|
||||
stopCh chan struct{}
|
||||
once sync.Once
|
||||
lastHeartbeat time.Time
|
||||
mu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
httpHandler http.Handler
|
||||
responseChans HTTPResponseHandler
|
||||
tunnelType protocol.TunnelType // Track tunnel type
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Flow control
|
||||
paused bool
|
||||
pauseCond *sync.Cond
|
||||
}
|
||||
// gost-like TCP tunnel (yamux)
|
||||
session *yamux.Session
|
||||
proxy *Proxy
|
||||
|
||||
// HTTPResponseHandler interface for response channel operations
|
||||
type HTTPResponseHandler interface {
|
||||
CreateResponseChan(requestID string) chan *protocol.HTTPResponse
|
||||
GetResponseChan(requestID string) <-chan *protocol.HTTPResponse
|
||||
CleanupResponseChan(requestID string)
|
||||
SendResponse(requestID string, resp *protocol.HTTPResponse)
|
||||
// Streaming response methods
|
||||
SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error
|
||||
SendStreamingChunk(requestID string, chunk []byte, isLast bool) error
|
||||
// Multi-connection support
|
||||
tunnelID string
|
||||
groupManager *ConnectionGroupManager
|
||||
}
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Connection {
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager) *Connection {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
c := &Connection{
|
||||
conn: conn,
|
||||
@@ -73,13 +65,12 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
|
||||
domain: domain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
responseChans: responseChans,
|
||||
stopCh: make(chan struct{}),
|
||||
lastHeartbeat: time.Now(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
groupManager: groupManager,
|
||||
}
|
||||
c.pauseCond = sync.NewCond(&c.mu)
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -97,8 +88,8 @@ func (c *Connection) Handle() error {
|
||||
// Use buffered reader to support peeking
|
||||
reader := bufio.NewReader(c.conn)
|
||||
|
||||
// Peek first 8 bytes to detect protocol
|
||||
peek, err := reader.Peek(8)
|
||||
// Peek first 4 bytes to detect protocol (HTTP methods are 4 bytes).
|
||||
peek, err := reader.Peek(4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to peek connection: %w", err)
|
||||
}
|
||||
@@ -127,6 +118,11 @@ func (c *Connection) Handle() error {
|
||||
sf := protocol.WithFrame(frame)
|
||||
defer sf.Close()
|
||||
|
||||
// Handle data connection request (for multi-connection pool)
|
||||
if sf.Frame.Type == protocol.FrameTypeDataConnect {
|
||||
return c.handleDataConnect(sf.Frame, reader)
|
||||
}
|
||||
|
||||
if sf.Frame.Type != protocol.FrameTypeRegister {
|
||||
return fmt.Errorf("expected register frame, got %s", sf.Frame.Type)
|
||||
}
|
||||
@@ -180,7 +176,6 @@ func (c *Connection) Handle() error {
|
||||
|
||||
// Store TCP connection reference and metadata for HTTP proxy routing
|
||||
c.tunnelConn.Conn = nil // We're using TCP, not WebSocket
|
||||
c.tunnelConn.SetTransport(c, req.TunnelType)
|
||||
c.tunnelConn.SetTunnelType(req.TunnelType)
|
||||
c.tunnelType = req.TunnelType
|
||||
|
||||
@@ -208,11 +203,33 @@ func (c *Connection) Handle() error {
|
||||
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
|
||||
}
|
||||
|
||||
// Generate TunnelID for multi-connection support if client supports it
|
||||
var tunnelID string
|
||||
var supportsDataConn bool
|
||||
recommendedConns := 0
|
||||
|
||||
if req.PoolCapabilities != nil && req.ConnectionType == "primary" && c.groupManager != nil {
|
||||
// Client supports connection pooling
|
||||
group := c.groupManager.CreateGroup(subdomain, req.Token, c, req.TunnelType)
|
||||
tunnelID = group.TunnelID
|
||||
c.tunnelID = tunnelID
|
||||
supportsDataConn = true
|
||||
recommendedConns = 4 // Recommend 4 data connections
|
||||
|
||||
c.logger.Info("Created connection group for multi-connection support",
|
||||
zap.String("tunnel_id", tunnelID),
|
||||
zap.Int("max_data_conns", req.PoolCapabilities.MaxDataConns),
|
||||
)
|
||||
}
|
||||
|
||||
resp := protocol.RegisterResponse{
|
||||
Subdomain: subdomain,
|
||||
Port: c.port,
|
||||
URL: tunnelURL,
|
||||
Message: "Tunnel registered successfully",
|
||||
Subdomain: subdomain,
|
||||
Port: c.port,
|
||||
URL: tunnelURL,
|
||||
Message: "Tunnel registered successfully",
|
||||
TunnelID: tunnelID,
|
||||
SupportsDataConn: supportsDataConn,
|
||||
RecommendedConns: recommendedConns,
|
||||
}
|
||||
|
||||
respData, _ := json.Marshal(resp)
|
||||
@@ -224,6 +241,17 @@ func (c *Connection) Handle() error {
|
||||
return fmt.Errorf("failed to send registration ack: %w", err)
|
||||
}
|
||||
|
||||
// Clear deadline for tunnel data-plane.
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// gost-like tunnels: switch to yamux after RegisterAck.
|
||||
if req.TunnelType == protocol.TunnelTypeTCP {
|
||||
return c.handleTCPTunnel(reader)
|
||||
}
|
||||
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
|
||||
return c.handleHTTPProxyTunnel(reader)
|
||||
}
|
||||
|
||||
c.frameWriter = protocol.NewFrameWriter(c.conn)
|
||||
|
||||
c.frameWriter.SetWriteErrorHandler(func(err error) {
|
||||
@@ -231,15 +259,6 @@ func (c *Connection) Handle() error {
|
||||
c.Close()
|
||||
})
|
||||
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if req.TunnelType == protocol.TunnelTypeTCP {
|
||||
c.proxy = NewTunnelProxy(c.port, subdomain, c.conn, c.logger)
|
||||
if err := c.proxy.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start TCP proxy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
go c.heartbeatChecker()
|
||||
|
||||
return c.handleFrames(reader)
|
||||
@@ -376,7 +395,7 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(reader)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
if isTimeoutError(err) {
|
||||
c.logger.Warn("Read timeout, connection may be dead")
|
||||
return fmt.Errorf("read timeout")
|
||||
}
|
||||
@@ -404,15 +423,6 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
c.handleHeartbeat()
|
||||
sf.Close()
|
||||
|
||||
case protocol.FrameTypeData:
|
||||
// Data frame from client (response to forwarded request)
|
||||
c.handleDataFrame(sf.Frame)
|
||||
sf.Close()
|
||||
|
||||
case protocol.FrameTypeFlowControl:
|
||||
c.handleFlowControl(sf.Frame)
|
||||
sf.Close()
|
||||
|
||||
case protocol.FrameTypeClose:
|
||||
sf.Close()
|
||||
c.logger.Info("Client requested close")
|
||||
@@ -436,127 +446,12 @@ func (c *Connection) handleHeartbeat() {
|
||||
// Send heartbeat ack
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeHeartbeatAck, nil)
|
||||
|
||||
err := c.frameWriter.WriteFrame(ackFrame)
|
||||
err := c.frameWriter.WriteControl(ackFrame)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to send heartbeat ack", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// handleDataFrame handles data frame (response from client)
|
||||
func (c *Connection) handleDataFrame(frame *protocol.Frame) {
|
||||
// Decode payload (auto-detects protocol version)
|
||||
header, data, err := protocol.DecodeDataPayload(frame.Payload)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode data payload",
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Debug("Received data frame",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.String("type", header.Type.String()),
|
||||
zap.Int("data_size", len(data)),
|
||||
)
|
||||
|
||||
switch header.Type {
|
||||
case protocol.DataTypeResponse:
|
||||
// TCP tunnel response, forward to proxy
|
||||
if c.proxy != nil {
|
||||
if err := c.proxy.HandleResponse(header.StreamID, data); err != nil {
|
||||
c.logger.Error("Failed to handle response",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
case protocol.DataTypeHTTPResponse:
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response channel handler for HTTP response",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Decode HTTP response (auto-detects JSON vs msgpack)
|
||||
httpResp, err := protocol.DecodeHTTPResponse(data)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode HTTP response",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Route by request ID when provided to keep request/response aligned.
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
c.responseChans.SendResponse(reqID, httpResp)
|
||||
case protocol.DataTypeHTTPHead:
|
||||
// Streaming HTTP response headers
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response handler for streaming HTTP head",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
httpHead, err := protocol.DecodeHTTPResponseHead(data)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode HTTP response head",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
if err := c.responseChans.SendStreamingHead(reqID, httpHead); err != nil {
|
||||
c.logger.Error("Failed to send streaming head",
|
||||
zap.String("request_id", reqID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
case protocol.DataTypeHTTPBodyChunk:
|
||||
// Streaming HTTP response body chunk
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response handler for streaming HTTP chunk",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
if err := c.responseChans.SendStreamingChunk(reqID, data, header.IsLast); err != nil {
|
||||
c.logger.Error("Failed to send streaming chunk",
|
||||
zap.String("request_id", reqID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
case protocol.DataTypeClose:
|
||||
// Client is closing the stream
|
||||
if c.proxy != nil {
|
||||
c.proxy.CloseStream(header.StreamID)
|
||||
}
|
||||
default:
|
||||
c.logger.Warn("Unknown data frame type",
|
||||
zap.String("type", header.Type.String()),
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeatChecker checks for heartbeat timeout
|
||||
func (c *Connection) heartbeatChecker() {
|
||||
ticker := time.NewTicker(constants.HeartbeatInterval)
|
||||
@@ -583,16 +478,6 @@ func (c *Connection) heartbeatChecker() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) SendFrame(frame *protocol.Frame) error {
|
||||
if c.frameWriter == nil {
|
||||
return protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
if frame.Type == protocol.FrameTypeData {
|
||||
return c.sendWithBackpressure(frame)
|
||||
}
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
|
||||
func (c *Connection) sendError(code, message string) {
|
||||
errMsg := protocol.ErrorMessage{
|
||||
Code: code,
|
||||
@@ -618,8 +503,12 @@ func (c *Connection) Close() {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
// Ensure any in-flight writes return quickly on shutdown to avoid hanging.
|
||||
if c.conn != nil {
|
||||
_ = c.conn.SetDeadline(time.Now())
|
||||
}
|
||||
|
||||
if c.frameWriter != nil {
|
||||
c.frameWriter.Flush()
|
||||
c.frameWriter.Close()
|
||||
}
|
||||
|
||||
@@ -627,7 +516,13 @@ func (c *Connection) Close() {
|
||||
c.proxy.Stop()
|
||||
}
|
||||
|
||||
c.conn.Close()
|
||||
if c.session != nil {
|
||||
_ = c.session.Close()
|
||||
}
|
||||
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
if c.port > 0 && c.portAlloc != nil {
|
||||
c.portAlloc.Release(c.port)
|
||||
@@ -635,6 +530,12 @@ func (c *Connection) Close() {
|
||||
|
||||
if c.subdomain != "" {
|
||||
c.manager.Unregister(c.subdomain)
|
||||
|
||||
// Clean up connection group when PRIMARY connection closes
|
||||
// (only primary connections have subdomain set)
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
c.groupManager.RemoveGroup(c.tunnelID)
|
||||
}
|
||||
}
|
||||
|
||||
c.logger.Info("Connection closed",
|
||||
@@ -643,11 +544,6 @@ func (c *Connection) Close() {
|
||||
})
|
||||
}
|
||||
|
||||
// GetSubdomain returns the assigned subdomain
|
||||
func (c *Connection) GetSubdomain() string {
|
||||
return c.subdomain
|
||||
}
|
||||
|
||||
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
|
||||
type httpResponseWriter struct {
|
||||
conn net.Conn
|
||||
@@ -698,39 +594,196 @@ func (w *httpResponseWriter) Write(data []byte) (int, error) {
|
||||
return w.writer.Write(data)
|
||||
}
|
||||
|
||||
func (c *Connection) handleFlowControl(frame *protocol.Frame) {
|
||||
msg, err := protocol.DecodeFlowControlMessage(frame.Payload)
|
||||
// handleDataConnect handles a data connection join request
|
||||
func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Reader) error {
|
||||
var req protocol.DataConnectRequest
|
||||
if err := json.Unmarshal(frame.Payload, &req); err != nil {
|
||||
c.sendError("invalid_request", "Failed to parse data connect request")
|
||||
return fmt.Errorf("failed to parse data connect request: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("Data connection request received",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Validate the request
|
||||
if c.groupManager == nil {
|
||||
c.sendDataConnectError("not_supported", "Multi-connection not supported")
|
||||
return fmt.Errorf("group manager not available")
|
||||
}
|
||||
|
||||
// Validate auth token
|
||||
if c.authToken != "" && req.Token != c.authToken {
|
||||
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed for data connection")
|
||||
}
|
||||
|
||||
group, ok := c.groupManager.GetGroup(req.TunnelID)
|
||||
if !ok || group == nil {
|
||||
c.sendDataConnectError("join_failed", "Tunnel not found")
|
||||
return fmt.Errorf("tunnel not found: %s", req.TunnelID)
|
||||
}
|
||||
|
||||
// Validate token against the primary registration token.
|
||||
if group.Token != "" && req.Token != group.Token {
|
||||
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed for data connection")
|
||||
}
|
||||
|
||||
// Store tunnelID for cleanup
|
||||
c.tunnelID = req.TunnelID
|
||||
|
||||
// For TCP tunnels, the data connection is upgraded to a yamux session and used for
|
||||
// stream forwarding, not framed request/response routing.
|
||||
if group.TunnelType == protocol.TunnelTypeTCP {
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: true,
|
||||
ConnectionID: req.ConnectionID,
|
||||
Message: "Data connection accepted",
|
||||
}
|
||||
|
||||
respData, _ := json.Marshal(resp)
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
|
||||
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
|
||||
return fmt.Errorf("failed to send data connect ack: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("TCP data connection established",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Clear deadline for yamux data-plane.
|
||||
_ = c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
group.AddSession(req.ConnectionID, session)
|
||||
defer group.RemoveSession(req.ConnectionID)
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Add data connection to group
|
||||
dataConn, err := c.groupManager.AddDataConnection(&req, c.conn)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode flow control", zap.Error(err))
|
||||
return
|
||||
c.sendDataConnectError("join_failed", err.Error())
|
||||
return fmt.Errorf("failed to join connection group: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
// Send success response
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: true,
|
||||
ConnectionID: req.ConnectionID,
|
||||
Message: "Data connection accepted",
|
||||
}
|
||||
|
||||
switch msg.Action {
|
||||
case protocol.FlowControlPause:
|
||||
c.paused = true
|
||||
c.logger.Warn("Client requested pause",
|
||||
zap.String("stream", msg.StreamID))
|
||||
respData, _ := json.Marshal(resp)
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
|
||||
case protocol.FlowControlResume:
|
||||
c.paused = false
|
||||
c.pauseCond.Broadcast()
|
||||
c.logger.Info("Client requested resume",
|
||||
zap.String("stream", msg.StreamID))
|
||||
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
|
||||
return fmt.Errorf("failed to send data connect ack: %w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
c.logger.Warn("Unknown flow control action",
|
||||
zap.String("action", string(msg.Action)))
|
||||
c.logger.Info("Data connection established",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Handle data frames on this connection
|
||||
return c.handleDataConnectionFrames(dataConn, reader)
|
||||
}
|
||||
|
||||
// handleDataConnectionFrames handles frames on a data connection
|
||||
func (c *Connection) handleDataConnectionFrames(dataConn *DataConnection, reader *bufio.Reader) error {
|
||||
defer func() {
|
||||
// Get the group and remove this data connection
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok {
|
||||
group.RemoveDataConnection(dataConn.ID)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-dataConn.stopCh:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(reader)
|
||||
if err != nil {
|
||||
// Timeout is OK, continue
|
||||
if isTimeoutError(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
dataConn.mu.Lock()
|
||||
dataConn.LastActive = time.Now()
|
||||
dataConn.mu.Unlock()
|
||||
|
||||
sf := protocol.WithFrame(frame)
|
||||
|
||||
switch sf.Frame.Type {
|
||||
case protocol.FrameTypeClose:
|
||||
sf.Close()
|
||||
c.logger.Info("Data connection closed by client",
|
||||
zap.String("connection_id", dataConn.ID))
|
||||
return nil
|
||||
|
||||
default:
|
||||
sf.Close()
|
||||
c.logger.Warn("Unexpected frame type on data connection",
|
||||
zap.String("type", sf.Frame.Type.String()),
|
||||
zap.String("connection_id", dataConn.ID),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) sendWithBackpressure(frame *protocol.Frame) error {
|
||||
c.mu.Lock()
|
||||
for c.paused {
|
||||
c.pauseCond.Wait()
|
||||
func isTimeoutError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
// Fallback for wrapped errors without net.Error
|
||||
return strings.Contains(err.Error(), "i/o timeout")
|
||||
}
|
||||
|
||||
// sendDataConnectError sends a data connect error response
|
||||
func (c *Connection) sendDataConnectError(code, message string) {
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: false,
|
||||
Message: fmt.Sprintf("%s: %s", code, message),
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
|
||||
438
internal/server/tcp/connection_group.go
Normal file
438
internal/server/tcp/connection_group.go
Normal file
@@ -0,0 +1,438 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
|
||||
type DataConnection struct {
|
||||
ID string
|
||||
Conn net.Conn
|
||||
LastActive time.Time
|
||||
closed bool
|
||||
closedMu sync.RWMutex
|
||||
stopCh chan struct{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type ConnectionGroup struct {
|
||||
TunnelID string
|
||||
Subdomain string
|
||||
Token string
|
||||
PrimaryConn *Connection
|
||||
DataConns map[string]*DataConnection
|
||||
Sessions map[string]*yamux.Session
|
||||
TunnelType protocol.TunnelType
|
||||
RegisteredAt time.Time
|
||||
LastActivity time.Time
|
||||
sessionIdx uint32
|
||||
mu sync.RWMutex
|
||||
stopCh chan struct{}
|
||||
logger *zap.Logger
|
||||
|
||||
heartbeatStarted bool
|
||||
}
|
||||
|
||||
func NewConnectionGroup(tunnelID, subdomain, token string, primaryConn *Connection, tunnelType protocol.TunnelType, logger *zap.Logger) *ConnectionGroup {
|
||||
return &ConnectionGroup{
|
||||
TunnelID: tunnelID,
|
||||
Subdomain: subdomain,
|
||||
Token: token,
|
||||
PrimaryConn: primaryConn,
|
||||
DataConns: make(map[string]*DataConnection),
|
||||
Sessions: make(map[string]*yamux.Session),
|
||||
TunnelType: tunnelType,
|
||||
RegisteredAt: time.Now(),
|
||||
LastActivity: time.Now(),
|
||||
stopCh: make(chan struct{}),
|
||||
logger: logger.With(zap.String("tunnel_id", tunnelID)),
|
||||
}
|
||||
}
|
||||
|
||||
// StartHeartbeat starts a goroutine that periodically pings all sessions
|
||||
// and removes dead ones. The caller should ensure this is only called once.
|
||||
func (g *ConnectionGroup) StartHeartbeat(interval, timeout time.Duration) {
|
||||
go g.heartbeatLoop(interval, timeout)
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
const maxConsecutiveFailures = 3
|
||||
failureCount := make(map[string]int)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
g.mu.RLock()
|
||||
sessions := make(map[string]*yamux.Session, len(g.Sessions))
|
||||
for id, s := range g.Sessions {
|
||||
sessions[id] = s
|
||||
}
|
||||
g.mu.RUnlock()
|
||||
|
||||
for id, session := range sessions {
|
||||
if session == nil || session.IsClosed() {
|
||||
g.RemoveSession(id)
|
||||
delete(failureCount, id)
|
||||
continue
|
||||
}
|
||||
|
||||
// Ping with timeout
|
||||
done := make(chan error, 1)
|
||||
go func(s *yamux.Session) {
|
||||
_, err := s.Ping()
|
||||
done <- err
|
||||
}(session)
|
||||
|
||||
var err error
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-time.After(timeout):
|
||||
err = fmt.Errorf("ping timeout")
|
||||
case <-g.stopCh:
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
failureCount[id]++
|
||||
g.logger.Debug("Session ping failed",
|
||||
zap.String("session_id", id),
|
||||
zap.Int("consecutive_failures", failureCount[id]),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
if failureCount[id] >= maxConsecutiveFailures {
|
||||
g.logger.Warn("Session ping failed too many times, removing",
|
||||
zap.String("session_id", id),
|
||||
zap.Int("failures", failureCount[id]),
|
||||
)
|
||||
g.RemoveSession(id)
|
||||
delete(failureCount, id)
|
||||
}
|
||||
} else {
|
||||
// Reset on success
|
||||
failureCount[id] = 0
|
||||
g.mu.Lock()
|
||||
g.LastActivity = time.Now()
|
||||
g.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all sessions are gone
|
||||
g.mu.RLock()
|
||||
sessionCount := len(g.Sessions)
|
||||
g.mu.RUnlock()
|
||||
|
||||
if sessionCount == 0 {
|
||||
g.logger.Info("All sessions closed, tunnel will be cleaned up")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) AddDataConnection(connID string, conn net.Conn) *DataConnection {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
dataConn := &DataConnection{
|
||||
ID: connID,
|
||||
Conn: conn,
|
||||
LastActive: time.Now(),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
g.DataConns[connID] = dataConn
|
||||
g.LastActivity = time.Now()
|
||||
return dataConn
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) RemoveDataConnection(connID string) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
if dataConn, ok := g.DataConns[connID]; ok {
|
||||
dataConn.closedMu.Lock()
|
||||
if !dataConn.closed {
|
||||
dataConn.closed = true
|
||||
close(dataConn.stopCh)
|
||||
if dataConn.Conn != nil {
|
||||
_ = dataConn.Conn.SetDeadline(time.Now())
|
||||
dataConn.Conn.Close()
|
||||
}
|
||||
}
|
||||
dataConn.closedMu.Unlock()
|
||||
delete(g.DataConns, connID)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) DataConnectionCount() int {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
return len(g.DataConns)
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) Close() {
|
||||
g.mu.Lock()
|
||||
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
g.mu.Unlock()
|
||||
return
|
||||
default:
|
||||
close(g.stopCh)
|
||||
}
|
||||
|
||||
dataConns := make([]*DataConnection, 0, len(g.DataConns))
|
||||
for _, dataConn := range g.DataConns {
|
||||
dataConns = append(dataConns, dataConn)
|
||||
}
|
||||
g.DataConns = make(map[string]*DataConnection)
|
||||
|
||||
sessions := make([]*yamux.Session, 0, len(g.Sessions))
|
||||
for _, session := range g.Sessions {
|
||||
if session != nil {
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
}
|
||||
g.Sessions = make(map[string]*yamux.Session)
|
||||
|
||||
g.mu.Unlock()
|
||||
|
||||
for _, dataConn := range dataConns {
|
||||
dataConn.closedMu.Lock()
|
||||
if !dataConn.closed {
|
||||
dataConn.closed = true
|
||||
close(dataConn.stopCh)
|
||||
if dataConn.Conn != nil {
|
||||
_ = dataConn.Conn.SetDeadline(time.Now())
|
||||
_ = dataConn.Conn.Close()
|
||||
}
|
||||
}
|
||||
dataConn.closedMu.Unlock()
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
_ = session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) IsStale(timeout time.Duration) bool {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
return time.Since(g.LastActivity) > timeout
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) AddSession(connID string, session *yamux.Session) {
|
||||
if connID == "" || session == nil {
|
||||
return
|
||||
}
|
||||
|
||||
g.mu.Lock()
|
||||
if g.Sessions == nil {
|
||||
g.Sessions = make(map[string]*yamux.Session)
|
||||
}
|
||||
g.Sessions[connID] = session
|
||||
g.LastActivity = time.Now()
|
||||
|
||||
// Start heartbeat on first session
|
||||
shouldStartHeartbeat := !g.heartbeatStarted
|
||||
if shouldStartHeartbeat {
|
||||
g.heartbeatStarted = true
|
||||
}
|
||||
g.mu.Unlock()
|
||||
|
||||
if shouldStartHeartbeat {
|
||||
g.StartHeartbeat(constants.HeartbeatInterval, constants.HeartbeatTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) RemoveSession(connID string) {
|
||||
if connID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var session *yamux.Session
|
||||
|
||||
g.mu.Lock()
|
||||
if g.Sessions != nil {
|
||||
session = g.Sessions[connID]
|
||||
delete(g.Sessions, connID)
|
||||
}
|
||||
g.mu.Unlock()
|
||||
|
||||
if session != nil {
|
||||
_ = session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) SessionCount() int {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
return len(g.Sessions)
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
|
||||
const (
|
||||
maxStreamsPerSession = 256
|
||||
maxRetries = 3
|
||||
backoffBase = 25 * time.Millisecond
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
return nil, net.ErrClosed
|
||||
default:
|
||||
}
|
||||
|
||||
sessions := g.sessionsSnapshot()
|
||||
if len(sessions) == 0 {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
|
||||
tried := make([]bool, len(sessions))
|
||||
anyUnderCap := false
|
||||
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
|
||||
|
||||
for range sessions {
|
||||
bestIdx := -1
|
||||
minStreams := int(^uint(0) >> 1)
|
||||
|
||||
for i := 0; i < len(sessions); i++ {
|
||||
idx := (start + i) % len(sessions)
|
||||
if tried[idx] {
|
||||
continue
|
||||
}
|
||||
|
||||
session := sessions[idx]
|
||||
if session == nil || session.IsClosed() {
|
||||
tried[idx] = true
|
||||
continue
|
||||
}
|
||||
|
||||
n := session.NumStreams()
|
||||
if n >= maxStreamsPerSession {
|
||||
continue
|
||||
}
|
||||
anyUnderCap = true
|
||||
|
||||
if n < minStreams {
|
||||
minStreams = n
|
||||
bestIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
if bestIdx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
tried[bestIdx] = true
|
||||
session := sessions[bestIdx]
|
||||
if session == nil || session.IsClosed() {
|
||||
continue
|
||||
}
|
||||
|
||||
stream, err := session.Open()
|
||||
if err == nil {
|
||||
return stream, nil
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
if session.IsClosed() {
|
||||
g.deleteClosedSessions()
|
||||
}
|
||||
}
|
||||
|
||||
if !anyUnderCap {
|
||||
lastErr = fmt.Errorf("all sessions are at stream capacity (%d)", maxStreamsPerSession)
|
||||
}
|
||||
|
||||
if attempt < maxRetries-1 {
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
return nil, net.ErrClosed
|
||||
case <-time.After(backoffBase * time.Duration(attempt+1)):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr == nil {
|
||||
lastErr = fmt.Errorf("failed to open stream")
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) selectSession() *yamux.Session {
|
||||
sessions := g.sessionsSnapshot()
|
||||
if len(sessions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
|
||||
minStreams := int(^uint(0) >> 1)
|
||||
var best *yamux.Session
|
||||
|
||||
for i := 0; i < len(sessions); i++ {
|
||||
session := sessions[(start+i)%len(sessions)]
|
||||
if session == nil || session.IsClosed() {
|
||||
continue
|
||||
}
|
||||
if n := session.NumStreams(); n < minStreams {
|
||||
minStreams = n
|
||||
best = session
|
||||
}
|
||||
}
|
||||
|
||||
return best
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
if len(g.Sessions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sessions := make([]*yamux.Session, 0, len(g.Sessions))
|
||||
for id, session := range g.Sessions {
|
||||
if session == nil || session.IsClosed() {
|
||||
delete(g.Sessions, id)
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
|
||||
if len(sessions) > 0 {
|
||||
g.LastActivity = time.Now()
|
||||
}
|
||||
|
||||
return sessions
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) deleteClosedSessions() {
|
||||
g.mu.Lock()
|
||||
for id, session := range g.Sessions {
|
||||
if session == nil || session.IsClosed() {
|
||||
delete(g.Sessions, id)
|
||||
}
|
||||
}
|
||||
g.mu.Unlock()
|
||||
}
|
||||
163
internal/server/tcp/connection_group_manager.go
Normal file
163
internal/server/tcp/connection_group_manager.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ConnectionGroupManager manages all connection groups
|
||||
type ConnectionGroupManager struct {
|
||||
groups map[string]*ConnectionGroup // TunnelID -> ConnectionGroup
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
|
||||
// Cleanup
|
||||
cleanupInterval time.Duration
|
||||
staleTimeout time.Duration
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewConnectionGroupManager creates a new connection group manager
|
||||
func NewConnectionGroupManager(logger *zap.Logger) *ConnectionGroupManager {
|
||||
m := &ConnectionGroupManager{
|
||||
groups: make(map[string]*ConnectionGroup),
|
||||
logger: logger,
|
||||
cleanupInterval: 60 * time.Second,
|
||||
staleTimeout: 5 * time.Minute,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
go m.cleanupLoop()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// GenerateTunnelID generates a unique tunnel ID
|
||||
func GenerateTunnelID() string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// CreateGroup creates a new connection group
|
||||
func (m *ConnectionGroupManager) CreateGroup(subdomain, token string, primaryConn *Connection, tunnelType protocol.TunnelType) *ConnectionGroup {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
tunnelID := GenerateTunnelID()
|
||||
|
||||
group := NewConnectionGroup(tunnelID, subdomain, token, primaryConn, tunnelType, m.logger)
|
||||
|
||||
m.groups[tunnelID] = group
|
||||
|
||||
return group
|
||||
}
|
||||
|
||||
// GetGroup returns a connection group by tunnel ID
|
||||
func (m *ConnectionGroupManager) GetGroup(tunnelID string) (*ConnectionGroup, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
group, ok := m.groups[tunnelID]
|
||||
return group, ok
|
||||
}
|
||||
|
||||
// RemoveGroup removes and closes a connection group
|
||||
func (m *ConnectionGroupManager) RemoveGroup(tunnelID string) {
|
||||
m.mu.Lock()
|
||||
group, ok := m.groups[tunnelID]
|
||||
if ok {
|
||||
delete(m.groups, tunnelID)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if ok && group != nil {
|
||||
group.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// AddDataConnection adds a data connection to a group
|
||||
func (m *ConnectionGroupManager) AddDataConnection(req *protocol.DataConnectRequest, conn net.Conn) (*DataConnection, error) {
|
||||
m.mu.RLock()
|
||||
group, ok := m.groups[req.TunnelID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tunnel not found: %s", req.TunnelID)
|
||||
}
|
||||
|
||||
// Validate token
|
||||
if group.Token != "" && req.Token != group.Token {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
dataConn := group.AddDataConnection(req.ConnectionID, conn)
|
||||
|
||||
return dataConn, nil
|
||||
}
|
||||
|
||||
// cleanupLoop periodically cleans up stale groups
|
||||
func (m *ConnectionGroupManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(m.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.cleanupStaleGroups()
|
||||
case <-m.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ConnectionGroupManager) cleanupStaleGroups() {
|
||||
// Collect stale groups under lock
|
||||
m.mu.Lock()
|
||||
var staleGroups []*ConnectionGroup
|
||||
var staleIDs []string
|
||||
for tunnelID, group := range m.groups {
|
||||
if group.IsStale(m.staleTimeout) {
|
||||
staleIDs = append(staleIDs, tunnelID)
|
||||
staleGroups = append(staleGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from map while holding lock
|
||||
for _, tunnelID := range staleIDs {
|
||||
delete(m.groups, tunnelID)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Close groups without holding lock to avoid blocking other operations
|
||||
for _, group := range staleGroups {
|
||||
group.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the manager
|
||||
func (m *ConnectionGroupManager) Close() {
|
||||
close(m.stopCh)
|
||||
|
||||
// Collect all groups under lock
|
||||
m.mu.Lock()
|
||||
groups := make([]*ConnectionGroup, 0, len(m.groups))
|
||||
for _, group := range m.groups {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
m.groups = make(map[string]*ConnectionGroup)
|
||||
m.mu.Unlock()
|
||||
|
||||
// Close groups without holding lock
|
||||
for _, group := range groups {
|
||||
group.Close()
|
||||
}
|
||||
}
|
||||
@@ -12,32 +12,34 @@ import (
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/recovery"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Listener handles TCP connections with TLS 1.3
|
||||
type Listener struct {
|
||||
address string
|
||||
tlsConfig *tls.Config
|
||||
authToken string
|
||||
manager *tunnel.Manager
|
||||
portAlloc *PortAllocator
|
||||
logger *zap.Logger
|
||||
domain string
|
||||
publicPort int
|
||||
httpHandler http.Handler
|
||||
responseChans HTTPResponseHandler
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
connections map[string]*Connection
|
||||
connMu sync.RWMutex
|
||||
workerPool *pool.WorkerPool // Worker pool for connection handling
|
||||
recoverer *recovery.Recoverer
|
||||
address string
|
||||
tlsConfig *tls.Config
|
||||
authToken string
|
||||
manager *tunnel.Manager
|
||||
portAlloc *PortAllocator
|
||||
logger *zap.Logger
|
||||
domain string
|
||||
publicPort int
|
||||
httpHandler http.Handler
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
connections map[string]*Connection
|
||||
connMu sync.RWMutex
|
||||
workerPool *pool.WorkerPool // Worker pool for connection handling
|
||||
recoverer *recovery.Recoverer
|
||||
panicMetrics *recovery.PanicMetrics
|
||||
|
||||
groupManager *ConnectionGroupManager
|
||||
}
|
||||
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Listener {
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler) *Listener {
|
||||
numCPU := pool.NumCPU()
|
||||
workers := numCPU * 5
|
||||
queueSize := workers * 20
|
||||
@@ -53,21 +55,21 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
recoverer := recovery.NewRecoverer(logger, panicMetrics)
|
||||
|
||||
return &Listener{
|
||||
address: address,
|
||||
tlsConfig: tlsConfig,
|
||||
authToken: authToken,
|
||||
manager: manager,
|
||||
portAlloc: portAlloc,
|
||||
logger: logger,
|
||||
domain: domain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
responseChans: responseChans,
|
||||
stopCh: make(chan struct{}),
|
||||
connections: make(map[string]*Connection),
|
||||
workerPool: workerPool,
|
||||
recoverer: recoverer,
|
||||
panicMetrics: panicMetrics,
|
||||
address: address,
|
||||
tlsConfig: tlsConfig,
|
||||
authToken: authToken,
|
||||
manager: manager,
|
||||
portAlloc: portAlloc,
|
||||
logger: logger,
|
||||
domain: domain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
stopCh: make(chan struct{}),
|
||||
connections: make(map[string]*Connection),
|
||||
workerPool: workerPool,
|
||||
recoverer: recoverer,
|
||||
panicMetrics: panicMetrics,
|
||||
groupManager: NewConnectionGroupManager(logger),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,7 +208,7 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.responseChans)
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.groupManager)
|
||||
|
||||
connID := netConn.RemoteAddr().String()
|
||||
l.connMu.Lock()
|
||||
@@ -222,14 +224,11 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
if err := conn.Handle(); err != nil {
|
||||
errStr := err.Error()
|
||||
|
||||
// Client disconnection errors - normal network behavior, log as DEBUG
|
||||
if strings.Contains(errStr, "connection reset by peer") ||
|
||||
// Client disconnection errors - normal network behavior, ignore
|
||||
if strings.Contains(errStr, "EOF") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
l.logger.Debug("Client disconnected",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -277,6 +276,10 @@ func (l *Listener) Stop() error {
|
||||
l.workerPool.Close()
|
||||
}
|
||||
|
||||
if l.groupManager != nil {
|
||||
l.groupManager.Close()
|
||||
}
|
||||
|
||||
l.logger.Info("TCP listener stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,64 +1,79 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TunnelProxy handles TCP connections for a specific tunnel
|
||||
type TunnelProxy struct {
|
||||
port int
|
||||
subdomain string
|
||||
tcpConn net.Conn // The tunnel control connection
|
||||
listener net.Listener
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
clientAddr string
|
||||
streams map[string]*proxyStream // streamID -> stream info
|
||||
streamMu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
bufferPool *pool.BufferPool
|
||||
// Proxy exposes a public TCP port and forwards each incoming
|
||||
// connection over a dedicated mux stream.
|
||||
type Proxy struct {
|
||||
port int
|
||||
subdomain string
|
||||
logger *zap.Logger
|
||||
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
once sync.Once
|
||||
wg sync.WaitGroup
|
||||
|
||||
openStream func() (net.Conn, error)
|
||||
stats trafficStats
|
||||
sem chan struct{}
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// proxyStream holds connection info with close state
|
||||
type proxyStream struct {
|
||||
conn net.Conn
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
type trafficStats interface {
|
||||
AddBytesIn(n int64)
|
||||
AddBytesOut(n int64)
|
||||
IncActiveConnections()
|
||||
DecActiveConnections()
|
||||
}
|
||||
|
||||
// NewTunnelProxy creates a new TCP tunnel proxy
|
||||
func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Logger) *TunnelProxy {
|
||||
return &TunnelProxy{
|
||||
port: port,
|
||||
subdomain: subdomain,
|
||||
tcpConn: tcpConn,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
clientAddr: tcpConn.RemoteAddr().String(),
|
||||
streams: make(map[string]*proxyStream),
|
||||
bufferPool: pool.NewBufferPool(),
|
||||
frameWriter: protocol.NewFrameWriter(tcpConn),
|
||||
func NewProxy(ctx context.Context, port int, subdomain string, openStream func() (net.Conn, error), stats trafficStats, logger *zap.Logger) *Proxy {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
cctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
const maxConcurrentConnections = 10000
|
||||
var sem chan struct{}
|
||||
if maxConcurrentConnections > 0 {
|
||||
sem = make(chan struct{}, maxConcurrentConnections)
|
||||
}
|
||||
|
||||
return &Proxy{
|
||||
port: port,
|
||||
subdomain: subdomain,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
openStream: openStream,
|
||||
stats: stats,
|
||||
sem: sem,
|
||||
ctx: cctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts listening on the allocated port
|
||||
func (p *TunnelProxy) Start() error {
|
||||
func (p *Proxy) Start() error {
|
||||
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
|
||||
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
|
||||
}
|
||||
|
||||
p.listener = listener
|
||||
p.listener = ln
|
||||
|
||||
p.logger.Info("TCP proxy started",
|
||||
zap.Int("port", p.port),
|
||||
@@ -67,14 +82,47 @@ func (p *TunnelProxy) Start() error {
|
||||
|
||||
p.wg.Add(1)
|
||||
go p.acceptLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop accepts incoming TCP connections
|
||||
func (p *TunnelProxy) acceptLoop() {
|
||||
func (p *Proxy) Stop() {
|
||||
p.once.Do(func() {
|
||||
close(p.stopCh)
|
||||
p.cancel()
|
||||
|
||||
if p.listener != nil {
|
||||
_ = p.listener.Close()
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
p.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
const stopTimeout = 30 * time.Second
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
p.logger.Info("TCP proxy stopped",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
)
|
||||
case <-time.After(stopTimeout):
|
||||
p.logger.Warn("TCP proxy stop timed out",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
zap.Duration("timeout", stopTimeout),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Proxy) acceptLoop() {
|
||||
defer p.wg.Done()
|
||||
|
||||
tcpLn, _ := p.listener.(*net.TCPListener)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
@@ -82,11 +130,13 @@ func (p *TunnelProxy) acceptLoop() {
|
||||
default:
|
||||
}
|
||||
|
||||
p.listener.(*net.TCPListener).SetDeadline(time.Now().Add(1 * time.Second))
|
||||
if tcpLn != nil {
|
||||
_ = tcpLn.SetDeadline(time.Now().Add(1 * time.Second))
|
||||
}
|
||||
|
||||
conn, err := p.listener.Accept()
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
@@ -98,187 +148,86 @@ func (p *TunnelProxy) acceptLoop() {
|
||||
}
|
||||
|
||||
p.wg.Add(1)
|
||||
go p.handleConnection(conn)
|
||||
go p.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) handleConnection(conn net.Conn) {
|
||||
func (p *Proxy) handleConn(conn net.Conn) {
|
||||
defer p.wg.Done()
|
||||
defer conn.Close()
|
||||
|
||||
streamID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), p.port)
|
||||
|
||||
stream := &proxyStream{
|
||||
conn: conn,
|
||||
closed: false,
|
||||
if p.sem != nil {
|
||||
select {
|
||||
case p.sem <- struct{}{}:
|
||||
defer func() { <-p.sem }()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
p.streamMu.Lock()
|
||||
p.streams[streamID] = stream
|
||||
p.streamMu.Unlock()
|
||||
if p.stats != nil {
|
||||
p.stats.IncActiveConnections()
|
||||
defer p.stats.DecActiveConnections()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
p.streamMu.Lock()
|
||||
delete(p.streams, streamID)
|
||||
p.streamMu.Unlock()
|
||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||
_ = tcpConn.SetNoDelay(true)
|
||||
_ = tcpConn.SetKeepAlive(true)
|
||||
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
_ = tcpConn.SetReadBuffer(256 * 1024)
|
||||
_ = tcpConn.SetWriteBuffer(256 * 1024)
|
||||
}
|
||||
|
||||
if p.openStream == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Open stream with timeout to prevent goroutine leak
|
||||
const openStreamTimeout = 10 * time.Second
|
||||
type streamResult struct {
|
||||
stream net.Conn
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan streamResult, 1)
|
||||
|
||||
go func() {
|
||||
s, err := p.openStream()
|
||||
resultCh <- streamResult{s, err}
|
||||
}()
|
||||
|
||||
bufPtr := p.bufferPool.Get(pool.SizeMedium)
|
||||
defer p.bufferPool.Put(bufPtr)
|
||||
|
||||
buffer := (*bufPtr)[:pool.SizeMedium]
|
||||
|
||||
for {
|
||||
// Check if stream is closed
|
||||
stream.mu.Lock()
|
||||
closed := stream.closed
|
||||
stream.mu.Unlock()
|
||||
if closed {
|
||||
break
|
||||
}
|
||||
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
if err := p.sendDataToTunnel(streamID, buffer[:n]); err != nil {
|
||||
p.logger.Debug("Send to tunnel failed", zap.Error(err))
|
||||
break
|
||||
var stream net.Conn
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
if result.err != nil {
|
||||
if !errors.Is(result.err, net.ErrClosed) {
|
||||
p.logger.Debug("Open stream failed", zap.Error(result.err))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
stream = result.stream
|
||||
case <-time.After(openStreamTimeout):
|
||||
p.logger.Debug("Open stream timeout")
|
||||
return
|
||||
case <-p.stopCh:
|
||||
default:
|
||||
p.sendCloseToTunnel(streamID)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) sendDataToTunnel(streamID string, data []byte) error {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return fmt.Errorf("tunnel proxy stopped")
|
||||
default:
|
||||
}
|
||||
|
||||
header := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: streamID,
|
||||
Type: protocol.DataTypeData,
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode payload: %w", err)
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
|
||||
err = p.frameWriter.WriteFrame(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write frame: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) sendCloseToTunnel(streamID string) {
|
||||
header := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: streamID,
|
||||
Type: protocol.DataTypeClose,
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
p.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
|
||||
p.streamMu.RLock()
|
||||
stream, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
// Stream may have been closed by client, this is normal
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if stream is closed
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
stream.mu.Unlock()
|
||||
|
||||
if _, err := stream.conn.Write(data); err != nil {
|
||||
p.logger.Debug("Write to client failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseStream closes a stream
|
||||
func (p *TunnelProxy) CloseStream(streamID string) {
|
||||
p.streamMu.RLock()
|
||||
stream, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as closed first
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
return
|
||||
}
|
||||
stream.closed = true
|
||||
stream.mu.Unlock()
|
||||
|
||||
// Now close the connection
|
||||
stream.conn.Close()
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) Stop() {
|
||||
p.logger.Info("Stopping TCP proxy",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
_ = netutil.PipeWithCallbacksAndBufferSize(
|
||||
p.ctx,
|
||||
conn,
|
||||
stream,
|
||||
pool.SizeLarge,
|
||||
func(n int64) {
|
||||
if p.stats != nil {
|
||||
p.stats.AddBytesIn(n)
|
||||
}
|
||||
},
|
||||
func(n int64) {
|
||||
if p.stats != nil {
|
||||
p.stats.AddBytesOut(n)
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
close(p.stopCh)
|
||||
|
||||
if p.listener != nil {
|
||||
p.listener.Close()
|
||||
}
|
||||
|
||||
p.streamMu.Lock()
|
||||
for _, stream := range p.streams {
|
||||
stream.mu.Lock()
|
||||
stream.closed = true
|
||||
stream.mu.Unlock()
|
||||
stream.conn.Close()
|
||||
}
|
||||
p.streams = make(map[string]*proxyStream)
|
||||
p.streamMu.Unlock()
|
||||
|
||||
p.wg.Wait()
|
||||
|
||||
if p.frameWriter != nil {
|
||||
p.frameWriter.Close()
|
||||
}
|
||||
|
||||
p.logger.Info("TCP proxy stopped", zap.Int("port", p.port))
|
||||
}
|
||||
|
||||
98
internal/server/tcp/tunnel.go
Normal file
98
internal/server/tcp/tunnel.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
)
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
openStream := session.Open
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
group.AddSession("primary", session)
|
||||
openStream = group.OpenStream
|
||||
}
|
||||
}
|
||||
|
||||
c.proxy = NewProxy(c.ctx, c.port, c.subdomain, openStream, c.tunnelConn, c.logger)
|
||||
if err := c.proxy.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start tcp proxy: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
openStream := session.Open
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
group.AddSession("primary", session)
|
||||
openStream = group.OpenStream
|
||||
}
|
||||
}
|
||||
|
||||
if c.tunnelConn != nil {
|
||||
c.tunnelConn.SetOpenStream(openStream)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user