mirror of
https://github.com/Gouryella/drip.git
synced 2026-04-28 21:29:58 +00:00
feat(tunnel): switch to yamux stream proxying and connection pooling
- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server - Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly - Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
This commit is contained in:
@@ -1,64 +1,79 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TunnelProxy handles TCP connections for a specific tunnel
|
||||
type TunnelProxy struct {
|
||||
port int
|
||||
subdomain string
|
||||
tcpConn net.Conn // The tunnel control connection
|
||||
listener net.Listener
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
clientAddr string
|
||||
streams map[string]*proxyStream // streamID -> stream info
|
||||
streamMu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
bufferPool *pool.BufferPool
|
||||
// Proxy exposes a public TCP port and forwards each incoming
|
||||
// connection over a dedicated mux stream.
|
||||
type Proxy struct {
|
||||
port int
|
||||
subdomain string
|
||||
logger *zap.Logger
|
||||
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
once sync.Once
|
||||
wg sync.WaitGroup
|
||||
|
||||
openStream func() (net.Conn, error)
|
||||
stats trafficStats
|
||||
sem chan struct{}
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// proxyStream holds connection info with close state
|
||||
type proxyStream struct {
|
||||
conn net.Conn
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
type trafficStats interface {
|
||||
AddBytesIn(n int64)
|
||||
AddBytesOut(n int64)
|
||||
IncActiveConnections()
|
||||
DecActiveConnections()
|
||||
}
|
||||
|
||||
// NewTunnelProxy creates a new TCP tunnel proxy
|
||||
func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Logger) *TunnelProxy {
|
||||
return &TunnelProxy{
|
||||
port: port,
|
||||
subdomain: subdomain,
|
||||
tcpConn: tcpConn,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
clientAddr: tcpConn.RemoteAddr().String(),
|
||||
streams: make(map[string]*proxyStream),
|
||||
bufferPool: pool.NewBufferPool(),
|
||||
frameWriter: protocol.NewFrameWriter(tcpConn),
|
||||
func NewProxy(ctx context.Context, port int, subdomain string, openStream func() (net.Conn, error), stats trafficStats, logger *zap.Logger) *Proxy {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
cctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
const maxConcurrentConnections = 10000
|
||||
var sem chan struct{}
|
||||
if maxConcurrentConnections > 0 {
|
||||
sem = make(chan struct{}, maxConcurrentConnections)
|
||||
}
|
||||
|
||||
return &Proxy{
|
||||
port: port,
|
||||
subdomain: subdomain,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
openStream: openStream,
|
||||
stats: stats,
|
||||
sem: sem,
|
||||
ctx: cctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts listening on the allocated port
|
||||
func (p *TunnelProxy) Start() error {
|
||||
func (p *Proxy) Start() error {
|
||||
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
|
||||
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
|
||||
}
|
||||
|
||||
p.listener = listener
|
||||
p.listener = ln
|
||||
|
||||
p.logger.Info("TCP proxy started",
|
||||
zap.Int("port", p.port),
|
||||
@@ -67,14 +82,47 @@ func (p *TunnelProxy) Start() error {
|
||||
|
||||
p.wg.Add(1)
|
||||
go p.acceptLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop accepts incoming TCP connections
|
||||
func (p *TunnelProxy) acceptLoop() {
|
||||
func (p *Proxy) Stop() {
|
||||
p.once.Do(func() {
|
||||
close(p.stopCh)
|
||||
p.cancel()
|
||||
|
||||
if p.listener != nil {
|
||||
_ = p.listener.Close()
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
p.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
const stopTimeout = 30 * time.Second
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
p.logger.Info("TCP proxy stopped",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
)
|
||||
case <-time.After(stopTimeout):
|
||||
p.logger.Warn("TCP proxy stop timed out",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
zap.Duration("timeout", stopTimeout),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Proxy) acceptLoop() {
|
||||
defer p.wg.Done()
|
||||
|
||||
tcpLn, _ := p.listener.(*net.TCPListener)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
@@ -82,11 +130,13 @@ func (p *TunnelProxy) acceptLoop() {
|
||||
default:
|
||||
}
|
||||
|
||||
p.listener.(*net.TCPListener).SetDeadline(time.Now().Add(1 * time.Second))
|
||||
if tcpLn != nil {
|
||||
_ = tcpLn.SetDeadline(time.Now().Add(1 * time.Second))
|
||||
}
|
||||
|
||||
conn, err := p.listener.Accept()
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
@@ -98,187 +148,86 @@ func (p *TunnelProxy) acceptLoop() {
|
||||
}
|
||||
|
||||
p.wg.Add(1)
|
||||
go p.handleConnection(conn)
|
||||
go p.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) handleConnection(conn net.Conn) {
|
||||
func (p *Proxy) handleConn(conn net.Conn) {
|
||||
defer p.wg.Done()
|
||||
defer conn.Close()
|
||||
|
||||
streamID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), p.port)
|
||||
|
||||
stream := &proxyStream{
|
||||
conn: conn,
|
||||
closed: false,
|
||||
if p.sem != nil {
|
||||
select {
|
||||
case p.sem <- struct{}{}:
|
||||
defer func() { <-p.sem }()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
p.streamMu.Lock()
|
||||
p.streams[streamID] = stream
|
||||
p.streamMu.Unlock()
|
||||
if p.stats != nil {
|
||||
p.stats.IncActiveConnections()
|
||||
defer p.stats.DecActiveConnections()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
p.streamMu.Lock()
|
||||
delete(p.streams, streamID)
|
||||
p.streamMu.Unlock()
|
||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||
_ = tcpConn.SetNoDelay(true)
|
||||
_ = tcpConn.SetKeepAlive(true)
|
||||
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
_ = tcpConn.SetReadBuffer(256 * 1024)
|
||||
_ = tcpConn.SetWriteBuffer(256 * 1024)
|
||||
}
|
||||
|
||||
if p.openStream == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Open stream with timeout to prevent goroutine leak
|
||||
const openStreamTimeout = 10 * time.Second
|
||||
type streamResult struct {
|
||||
stream net.Conn
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan streamResult, 1)
|
||||
|
||||
go func() {
|
||||
s, err := p.openStream()
|
||||
resultCh <- streamResult{s, err}
|
||||
}()
|
||||
|
||||
bufPtr := p.bufferPool.Get(pool.SizeMedium)
|
||||
defer p.bufferPool.Put(bufPtr)
|
||||
|
||||
buffer := (*bufPtr)[:pool.SizeMedium]
|
||||
|
||||
for {
|
||||
// Check if stream is closed
|
||||
stream.mu.Lock()
|
||||
closed := stream.closed
|
||||
stream.mu.Unlock()
|
||||
if closed {
|
||||
break
|
||||
}
|
||||
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
if err := p.sendDataToTunnel(streamID, buffer[:n]); err != nil {
|
||||
p.logger.Debug("Send to tunnel failed", zap.Error(err))
|
||||
break
|
||||
var stream net.Conn
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
if result.err != nil {
|
||||
if !errors.Is(result.err, net.ErrClosed) {
|
||||
p.logger.Debug("Open stream failed", zap.Error(result.err))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
stream = result.stream
|
||||
case <-time.After(openStreamTimeout):
|
||||
p.logger.Debug("Open stream timeout")
|
||||
return
|
||||
case <-p.stopCh:
|
||||
default:
|
||||
p.sendCloseToTunnel(streamID)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) sendDataToTunnel(streamID string, data []byte) error {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return fmt.Errorf("tunnel proxy stopped")
|
||||
default:
|
||||
}
|
||||
|
||||
header := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: streamID,
|
||||
Type: protocol.DataTypeData,
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode payload: %w", err)
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
|
||||
err = p.frameWriter.WriteFrame(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write frame: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) sendCloseToTunnel(streamID string) {
|
||||
header := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: streamID,
|
||||
Type: protocol.DataTypeClose,
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
p.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
|
||||
p.streamMu.RLock()
|
||||
stream, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
// Stream may have been closed by client, this is normal
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if stream is closed
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
stream.mu.Unlock()
|
||||
|
||||
if _, err := stream.conn.Write(data); err != nil {
|
||||
p.logger.Debug("Write to client failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseStream closes a stream
|
||||
func (p *TunnelProxy) CloseStream(streamID string) {
|
||||
p.streamMu.RLock()
|
||||
stream, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as closed first
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
return
|
||||
}
|
||||
stream.closed = true
|
||||
stream.mu.Unlock()
|
||||
|
||||
// Now close the connection
|
||||
stream.conn.Close()
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) Stop() {
|
||||
p.logger.Info("Stopping TCP proxy",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
_ = netutil.PipeWithCallbacksAndBufferSize(
|
||||
p.ctx,
|
||||
conn,
|
||||
stream,
|
||||
pool.SizeLarge,
|
||||
func(n int64) {
|
||||
if p.stats != nil {
|
||||
p.stats.AddBytesIn(n)
|
||||
}
|
||||
},
|
||||
func(n int64) {
|
||||
if p.stats != nil {
|
||||
p.stats.AddBytesOut(n)
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
close(p.stopCh)
|
||||
|
||||
if p.listener != nil {
|
||||
p.listener.Close()
|
||||
}
|
||||
|
||||
p.streamMu.Lock()
|
||||
for _, stream := range p.streams {
|
||||
stream.mu.Lock()
|
||||
stream.closed = true
|
||||
stream.mu.Unlock()
|
||||
stream.conn.Close()
|
||||
}
|
||||
p.streams = make(map[string]*proxyStream)
|
||||
p.streamMu.Unlock()
|
||||
|
||||
p.wg.Wait()
|
||||
|
||||
if p.frameWriter != nil {
|
||||
p.frameWriter.Close()
|
||||
}
|
||||
|
||||
p.logger.Info("TCP proxy stopped", zap.Int("port", p.port))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user