refactor(tcp): Optimizes the TCP tunnel data connection processing logic.

The data connection processing logic was refactored, and a successful response was sent in advance before upgrading to a yamux session.
Redundant DataConnection structures and related management methods were removed. Adjustments were also made.
In ConnectionGroup, the session selection logic prioritizes using non-primary sessions for data transmission.
Only fall back to the main session when no data session is available, in order to improve forwarding efficiency and stability.
This commit is contained in:
Gouryella
2025-12-15 16:49:42 +08:00
parent be4fe2059c
commit 7431d821d8
4 changed files with 41 additions and 204 deletions

View File

@@ -634,66 +634,7 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
// Store tunnelID for cleanup
c.tunnelID = req.TunnelID
// For TCP tunnels, the data connection is upgraded to a yamux session and used for
// stream forwarding, not framed request/response routing.
if group.TunnelType == protocol.TunnelTypeTCP {
resp := protocol.DataConnectResponse{
Accepted: true,
ConnectionID: req.ConnectionID,
Message: "Data connection accepted",
}
respData, _ := json.Marshal(resp)
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
return fmt.Errorf("failed to send data connect ack: %w", err)
}
c.logger.Info("TCP data connection established",
zap.String("tunnel_id", req.TunnelID),
zap.String("connection_id", req.ConnectionID),
)
// Clear deadline for yamux data-plane.
_ = c.conn.SetReadDeadline(time.Time{})
// Public server acts as yamux Client, client connector acts as yamux Server.
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Client(bc, cfg)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
group.AddSession(req.ConnectionID, session)
defer group.RemoveSession(req.ConnectionID)
select {
case <-c.stopCh:
return nil
case <-session.CloseChan():
return nil
}
}
// Add data connection to group
dataConn, err := c.groupManager.AddDataConnection(&req, c.conn)
if err != nil {
c.sendDataConnectError("join_failed", err.Error())
return fmt.Errorf("failed to join connection group: %w", err)
}
// Send success response
// Send success response before upgrading the connection to yamux.
resp := protocol.DataConnectResponse{
Accepted: true,
ConnectionID: req.ConnectionID,
@@ -712,56 +653,34 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
zap.String("connection_id", req.ConnectionID),
)
// Handle data frames on this connection
return c.handleDataConnectionFrames(dataConn, reader)
}
// Clear deadline for yamux data-plane.
_ = c.conn.SetReadDeadline(time.Time{})
// handleDataConnectionFrames handles frames on a data connection
func (c *Connection) handleDataConnectionFrames(dataConn *DataConnection, reader *bufio.Reader) error {
defer func() {
// Get the group and remove this data connection
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok {
group.RemoveDataConnection(dataConn.ID)
}
}()
// Public server acts as yamux Client, client connector acts as yamux Server.
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
for {
select {
case <-dataConn.stopCh:
return nil
default:
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(reader)
if err != nil {
// Timeout is OK, continue
if isTimeoutError(err) {
continue
}
return err
}
session, err := yamux.Client(bc, cfg)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
dataConn.mu.Lock()
dataConn.LastActive = time.Now()
dataConn.mu.Unlock()
group.AddSession(req.ConnectionID, session)
defer group.RemoveSession(req.ConnectionID)
sf := protocol.WithFrame(frame)
switch sf.Frame.Type {
case protocol.FrameTypeClose:
sf.Close()
c.logger.Info("Data connection closed by client",
zap.String("connection_id", dataConn.ID))
return nil
default:
sf.Close()
c.logger.Warn("Unexpected frame type on data connection",
zap.String("type", sf.Frame.Type.String()),
zap.String("connection_id", dataConn.ID),
)
}
select {
case <-c.stopCh:
return nil
case <-session.CloseChan():
return nil
}
}

View File

@@ -15,23 +15,11 @@ import (
"go.uber.org/zap"
)
type DataConnection struct {
ID string
Conn net.Conn
LastActive time.Time
closed bool
closedMu sync.RWMutex
stopCh chan struct{}
mu sync.RWMutex
}
type ConnectionGroup struct {
TunnelID string
Subdomain string
Token string
PrimaryConn *Connection
DataConns map[string]*DataConnection
Sessions map[string]*yamux.Session
TunnelType protocol.TunnelType
RegisteredAt time.Time
@@ -50,7 +38,6 @@ func NewConnectionGroup(tunnelID, subdomain, token string, primaryConn *Connecti
Subdomain: subdomain,
Token: token,
PrimaryConn: primaryConn,
DataConns: make(map[string]*DataConnection),
Sessions: make(map[string]*yamux.Session),
TunnelType: tunnelType,
RegisteredAt: time.Now(),
@@ -146,46 +133,6 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
}
}
func (g *ConnectionGroup) AddDataConnection(connID string, conn net.Conn) *DataConnection {
g.mu.Lock()
defer g.mu.Unlock()
dataConn := &DataConnection{
ID: connID,
Conn: conn,
LastActive: time.Now(),
stopCh: make(chan struct{}),
}
g.DataConns[connID] = dataConn
g.LastActivity = time.Now()
return dataConn
}
func (g *ConnectionGroup) RemoveDataConnection(connID string) {
g.mu.Lock()
defer g.mu.Unlock()
if dataConn, ok := g.DataConns[connID]; ok {
dataConn.closedMu.Lock()
if !dataConn.closed {
dataConn.closed = true
close(dataConn.stopCh)
if dataConn.Conn != nil {
_ = dataConn.Conn.SetDeadline(time.Now())
dataConn.Conn.Close()
}
}
dataConn.closedMu.Unlock()
delete(g.DataConns, connID)
}
}
func (g *ConnectionGroup) DataConnectionCount() int {
g.mu.RLock()
defer g.mu.RUnlock()
return len(g.DataConns)
}
func (g *ConnectionGroup) Close() {
g.mu.Lock()
@@ -197,12 +144,6 @@ func (g *ConnectionGroup) Close() {
close(g.stopCh)
}
dataConns := make([]*DataConnection, 0, len(g.DataConns))
for _, dataConn := range g.DataConns {
dataConns = append(dataConns, dataConn)
}
g.DataConns = make(map[string]*DataConnection)
sessions := make([]*yamux.Session, 0, len(g.Sessions))
for _, session := range g.Sessions {
if session != nil {
@@ -213,19 +154,6 @@ func (g *ConnectionGroup) Close() {
g.mu.Unlock()
for _, dataConn := range dataConns {
dataConn.closedMu.Lock()
if !dataConn.closed {
dataConn.closed = true
close(dataConn.stopCh)
if dataConn.Conn != nil {
_ = dataConn.Conn.SetDeadline(time.Now())
_ = dataConn.Conn.Close()
}
}
dataConn.closedMu.Unlock()
}
for _, session := range sessions {
_ = session.Close()
}
@@ -302,7 +230,13 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
default:
}
sessions := g.sessionsSnapshot()
// Prefer data sessions for data-plane traffic; keep the primary session
// as control-plane (client ping/latency), and only fall back to primary
// when no data session exists.
sessions := g.sessionsSnapshot(false)
if len(sessions) == 0 {
sessions = g.sessionsSnapshot(true)
}
if len(sessions) == 0 {
return nil, net.ErrClosed
}
@@ -380,7 +314,10 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
}
func (g *ConnectionGroup) selectSession() *yamux.Session {
sessions := g.sessionsSnapshot()
sessions := g.sessionsSnapshot(false)
if len(sessions) == 0 {
sessions = g.sessionsSnapshot(true)
}
if len(sessions) == 0 {
return nil
}
@@ -403,7 +340,7 @@ func (g *ConnectionGroup) selectSession() *yamux.Session {
return best
}
func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
func (g *ConnectionGroup) sessionsSnapshot(includePrimary bool) []*yamux.Session {
g.mu.Lock()
defer g.mu.Unlock()
@@ -417,6 +354,9 @@ func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
delete(g.Sessions, id)
continue
}
if id == "primary" && !includePrimary {
continue
}
sessions = append(sessions, session)
}

View File

@@ -3,8 +3,6 @@ package tcp
import (
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"sync"
"time"
@@ -84,26 +82,6 @@ func (m *ConnectionGroupManager) RemoveGroup(tunnelID string) {
}
}
// AddDataConnection adds a data connection to a group
func (m *ConnectionGroupManager) AddDataConnection(req *protocol.DataConnectRequest, conn net.Conn) (*DataConnection, error) {
m.mu.RLock()
group, ok := m.groups[req.TunnelID]
m.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("tunnel not found: %s", req.TunnelID)
}
// Validate token
if group.Token != "" && req.Token != group.Token {
return nil, fmt.Errorf("invalid token")
}
dataConn := group.AddDataConnection(req.ConnectionID, conn)
return dataConn, nil
}
// cleanupLoop periodically cleans up stale groups
func (m *ConnectionGroupManager) cleanupLoop() {
ticker := time.NewTicker(m.cleanupInterval)

View File

@@ -39,7 +39,7 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
c.session = session
openStream := session.Open
if c.tunnelID != "" && c.groupManager != nil {
if c.groupManager != nil {
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
group.AddSession("primary", session)
openStream = group.OpenStream
@@ -78,7 +78,7 @@ func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
c.session = session
openStream := session.Open
if c.tunnelID != "" && c.groupManager != nil {
if c.groupManager != nil {
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
group.AddSession("primary", session)
openStream = group.OpenStream