mirror of
https://github.com/Gouryella/drip.git
synced 2026-03-01 15:52:32 +00:00
refactor(tcp): Optimizes the TCP tunnel data connection processing logic.
The data connection processing logic was refactored, and a successful response was sent in advance before upgrading to a yamux session. Redundant DataConnection structures and related management methods were removed. Adjustments were also made. In ConnectionGroup, the session selection logic prioritizes using non-primary sessions for data transmission. Only fall back to the main session when no data session is available, in order to improve forwarding efficiency and stability.
This commit is contained in:
@@ -634,66 +634,7 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
|
||||
// Store tunnelID for cleanup
|
||||
c.tunnelID = req.TunnelID
|
||||
|
||||
// For TCP tunnels, the data connection is upgraded to a yamux session and used for
|
||||
// stream forwarding, not framed request/response routing.
|
||||
if group.TunnelType == protocol.TunnelTypeTCP {
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: true,
|
||||
ConnectionID: req.ConnectionID,
|
||||
Message: "Data connection accepted",
|
||||
}
|
||||
|
||||
respData, _ := json.Marshal(resp)
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
|
||||
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
|
||||
return fmt.Errorf("failed to send data connect ack: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("TCP data connection established",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Clear deadline for yamux data-plane.
|
||||
_ = c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
group.AddSession(req.ConnectionID, session)
|
||||
defer group.RemoveSession(req.ConnectionID)
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Add data connection to group
|
||||
dataConn, err := c.groupManager.AddDataConnection(&req, c.conn)
|
||||
if err != nil {
|
||||
c.sendDataConnectError("join_failed", err.Error())
|
||||
return fmt.Errorf("failed to join connection group: %w", err)
|
||||
}
|
||||
|
||||
// Send success response
|
||||
// Send success response before upgrading the connection to yamux.
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: true,
|
||||
ConnectionID: req.ConnectionID,
|
||||
@@ -712,56 +653,34 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Handle data frames on this connection
|
||||
return c.handleDataConnectionFrames(dataConn, reader)
|
||||
}
|
||||
// Clear deadline for yamux data-plane.
|
||||
_ = c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// handleDataConnectionFrames handles frames on a data connection
|
||||
func (c *Connection) handleDataConnectionFrames(dataConn *DataConnection, reader *bufio.Reader) error {
|
||||
defer func() {
|
||||
// Get the group and remove this data connection
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok {
|
||||
group.RemoveDataConnection(dataConn.ID)
|
||||
}
|
||||
}()
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-dataConn.stopCh:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(reader)
|
||||
if err != nil {
|
||||
// Timeout is OK, continue
|
||||
if isTimeoutError(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
dataConn.mu.Lock()
|
||||
dataConn.LastActive = time.Now()
|
||||
dataConn.mu.Unlock()
|
||||
group.AddSession(req.ConnectionID, session)
|
||||
defer group.RemoveSession(req.ConnectionID)
|
||||
|
||||
sf := protocol.WithFrame(frame)
|
||||
|
||||
switch sf.Frame.Type {
|
||||
case protocol.FrameTypeClose:
|
||||
sf.Close()
|
||||
c.logger.Info("Data connection closed by client",
|
||||
zap.String("connection_id", dataConn.ID))
|
||||
return nil
|
||||
|
||||
default:
|
||||
sf.Close()
|
||||
c.logger.Warn("Unexpected frame type on data connection",
|
||||
zap.String("type", sf.Frame.Type.String()),
|
||||
zap.String("connection_id", dataConn.ID),
|
||||
)
|
||||
}
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,23 +15,11 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
|
||||
type DataConnection struct {
|
||||
ID string
|
||||
Conn net.Conn
|
||||
LastActive time.Time
|
||||
closed bool
|
||||
closedMu sync.RWMutex
|
||||
stopCh chan struct{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type ConnectionGroup struct {
|
||||
TunnelID string
|
||||
Subdomain string
|
||||
Token string
|
||||
PrimaryConn *Connection
|
||||
DataConns map[string]*DataConnection
|
||||
Sessions map[string]*yamux.Session
|
||||
TunnelType protocol.TunnelType
|
||||
RegisteredAt time.Time
|
||||
@@ -50,7 +38,6 @@ func NewConnectionGroup(tunnelID, subdomain, token string, primaryConn *Connecti
|
||||
Subdomain: subdomain,
|
||||
Token: token,
|
||||
PrimaryConn: primaryConn,
|
||||
DataConns: make(map[string]*DataConnection),
|
||||
Sessions: make(map[string]*yamux.Session),
|
||||
TunnelType: tunnelType,
|
||||
RegisteredAt: time.Now(),
|
||||
@@ -146,46 +133,6 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) AddDataConnection(connID string, conn net.Conn) *DataConnection {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
dataConn := &DataConnection{
|
||||
ID: connID,
|
||||
Conn: conn,
|
||||
LastActive: time.Now(),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
g.DataConns[connID] = dataConn
|
||||
g.LastActivity = time.Now()
|
||||
return dataConn
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) RemoveDataConnection(connID string) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
if dataConn, ok := g.DataConns[connID]; ok {
|
||||
dataConn.closedMu.Lock()
|
||||
if !dataConn.closed {
|
||||
dataConn.closed = true
|
||||
close(dataConn.stopCh)
|
||||
if dataConn.Conn != nil {
|
||||
_ = dataConn.Conn.SetDeadline(time.Now())
|
||||
dataConn.Conn.Close()
|
||||
}
|
||||
}
|
||||
dataConn.closedMu.Unlock()
|
||||
delete(g.DataConns, connID)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) DataConnectionCount() int {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
return len(g.DataConns)
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) Close() {
|
||||
g.mu.Lock()
|
||||
|
||||
@@ -197,12 +144,6 @@ func (g *ConnectionGroup) Close() {
|
||||
close(g.stopCh)
|
||||
}
|
||||
|
||||
dataConns := make([]*DataConnection, 0, len(g.DataConns))
|
||||
for _, dataConn := range g.DataConns {
|
||||
dataConns = append(dataConns, dataConn)
|
||||
}
|
||||
g.DataConns = make(map[string]*DataConnection)
|
||||
|
||||
sessions := make([]*yamux.Session, 0, len(g.Sessions))
|
||||
for _, session := range g.Sessions {
|
||||
if session != nil {
|
||||
@@ -213,19 +154,6 @@ func (g *ConnectionGroup) Close() {
|
||||
|
||||
g.mu.Unlock()
|
||||
|
||||
for _, dataConn := range dataConns {
|
||||
dataConn.closedMu.Lock()
|
||||
if !dataConn.closed {
|
||||
dataConn.closed = true
|
||||
close(dataConn.stopCh)
|
||||
if dataConn.Conn != nil {
|
||||
_ = dataConn.Conn.SetDeadline(time.Now())
|
||||
_ = dataConn.Conn.Close()
|
||||
}
|
||||
}
|
||||
dataConn.closedMu.Unlock()
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
_ = session.Close()
|
||||
}
|
||||
@@ -302,7 +230,13 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
|
||||
default:
|
||||
}
|
||||
|
||||
sessions := g.sessionsSnapshot()
|
||||
// Prefer data sessions for data-plane traffic; keep the primary session
|
||||
// as control-plane (client ping/latency), and only fall back to primary
|
||||
// when no data session exists.
|
||||
sessions := g.sessionsSnapshot(false)
|
||||
if len(sessions) == 0 {
|
||||
sessions = g.sessionsSnapshot(true)
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
@@ -380,7 +314,10 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) selectSession() *yamux.Session {
|
||||
sessions := g.sessionsSnapshot()
|
||||
sessions := g.sessionsSnapshot(false)
|
||||
if len(sessions) == 0 {
|
||||
sessions = g.sessionsSnapshot(true)
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -403,7 +340,7 @@ func (g *ConnectionGroup) selectSession() *yamux.Session {
|
||||
return best
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
|
||||
func (g *ConnectionGroup) sessionsSnapshot(includePrimary bool) []*yamux.Session {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
@@ -417,6 +354,9 @@ func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
|
||||
delete(g.Sessions, id)
|
||||
continue
|
||||
}
|
||||
if id == "primary" && !includePrimary {
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,6 @@ package tcp
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -84,26 +82,6 @@ func (m *ConnectionGroupManager) RemoveGroup(tunnelID string) {
|
||||
}
|
||||
}
|
||||
|
||||
// AddDataConnection adds a data connection to a group
|
||||
func (m *ConnectionGroupManager) AddDataConnection(req *protocol.DataConnectRequest, conn net.Conn) (*DataConnection, error) {
|
||||
m.mu.RLock()
|
||||
group, ok := m.groups[req.TunnelID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tunnel not found: %s", req.TunnelID)
|
||||
}
|
||||
|
||||
// Validate token
|
||||
if group.Token != "" && req.Token != group.Token {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
dataConn := group.AddDataConnection(req.ConnectionID, conn)
|
||||
|
||||
return dataConn, nil
|
||||
}
|
||||
|
||||
// cleanupLoop periodically cleans up stale groups
|
||||
func (m *ConnectionGroupManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(m.cleanupInterval)
|
||||
|
||||
@@ -39,7 +39,7 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
c.session = session
|
||||
|
||||
openStream := session.Open
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
if c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
group.AddSession("primary", session)
|
||||
openStream = group.OpenStream
|
||||
@@ -78,7 +78,7 @@ func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
|
||||
c.session = session
|
||||
|
||||
openStream := session.Open
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
if c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
group.AddSession("primary", session)
|
||||
openStream = group.OpenStream
|
||||
|
||||
Reference in New Issue
Block a user