Files
drip/internal/server/tcp/connection_group.go
Gouryella 0c19c3300c feat(tunnel): switch to yamux stream proxying and connection pooling
- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server
- Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly
- Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
2025-12-13 18:03:44 +08:00

439 lines
8.9 KiB
Go

package tcp
import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/yamux"
"drip/internal/shared/constants"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
type DataConnection struct {
ID string
Conn net.Conn
LastActive time.Time
closed bool
closedMu sync.RWMutex
stopCh chan struct{}
mu sync.RWMutex
}
type ConnectionGroup struct {
TunnelID string
Subdomain string
Token string
PrimaryConn *Connection
DataConns map[string]*DataConnection
Sessions map[string]*yamux.Session
TunnelType protocol.TunnelType
RegisteredAt time.Time
LastActivity time.Time
sessionIdx uint32
mu sync.RWMutex
stopCh chan struct{}
logger *zap.Logger
heartbeatStarted bool
}
func NewConnectionGroup(tunnelID, subdomain, token string, primaryConn *Connection, tunnelType protocol.TunnelType, logger *zap.Logger) *ConnectionGroup {
return &ConnectionGroup{
TunnelID: tunnelID,
Subdomain: subdomain,
Token: token,
PrimaryConn: primaryConn,
DataConns: make(map[string]*DataConnection),
Sessions: make(map[string]*yamux.Session),
TunnelType: tunnelType,
RegisteredAt: time.Now(),
LastActivity: time.Now(),
stopCh: make(chan struct{}),
logger: logger.With(zap.String("tunnel_id", tunnelID)),
}
}
// StartHeartbeat starts a goroutine that periodically pings all sessions
// and removes dead ones. The caller should ensure this is only called once.
func (g *ConnectionGroup) StartHeartbeat(interval, timeout time.Duration) {
go g.heartbeatLoop(interval, timeout)
}
func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
const maxConsecutiveFailures = 3
failureCount := make(map[string]int)
for {
select {
case <-g.stopCh:
return
case <-ticker.C:
}
g.mu.RLock()
sessions := make(map[string]*yamux.Session, len(g.Sessions))
for id, s := range g.Sessions {
sessions[id] = s
}
g.mu.RUnlock()
for id, session := range sessions {
if session == nil || session.IsClosed() {
g.RemoveSession(id)
delete(failureCount, id)
continue
}
// Ping with timeout
done := make(chan error, 1)
go func(s *yamux.Session) {
_, err := s.Ping()
done <- err
}(session)
var err error
select {
case err = <-done:
case <-time.After(timeout):
err = fmt.Errorf("ping timeout")
case <-g.stopCh:
return
}
if err != nil {
failureCount[id]++
g.logger.Debug("Session ping failed",
zap.String("session_id", id),
zap.Int("consecutive_failures", failureCount[id]),
zap.Error(err),
)
if failureCount[id] >= maxConsecutiveFailures {
g.logger.Warn("Session ping failed too many times, removing",
zap.String("session_id", id),
zap.Int("failures", failureCount[id]),
)
g.RemoveSession(id)
delete(failureCount, id)
}
} else {
// Reset on success
failureCount[id] = 0
g.mu.Lock()
g.LastActivity = time.Now()
g.mu.Unlock()
}
}
// Check if all sessions are gone
g.mu.RLock()
sessionCount := len(g.Sessions)
g.mu.RUnlock()
if sessionCount == 0 {
g.logger.Info("All sessions closed, tunnel will be cleaned up")
}
}
}
func (g *ConnectionGroup) AddDataConnection(connID string, conn net.Conn) *DataConnection {
g.mu.Lock()
defer g.mu.Unlock()
dataConn := &DataConnection{
ID: connID,
Conn: conn,
LastActive: time.Now(),
stopCh: make(chan struct{}),
}
g.DataConns[connID] = dataConn
g.LastActivity = time.Now()
return dataConn
}
func (g *ConnectionGroup) RemoveDataConnection(connID string) {
g.mu.Lock()
defer g.mu.Unlock()
if dataConn, ok := g.DataConns[connID]; ok {
dataConn.closedMu.Lock()
if !dataConn.closed {
dataConn.closed = true
close(dataConn.stopCh)
if dataConn.Conn != nil {
_ = dataConn.Conn.SetDeadline(time.Now())
dataConn.Conn.Close()
}
}
dataConn.closedMu.Unlock()
delete(g.DataConns, connID)
}
}
func (g *ConnectionGroup) DataConnectionCount() int {
g.mu.RLock()
defer g.mu.RUnlock()
return len(g.DataConns)
}
func (g *ConnectionGroup) Close() {
g.mu.Lock()
select {
case <-g.stopCh:
g.mu.Unlock()
return
default:
close(g.stopCh)
}
dataConns := make([]*DataConnection, 0, len(g.DataConns))
for _, dataConn := range g.DataConns {
dataConns = append(dataConns, dataConn)
}
g.DataConns = make(map[string]*DataConnection)
sessions := make([]*yamux.Session, 0, len(g.Sessions))
for _, session := range g.Sessions {
if session != nil {
sessions = append(sessions, session)
}
}
g.Sessions = make(map[string]*yamux.Session)
g.mu.Unlock()
for _, dataConn := range dataConns {
dataConn.closedMu.Lock()
if !dataConn.closed {
dataConn.closed = true
close(dataConn.stopCh)
if dataConn.Conn != nil {
_ = dataConn.Conn.SetDeadline(time.Now())
_ = dataConn.Conn.Close()
}
}
dataConn.closedMu.Unlock()
}
for _, session := range sessions {
_ = session.Close()
}
}
func (g *ConnectionGroup) IsStale(timeout time.Duration) bool {
g.mu.RLock()
defer g.mu.RUnlock()
return time.Since(g.LastActivity) > timeout
}
func (g *ConnectionGroup) AddSession(connID string, session *yamux.Session) {
if connID == "" || session == nil {
return
}
g.mu.Lock()
if g.Sessions == nil {
g.Sessions = make(map[string]*yamux.Session)
}
g.Sessions[connID] = session
g.LastActivity = time.Now()
// Start heartbeat on first session
shouldStartHeartbeat := !g.heartbeatStarted
if shouldStartHeartbeat {
g.heartbeatStarted = true
}
g.mu.Unlock()
if shouldStartHeartbeat {
g.StartHeartbeat(constants.HeartbeatInterval, constants.HeartbeatTimeout)
}
}
func (g *ConnectionGroup) RemoveSession(connID string) {
if connID == "" {
return
}
var session *yamux.Session
g.mu.Lock()
if g.Sessions != nil {
session = g.Sessions[connID]
delete(g.Sessions, connID)
}
g.mu.Unlock()
if session != nil {
_ = session.Close()
}
}
func (g *ConnectionGroup) SessionCount() int {
g.mu.RLock()
defer g.mu.RUnlock()
return len(g.Sessions)
}
func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
const (
maxStreamsPerSession = 256
maxRetries = 3
backoffBase = 25 * time.Millisecond
)
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
select {
case <-g.stopCh:
return nil, net.ErrClosed
default:
}
sessions := g.sessionsSnapshot()
if len(sessions) == 0 {
return nil, net.ErrClosed
}
tried := make([]bool, len(sessions))
anyUnderCap := false
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
for range sessions {
bestIdx := -1
minStreams := int(^uint(0) >> 1)
for i := 0; i < len(sessions); i++ {
idx := (start + i) % len(sessions)
if tried[idx] {
continue
}
session := sessions[idx]
if session == nil || session.IsClosed() {
tried[idx] = true
continue
}
n := session.NumStreams()
if n >= maxStreamsPerSession {
continue
}
anyUnderCap = true
if n < minStreams {
minStreams = n
bestIdx = idx
}
}
if bestIdx == -1 {
break
}
tried[bestIdx] = true
session := sessions[bestIdx]
if session == nil || session.IsClosed() {
continue
}
stream, err := session.Open()
if err == nil {
return stream, nil
}
lastErr = err
if session.IsClosed() {
g.deleteClosedSessions()
}
}
if !anyUnderCap {
lastErr = fmt.Errorf("all sessions are at stream capacity (%d)", maxStreamsPerSession)
}
if attempt < maxRetries-1 {
select {
case <-g.stopCh:
return nil, net.ErrClosed
case <-time.After(backoffBase * time.Duration(attempt+1)):
}
}
}
if lastErr == nil {
lastErr = fmt.Errorf("failed to open stream")
}
return nil, lastErr
}
func (g *ConnectionGroup) selectSession() *yamux.Session {
sessions := g.sessionsSnapshot()
if len(sessions) == 0 {
return nil
}
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
minStreams := int(^uint(0) >> 1)
var best *yamux.Session
for i := 0; i < len(sessions); i++ {
session := sessions[(start+i)%len(sessions)]
if session == nil || session.IsClosed() {
continue
}
if n := session.NumStreams(); n < minStreams {
minStreams = n
best = session
}
}
return best
}
func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
g.mu.Lock()
defer g.mu.Unlock()
if len(g.Sessions) == 0 {
return nil
}
sessions := make([]*yamux.Session, 0, len(g.Sessions))
for id, session := range g.Sessions {
if session == nil || session.IsClosed() {
delete(g.Sessions, id)
continue
}
sessions = append(sessions, session)
}
if len(sessions) > 0 {
g.LastActivity = time.Now()
}
return sessions
}
func (g *ConnectionGroup) deleteClosedSessions() {
g.mu.Lock()
for id, session := range g.Sessions {
if session == nil || session.IsClosed() {
delete(g.Sessions, id)
}
}
g.mu.Unlock()
}