mirror of
https://github.com/Gouryella/drip.git
synced 2026-03-02 00:03:07 +00:00
- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server - Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly - Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
439 lines
8.9 KiB
Go
439 lines
8.9 KiB
Go
package tcp
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/hashicorp/yamux"
|
|
|
|
"drip/internal/shared/constants"
|
|
"drip/internal/shared/protocol"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
|
|
type DataConnection struct {
|
|
ID string
|
|
Conn net.Conn
|
|
LastActive time.Time
|
|
closed bool
|
|
closedMu sync.RWMutex
|
|
stopCh chan struct{}
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
type ConnectionGroup struct {
|
|
TunnelID string
|
|
Subdomain string
|
|
Token string
|
|
PrimaryConn *Connection
|
|
DataConns map[string]*DataConnection
|
|
Sessions map[string]*yamux.Session
|
|
TunnelType protocol.TunnelType
|
|
RegisteredAt time.Time
|
|
LastActivity time.Time
|
|
sessionIdx uint32
|
|
mu sync.RWMutex
|
|
stopCh chan struct{}
|
|
logger *zap.Logger
|
|
|
|
heartbeatStarted bool
|
|
}
|
|
|
|
func NewConnectionGroup(tunnelID, subdomain, token string, primaryConn *Connection, tunnelType protocol.TunnelType, logger *zap.Logger) *ConnectionGroup {
|
|
return &ConnectionGroup{
|
|
TunnelID: tunnelID,
|
|
Subdomain: subdomain,
|
|
Token: token,
|
|
PrimaryConn: primaryConn,
|
|
DataConns: make(map[string]*DataConnection),
|
|
Sessions: make(map[string]*yamux.Session),
|
|
TunnelType: tunnelType,
|
|
RegisteredAt: time.Now(),
|
|
LastActivity: time.Now(),
|
|
stopCh: make(chan struct{}),
|
|
logger: logger.With(zap.String("tunnel_id", tunnelID)),
|
|
}
|
|
}
|
|
|
|
// StartHeartbeat starts a goroutine that periodically pings all sessions
|
|
// and removes dead ones. The caller should ensure this is only called once.
|
|
func (g *ConnectionGroup) StartHeartbeat(interval, timeout time.Duration) {
|
|
go g.heartbeatLoop(interval, timeout)
|
|
}
|
|
|
|
func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
const maxConsecutiveFailures = 3
|
|
failureCount := make(map[string]int)
|
|
|
|
for {
|
|
select {
|
|
case <-g.stopCh:
|
|
return
|
|
case <-ticker.C:
|
|
}
|
|
|
|
g.mu.RLock()
|
|
sessions := make(map[string]*yamux.Session, len(g.Sessions))
|
|
for id, s := range g.Sessions {
|
|
sessions[id] = s
|
|
}
|
|
g.mu.RUnlock()
|
|
|
|
for id, session := range sessions {
|
|
if session == nil || session.IsClosed() {
|
|
g.RemoveSession(id)
|
|
delete(failureCount, id)
|
|
continue
|
|
}
|
|
|
|
// Ping with timeout
|
|
done := make(chan error, 1)
|
|
go func(s *yamux.Session) {
|
|
_, err := s.Ping()
|
|
done <- err
|
|
}(session)
|
|
|
|
var err error
|
|
select {
|
|
case err = <-done:
|
|
case <-time.After(timeout):
|
|
err = fmt.Errorf("ping timeout")
|
|
case <-g.stopCh:
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
failureCount[id]++
|
|
g.logger.Debug("Session ping failed",
|
|
zap.String("session_id", id),
|
|
zap.Int("consecutive_failures", failureCount[id]),
|
|
zap.Error(err),
|
|
)
|
|
|
|
if failureCount[id] >= maxConsecutiveFailures {
|
|
g.logger.Warn("Session ping failed too many times, removing",
|
|
zap.String("session_id", id),
|
|
zap.Int("failures", failureCount[id]),
|
|
)
|
|
g.RemoveSession(id)
|
|
delete(failureCount, id)
|
|
}
|
|
} else {
|
|
// Reset on success
|
|
failureCount[id] = 0
|
|
g.mu.Lock()
|
|
g.LastActivity = time.Now()
|
|
g.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// Check if all sessions are gone
|
|
g.mu.RLock()
|
|
sessionCount := len(g.Sessions)
|
|
g.mu.RUnlock()
|
|
|
|
if sessionCount == 0 {
|
|
g.logger.Info("All sessions closed, tunnel will be cleaned up")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (g *ConnectionGroup) AddDataConnection(connID string, conn net.Conn) *DataConnection {
|
|
g.mu.Lock()
|
|
defer g.mu.Unlock()
|
|
|
|
dataConn := &DataConnection{
|
|
ID: connID,
|
|
Conn: conn,
|
|
LastActive: time.Now(),
|
|
stopCh: make(chan struct{}),
|
|
}
|
|
g.DataConns[connID] = dataConn
|
|
g.LastActivity = time.Now()
|
|
return dataConn
|
|
}
|
|
|
|
func (g *ConnectionGroup) RemoveDataConnection(connID string) {
|
|
g.mu.Lock()
|
|
defer g.mu.Unlock()
|
|
|
|
if dataConn, ok := g.DataConns[connID]; ok {
|
|
dataConn.closedMu.Lock()
|
|
if !dataConn.closed {
|
|
dataConn.closed = true
|
|
close(dataConn.stopCh)
|
|
if dataConn.Conn != nil {
|
|
_ = dataConn.Conn.SetDeadline(time.Now())
|
|
dataConn.Conn.Close()
|
|
}
|
|
}
|
|
dataConn.closedMu.Unlock()
|
|
delete(g.DataConns, connID)
|
|
}
|
|
}
|
|
|
|
func (g *ConnectionGroup) DataConnectionCount() int {
|
|
g.mu.RLock()
|
|
defer g.mu.RUnlock()
|
|
return len(g.DataConns)
|
|
}
|
|
|
|
func (g *ConnectionGroup) Close() {
|
|
g.mu.Lock()
|
|
|
|
select {
|
|
case <-g.stopCh:
|
|
g.mu.Unlock()
|
|
return
|
|
default:
|
|
close(g.stopCh)
|
|
}
|
|
|
|
dataConns := make([]*DataConnection, 0, len(g.DataConns))
|
|
for _, dataConn := range g.DataConns {
|
|
dataConns = append(dataConns, dataConn)
|
|
}
|
|
g.DataConns = make(map[string]*DataConnection)
|
|
|
|
sessions := make([]*yamux.Session, 0, len(g.Sessions))
|
|
for _, session := range g.Sessions {
|
|
if session != nil {
|
|
sessions = append(sessions, session)
|
|
}
|
|
}
|
|
g.Sessions = make(map[string]*yamux.Session)
|
|
|
|
g.mu.Unlock()
|
|
|
|
for _, dataConn := range dataConns {
|
|
dataConn.closedMu.Lock()
|
|
if !dataConn.closed {
|
|
dataConn.closed = true
|
|
close(dataConn.stopCh)
|
|
if dataConn.Conn != nil {
|
|
_ = dataConn.Conn.SetDeadline(time.Now())
|
|
_ = dataConn.Conn.Close()
|
|
}
|
|
}
|
|
dataConn.closedMu.Unlock()
|
|
}
|
|
|
|
for _, session := range sessions {
|
|
_ = session.Close()
|
|
}
|
|
}
|
|
|
|
func (g *ConnectionGroup) IsStale(timeout time.Duration) bool {
|
|
g.mu.RLock()
|
|
defer g.mu.RUnlock()
|
|
return time.Since(g.LastActivity) > timeout
|
|
}
|
|
|
|
func (g *ConnectionGroup) AddSession(connID string, session *yamux.Session) {
|
|
if connID == "" || session == nil {
|
|
return
|
|
}
|
|
|
|
g.mu.Lock()
|
|
if g.Sessions == nil {
|
|
g.Sessions = make(map[string]*yamux.Session)
|
|
}
|
|
g.Sessions[connID] = session
|
|
g.LastActivity = time.Now()
|
|
|
|
// Start heartbeat on first session
|
|
shouldStartHeartbeat := !g.heartbeatStarted
|
|
if shouldStartHeartbeat {
|
|
g.heartbeatStarted = true
|
|
}
|
|
g.mu.Unlock()
|
|
|
|
if shouldStartHeartbeat {
|
|
g.StartHeartbeat(constants.HeartbeatInterval, constants.HeartbeatTimeout)
|
|
}
|
|
}
|
|
|
|
func (g *ConnectionGroup) RemoveSession(connID string) {
|
|
if connID == "" {
|
|
return
|
|
}
|
|
|
|
var session *yamux.Session
|
|
|
|
g.mu.Lock()
|
|
if g.Sessions != nil {
|
|
session = g.Sessions[connID]
|
|
delete(g.Sessions, connID)
|
|
}
|
|
g.mu.Unlock()
|
|
|
|
if session != nil {
|
|
_ = session.Close()
|
|
}
|
|
}
|
|
|
|
func (g *ConnectionGroup) SessionCount() int {
|
|
g.mu.RLock()
|
|
defer g.mu.RUnlock()
|
|
return len(g.Sessions)
|
|
}
|
|
|
|
func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
|
|
const (
|
|
maxStreamsPerSession = 256
|
|
maxRetries = 3
|
|
backoffBase = 25 * time.Millisecond
|
|
)
|
|
|
|
var lastErr error
|
|
|
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
|
select {
|
|
case <-g.stopCh:
|
|
return nil, net.ErrClosed
|
|
default:
|
|
}
|
|
|
|
sessions := g.sessionsSnapshot()
|
|
if len(sessions) == 0 {
|
|
return nil, net.ErrClosed
|
|
}
|
|
|
|
tried := make([]bool, len(sessions))
|
|
anyUnderCap := false
|
|
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
|
|
|
|
for range sessions {
|
|
bestIdx := -1
|
|
minStreams := int(^uint(0) >> 1)
|
|
|
|
for i := 0; i < len(sessions); i++ {
|
|
idx := (start + i) % len(sessions)
|
|
if tried[idx] {
|
|
continue
|
|
}
|
|
|
|
session := sessions[idx]
|
|
if session == nil || session.IsClosed() {
|
|
tried[idx] = true
|
|
continue
|
|
}
|
|
|
|
n := session.NumStreams()
|
|
if n >= maxStreamsPerSession {
|
|
continue
|
|
}
|
|
anyUnderCap = true
|
|
|
|
if n < minStreams {
|
|
minStreams = n
|
|
bestIdx = idx
|
|
}
|
|
}
|
|
|
|
if bestIdx == -1 {
|
|
break
|
|
}
|
|
|
|
tried[bestIdx] = true
|
|
session := sessions[bestIdx]
|
|
if session == nil || session.IsClosed() {
|
|
continue
|
|
}
|
|
|
|
stream, err := session.Open()
|
|
if err == nil {
|
|
return stream, nil
|
|
}
|
|
lastErr = err
|
|
|
|
if session.IsClosed() {
|
|
g.deleteClosedSessions()
|
|
}
|
|
}
|
|
|
|
if !anyUnderCap {
|
|
lastErr = fmt.Errorf("all sessions are at stream capacity (%d)", maxStreamsPerSession)
|
|
}
|
|
|
|
if attempt < maxRetries-1 {
|
|
select {
|
|
case <-g.stopCh:
|
|
return nil, net.ErrClosed
|
|
case <-time.After(backoffBase * time.Duration(attempt+1)):
|
|
}
|
|
}
|
|
}
|
|
|
|
if lastErr == nil {
|
|
lastErr = fmt.Errorf("failed to open stream")
|
|
}
|
|
return nil, lastErr
|
|
}
|
|
|
|
func (g *ConnectionGroup) selectSession() *yamux.Session {
|
|
sessions := g.sessionsSnapshot()
|
|
if len(sessions) == 0 {
|
|
return nil
|
|
}
|
|
|
|
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
|
|
minStreams := int(^uint(0) >> 1)
|
|
var best *yamux.Session
|
|
|
|
for i := 0; i < len(sessions); i++ {
|
|
session := sessions[(start+i)%len(sessions)]
|
|
if session == nil || session.IsClosed() {
|
|
continue
|
|
}
|
|
if n := session.NumStreams(); n < minStreams {
|
|
minStreams = n
|
|
best = session
|
|
}
|
|
}
|
|
|
|
return best
|
|
}
|
|
|
|
func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
|
|
g.mu.Lock()
|
|
defer g.mu.Unlock()
|
|
|
|
if len(g.Sessions) == 0 {
|
|
return nil
|
|
}
|
|
|
|
sessions := make([]*yamux.Session, 0, len(g.Sessions))
|
|
for id, session := range g.Sessions {
|
|
if session == nil || session.IsClosed() {
|
|
delete(g.Sessions, id)
|
|
continue
|
|
}
|
|
sessions = append(sessions, session)
|
|
}
|
|
|
|
if len(sessions) > 0 {
|
|
g.LastActivity = time.Now()
|
|
}
|
|
|
|
return sessions
|
|
}
|
|
|
|
func (g *ConnectionGroup) deleteClosedSessions() {
|
|
g.mu.Lock()
|
|
for id, session := range g.Sessions {
|
|
if session == nil || session.IsClosed() {
|
|
delete(g.Sessions, id)
|
|
}
|
|
}
|
|
g.mu.Unlock()
|
|
}
|