mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-27 06:42:05 +00:00
- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server - Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly - Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
313 lines
6.4 KiB
Go
313 lines
6.4 KiB
Go
package tcp
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
json "github.com/goccy/go-json"
|
|
"github.com/hashicorp/yamux"
|
|
"go.uber.org/zap"
|
|
|
|
"drip/internal/shared/constants"
|
|
"drip/internal/shared/protocol"
|
|
)
|
|
|
|
// sessionHandle wraps a yamux session with metadata.
|
|
type sessionHandle struct {
|
|
id string
|
|
conn net.Conn
|
|
session *yamux.Session
|
|
active atomic.Int64
|
|
lastActive atomic.Int64 // unix nanos
|
|
closed atomic.Bool
|
|
}
|
|
|
|
func (h *sessionHandle) touch() {
|
|
h.lastActive.Store(time.Now().UnixNano())
|
|
}
|
|
|
|
func (h *sessionHandle) lastActiveTime() time.Time {
|
|
n := h.lastActive.Load()
|
|
if n == 0 {
|
|
return time.Time{}
|
|
}
|
|
return time.Unix(0, n)
|
|
}
|
|
|
|
// scalerLoop monitors load and adjusts session count.
|
|
func (c *PoolClient) scalerLoop() {
|
|
defer c.wg.Done()
|
|
|
|
const (
|
|
checkInterval = 5 * time.Second
|
|
scaleUpCooldown = 5 * time.Second
|
|
scaleDownCooldown = 60 * time.Second
|
|
capacityPerSession = int64(64)
|
|
scaleUpLoad = 0.7
|
|
scaleDownLoad = 0.3
|
|
)
|
|
|
|
ticker := time.NewTicker(checkInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-c.stopCh:
|
|
return
|
|
case <-ticker.C:
|
|
}
|
|
|
|
c.mu.Lock()
|
|
desired := c.desiredTotal
|
|
if desired == 0 {
|
|
desired = c.initialSessions
|
|
c.desiredTotal = desired
|
|
}
|
|
lastScale := c.lastScale
|
|
c.mu.Unlock()
|
|
|
|
current := c.sessionCount()
|
|
if current <= 0 {
|
|
continue
|
|
}
|
|
|
|
active := c.stats.GetActiveConnections()
|
|
load := float64(active) / float64(int64(current)*capacityPerSession)
|
|
|
|
sinceLastScale := time.Since(lastScale)
|
|
if sinceLastScale >= scaleUpCooldown && load > scaleUpLoad && desired < c.maxSessions {
|
|
c.mu.Lock()
|
|
c.desiredTotal = min(c.desiredTotal+1, c.maxSessions)
|
|
c.lastScale = time.Now()
|
|
c.mu.Unlock()
|
|
} else if sinceLastScale >= scaleDownCooldown && load < scaleDownLoad && desired > c.minSessions {
|
|
c.mu.Lock()
|
|
c.desiredTotal = max(c.desiredTotal-1, c.minSessions)
|
|
c.lastScale = time.Now()
|
|
c.mu.Unlock()
|
|
}
|
|
|
|
c.ensureSessions()
|
|
}
|
|
}
|
|
|
|
// ensureSessions adjusts session count to match desired.
|
|
func (c *PoolClient) ensureSessions() {
|
|
if c.IsClosed() || c.tunnelID == "" {
|
|
return
|
|
}
|
|
|
|
c.mu.RLock()
|
|
desired := c.desiredTotal
|
|
c.mu.RUnlock()
|
|
|
|
desired = min(max(desired, c.minSessions), c.maxSessions)
|
|
|
|
current := c.sessionCount()
|
|
if current < desired {
|
|
for i := 0; i < desired-current; i++ {
|
|
if err := c.addDataSession(); err != nil {
|
|
c.logger.Debug("Add data session failed", zap.Error(err))
|
|
break
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
if current > desired {
|
|
c.removeIdleSessions(current - desired)
|
|
}
|
|
}
|
|
|
|
// addDataSession creates a new data session.
|
|
func (c *PoolClient) addDataSession() error {
|
|
select {
|
|
case <-c.stopCh:
|
|
return net.ErrClosed
|
|
default:
|
|
}
|
|
|
|
if c.tunnelID == "" {
|
|
return fmt.Errorf("server does not support data connections")
|
|
}
|
|
|
|
conn, err := c.dialTLS()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
connID := fmt.Sprintf("data-%d", time.Now().UnixNano())
|
|
|
|
req := protocol.DataConnectRequest{
|
|
TunnelID: c.tunnelID,
|
|
Token: c.token,
|
|
ConnectionID: connID,
|
|
}
|
|
|
|
payload, err := json.Marshal(req)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return fmt.Errorf("failed to marshal data connect request: %w", err)
|
|
}
|
|
|
|
if err := protocol.WriteFrame(conn, protocol.NewFrame(protocol.FrameTypeDataConnect, payload)); err != nil {
|
|
_ = conn.Close()
|
|
return fmt.Errorf("failed to send data connect: %w", err)
|
|
}
|
|
|
|
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
|
ack, err := protocol.ReadFrame(conn)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return fmt.Errorf("failed to read data connect ack: %w", err)
|
|
}
|
|
defer ack.Release()
|
|
_ = conn.SetReadDeadline(time.Time{})
|
|
|
|
if ack.Type == protocol.FrameTypeError {
|
|
var errMsg protocol.ErrorMessage
|
|
if e := json.Unmarshal(ack.Payload, &errMsg); e == nil {
|
|
_ = conn.Close()
|
|
return fmt.Errorf("data connect error: %s - %s", errMsg.Code, errMsg.Message)
|
|
}
|
|
_ = conn.Close()
|
|
return fmt.Errorf("data connect error")
|
|
}
|
|
if ack.Type != protocol.FrameTypeDataConnectAck {
|
|
_ = conn.Close()
|
|
return fmt.Errorf("unexpected data connect ack frame: %s", ack.Type)
|
|
}
|
|
|
|
var resp protocol.DataConnectResponse
|
|
if err := json.Unmarshal(ack.Payload, &resp); err != nil {
|
|
_ = conn.Close()
|
|
return fmt.Errorf("failed to parse data connect response: %w", err)
|
|
}
|
|
if !resp.Accepted {
|
|
_ = conn.Close()
|
|
return fmt.Errorf("data connection rejected: %s", resp.Message)
|
|
}
|
|
|
|
yamuxCfg := yamux.DefaultConfig()
|
|
yamuxCfg.EnableKeepAlive = false
|
|
yamuxCfg.LogOutput = io.Discard
|
|
yamuxCfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
|
|
|
session, err := yamux.Server(conn, yamuxCfg)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return fmt.Errorf("failed to init yamux session: %w", err)
|
|
}
|
|
|
|
h := &sessionHandle{
|
|
id: connID,
|
|
conn: conn,
|
|
session: session,
|
|
}
|
|
h.touch()
|
|
|
|
c.mu.Lock()
|
|
c.dataSessions[connID] = h
|
|
c.mu.Unlock()
|
|
|
|
c.wg.Add(1)
|
|
go c.acceptLoop(h, false)
|
|
|
|
c.wg.Add(1)
|
|
go c.sessionWatcher(h, false)
|
|
|
|
return nil
|
|
}
|
|
|
|
// removeIdleSessions removes n idle sessions.
|
|
func (c *PoolClient) removeIdleSessions(n int) {
|
|
if n <= 0 {
|
|
return
|
|
}
|
|
|
|
type candidate struct {
|
|
id string
|
|
active int64
|
|
lastActive time.Time
|
|
}
|
|
|
|
c.mu.RLock()
|
|
candidates := make([]candidate, 0, len(c.dataSessions))
|
|
for id, h := range c.dataSessions {
|
|
candidates = append(candidates, candidate{
|
|
id: id,
|
|
active: h.active.Load(),
|
|
lastActive: h.lastActiveTime(),
|
|
})
|
|
}
|
|
c.mu.RUnlock()
|
|
|
|
removed := 0
|
|
for removed < n {
|
|
var best candidate
|
|
found := false
|
|
for _, cand := range candidates {
|
|
if cand.active != 0 {
|
|
continue
|
|
}
|
|
if !found || cand.lastActive.Before(best.lastActive) {
|
|
best = cand
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
return
|
|
}
|
|
if c.removeDataSession(best.id) {
|
|
removed++
|
|
}
|
|
for i := range candidates {
|
|
if candidates[i].id == best.id {
|
|
candidates[i].active = 1
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// removeDataSession removes a data session by ID.
|
|
func (c *PoolClient) removeDataSession(id string) bool {
|
|
var h *sessionHandle
|
|
|
|
c.mu.Lock()
|
|
h = c.dataSessions[id]
|
|
delete(c.dataSessions, id)
|
|
c.mu.Unlock()
|
|
|
|
if h == nil {
|
|
return false
|
|
}
|
|
|
|
if !h.closed.CompareAndSwap(false, true) {
|
|
return false
|
|
}
|
|
|
|
if h.session != nil {
|
|
_ = h.session.Close()
|
|
}
|
|
if h.conn != nil {
|
|
_ = h.conn.Close()
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// sessionCount returns the total number of active sessions.
|
|
func (c *PoolClient) sessionCount() int {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
count := len(c.dataSessions)
|
|
if c.primary != nil {
|
|
count++
|
|
}
|
|
return count
|
|
}
|