Files
drip/internal/client/tcp/pool_session.go
Gouryella 0c19c3300c feat(tunnel): switch to yamux stream proxying and connection pooling
- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server
- Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly
- Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
2025-12-13 18:03:44 +08:00

313 lines
6.4 KiB
Go

package tcp
import (
"fmt"
"io"
"net"
"sync/atomic"
"time"
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
"drip/internal/shared/constants"
"drip/internal/shared/protocol"
)
// sessionHandle wraps a yamux session with metadata.
type sessionHandle struct {
id string
conn net.Conn
session *yamux.Session
active atomic.Int64
lastActive atomic.Int64 // unix nanos
closed atomic.Bool
}
func (h *sessionHandle) touch() {
h.lastActive.Store(time.Now().UnixNano())
}
func (h *sessionHandle) lastActiveTime() time.Time {
n := h.lastActive.Load()
if n == 0 {
return time.Time{}
}
return time.Unix(0, n)
}
// scalerLoop monitors load and adjusts session count.
func (c *PoolClient) scalerLoop() {
defer c.wg.Done()
const (
checkInterval = 5 * time.Second
scaleUpCooldown = 5 * time.Second
scaleDownCooldown = 60 * time.Second
capacityPerSession = int64(64)
scaleUpLoad = 0.7
scaleDownLoad = 0.3
)
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-c.stopCh:
return
case <-ticker.C:
}
c.mu.Lock()
desired := c.desiredTotal
if desired == 0 {
desired = c.initialSessions
c.desiredTotal = desired
}
lastScale := c.lastScale
c.mu.Unlock()
current := c.sessionCount()
if current <= 0 {
continue
}
active := c.stats.GetActiveConnections()
load := float64(active) / float64(int64(current)*capacityPerSession)
sinceLastScale := time.Since(lastScale)
if sinceLastScale >= scaleUpCooldown && load > scaleUpLoad && desired < c.maxSessions {
c.mu.Lock()
c.desiredTotal = min(c.desiredTotal+1, c.maxSessions)
c.lastScale = time.Now()
c.mu.Unlock()
} else if sinceLastScale >= scaleDownCooldown && load < scaleDownLoad && desired > c.minSessions {
c.mu.Lock()
c.desiredTotal = max(c.desiredTotal-1, c.minSessions)
c.lastScale = time.Now()
c.mu.Unlock()
}
c.ensureSessions()
}
}
// ensureSessions adjusts session count to match desired.
func (c *PoolClient) ensureSessions() {
if c.IsClosed() || c.tunnelID == "" {
return
}
c.mu.RLock()
desired := c.desiredTotal
c.mu.RUnlock()
desired = min(max(desired, c.minSessions), c.maxSessions)
current := c.sessionCount()
if current < desired {
for i := 0; i < desired-current; i++ {
if err := c.addDataSession(); err != nil {
c.logger.Debug("Add data session failed", zap.Error(err))
break
}
}
return
}
if current > desired {
c.removeIdleSessions(current - desired)
}
}
// addDataSession creates a new data session.
func (c *PoolClient) addDataSession() error {
select {
case <-c.stopCh:
return net.ErrClosed
default:
}
if c.tunnelID == "" {
return fmt.Errorf("server does not support data connections")
}
conn, err := c.dialTLS()
if err != nil {
return err
}
connID := fmt.Sprintf("data-%d", time.Now().UnixNano())
req := protocol.DataConnectRequest{
TunnelID: c.tunnelID,
Token: c.token,
ConnectionID: connID,
}
payload, err := json.Marshal(req)
if err != nil {
_ = conn.Close()
return fmt.Errorf("failed to marshal data connect request: %w", err)
}
if err := protocol.WriteFrame(conn, protocol.NewFrame(protocol.FrameTypeDataConnect, payload)); err != nil {
_ = conn.Close()
return fmt.Errorf("failed to send data connect: %w", err)
}
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
ack, err := protocol.ReadFrame(conn)
if err != nil {
_ = conn.Close()
return fmt.Errorf("failed to read data connect ack: %w", err)
}
defer ack.Release()
_ = conn.SetReadDeadline(time.Time{})
if ack.Type == protocol.FrameTypeError {
var errMsg protocol.ErrorMessage
if e := json.Unmarshal(ack.Payload, &errMsg); e == nil {
_ = conn.Close()
return fmt.Errorf("data connect error: %s - %s", errMsg.Code, errMsg.Message)
}
_ = conn.Close()
return fmt.Errorf("data connect error")
}
if ack.Type != protocol.FrameTypeDataConnectAck {
_ = conn.Close()
return fmt.Errorf("unexpected data connect ack frame: %s", ack.Type)
}
var resp protocol.DataConnectResponse
if err := json.Unmarshal(ack.Payload, &resp); err != nil {
_ = conn.Close()
return fmt.Errorf("failed to parse data connect response: %w", err)
}
if !resp.Accepted {
_ = conn.Close()
return fmt.Errorf("data connection rejected: %s", resp.Message)
}
yamuxCfg := yamux.DefaultConfig()
yamuxCfg.EnableKeepAlive = false
yamuxCfg.LogOutput = io.Discard
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Server(conn, yamuxCfg)
if err != nil {
_ = conn.Close()
return fmt.Errorf("failed to init yamux session: %w", err)
}
h := &sessionHandle{
id: connID,
conn: conn,
session: session,
}
h.touch()
c.mu.Lock()
c.dataSessions[connID] = h
c.mu.Unlock()
c.wg.Add(1)
go c.acceptLoop(h, false)
c.wg.Add(1)
go c.sessionWatcher(h, false)
return nil
}
// removeIdleSessions removes n idle sessions.
func (c *PoolClient) removeIdleSessions(n int) {
if n <= 0 {
return
}
type candidate struct {
id string
active int64
lastActive time.Time
}
c.mu.RLock()
candidates := make([]candidate, 0, len(c.dataSessions))
for id, h := range c.dataSessions {
candidates = append(candidates, candidate{
id: id,
active: h.active.Load(),
lastActive: h.lastActiveTime(),
})
}
c.mu.RUnlock()
removed := 0
for removed < n {
var best candidate
found := false
for _, cand := range candidates {
if cand.active != 0 {
continue
}
if !found || cand.lastActive.Before(best.lastActive) {
best = cand
found = true
}
}
if !found {
return
}
if c.removeDataSession(best.id) {
removed++
}
for i := range candidates {
if candidates[i].id == best.id {
candidates[i].active = 1
break
}
}
}
}
// removeDataSession removes a data session by ID.
func (c *PoolClient) removeDataSession(id string) bool {
var h *sessionHandle
c.mu.Lock()
h = c.dataSessions[id]
delete(c.dataSessions, id)
c.mu.Unlock()
if h == nil {
return false
}
if !h.closed.CompareAndSwap(false, true) {
return false
}
if h.session != nil {
_ = h.session.Close()
}
if h.conn != nil {
_ = h.conn.Close()
}
return true
}
// sessionCount returns the total number of active sessions.
func (c *PoolClient) sessionCount() int {
c.mu.RLock()
defer c.mu.RUnlock()
count := len(c.dataSessions)
if c.primary != nil {
count++
}
return count
}