From 7431d821d8828a884b0062455aea02a8f7292a89 Mon Sep 17 00:00:00 2001 From: Gouryella Date: Mon, 15 Dec 2025 16:49:42 +0800 Subject: [PATCH] refactor(tcp): Optimizes the TCP tunnel data connection processing logic. The data connection processing logic was refactored, and a successful response was sent in advance before upgrading to a yamux session. Redundant DataConnection structures and related management methods were removed. Adjustments were also made. In ConnectionGroup, the session selection logic prioritizes using non-primary sessions for data transmission. Only fall back to the main session when no data session is available, in order to improve forwarding efficiency and stability. --- internal/server/tcp/connection.go | 129 ++++-------------- internal/server/tcp/connection_group.go | 90 ++---------- .../server/tcp/connection_group_manager.go | 22 --- internal/server/tcp/tunnel.go | 4 +- 4 files changed, 41 insertions(+), 204 deletions(-) diff --git a/internal/server/tcp/connection.go b/internal/server/tcp/connection.go index bffb941..99e0cab 100644 --- a/internal/server/tcp/connection.go +++ b/internal/server/tcp/connection.go @@ -634,66 +634,7 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read // Store tunnelID for cleanup c.tunnelID = req.TunnelID - // For TCP tunnels, the data connection is upgraded to a yamux session and used for - // stream forwarding, not framed request/response routing. - if group.TunnelType == protocol.TunnelTypeTCP { - resp := protocol.DataConnectResponse{ - Accepted: true, - ConnectionID: req.ConnectionID, - Message: "Data connection accepted", - } - - respData, _ := json.Marshal(resp) - ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData) - - if err := protocol.WriteFrame(c.conn, ackFrame); err != nil { - return fmt.Errorf("failed to send data connect ack: %w", err) - } - - c.logger.Info("TCP data connection established", - zap.String("tunnel_id", req.TunnelID), - zap.String("connection_id", req.ConnectionID), - ) - - // Clear deadline for yamux data-plane. - _ = c.conn.SetReadDeadline(time.Time{}) - - // Public server acts as yamux Client, client connector acts as yamux Server. - bc := &bufferedConn{ - Conn: c.conn, - reader: reader, - } - - cfg := yamux.DefaultConfig() - cfg.EnableKeepAlive = false - cfg.LogOutput = io.Discard - cfg.AcceptBacklog = constants.YamuxAcceptBacklog - - session, err := yamux.Client(bc, cfg) - if err != nil { - return fmt.Errorf("failed to init yamux session: %w", err) - } - c.session = session - - group.AddSession(req.ConnectionID, session) - defer group.RemoveSession(req.ConnectionID) - - select { - case <-c.stopCh: - return nil - case <-session.CloseChan(): - return nil - } - } - - // Add data connection to group - dataConn, err := c.groupManager.AddDataConnection(&req, c.conn) - if err != nil { - c.sendDataConnectError("join_failed", err.Error()) - return fmt.Errorf("failed to join connection group: %w", err) - } - - // Send success response + // Send success response before upgrading the connection to yamux. resp := protocol.DataConnectResponse{ Accepted: true, ConnectionID: req.ConnectionID, @@ -712,56 +653,34 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read zap.String("connection_id", req.ConnectionID), ) - // Handle data frames on this connection - return c.handleDataConnectionFrames(dataConn, reader) -} + // Clear deadline for yamux data-plane. + _ = c.conn.SetReadDeadline(time.Time{}) -// handleDataConnectionFrames handles frames on a data connection -func (c *Connection) handleDataConnectionFrames(dataConn *DataConnection, reader *bufio.Reader) error { - defer func() { - // Get the group and remove this data connection - if group, ok := c.groupManager.GetGroup(c.tunnelID); ok { - group.RemoveDataConnection(dataConn.ID) - } - }() + // Public server acts as yamux Client, client connector acts as yamux Server. + bc := &bufferedConn{ + Conn: c.conn, + reader: reader, + } - for { - select { - case <-dataConn.stopCh: - return nil - default: - } + cfg := yamux.DefaultConfig() + cfg.EnableKeepAlive = false + cfg.LogOutput = io.Discard + cfg.AcceptBacklog = constants.YamuxAcceptBacklog - c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout)) - frame, err := protocol.ReadFrame(reader) - if err != nil { - // Timeout is OK, continue - if isTimeoutError(err) { - continue - } - return err - } + session, err := yamux.Client(bc, cfg) + if err != nil { + return fmt.Errorf("failed to init yamux session: %w", err) + } + c.session = session - dataConn.mu.Lock() - dataConn.LastActive = time.Now() - dataConn.mu.Unlock() + group.AddSession(req.ConnectionID, session) + defer group.RemoveSession(req.ConnectionID) - sf := protocol.WithFrame(frame) - - switch sf.Frame.Type { - case protocol.FrameTypeClose: - sf.Close() - c.logger.Info("Data connection closed by client", - zap.String("connection_id", dataConn.ID)) - return nil - - default: - sf.Close() - c.logger.Warn("Unexpected frame type on data connection", - zap.String("type", sf.Frame.Type.String()), - zap.String("connection_id", dataConn.ID), - ) - } + select { + case <-c.stopCh: + return nil + case <-session.CloseChan(): + return nil } } diff --git a/internal/server/tcp/connection_group.go b/internal/server/tcp/connection_group.go index 3d4cda6..c8dbda0 100644 --- a/internal/server/tcp/connection_group.go +++ b/internal/server/tcp/connection_group.go @@ -15,23 +15,11 @@ import ( "go.uber.org/zap" ) - -type DataConnection struct { - ID string - Conn net.Conn - LastActive time.Time - closed bool - closedMu sync.RWMutex - stopCh chan struct{} - mu sync.RWMutex -} - type ConnectionGroup struct { TunnelID string Subdomain string Token string PrimaryConn *Connection - DataConns map[string]*DataConnection Sessions map[string]*yamux.Session TunnelType protocol.TunnelType RegisteredAt time.Time @@ -50,7 +38,6 @@ func NewConnectionGroup(tunnelID, subdomain, token string, primaryConn *Connecti Subdomain: subdomain, Token: token, PrimaryConn: primaryConn, - DataConns: make(map[string]*DataConnection), Sessions: make(map[string]*yamux.Session), TunnelType: tunnelType, RegisteredAt: time.Now(), @@ -146,46 +133,6 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) { } } -func (g *ConnectionGroup) AddDataConnection(connID string, conn net.Conn) *DataConnection { - g.mu.Lock() - defer g.mu.Unlock() - - dataConn := &DataConnection{ - ID: connID, - Conn: conn, - LastActive: time.Now(), - stopCh: make(chan struct{}), - } - g.DataConns[connID] = dataConn - g.LastActivity = time.Now() - return dataConn -} - -func (g *ConnectionGroup) RemoveDataConnection(connID string) { - g.mu.Lock() - defer g.mu.Unlock() - - if dataConn, ok := g.DataConns[connID]; ok { - dataConn.closedMu.Lock() - if !dataConn.closed { - dataConn.closed = true - close(dataConn.stopCh) - if dataConn.Conn != nil { - _ = dataConn.Conn.SetDeadline(time.Now()) - dataConn.Conn.Close() - } - } - dataConn.closedMu.Unlock() - delete(g.DataConns, connID) - } -} - -func (g *ConnectionGroup) DataConnectionCount() int { - g.mu.RLock() - defer g.mu.RUnlock() - return len(g.DataConns) -} - func (g *ConnectionGroup) Close() { g.mu.Lock() @@ -197,12 +144,6 @@ func (g *ConnectionGroup) Close() { close(g.stopCh) } - dataConns := make([]*DataConnection, 0, len(g.DataConns)) - for _, dataConn := range g.DataConns { - dataConns = append(dataConns, dataConn) - } - g.DataConns = make(map[string]*DataConnection) - sessions := make([]*yamux.Session, 0, len(g.Sessions)) for _, session := range g.Sessions { if session != nil { @@ -213,19 +154,6 @@ func (g *ConnectionGroup) Close() { g.mu.Unlock() - for _, dataConn := range dataConns { - dataConn.closedMu.Lock() - if !dataConn.closed { - dataConn.closed = true - close(dataConn.stopCh) - if dataConn.Conn != nil { - _ = dataConn.Conn.SetDeadline(time.Now()) - _ = dataConn.Conn.Close() - } - } - dataConn.closedMu.Unlock() - } - for _, session := range sessions { _ = session.Close() } @@ -302,7 +230,13 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) { default: } - sessions := g.sessionsSnapshot() + // Prefer data sessions for data-plane traffic; keep the primary session + // as control-plane (client ping/latency), and only fall back to primary + // when no data session exists. + sessions := g.sessionsSnapshot(false) + if len(sessions) == 0 { + sessions = g.sessionsSnapshot(true) + } if len(sessions) == 0 { return nil, net.ErrClosed } @@ -380,7 +314,10 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) { } func (g *ConnectionGroup) selectSession() *yamux.Session { - sessions := g.sessionsSnapshot() + sessions := g.sessionsSnapshot(false) + if len(sessions) == 0 { + sessions = g.sessionsSnapshot(true) + } if len(sessions) == 0 { return nil } @@ -403,7 +340,7 @@ func (g *ConnectionGroup) selectSession() *yamux.Session { return best } -func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session { +func (g *ConnectionGroup) sessionsSnapshot(includePrimary bool) []*yamux.Session { g.mu.Lock() defer g.mu.Unlock() @@ -417,6 +354,9 @@ func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session { delete(g.Sessions, id) continue } + if id == "primary" && !includePrimary { + continue + } sessions = append(sessions, session) } diff --git a/internal/server/tcp/connection_group_manager.go b/internal/server/tcp/connection_group_manager.go index 89312db..93aa4d8 100644 --- a/internal/server/tcp/connection_group_manager.go +++ b/internal/server/tcp/connection_group_manager.go @@ -3,8 +3,6 @@ package tcp import ( "crypto/rand" "encoding/hex" - "fmt" - "net" "sync" "time" @@ -84,26 +82,6 @@ func (m *ConnectionGroupManager) RemoveGroup(tunnelID string) { } } -// AddDataConnection adds a data connection to a group -func (m *ConnectionGroupManager) AddDataConnection(req *protocol.DataConnectRequest, conn net.Conn) (*DataConnection, error) { - m.mu.RLock() - group, ok := m.groups[req.TunnelID] - m.mu.RUnlock() - - if !ok { - return nil, fmt.Errorf("tunnel not found: %s", req.TunnelID) - } - - // Validate token - if group.Token != "" && req.Token != group.Token { - return nil, fmt.Errorf("invalid token") - } - - dataConn := group.AddDataConnection(req.ConnectionID, conn) - - return dataConn, nil -} - // cleanupLoop periodically cleans up stale groups func (m *ConnectionGroupManager) cleanupLoop() { ticker := time.NewTicker(m.cleanupInterval) diff --git a/internal/server/tcp/tunnel.go b/internal/server/tcp/tunnel.go index 4f9c266..6ec4682 100644 --- a/internal/server/tcp/tunnel.go +++ b/internal/server/tcp/tunnel.go @@ -39,7 +39,7 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error { c.session = session openStream := session.Open - if c.tunnelID != "" && c.groupManager != nil { + if c.groupManager != nil { if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil { group.AddSession("primary", session) openStream = group.OpenStream @@ -78,7 +78,7 @@ func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error { c.session = session openStream := session.Open - if c.tunnelID != "" && c.groupManager != nil { + if c.groupManager != nil { if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil { group.AddSession("primary", session) openStream = group.OpenStream