feat(tunnel): switch to yamux stream proxying and connection pooling

- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server
- Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly
- Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
This commit is contained in:
Gouryella
2025-12-13 18:03:44 +08:00
parent 3c93789266
commit 0c19c3300c
55 changed files with 3380 additions and 4849 deletions

View File

@@ -1,64 +1,79 @@
package tcp
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// TunnelProxy handles TCP connections for a specific tunnel
type TunnelProxy struct {
port int
subdomain string
tcpConn net.Conn // The tunnel control connection
listener net.Listener
logger *zap.Logger
stopCh chan struct{}
wg sync.WaitGroup
clientAddr string
streams map[string]*proxyStream // streamID -> stream info
streamMu sync.RWMutex
frameWriter *protocol.FrameWriter
bufferPool *pool.BufferPool
// Proxy exposes a public TCP port and forwards each incoming
// connection over a dedicated mux stream.
type Proxy struct {
port int
subdomain string
logger *zap.Logger
listener net.Listener
stopCh chan struct{}
once sync.Once
wg sync.WaitGroup
openStream func() (net.Conn, error)
stats trafficStats
sem chan struct{}
ctx context.Context
cancel context.CancelFunc
}
// proxyStream holds connection info with close state
type proxyStream struct {
conn net.Conn
closed bool
mu sync.Mutex
type trafficStats interface {
AddBytesIn(n int64)
AddBytesOut(n int64)
IncActiveConnections()
DecActiveConnections()
}
// NewTunnelProxy creates a new TCP tunnel proxy
func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Logger) *TunnelProxy {
return &TunnelProxy{
port: port,
subdomain: subdomain,
tcpConn: tcpConn,
logger: logger,
stopCh: make(chan struct{}),
clientAddr: tcpConn.RemoteAddr().String(),
streams: make(map[string]*proxyStream),
bufferPool: pool.NewBufferPool(),
frameWriter: protocol.NewFrameWriter(tcpConn),
func NewProxy(ctx context.Context, port int, subdomain string, openStream func() (net.Conn, error), stats trafficStats, logger *zap.Logger) *Proxy {
if ctx == nil {
ctx = context.Background()
}
cctx, cancel := context.WithCancel(ctx)
const maxConcurrentConnections = 10000
var sem chan struct{}
if maxConcurrentConnections > 0 {
sem = make(chan struct{}, maxConcurrentConnections)
}
return &Proxy{
port: port,
subdomain: subdomain,
logger: logger,
stopCh: make(chan struct{}),
openStream: openStream,
stats: stats,
sem: sem,
ctx: cctx,
cancel: cancel,
}
}
// Start starts listening on the allocated port
func (p *TunnelProxy) Start() error {
func (p *Proxy) Start() error {
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
listener, err := net.Listen("tcp", addr)
ln, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
}
p.listener = listener
p.listener = ln
p.logger.Info("TCP proxy started",
zap.Int("port", p.port),
@@ -67,14 +82,47 @@ func (p *TunnelProxy) Start() error {
p.wg.Add(1)
go p.acceptLoop()
return nil
}
// acceptLoop accepts incoming TCP connections
func (p *TunnelProxy) acceptLoop() {
func (p *Proxy) Stop() {
p.once.Do(func() {
close(p.stopCh)
p.cancel()
if p.listener != nil {
_ = p.listener.Close()
}
done := make(chan struct{})
go func() {
p.wg.Wait()
close(done)
}()
const stopTimeout = 30 * time.Second
select {
case <-done:
p.logger.Info("TCP proxy stopped",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
)
case <-time.After(stopTimeout):
p.logger.Warn("TCP proxy stop timed out",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
zap.Duration("timeout", stopTimeout),
)
}
})
}
func (p *Proxy) acceptLoop() {
defer p.wg.Done()
tcpLn, _ := p.listener.(*net.TCPListener)
for {
select {
case <-p.stopCh:
@@ -82,11 +130,13 @@ func (p *TunnelProxy) acceptLoop() {
default:
}
p.listener.(*net.TCPListener).SetDeadline(time.Now().Add(1 * time.Second))
if tcpLn != nil {
_ = tcpLn.SetDeadline(time.Now().Add(1 * time.Second))
}
conn, err := p.listener.Accept()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
if ne, ok := err.(net.Error); ok && ne.Timeout() {
continue
}
select {
@@ -98,187 +148,86 @@ func (p *TunnelProxy) acceptLoop() {
}
p.wg.Add(1)
go p.handleConnection(conn)
go p.handleConn(conn)
}
}
func (p *TunnelProxy) handleConnection(conn net.Conn) {
func (p *Proxy) handleConn(conn net.Conn) {
defer p.wg.Done()
defer conn.Close()
streamID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), p.port)
stream := &proxyStream{
conn: conn,
closed: false,
if p.sem != nil {
select {
case p.sem <- struct{}{}:
defer func() { <-p.sem }()
default:
return
}
}
p.streamMu.Lock()
p.streams[streamID] = stream
p.streamMu.Unlock()
if p.stats != nil {
p.stats.IncActiveConnections()
defer p.stats.DecActiveConnections()
}
defer func() {
p.streamMu.Lock()
delete(p.streams, streamID)
p.streamMu.Unlock()
if tcpConn, ok := conn.(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetReadBuffer(256 * 1024)
_ = tcpConn.SetWriteBuffer(256 * 1024)
}
if p.openStream == nil {
return
}
// Open stream with timeout to prevent goroutine leak
const openStreamTimeout = 10 * time.Second
type streamResult struct {
stream net.Conn
err error
}
resultCh := make(chan streamResult, 1)
go func() {
s, err := p.openStream()
resultCh <- streamResult{s, err}
}()
bufPtr := p.bufferPool.Get(pool.SizeMedium)
defer p.bufferPool.Put(bufPtr)
buffer := (*bufPtr)[:pool.SizeMedium]
for {
// Check if stream is closed
stream.mu.Lock()
closed := stream.closed
stream.mu.Unlock()
if closed {
break
}
n, err := conn.Read(buffer)
if err != nil {
break
}
if n > 0 {
if err := p.sendDataToTunnel(streamID, buffer[:n]); err != nil {
p.logger.Debug("Send to tunnel failed", zap.Error(err))
break
var stream net.Conn
select {
case result := <-resultCh:
if result.err != nil {
if !errors.Is(result.err, net.ErrClosed) {
p.logger.Debug("Open stream failed", zap.Error(result.err))
}
return
}
}
select {
stream = result.stream
case <-time.After(openStreamTimeout):
p.logger.Debug("Open stream timeout")
return
case <-p.stopCh:
default:
p.sendCloseToTunnel(streamID)
}
}
func (p *TunnelProxy) sendDataToTunnel(streamID string, data []byte) error {
select {
case <-p.stopCh:
return fmt.Errorf("tunnel proxy stopped")
default:
}
header := protocol.DataHeader{
StreamID: streamID,
RequestID: streamID,
Type: protocol.DataTypeData,
IsLast: false,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, data)
if err != nil {
return fmt.Errorf("failed to encode payload: %w", err)
}
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
err = p.frameWriter.WriteFrame(frame)
if err != nil {
return fmt.Errorf("failed to write frame: %w", err)
}
return nil
}
func (p *TunnelProxy) sendCloseToTunnel(streamID string) {
header := protocol.DataHeader{
StreamID: streamID,
RequestID: streamID,
Type: protocol.DataTypeClose,
IsLast: true,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
if err != nil {
return
}
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
p.frameWriter.WriteFrame(frame)
}
defer stream.Close()
func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
p.streamMu.RLock()
stream, ok := p.streams[streamID]
p.streamMu.RUnlock()
if !ok {
// Stream may have been closed by client, this is normal
return nil
}
// Check if stream is closed
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
return nil
}
stream.mu.Unlock()
if _, err := stream.conn.Write(data); err != nil {
p.logger.Debug("Write to client failed", zap.Error(err))
return err
}
return nil
}
// CloseStream closes a stream
func (p *TunnelProxy) CloseStream(streamID string) {
p.streamMu.RLock()
stream, ok := p.streams[streamID]
p.streamMu.RUnlock()
if !ok {
return
}
// Mark as closed first
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
return
}
stream.closed = true
stream.mu.Unlock()
// Now close the connection
stream.conn.Close()
}
func (p *TunnelProxy) Stop() {
p.logger.Info("Stopping TCP proxy",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
_ = netutil.PipeWithCallbacksAndBufferSize(
p.ctx,
conn,
stream,
pool.SizeLarge,
func(n int64) {
if p.stats != nil {
p.stats.AddBytesIn(n)
}
},
func(n int64) {
if p.stats != nil {
p.stats.AddBytesOut(n)
}
},
)
close(p.stopCh)
if p.listener != nil {
p.listener.Close()
}
p.streamMu.Lock()
for _, stream := range p.streams {
stream.mu.Lock()
stream.closed = true
stream.mu.Unlock()
stream.conn.Close()
}
p.streams = make(map[string]*proxyStream)
p.streamMu.Unlock()
p.wg.Wait()
if p.frameWriter != nil {
p.frameWriter.Close()
}
p.logger.Info("TCP proxy stopped", zap.Int("port", p.port))
}