feat(tunnel): switch to yamux stream proxying and connection pooling

- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server
- Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly
- Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
This commit is contained in:
Gouryella
2025-12-13 18:03:44 +08:00
parent 3c93789266
commit 0c19c3300c
55 changed files with 3380 additions and 4849 deletions

View File

@@ -1,490 +1,251 @@
package proxy
import (
"context"
"bufio"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
json "github.com/goccy/go-json"
"drip/internal/server/tunnel"
"drip/internal/shared/pool"
"drip/internal/shared/httputil"
"drip/internal/shared/netutil"
"drip/internal/shared/protocol"
"drip/internal/shared/utils"
"go.uber.org/zap"
)
const openStreamTimeout = 10 * time.Second
type Handler struct {
manager *tunnel.Manager
logger *zap.Logger
responses *ResponseHandler
domain string
authToken string
headerPool *pool.HeaderPool
bufferPool *pool.AdaptiveBufferPool
manager *tunnel.Manager
logger *zap.Logger
domain string
authToken string
}
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, responses *ResponseHandler, domain string, authToken string) *Handler {
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, authToken string) *Handler {
return &Handler{
manager: manager,
logger: logger,
responses: responses,
domain: domain,
authToken: authToken,
headerPool: pool.NewHeaderPool(),
bufferPool: pool.NewAdaptiveBufferPool(),
manager: manager,
logger: logger,
domain: domain,
authToken: authToken,
}
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Always handle /health and /stats directly, regardless of subdomain
// Always handle /health and /stats directly, regardless of subdomain.
if r.URL.Path == "/health" {
h.serveHealth(w, r)
return
}
if r.URL.Path == "/stats" {
h.serveStats(w, r)
return
}
subdomain := h.extractSubdomain(r.Host)
if subdomain == "" {
h.serveHomePage(w, r)
return
}
conn, ok := h.manager.Get(subdomain)
if !ok {
tconn, ok := h.manager.Get(subdomain)
if !ok || tconn == nil {
http.Error(w, "Tunnel not found. The tunnel may have been closed.", http.StatusNotFound)
return
}
if conn.IsClosed() {
if tconn.IsClosed() {
http.Error(w, "Tunnel connection closed", http.StatusBadGateway)
return
}
transport := conn.GetTransport()
if transport == nil {
http.Error(w, "Tunnel control channel not ready", http.StatusBadGateway)
return
}
tType := conn.GetTunnelType()
tType := tconn.GetTunnelType()
if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS {
http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway)
return
}
requestID := utils.GenerateID()
// Check for WebSocket upgrade
if httputil.IsWebSocketUpgrade(r) {
h.handleWebSocket(w, r, tconn)
return
}
h.handleAdaptiveRequest(w, r, transport, requestID, subdomain)
}
// Open stream with timeout
stream, err := h.openStreamWithTimeout(tconn)
if err != nil {
w.Header().Set("Connection", "close")
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
return
}
defer stream.Close()
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
const streamingThreshold int64 = 1 * 1024 * 1024
// Track active connections
tconn.IncActiveConnections()
defer tconn.DecActiveConnections()
// Wrap stream with counting for traffic stats
countingStream := netutil.NewCountingConn(stream,
tconn.AddBytesOut, // Data read from stream = bytes out to client
tconn.AddBytesIn, // Data written to stream = bytes in from client
)
// 1) Write request over the stream (net/http handles large bodies correctly).
if err := r.Write(countingStream); err != nil {
w.Header().Set("Connection", "close")
_ = r.Body.Close()
http.Error(w, "Forward failed", http.StatusBadGateway)
return
}
// 2) Read response from stream.
resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r)
if err != nil {
w.Header().Set("Connection", "close")
http.Error(w, "Read response failed", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 3) Copy headers (strip hop-by-hop).
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
statusCode := resp.StatusCode
if statusCode == 0 {
statusCode = http.StatusOK
}
// Ensure message delimiting works with our custom ResponseWriter:
// - If Content-Length is known, send it.
// - Otherwise, re-chunk the decoded body ourselves.
if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified {
if resp.ContentLength >= 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
} else {
w.Header().Del("Content-Length")
}
w.WriteHeader(statusCode)
return
}
if resp.ContentLength >= 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
w.WriteHeader(statusCode)
ctx := r.Context()
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
stream.Close()
case <-done:
}
}()
_, _ = io.Copy(w, resp.Body)
close(done)
stream.Close()
return
}
w.Header().Del("Content-Length")
w.Header().Set("Transfer-Encoding", "chunked")
if len(resp.Trailer) > 0 {
w.Header().Set("Trailer", trailerKeys(resp.Trailer))
}
w.WriteHeader(statusCode)
ctx := r.Context()
var cancelTransport func()
if transport != nil {
cancelOnce := sync.Once{}
cancelFunc := func() {
header := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeClose,
IsLast: true,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
if err != nil {
return
}
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
if err := transport.SendFrame(frame); err != nil {
h.logger.Debug("Failed to send cancel frame to client",
zap.String("request_id", requestID),
zap.Error(err),
)
}
}
cancelTransport = func() {
cancelOnce.Do(cancelFunc)
}
h.responses.RegisterCancelFunc(requestID, cancelTransport)
defer h.responses.CleanupCancelFunc(requestID)
}
largeBufferPtr := h.bufferPool.GetLarge()
tempBufPtr := h.bufferPool.GetMedium()
defer func() {
h.bufferPool.PutLarge(largeBufferPtr)
h.bufferPool.PutMedium(tempBufPtr)
}()
buffer := (*largeBufferPtr)[:0]
tempBuf := (*tempBufPtr)[:pool.MediumBufferSize]
var totalRead int64
var hitThreshold bool
for totalRead < streamingThreshold {
n, err := r.Body.Read(tempBuf)
if n > 0 {
buffer = append(buffer, tempBuf[:n]...)
totalRead += int64(n)
}
if err == io.EOF {
r.Body.Close()
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
return
}
if err != nil {
r.Body.Close()
h.logger.Error("Read request body failed", zap.Error(err))
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
if totalRead >= streamingThreshold {
hitThreshold = true
break
}
}
if !hitThreshold {
r.Body.Close()
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
return
}
h.streamLargeRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
}
func (h *Handler) sendBufferedRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), body []byte) {
headers := h.headerPool.Get()
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
httpReq := protocol.HTTPRequest{
Method: r.Method,
URL: r.URL.String(),
Headers: headers,
Body: body,
}
reqBytes, err := protocol.EncodeHTTPRequest(&httpReq)
h.headerPool.Put(headers)
if err != nil {
h.logger.Error("Encode HTTP request failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
header := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequest,
IsLast: true,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, reqBytes)
if err != nil {
h.logger.Error("Encode data payload failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
respChan := h.responses.CreateResponseChan(requestID)
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
defer func() {
h.responses.CleanupResponseChan(requestID)
h.responses.CleanupStreamingResponse(requestID)
}()
if err := transport.SendFrame(frame); err != nil {
h.logger.Error("Send frame to tunnel failed", zap.Error(err))
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
return
}
select {
case respMsg := <-respChan:
if respMsg == nil {
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
return
}
h.writeHTTPResponse(w, respMsg, subdomain, r)
case <-streamingDone:
// Streaming response has been fully written by SendStreamingChunk
case <-ctx.Done():
if cancelTransport != nil {
cancelTransport()
}
h.logger.Debug("HTTP request context cancelled",
zap.String("request_id", requestID),
zap.String("subdomain", subdomain),
)
return
case <-time.After(5 * time.Minute):
h.logger.Error("Request timeout",
zap.String("request_id", requestID),
zap.String("url", r.URL.String()),
)
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
}
}
func (h *Handler) streamLargeRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), bufferedData []byte) {
headers := h.headerPool.Get()
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
httpReqHead := protocol.HTTPRequestHead{
Method: r.Method,
URL: r.URL.String(),
Headers: headers,
ContentLength: r.ContentLength,
}
headBytes, err := protocol.EncodeHTTPRequestHead(&httpReqHead)
h.headerPool.Put(headers)
if err != nil {
h.logger.Error("Encode HTTP request head failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
headHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPHead, // shared streaming head type
IsLast: false,
}
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
if err != nil {
h.logger.Error("Encode head payload failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
respChan := h.responses.CreateResponseChan(requestID)
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
defer func() {
h.responses.CleanupResponseChan(requestID)
h.responses.CleanupStreamingResponse(requestID)
}()
if err := transport.SendFrame(headFrame); err != nil {
h.logger.Error("Send head frame failed", zap.Error(err))
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
return
}
if len(bufferedData) > 0 {
chunkHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
IsLast: false,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, bufferedData)
if err != nil {
h.logger.Error("Encode buffered chunk failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := transport.SendFrame(chunkFrame); err != nil {
h.logger.Error("Send buffered chunk failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
}
streamBufPtr := h.bufferPool.GetMedium()
defer h.bufferPool.PutMedium(streamBufPtr)
buffer := (*streamBufPtr)[:pool.MediumBufferSize]
for {
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
if cancelTransport != nil {
cancelTransport()
}
h.logger.Debug("Streaming request cancelled via context",
zap.String("request_id", requestID),
zap.String("subdomain", subdomain),
)
return
default:
stream.Close()
case <-done:
}
}()
n, readErr := r.Body.Read(buffer)
if n > 0 {
isLast := readErr == io.EOF
chunkHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
IsLast: isLast,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer[:n])
if err != nil {
h.logger.Error("Encode chunk payload failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := transport.SendFrame(chunkFrame); err != nil {
h.logger.Error("Send chunk frame failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
}
if readErr == io.EOF {
if n == 0 {
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
}
break
}
if readErr != nil {
h.logger.Error("Read request body failed", zap.Error(readErr))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
if err := writeChunked(w, resp.Body, resp.Trailer); err != nil {
h.logger.Debug("Write chunked response failed", zap.Error(err))
}
close(done)
stream.Close()
}
r.Body.Close()
func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, error) {
type result struct {
stream net.Conn
err error
}
ch := make(chan result, 1)
go func() {
s, err := tconn.OpenStream()
ch <- result{s, err}
}()
select {
case respMsg := <-respChan:
if respMsg == nil {
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
return
}
h.writeHTTPResponse(w, respMsg, subdomain, r)
case <-streamingDone:
// Streaming response has been fully written by SendStreamingChunk
case <-ctx.Done():
if cancelTransport != nil {
cancelTransport()
}
h.logger.Debug("Streaming HTTP request context cancelled",
zap.String("request_id", requestID),
zap.String("subdomain", subdomain),
)
return
case <-time.After(5 * time.Minute):
h.logger.Error("Streaming request timeout",
zap.String("request_id", requestID),
zap.String("url", r.URL.String()),
)
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
case r := <-ch:
return r.stream, r.err
case <-time.After(openStreamTimeout):
return nil, fmt.Errorf("open stream timeout")
}
}
func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPResponse, subdomain string, r *http.Request) {
if resp == nil {
http.Error(w, "Invalid response from tunnel", http.StatusBadGateway)
func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection) {
stream, err := h.openStreamWithTimeout(tconn)
if err != nil {
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
return
}
// For buffered responses, we have the complete body, so we can set Content-Length
// Skip ALL hop-by-hop headers - client should have already cleaned them
for key, values := range resp.Headers {
tconn.IncActiveConnections()
hj, ok := w.(http.Hijacker)
if !ok {
stream.Close()
tconn.DecActiveConnections()
http.Error(w, "WebSocket not supported", http.StatusInternalServerError)
return
}
clientConn, _, err := hj.Hijack()
if err != nil {
stream.Close()
tconn.DecActiveConnections()
http.Error(w, "Failed to hijack connection", http.StatusInternalServerError)
return
}
if err := r.Write(stream); err != nil {
stream.Close()
clientConn.Close()
tconn.DecActiveConnections()
return
}
go func() {
defer stream.Close()
defer clientConn.Close()
defer tconn.DecActiveConnections()
_ = netutil.PipeWithCallbacks(r.Context(), stream, clientConn,
func(n int64) { tconn.AddBytesOut(n) },
func(n int64) { tconn.AddBytesIn(n) },
)
}()
}
func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHost string) {
for key, values := range src {
canonicalKey := http.CanonicalHeaderKey(key)
// Skip hop-by-hop headers completely using canonical key comparison
// Hop-by-hop headers must not be forwarded.
if canonicalKey == "Connection" ||
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
@@ -496,29 +257,61 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
}
if canonicalKey == "Location" && len(values) > 0 {
rewrittenLocation := h.rewriteLocationHeader(values[0], r.Host)
w.Header().Set("Location", rewrittenLocation)
dst.Set("Location", h.rewriteLocationHeader(values[0], proxyHost))
continue
}
for _, value := range values {
w.Header().Add(key, value)
dst.Add(key, value)
}
}
}
func trailerKeys(hdr http.Header) string {
keys := make([]string, 0, len(hdr))
for k := range hdr {
keys = append(keys, k)
}
// Deterministic order is nicer for debugging; no semantic impact.
sortStrings(keys)
return strings.Join(keys, ", ")
}
func writeChunked(w io.Writer, r io.Reader, trailer http.Header) error {
buf := make([]byte, 32*1024)
for {
n, err := r.Read(buf)
if n > 0 {
if _, werr := fmt.Fprintf(w, "%x\r\n", n); werr != nil {
return werr
}
if _, werr := w.Write(buf[:n]); werr != nil {
return werr
}
if _, werr := io.WriteString(w, "\r\n"); werr != nil {
return werr
}
}
if err == io.EOF {
break
}
if err != nil {
return err
}
}
// For buffered mode, always set Content-Length with the actual body size
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
statusCode := resp.StatusCode
if statusCode == 0 {
statusCode = http.StatusOK
if _, err := io.WriteString(w, "0\r\n"); err != nil {
return err
}
w.WriteHeader(statusCode)
if len(resp.Body) > 0 {
w.Write(resp.Body)
for k, vv := range trailer {
for _, v := range vv {
if _, err := io.WriteString(w, fmt.Sprintf("%s: %s\r\n", k, v)); err != nil {
return err
}
}
}
_, err := io.WriteString(w, "\r\n")
return err
}
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
@@ -535,22 +328,13 @@ func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
strings.HasPrefix(locationURL.Host, "localhost:") ||
locationURL.Host == "127.0.0.1" ||
strings.HasPrefix(locationURL.Host, "127.0.0.1:") {
scheme := "https"
if strings.Contains(proxyHost, ":") && !strings.Contains(proxyHost, "https") {
parts := strings.Split(proxyHost, ":")
if len(parts) == 2 && parts[1] != "443" {
scheme = "https"
}
}
rewritten := fmt.Sprintf("%s://%s%s", scheme, proxyHost, locationURL.Path)
rewritten := fmt.Sprintf("https://%s%s", proxyHost, locationURL.Path)
if locationURL.RawQuery != "" {
rewritten += "?" + locationURL.RawQuery
}
if locationURL.Fragment != "" {
rewritten += "#" + locationURL.Fragment
}
return rewritten
}
@@ -568,8 +352,7 @@ func (h *Handler) extractSubdomain(host string) string {
suffix := "." + h.domain
if strings.HasSuffix(host, suffix) {
subdomain := strings.TrimSuffix(host, suffix)
return subdomain
return strings.TrimSuffix(host, suffix)
}
return ""
@@ -652,9 +435,17 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
}
for _, conn := range connections {
if conn == nil {
continue
}
stats["tunnels"] = append(stats["tunnels"].([]map[string]interface{}), map[string]interface{}{
"subdomain": conn.Subdomain,
"last_active": conn.LastActive.Unix(),
"subdomain": conn.Subdomain,
"tunnel_type": string(conn.GetTunnelType()),
"last_active": conn.LastActive.Unix(),
"bytes_in": conn.GetBytesIn(),
"bytes_out": conn.GetBytesOut(),
"active_connections": conn.GetActiveConnections(),
"total_bytes": conn.GetBytesIn() + conn.GetBytesOut(),
})
}
@@ -668,3 +459,13 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
func sortStrings(s []string) {
for i := 0; i < len(s); i++ {
for j := i + 1; j < len(s); j++ {
if s[j] < s[i] {
s[i], s[j] = s[j], s[i]
}
}
}
}

View File

@@ -1,421 +0,0 @@
package proxy
import (
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// responseChanEntry holds a response channel and its creation time
type responseChanEntry struct {
ch chan *protocol.HTTPResponse
createdAt time.Time
}
// streamingResponseEntry holds a streaming response writer
type streamingResponseEntry struct {
w http.ResponseWriter
flusher http.Flusher
createdAt time.Time
lastActivityAt time.Time
headersSent bool
done chan struct{}
mu sync.Mutex
}
// ResponseHandler manages response channels for HTTP requests over TCP/Frame protocol
type ResponseHandler struct {
channels map[string]*responseChanEntry
streamingChannels map[string]*streamingResponseEntry
cancelFuncs map[string]func()
mu sync.RWMutex
logger *zap.Logger
stopCh chan struct{}
}
// NewResponseHandler creates a new response handler
func NewResponseHandler(logger *zap.Logger) *ResponseHandler {
h := &ResponseHandler{
channels: make(map[string]*responseChanEntry),
streamingChannels: make(map[string]*streamingResponseEntry),
cancelFuncs: make(map[string]func()),
logger: logger,
stopCh: make(chan struct{}),
}
// Start single cleanup goroutine instead of one per request
go h.cleanupLoop()
return h
}
// CreateResponseChan creates a response channel for a request ID
func (h *ResponseHandler) CreateResponseChan(requestID string) chan *protocol.HTTPResponse {
h.mu.Lock()
defer h.mu.Unlock()
ch := make(chan *protocol.HTTPResponse, 1)
h.channels[requestID] = &responseChanEntry{
ch: ch,
createdAt: time.Now(),
}
return ch
}
// CreateStreamingResponse creates a streaming response entry for a request ID
func (h *ResponseHandler) CreateStreamingResponse(requestID string, w http.ResponseWriter) chan struct{} {
h.mu.Lock()
defer h.mu.Unlock()
flusher, _ := w.(http.Flusher)
done := make(chan struct{})
now := time.Now()
h.streamingChannels[requestID] = &streamingResponseEntry{
w: w,
flusher: flusher,
createdAt: now,
lastActivityAt: now,
done: done,
}
return done
}
// RegisterCancelFunc registers a callback to be invoked when the downstream disconnects.
func (h *ResponseHandler) RegisterCancelFunc(requestID string, cancel func()) {
if cancel == nil {
return
}
h.mu.Lock()
h.cancelFuncs[requestID] = cancel
h.mu.Unlock()
}
// GetResponseChan gets the response channel for a request ID
func (h *ResponseHandler) GetResponseChan(requestID string) <-chan *protocol.HTTPResponse {
h.mu.RLock()
defer h.mu.RUnlock()
if entry := h.channels[requestID]; entry != nil {
return entry.ch
}
return nil
}
// SendResponse sends a response to the waiting channel
func (h *ResponseHandler) SendResponse(requestID string, resp *protocol.HTTPResponse) {
h.mu.RLock()
entry, exists := h.channels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
return
}
select {
case entry.ch <- resp:
case <-time.After(30 * time.Second):
h.logger.Error("Timeout sending response to channel - handler may have abandoned",
zap.String("request_id", requestID),
zap.Int("status_code", resp.StatusCode),
zap.Int("body_size", len(resp.Body)),
)
}
}
func (h *ResponseHandler) SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error {
h.mu.RLock()
entry, exists := h.streamingChannels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
return nil
}
entry.mu.Lock()
defer entry.mu.Unlock()
select {
case <-entry.done:
return nil
default:
}
if entry.headersSent {
return nil
}
// Copy headers, removing hop-by-hop headers that were already handled by client
// Client's cleanResponseHeaders already removed Transfer-Encoding, Connection, etc.
// But we need to check again in case they slipped through
hasContentLength := false
for key, values := range head.Headers {
canonicalKey := http.CanonicalHeaderKey(key)
// Skip ALL hop-by-hop headers
if canonicalKey == "Connection" ||
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
canonicalKey == "Upgrade" ||
canonicalKey == "Proxy-Connection" ||
canonicalKey == "Te" ||
canonicalKey == "Trailer" {
continue
}
if canonicalKey == "Content-Length" {
hasContentLength = true
}
for _, value := range values {
entry.w.Header().Add(key, value)
}
}
// For streaming responses, decide how to indicate message length
if head.ContentLength >= 0 && !hasContentLength {
entry.w.Header().Set("Content-Length", fmt.Sprintf("%d", head.ContentLength))
}
statusCode := head.StatusCode
if statusCode == 0 {
statusCode = http.StatusOK
}
entry.w.WriteHeader(statusCode)
entry.headersSent = true
entry.lastActivityAt = time.Now()
if entry.flusher != nil {
entry.flusher.Flush()
}
return nil
}
func (h *ResponseHandler) SendStreamingChunk(requestID string, chunk []byte, isLast bool) error {
h.mu.RLock()
entry, exists := h.streamingChannels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
return nil
}
entry.mu.Lock()
defer entry.mu.Unlock()
select {
case <-entry.done:
return nil
default:
}
if len(chunk) > 0 {
_, err := entry.w.Write(chunk)
if err != nil {
if isClientDisconnectError(err) {
select {
case <-entry.done:
default:
close(entry.done)
}
h.triggerCancel(requestID)
return nil
}
select {
case <-entry.done:
default:
close(entry.done)
}
h.triggerCancel(requestID)
return nil
}
entry.lastActivityAt = time.Now()
if entry.flusher != nil {
entry.flusher.Flush()
}
}
if isLast {
select {
case <-entry.done:
default:
close(entry.done)
}
}
return nil
}
func isClientDisconnectError(err error) bool {
if err == nil {
return false
}
if netErr, ok := err.(*net.OpError); ok {
if netErr.Err != nil {
errStr := netErr.Err.Error()
if strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "connection refused") {
return true
}
}
}
errStr := err.Error()
return strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "use of closed network connection")
}
// triggerCancel invokes and removes the cancel callback for a request.
func (h *ResponseHandler) triggerCancel(requestID string) {
h.mu.Lock()
cancel := h.cancelFuncs[requestID]
if cancel != nil {
delete(h.cancelFuncs, requestID)
}
h.mu.Unlock()
if cancel != nil {
go func() {
cancel()
}()
}
}
func (h *ResponseHandler) CleanupResponseChan(requestID string) {
h.mu.Lock()
defer h.mu.Unlock()
if entry, exists := h.channels[requestID]; exists {
close(entry.ch)
delete(h.channels, requestID)
}
}
func (h *ResponseHandler) CleanupStreamingResponse(requestID string) {
h.mu.Lock()
defer h.mu.Unlock()
if entry, exists := h.streamingChannels[requestID]; exists {
select {
case <-entry.done:
default:
close(entry.done)
}
delete(h.streamingChannels, requestID)
}
}
// CleanupCancelFunc removes a registered cancel callback.
func (h *ResponseHandler) CleanupCancelFunc(requestID string) {
h.mu.Lock()
delete(h.cancelFuncs, requestID)
h.mu.Unlock()
}
func (h *ResponseHandler) GetPendingCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.channels) + len(h.streamingChannels)
}
func (h *ResponseHandler) cleanupLoop() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h.cleanupExpiredChannels()
case <-h.stopCh:
return
}
}
}
func (h *ResponseHandler) cleanupExpiredChannels() {
now := time.Now()
timeout := 5 * time.Minute
streamingTimeout := 5 * time.Minute
h.mu.Lock()
defer h.mu.Unlock()
expiredCount := 0
cancelList := make([]string, 0)
for requestID, entry := range h.channels {
if now.Sub(entry.createdAt) > timeout {
close(entry.ch)
delete(h.channels, requestID)
expiredCount++
}
}
for requestID, entry := range h.streamingChannels {
if now.Sub(entry.lastActivityAt) > streamingTimeout {
select {
case <-entry.done:
default:
close(entry.done)
}
delete(h.streamingChannels, requestID)
cancelList = append(cancelList, requestID)
expiredCount++
}
}
for _, requestID := range cancelList {
if cancel := h.cancelFuncs[requestID]; cancel != nil {
delete(h.cancelFuncs, requestID)
go cancel()
}
}
if expiredCount > 0 {
h.logger.Debug("Cleaned up expired response channels",
zap.Int("count", expiredCount),
zap.Int("remaining", len(h.channels)+len(h.streamingChannels)),
)
}
}
func (h *ResponseHandler) Close() {
close(h.stopCh)
h.mu.Lock()
defer h.mu.Unlock()
for _, entry := range h.channels {
close(entry.ch)
}
h.channels = make(map[string]*responseChanEntry)
for _, entry := range h.streamingChannels {
select {
case <-entry.done:
default:
close(entry.done)
}
}
h.streamingChannels = make(map[string]*streamingResponseEntry)
for _, cancel := range h.cancelFuncs {
cancel()
}
h.cancelFuncs = make(map[string]func())
}

View File

@@ -13,6 +13,7 @@ import (
"time"
json "github.com/goccy/go-json"
"github.com/hashicorp/yamux"
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
@@ -33,36 +34,27 @@ type Connection struct {
publicPort int
portAlloc *PortAllocator
tunnelConn *tunnel.Connection
proxy *TunnelProxy
stopCh chan struct{}
once sync.Once
lastHeartbeat time.Time
mu sync.RWMutex
frameWriter *protocol.FrameWriter
httpHandler http.Handler
responseChans HTTPResponseHandler
tunnelType protocol.TunnelType // Track tunnel type
ctx context.Context
cancel context.CancelFunc
// Flow control
paused bool
pauseCond *sync.Cond
}
// gost-like TCP tunnel (yamux)
session *yamux.Session
proxy *Proxy
// HTTPResponseHandler interface for response channel operations
type HTTPResponseHandler interface {
CreateResponseChan(requestID string) chan *protocol.HTTPResponse
GetResponseChan(requestID string) <-chan *protocol.HTTPResponse
CleanupResponseChan(requestID string)
SendResponse(requestID string, resp *protocol.HTTPResponse)
// Streaming response methods
SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error
SendStreamingChunk(requestID string, chunk []byte, isLast bool) error
// Multi-connection support
tunnelID string
groupManager *ConnectionGroupManager
}
// NewConnection creates a new connection handler
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Connection {
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager) *Connection {
ctx, cancel := context.WithCancel(context.Background())
c := &Connection{
conn: conn,
@@ -73,13 +65,12 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
domain: domain,
publicPort: publicPort,
httpHandler: httpHandler,
responseChans: responseChans,
stopCh: make(chan struct{}),
lastHeartbeat: time.Now(),
ctx: ctx,
cancel: cancel,
groupManager: groupManager,
}
c.pauseCond = sync.NewCond(&c.mu)
return c
}
@@ -97,8 +88,8 @@ func (c *Connection) Handle() error {
// Use buffered reader to support peeking
reader := bufio.NewReader(c.conn)
// Peek first 8 bytes to detect protocol
peek, err := reader.Peek(8)
// Peek first 4 bytes to detect protocol (HTTP methods are 4 bytes).
peek, err := reader.Peek(4)
if err != nil {
return fmt.Errorf("failed to peek connection: %w", err)
}
@@ -127,6 +118,11 @@ func (c *Connection) Handle() error {
sf := protocol.WithFrame(frame)
defer sf.Close()
// Handle data connection request (for multi-connection pool)
if sf.Frame.Type == protocol.FrameTypeDataConnect {
return c.handleDataConnect(sf.Frame, reader)
}
if sf.Frame.Type != protocol.FrameTypeRegister {
return fmt.Errorf("expected register frame, got %s", sf.Frame.Type)
}
@@ -180,7 +176,6 @@ func (c *Connection) Handle() error {
// Store TCP connection reference and metadata for HTTP proxy routing
c.tunnelConn.Conn = nil // We're using TCP, not WebSocket
c.tunnelConn.SetTransport(c, req.TunnelType)
c.tunnelConn.SetTunnelType(req.TunnelType)
c.tunnelType = req.TunnelType
@@ -208,11 +203,33 @@ func (c *Connection) Handle() error {
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
}
// Generate TunnelID for multi-connection support if client supports it
var tunnelID string
var supportsDataConn bool
recommendedConns := 0
if req.PoolCapabilities != nil && req.ConnectionType == "primary" && c.groupManager != nil {
// Client supports connection pooling
group := c.groupManager.CreateGroup(subdomain, req.Token, c, req.TunnelType)
tunnelID = group.TunnelID
c.tunnelID = tunnelID
supportsDataConn = true
recommendedConns = 4 // Recommend 4 data connections
c.logger.Info("Created connection group for multi-connection support",
zap.String("tunnel_id", tunnelID),
zap.Int("max_data_conns", req.PoolCapabilities.MaxDataConns),
)
}
resp := protocol.RegisterResponse{
Subdomain: subdomain,
Port: c.port,
URL: tunnelURL,
Message: "Tunnel registered successfully",
Subdomain: subdomain,
Port: c.port,
URL: tunnelURL,
Message: "Tunnel registered successfully",
TunnelID: tunnelID,
SupportsDataConn: supportsDataConn,
RecommendedConns: recommendedConns,
}
respData, _ := json.Marshal(resp)
@@ -224,6 +241,17 @@ func (c *Connection) Handle() error {
return fmt.Errorf("failed to send registration ack: %w", err)
}
// Clear deadline for tunnel data-plane.
c.conn.SetReadDeadline(time.Time{})
// gost-like tunnels: switch to yamux after RegisterAck.
if req.TunnelType == protocol.TunnelTypeTCP {
return c.handleTCPTunnel(reader)
}
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
return c.handleHTTPProxyTunnel(reader)
}
c.frameWriter = protocol.NewFrameWriter(c.conn)
c.frameWriter.SetWriteErrorHandler(func(err error) {
@@ -231,15 +259,6 @@ func (c *Connection) Handle() error {
c.Close()
})
c.conn.SetReadDeadline(time.Time{})
if req.TunnelType == protocol.TunnelTypeTCP {
c.proxy = NewTunnelProxy(c.port, subdomain, c.conn, c.logger)
if err := c.proxy.Start(); err != nil {
return fmt.Errorf("failed to start TCP proxy: %w", err)
}
}
go c.heartbeatChecker()
return c.handleFrames(reader)
@@ -376,7 +395,7 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(reader)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
if isTimeoutError(err) {
c.logger.Warn("Read timeout, connection may be dead")
return fmt.Errorf("read timeout")
}
@@ -404,15 +423,6 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
c.handleHeartbeat()
sf.Close()
case protocol.FrameTypeData:
// Data frame from client (response to forwarded request)
c.handleDataFrame(sf.Frame)
sf.Close()
case protocol.FrameTypeFlowControl:
c.handleFlowControl(sf.Frame)
sf.Close()
case protocol.FrameTypeClose:
sf.Close()
c.logger.Info("Client requested close")
@@ -436,127 +446,12 @@ func (c *Connection) handleHeartbeat() {
// Send heartbeat ack
ackFrame := protocol.NewFrame(protocol.FrameTypeHeartbeatAck, nil)
err := c.frameWriter.WriteFrame(ackFrame)
err := c.frameWriter.WriteControl(ackFrame)
if err != nil {
c.logger.Error("Failed to send heartbeat ack", zap.Error(err))
}
}
// handleDataFrame handles data frame (response from client)
func (c *Connection) handleDataFrame(frame *protocol.Frame) {
// Decode payload (auto-detects protocol version)
header, data, err := protocol.DecodeDataPayload(frame.Payload)
if err != nil {
c.logger.Error("Failed to decode data payload",
zap.Error(err),
)
return
}
c.logger.Debug("Received data frame",
zap.String("stream_id", header.StreamID),
zap.String("type", header.Type.String()),
zap.Int("data_size", len(data)),
)
switch header.Type {
case protocol.DataTypeResponse:
// TCP tunnel response, forward to proxy
if c.proxy != nil {
if err := c.proxy.HandleResponse(header.StreamID, data); err != nil {
c.logger.Error("Failed to handle response",
zap.String("stream_id", header.StreamID),
zap.Error(err),
)
}
}
case protocol.DataTypeHTTPResponse:
if c.responseChans == nil {
c.logger.Warn("No response channel handler for HTTP response",
zap.String("stream_id", header.StreamID),
)
return
}
// Decode HTTP response (auto-detects JSON vs msgpack)
httpResp, err := protocol.DecodeHTTPResponse(data)
if err != nil {
c.logger.Error("Failed to decode HTTP response",
zap.String("stream_id", header.StreamID),
zap.Error(err),
)
return
}
// Route by request ID when provided to keep request/response aligned.
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
c.responseChans.SendResponse(reqID, httpResp)
case protocol.DataTypeHTTPHead:
// Streaming HTTP response headers
if c.responseChans == nil {
c.logger.Warn("No response handler for streaming HTTP head",
zap.String("stream_id", header.StreamID),
)
return
}
httpHead, err := protocol.DecodeHTTPResponseHead(data)
if err != nil {
c.logger.Error("Failed to decode HTTP response head",
zap.String("stream_id", header.StreamID),
zap.Error(err),
)
return
}
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
if err := c.responseChans.SendStreamingHead(reqID, httpHead); err != nil {
c.logger.Error("Failed to send streaming head",
zap.String("request_id", reqID),
zap.Error(err),
)
}
case protocol.DataTypeHTTPBodyChunk:
// Streaming HTTP response body chunk
if c.responseChans == nil {
c.logger.Warn("No response handler for streaming HTTP chunk",
zap.String("stream_id", header.StreamID),
)
return
}
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
if err := c.responseChans.SendStreamingChunk(reqID, data, header.IsLast); err != nil {
c.logger.Error("Failed to send streaming chunk",
zap.String("request_id", reqID),
zap.Error(err),
)
}
case protocol.DataTypeClose:
// Client is closing the stream
if c.proxy != nil {
c.proxy.CloseStream(header.StreamID)
}
default:
c.logger.Warn("Unknown data frame type",
zap.String("type", header.Type.String()),
zap.String("stream_id", header.StreamID),
)
}
}
// heartbeatChecker checks for heartbeat timeout
func (c *Connection) heartbeatChecker() {
ticker := time.NewTicker(constants.HeartbeatInterval)
@@ -583,16 +478,6 @@ func (c *Connection) heartbeatChecker() {
}
}
func (c *Connection) SendFrame(frame *protocol.Frame) error {
if c.frameWriter == nil {
return protocol.WriteFrame(c.conn, frame)
}
if frame.Type == protocol.FrameTypeData {
return c.sendWithBackpressure(frame)
}
return c.frameWriter.WriteFrame(frame)
}
func (c *Connection) sendError(code, message string) {
errMsg := protocol.ErrorMessage{
Code: code,
@@ -618,8 +503,12 @@ func (c *Connection) Close() {
c.cancel()
}
// Ensure any in-flight writes return quickly on shutdown to avoid hanging.
if c.conn != nil {
_ = c.conn.SetDeadline(time.Now())
}
if c.frameWriter != nil {
c.frameWriter.Flush()
c.frameWriter.Close()
}
@@ -627,7 +516,13 @@ func (c *Connection) Close() {
c.proxy.Stop()
}
c.conn.Close()
if c.session != nil {
_ = c.session.Close()
}
if c.conn != nil {
c.conn.Close()
}
if c.port > 0 && c.portAlloc != nil {
c.portAlloc.Release(c.port)
@@ -635,6 +530,12 @@ func (c *Connection) Close() {
if c.subdomain != "" {
c.manager.Unregister(c.subdomain)
// Clean up connection group when PRIMARY connection closes
// (only primary connections have subdomain set)
if c.tunnelID != "" && c.groupManager != nil {
c.groupManager.RemoveGroup(c.tunnelID)
}
}
c.logger.Info("Connection closed",
@@ -643,11 +544,6 @@ func (c *Connection) Close() {
})
}
// GetSubdomain returns the assigned subdomain
func (c *Connection) GetSubdomain() string {
return c.subdomain
}
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
type httpResponseWriter struct {
conn net.Conn
@@ -698,39 +594,196 @@ func (w *httpResponseWriter) Write(data []byte) (int, error) {
return w.writer.Write(data)
}
func (c *Connection) handleFlowControl(frame *protocol.Frame) {
msg, err := protocol.DecodeFlowControlMessage(frame.Payload)
// handleDataConnect handles a data connection join request
func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Reader) error {
var req protocol.DataConnectRequest
if err := json.Unmarshal(frame.Payload, &req); err != nil {
c.sendError("invalid_request", "Failed to parse data connect request")
return fmt.Errorf("failed to parse data connect request: %w", err)
}
c.logger.Info("Data connection request received",
zap.String("tunnel_id", req.TunnelID),
zap.String("connection_id", req.ConnectionID),
)
// Validate the request
if c.groupManager == nil {
c.sendDataConnectError("not_supported", "Multi-connection not supported")
return fmt.Errorf("group manager not available")
}
// Validate auth token
if c.authToken != "" && req.Token != c.authToken {
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed for data connection")
}
group, ok := c.groupManager.GetGroup(req.TunnelID)
if !ok || group == nil {
c.sendDataConnectError("join_failed", "Tunnel not found")
return fmt.Errorf("tunnel not found: %s", req.TunnelID)
}
// Validate token against the primary registration token.
if group.Token != "" && req.Token != group.Token {
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed for data connection")
}
// Store tunnelID for cleanup
c.tunnelID = req.TunnelID
// For TCP tunnels, the data connection is upgraded to a yamux session and used for
// stream forwarding, not framed request/response routing.
if group.TunnelType == protocol.TunnelTypeTCP {
resp := protocol.DataConnectResponse{
Accepted: true,
ConnectionID: req.ConnectionID,
Message: "Data connection accepted",
}
respData, _ := json.Marshal(resp)
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
return fmt.Errorf("failed to send data connect ack: %w", err)
}
c.logger.Info("TCP data connection established",
zap.String("tunnel_id", req.TunnelID),
zap.String("connection_id", req.ConnectionID),
)
// Clear deadline for yamux data-plane.
_ = c.conn.SetReadDeadline(time.Time{})
// Public server acts as yamux Client, client connector acts as yamux Server.
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Client(bc, cfg)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
group.AddSession(req.ConnectionID, session)
defer group.RemoveSession(req.ConnectionID)
select {
case <-c.stopCh:
return nil
case <-session.CloseChan():
return nil
}
}
// Add data connection to group
dataConn, err := c.groupManager.AddDataConnection(&req, c.conn)
if err != nil {
c.logger.Error("Failed to decode flow control", zap.Error(err))
return
c.sendDataConnectError("join_failed", err.Error())
return fmt.Errorf("failed to join connection group: %w", err)
}
c.mu.Lock()
defer c.mu.Unlock()
// Send success response
resp := protocol.DataConnectResponse{
Accepted: true,
ConnectionID: req.ConnectionID,
Message: "Data connection accepted",
}
switch msg.Action {
case protocol.FlowControlPause:
c.paused = true
c.logger.Warn("Client requested pause",
zap.String("stream", msg.StreamID))
respData, _ := json.Marshal(resp)
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
case protocol.FlowControlResume:
c.paused = false
c.pauseCond.Broadcast()
c.logger.Info("Client requested resume",
zap.String("stream", msg.StreamID))
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
return fmt.Errorf("failed to send data connect ack: %w", err)
}
default:
c.logger.Warn("Unknown flow control action",
zap.String("action", string(msg.Action)))
c.logger.Info("Data connection established",
zap.String("tunnel_id", req.TunnelID),
zap.String("connection_id", req.ConnectionID),
)
// Handle data frames on this connection
return c.handleDataConnectionFrames(dataConn, reader)
}
// handleDataConnectionFrames handles frames on a data connection
func (c *Connection) handleDataConnectionFrames(dataConn *DataConnection, reader *bufio.Reader) error {
defer func() {
// Get the group and remove this data connection
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok {
group.RemoveDataConnection(dataConn.ID)
}
}()
for {
select {
case <-dataConn.stopCh:
return nil
default:
}
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(reader)
if err != nil {
// Timeout is OK, continue
if isTimeoutError(err) {
continue
}
return err
}
dataConn.mu.Lock()
dataConn.LastActive = time.Now()
dataConn.mu.Unlock()
sf := protocol.WithFrame(frame)
switch sf.Frame.Type {
case protocol.FrameTypeClose:
sf.Close()
c.logger.Info("Data connection closed by client",
zap.String("connection_id", dataConn.ID))
return nil
default:
sf.Close()
c.logger.Warn("Unexpected frame type on data connection",
zap.String("type", sf.Frame.Type.String()),
zap.String("connection_id", dataConn.ID),
)
}
}
}
func (c *Connection) sendWithBackpressure(frame *protocol.Frame) error {
c.mu.Lock()
for c.paused {
c.pauseCond.Wait()
func isTimeoutError(err error) bool {
if err == nil {
return false
}
c.mu.Unlock()
return c.frameWriter.WriteFrame(frame)
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
// Fallback for wrapped errors without net.Error
return strings.Contains(err.Error(), "i/o timeout")
}
// sendDataConnectError sends a data connect error response
func (c *Connection) sendDataConnectError(code, message string) {
resp := protocol.DataConnectResponse{
Accepted: false,
Message: fmt.Sprintf("%s: %s", code, message),
}
respData, _ := json.Marshal(resp)
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
protocol.WriteFrame(c.conn, frame)
}

View File

@@ -0,0 +1,438 @@
package tcp
import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/yamux"
"drip/internal/shared/constants"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
type DataConnection struct {
ID string
Conn net.Conn
LastActive time.Time
closed bool
closedMu sync.RWMutex
stopCh chan struct{}
mu sync.RWMutex
}
type ConnectionGroup struct {
TunnelID string
Subdomain string
Token string
PrimaryConn *Connection
DataConns map[string]*DataConnection
Sessions map[string]*yamux.Session
TunnelType protocol.TunnelType
RegisteredAt time.Time
LastActivity time.Time
sessionIdx uint32
mu sync.RWMutex
stopCh chan struct{}
logger *zap.Logger
heartbeatStarted bool
}
func NewConnectionGroup(tunnelID, subdomain, token string, primaryConn *Connection, tunnelType protocol.TunnelType, logger *zap.Logger) *ConnectionGroup {
return &ConnectionGroup{
TunnelID: tunnelID,
Subdomain: subdomain,
Token: token,
PrimaryConn: primaryConn,
DataConns: make(map[string]*DataConnection),
Sessions: make(map[string]*yamux.Session),
TunnelType: tunnelType,
RegisteredAt: time.Now(),
LastActivity: time.Now(),
stopCh: make(chan struct{}),
logger: logger.With(zap.String("tunnel_id", tunnelID)),
}
}
// StartHeartbeat starts a goroutine that periodically pings all sessions
// and removes dead ones. The caller should ensure this is only called once.
func (g *ConnectionGroup) StartHeartbeat(interval, timeout time.Duration) {
go g.heartbeatLoop(interval, timeout)
}
func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
const maxConsecutiveFailures = 3
failureCount := make(map[string]int)
for {
select {
case <-g.stopCh:
return
case <-ticker.C:
}
g.mu.RLock()
sessions := make(map[string]*yamux.Session, len(g.Sessions))
for id, s := range g.Sessions {
sessions[id] = s
}
g.mu.RUnlock()
for id, session := range sessions {
if session == nil || session.IsClosed() {
g.RemoveSession(id)
delete(failureCount, id)
continue
}
// Ping with timeout
done := make(chan error, 1)
go func(s *yamux.Session) {
_, err := s.Ping()
done <- err
}(session)
var err error
select {
case err = <-done:
case <-time.After(timeout):
err = fmt.Errorf("ping timeout")
case <-g.stopCh:
return
}
if err != nil {
failureCount[id]++
g.logger.Debug("Session ping failed",
zap.String("session_id", id),
zap.Int("consecutive_failures", failureCount[id]),
zap.Error(err),
)
if failureCount[id] >= maxConsecutiveFailures {
g.logger.Warn("Session ping failed too many times, removing",
zap.String("session_id", id),
zap.Int("failures", failureCount[id]),
)
g.RemoveSession(id)
delete(failureCount, id)
}
} else {
// Reset on success
failureCount[id] = 0
g.mu.Lock()
g.LastActivity = time.Now()
g.mu.Unlock()
}
}
// Check if all sessions are gone
g.mu.RLock()
sessionCount := len(g.Sessions)
g.mu.RUnlock()
if sessionCount == 0 {
g.logger.Info("All sessions closed, tunnel will be cleaned up")
}
}
}
func (g *ConnectionGroup) AddDataConnection(connID string, conn net.Conn) *DataConnection {
g.mu.Lock()
defer g.mu.Unlock()
dataConn := &DataConnection{
ID: connID,
Conn: conn,
LastActive: time.Now(),
stopCh: make(chan struct{}),
}
g.DataConns[connID] = dataConn
g.LastActivity = time.Now()
return dataConn
}
func (g *ConnectionGroup) RemoveDataConnection(connID string) {
g.mu.Lock()
defer g.mu.Unlock()
if dataConn, ok := g.DataConns[connID]; ok {
dataConn.closedMu.Lock()
if !dataConn.closed {
dataConn.closed = true
close(dataConn.stopCh)
if dataConn.Conn != nil {
_ = dataConn.Conn.SetDeadline(time.Now())
dataConn.Conn.Close()
}
}
dataConn.closedMu.Unlock()
delete(g.DataConns, connID)
}
}
func (g *ConnectionGroup) DataConnectionCount() int {
g.mu.RLock()
defer g.mu.RUnlock()
return len(g.DataConns)
}
func (g *ConnectionGroup) Close() {
g.mu.Lock()
select {
case <-g.stopCh:
g.mu.Unlock()
return
default:
close(g.stopCh)
}
dataConns := make([]*DataConnection, 0, len(g.DataConns))
for _, dataConn := range g.DataConns {
dataConns = append(dataConns, dataConn)
}
g.DataConns = make(map[string]*DataConnection)
sessions := make([]*yamux.Session, 0, len(g.Sessions))
for _, session := range g.Sessions {
if session != nil {
sessions = append(sessions, session)
}
}
g.Sessions = make(map[string]*yamux.Session)
g.mu.Unlock()
for _, dataConn := range dataConns {
dataConn.closedMu.Lock()
if !dataConn.closed {
dataConn.closed = true
close(dataConn.stopCh)
if dataConn.Conn != nil {
_ = dataConn.Conn.SetDeadline(time.Now())
_ = dataConn.Conn.Close()
}
}
dataConn.closedMu.Unlock()
}
for _, session := range sessions {
_ = session.Close()
}
}
func (g *ConnectionGroup) IsStale(timeout time.Duration) bool {
g.mu.RLock()
defer g.mu.RUnlock()
return time.Since(g.LastActivity) > timeout
}
func (g *ConnectionGroup) AddSession(connID string, session *yamux.Session) {
if connID == "" || session == nil {
return
}
g.mu.Lock()
if g.Sessions == nil {
g.Sessions = make(map[string]*yamux.Session)
}
g.Sessions[connID] = session
g.LastActivity = time.Now()
// Start heartbeat on first session
shouldStartHeartbeat := !g.heartbeatStarted
if shouldStartHeartbeat {
g.heartbeatStarted = true
}
g.mu.Unlock()
if shouldStartHeartbeat {
g.StartHeartbeat(constants.HeartbeatInterval, constants.HeartbeatTimeout)
}
}
func (g *ConnectionGroup) RemoveSession(connID string) {
if connID == "" {
return
}
var session *yamux.Session
g.mu.Lock()
if g.Sessions != nil {
session = g.Sessions[connID]
delete(g.Sessions, connID)
}
g.mu.Unlock()
if session != nil {
_ = session.Close()
}
}
func (g *ConnectionGroup) SessionCount() int {
g.mu.RLock()
defer g.mu.RUnlock()
return len(g.Sessions)
}
func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
const (
maxStreamsPerSession = 256
maxRetries = 3
backoffBase = 25 * time.Millisecond
)
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
select {
case <-g.stopCh:
return nil, net.ErrClosed
default:
}
sessions := g.sessionsSnapshot()
if len(sessions) == 0 {
return nil, net.ErrClosed
}
tried := make([]bool, len(sessions))
anyUnderCap := false
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
for range sessions {
bestIdx := -1
minStreams := int(^uint(0) >> 1)
for i := 0; i < len(sessions); i++ {
idx := (start + i) % len(sessions)
if tried[idx] {
continue
}
session := sessions[idx]
if session == nil || session.IsClosed() {
tried[idx] = true
continue
}
n := session.NumStreams()
if n >= maxStreamsPerSession {
continue
}
anyUnderCap = true
if n < minStreams {
minStreams = n
bestIdx = idx
}
}
if bestIdx == -1 {
break
}
tried[bestIdx] = true
session := sessions[bestIdx]
if session == nil || session.IsClosed() {
continue
}
stream, err := session.Open()
if err == nil {
return stream, nil
}
lastErr = err
if session.IsClosed() {
g.deleteClosedSessions()
}
}
if !anyUnderCap {
lastErr = fmt.Errorf("all sessions are at stream capacity (%d)", maxStreamsPerSession)
}
if attempt < maxRetries-1 {
select {
case <-g.stopCh:
return nil, net.ErrClosed
case <-time.After(backoffBase * time.Duration(attempt+1)):
}
}
}
if lastErr == nil {
lastErr = fmt.Errorf("failed to open stream")
}
return nil, lastErr
}
func (g *ConnectionGroup) selectSession() *yamux.Session {
sessions := g.sessionsSnapshot()
if len(sessions) == 0 {
return nil
}
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
minStreams := int(^uint(0) >> 1)
var best *yamux.Session
for i := 0; i < len(sessions); i++ {
session := sessions[(start+i)%len(sessions)]
if session == nil || session.IsClosed() {
continue
}
if n := session.NumStreams(); n < minStreams {
minStreams = n
best = session
}
}
return best
}
func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
g.mu.Lock()
defer g.mu.Unlock()
if len(g.Sessions) == 0 {
return nil
}
sessions := make([]*yamux.Session, 0, len(g.Sessions))
for id, session := range g.Sessions {
if session == nil || session.IsClosed() {
delete(g.Sessions, id)
continue
}
sessions = append(sessions, session)
}
if len(sessions) > 0 {
g.LastActivity = time.Now()
}
return sessions
}
func (g *ConnectionGroup) deleteClosedSessions() {
g.mu.Lock()
for id, session := range g.Sessions {
if session == nil || session.IsClosed() {
delete(g.Sessions, id)
}
}
g.mu.Unlock()
}

View File

@@ -0,0 +1,163 @@
package tcp
import (
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"sync"
"time"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// ConnectionGroupManager manages all connection groups
type ConnectionGroupManager struct {
groups map[string]*ConnectionGroup // TunnelID -> ConnectionGroup
mu sync.RWMutex
logger *zap.Logger
// Cleanup
cleanupInterval time.Duration
staleTimeout time.Duration
stopCh chan struct{}
}
// NewConnectionGroupManager creates a new connection group manager
func NewConnectionGroupManager(logger *zap.Logger) *ConnectionGroupManager {
m := &ConnectionGroupManager{
groups: make(map[string]*ConnectionGroup),
logger: logger,
cleanupInterval: 60 * time.Second,
staleTimeout: 5 * time.Minute,
stopCh: make(chan struct{}),
}
go m.cleanupLoop()
return m
}
// GenerateTunnelID generates a unique tunnel ID
func GenerateTunnelID() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}
// CreateGroup creates a new connection group
func (m *ConnectionGroupManager) CreateGroup(subdomain, token string, primaryConn *Connection, tunnelType protocol.TunnelType) *ConnectionGroup {
m.mu.Lock()
defer m.mu.Unlock()
tunnelID := GenerateTunnelID()
group := NewConnectionGroup(tunnelID, subdomain, token, primaryConn, tunnelType, m.logger)
m.groups[tunnelID] = group
return group
}
// GetGroup returns a connection group by tunnel ID
func (m *ConnectionGroupManager) GetGroup(tunnelID string) (*ConnectionGroup, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
group, ok := m.groups[tunnelID]
return group, ok
}
// RemoveGroup removes and closes a connection group
func (m *ConnectionGroupManager) RemoveGroup(tunnelID string) {
m.mu.Lock()
group, ok := m.groups[tunnelID]
if ok {
delete(m.groups, tunnelID)
}
m.mu.Unlock()
if ok && group != nil {
group.Close()
}
}
// AddDataConnection adds a data connection to a group
func (m *ConnectionGroupManager) AddDataConnection(req *protocol.DataConnectRequest, conn net.Conn) (*DataConnection, error) {
m.mu.RLock()
group, ok := m.groups[req.TunnelID]
m.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("tunnel not found: %s", req.TunnelID)
}
// Validate token
if group.Token != "" && req.Token != group.Token {
return nil, fmt.Errorf("invalid token")
}
dataConn := group.AddDataConnection(req.ConnectionID, conn)
return dataConn, nil
}
// cleanupLoop periodically cleans up stale groups
func (m *ConnectionGroupManager) cleanupLoop() {
ticker := time.NewTicker(m.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
m.cleanupStaleGroups()
case <-m.stopCh:
return
}
}
}
func (m *ConnectionGroupManager) cleanupStaleGroups() {
// Collect stale groups under lock
m.mu.Lock()
var staleGroups []*ConnectionGroup
var staleIDs []string
for tunnelID, group := range m.groups {
if group.IsStale(m.staleTimeout) {
staleIDs = append(staleIDs, tunnelID)
staleGroups = append(staleGroups, group)
}
}
// Remove from map while holding lock
for _, tunnelID := range staleIDs {
delete(m.groups, tunnelID)
}
m.mu.Unlock()
// Close groups without holding lock to avoid blocking other operations
for _, group := range staleGroups {
group.Close()
}
}
// Close shuts down the manager
func (m *ConnectionGroupManager) Close() {
close(m.stopCh)
// Collect all groups under lock
m.mu.Lock()
groups := make([]*ConnectionGroup, 0, len(m.groups))
for _, group := range m.groups {
groups = append(groups, group)
}
m.groups = make(map[string]*ConnectionGroup)
m.mu.Unlock()
// Close groups without holding lock
for _, group := range groups {
group.Close()
}
}

View File

@@ -12,32 +12,34 @@ import (
"drip/internal/server/tunnel"
"drip/internal/shared/pool"
"drip/internal/shared/recovery"
"go.uber.org/zap"
)
// Listener handles TCP connections with TLS 1.3
type Listener struct {
address string
tlsConfig *tls.Config
authToken string
manager *tunnel.Manager
portAlloc *PortAllocator
logger *zap.Logger
domain string
publicPort int
httpHandler http.Handler
responseChans HTTPResponseHandler
listener net.Listener
stopCh chan struct{}
wg sync.WaitGroup
connections map[string]*Connection
connMu sync.RWMutex
workerPool *pool.WorkerPool // Worker pool for connection handling
recoverer *recovery.Recoverer
address string
tlsConfig *tls.Config
authToken string
manager *tunnel.Manager
portAlloc *PortAllocator
logger *zap.Logger
domain string
publicPort int
httpHandler http.Handler
listener net.Listener
stopCh chan struct{}
wg sync.WaitGroup
connections map[string]*Connection
connMu sync.RWMutex
workerPool *pool.WorkerPool // Worker pool for connection handling
recoverer *recovery.Recoverer
panicMetrics *recovery.PanicMetrics
groupManager *ConnectionGroupManager
}
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Listener {
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler) *Listener {
numCPU := pool.NumCPU()
workers := numCPU * 5
queueSize := workers * 20
@@ -53,21 +55,21 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
recoverer := recovery.NewRecoverer(logger, panicMetrics)
return &Listener{
address: address,
tlsConfig: tlsConfig,
authToken: authToken,
manager: manager,
portAlloc: portAlloc,
logger: logger,
domain: domain,
publicPort: publicPort,
httpHandler: httpHandler,
responseChans: responseChans,
stopCh: make(chan struct{}),
connections: make(map[string]*Connection),
workerPool: workerPool,
recoverer: recoverer,
panicMetrics: panicMetrics,
address: address,
tlsConfig: tlsConfig,
authToken: authToken,
manager: manager,
portAlloc: portAlloc,
logger: logger,
domain: domain,
publicPort: publicPort,
httpHandler: httpHandler,
stopCh: make(chan struct{}),
connections: make(map[string]*Connection),
workerPool: workerPool,
recoverer: recoverer,
panicMetrics: panicMetrics,
groupManager: NewConnectionGroupManager(logger),
}
}
@@ -206,7 +208,7 @@ func (l *Listener) handleConnection(netConn net.Conn) {
return
}
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.responseChans)
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.groupManager)
connID := netConn.RemoteAddr().String()
l.connMu.Lock()
@@ -222,14 +224,11 @@ func (l *Listener) handleConnection(netConn net.Conn) {
if err := conn.Handle(); err != nil {
errStr := err.Error()
// Client disconnection errors - normal network behavior, log as DEBUG
if strings.Contains(errStr, "connection reset by peer") ||
// Client disconnection errors - normal network behavior, ignore
if strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") {
l.logger.Debug("Client disconnected",
zap.String("remote_addr", connID),
zap.Error(err),
)
return
}
@@ -277,6 +276,10 @@ func (l *Listener) Stop() error {
l.workerPool.Close()
}
if l.groupManager != nil {
l.groupManager.Close()
}
l.logger.Info("TCP listener stopped")
return nil
}

View File

@@ -1,64 +1,79 @@
package tcp
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// TunnelProxy handles TCP connections for a specific tunnel
type TunnelProxy struct {
port int
subdomain string
tcpConn net.Conn // The tunnel control connection
listener net.Listener
logger *zap.Logger
stopCh chan struct{}
wg sync.WaitGroup
clientAddr string
streams map[string]*proxyStream // streamID -> stream info
streamMu sync.RWMutex
frameWriter *protocol.FrameWriter
bufferPool *pool.BufferPool
// Proxy exposes a public TCP port and forwards each incoming
// connection over a dedicated mux stream.
type Proxy struct {
port int
subdomain string
logger *zap.Logger
listener net.Listener
stopCh chan struct{}
once sync.Once
wg sync.WaitGroup
openStream func() (net.Conn, error)
stats trafficStats
sem chan struct{}
ctx context.Context
cancel context.CancelFunc
}
// proxyStream holds connection info with close state
type proxyStream struct {
conn net.Conn
closed bool
mu sync.Mutex
type trafficStats interface {
AddBytesIn(n int64)
AddBytesOut(n int64)
IncActiveConnections()
DecActiveConnections()
}
// NewTunnelProxy creates a new TCP tunnel proxy
func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Logger) *TunnelProxy {
return &TunnelProxy{
port: port,
subdomain: subdomain,
tcpConn: tcpConn,
logger: logger,
stopCh: make(chan struct{}),
clientAddr: tcpConn.RemoteAddr().String(),
streams: make(map[string]*proxyStream),
bufferPool: pool.NewBufferPool(),
frameWriter: protocol.NewFrameWriter(tcpConn),
func NewProxy(ctx context.Context, port int, subdomain string, openStream func() (net.Conn, error), stats trafficStats, logger *zap.Logger) *Proxy {
if ctx == nil {
ctx = context.Background()
}
cctx, cancel := context.WithCancel(ctx)
const maxConcurrentConnections = 10000
var sem chan struct{}
if maxConcurrentConnections > 0 {
sem = make(chan struct{}, maxConcurrentConnections)
}
return &Proxy{
port: port,
subdomain: subdomain,
logger: logger,
stopCh: make(chan struct{}),
openStream: openStream,
stats: stats,
sem: sem,
ctx: cctx,
cancel: cancel,
}
}
// Start starts listening on the allocated port
func (p *TunnelProxy) Start() error {
func (p *Proxy) Start() error {
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
listener, err := net.Listen("tcp", addr)
ln, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
}
p.listener = listener
p.listener = ln
p.logger.Info("TCP proxy started",
zap.Int("port", p.port),
@@ -67,14 +82,47 @@ func (p *TunnelProxy) Start() error {
p.wg.Add(1)
go p.acceptLoop()
return nil
}
// acceptLoop accepts incoming TCP connections
func (p *TunnelProxy) acceptLoop() {
func (p *Proxy) Stop() {
p.once.Do(func() {
close(p.stopCh)
p.cancel()
if p.listener != nil {
_ = p.listener.Close()
}
done := make(chan struct{})
go func() {
p.wg.Wait()
close(done)
}()
const stopTimeout = 30 * time.Second
select {
case <-done:
p.logger.Info("TCP proxy stopped",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
)
case <-time.After(stopTimeout):
p.logger.Warn("TCP proxy stop timed out",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
zap.Duration("timeout", stopTimeout),
)
}
})
}
func (p *Proxy) acceptLoop() {
defer p.wg.Done()
tcpLn, _ := p.listener.(*net.TCPListener)
for {
select {
case <-p.stopCh:
@@ -82,11 +130,13 @@ func (p *TunnelProxy) acceptLoop() {
default:
}
p.listener.(*net.TCPListener).SetDeadline(time.Now().Add(1 * time.Second))
if tcpLn != nil {
_ = tcpLn.SetDeadline(time.Now().Add(1 * time.Second))
}
conn, err := p.listener.Accept()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
if ne, ok := err.(net.Error); ok && ne.Timeout() {
continue
}
select {
@@ -98,187 +148,86 @@ func (p *TunnelProxy) acceptLoop() {
}
p.wg.Add(1)
go p.handleConnection(conn)
go p.handleConn(conn)
}
}
func (p *TunnelProxy) handleConnection(conn net.Conn) {
func (p *Proxy) handleConn(conn net.Conn) {
defer p.wg.Done()
defer conn.Close()
streamID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), p.port)
stream := &proxyStream{
conn: conn,
closed: false,
if p.sem != nil {
select {
case p.sem <- struct{}{}:
defer func() { <-p.sem }()
default:
return
}
}
p.streamMu.Lock()
p.streams[streamID] = stream
p.streamMu.Unlock()
if p.stats != nil {
p.stats.IncActiveConnections()
defer p.stats.DecActiveConnections()
}
defer func() {
p.streamMu.Lock()
delete(p.streams, streamID)
p.streamMu.Unlock()
if tcpConn, ok := conn.(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetReadBuffer(256 * 1024)
_ = tcpConn.SetWriteBuffer(256 * 1024)
}
if p.openStream == nil {
return
}
// Open stream with timeout to prevent goroutine leak
const openStreamTimeout = 10 * time.Second
type streamResult struct {
stream net.Conn
err error
}
resultCh := make(chan streamResult, 1)
go func() {
s, err := p.openStream()
resultCh <- streamResult{s, err}
}()
bufPtr := p.bufferPool.Get(pool.SizeMedium)
defer p.bufferPool.Put(bufPtr)
buffer := (*bufPtr)[:pool.SizeMedium]
for {
// Check if stream is closed
stream.mu.Lock()
closed := stream.closed
stream.mu.Unlock()
if closed {
break
}
n, err := conn.Read(buffer)
if err != nil {
break
}
if n > 0 {
if err := p.sendDataToTunnel(streamID, buffer[:n]); err != nil {
p.logger.Debug("Send to tunnel failed", zap.Error(err))
break
var stream net.Conn
select {
case result := <-resultCh:
if result.err != nil {
if !errors.Is(result.err, net.ErrClosed) {
p.logger.Debug("Open stream failed", zap.Error(result.err))
}
return
}
}
select {
stream = result.stream
case <-time.After(openStreamTimeout):
p.logger.Debug("Open stream timeout")
return
case <-p.stopCh:
default:
p.sendCloseToTunnel(streamID)
}
}
func (p *TunnelProxy) sendDataToTunnel(streamID string, data []byte) error {
select {
case <-p.stopCh:
return fmt.Errorf("tunnel proxy stopped")
default:
}
header := protocol.DataHeader{
StreamID: streamID,
RequestID: streamID,
Type: protocol.DataTypeData,
IsLast: false,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, data)
if err != nil {
return fmt.Errorf("failed to encode payload: %w", err)
}
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
err = p.frameWriter.WriteFrame(frame)
if err != nil {
return fmt.Errorf("failed to write frame: %w", err)
}
return nil
}
func (p *TunnelProxy) sendCloseToTunnel(streamID string) {
header := protocol.DataHeader{
StreamID: streamID,
RequestID: streamID,
Type: protocol.DataTypeClose,
IsLast: true,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
if err != nil {
return
}
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
p.frameWriter.WriteFrame(frame)
}
defer stream.Close()
func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
p.streamMu.RLock()
stream, ok := p.streams[streamID]
p.streamMu.RUnlock()
if !ok {
// Stream may have been closed by client, this is normal
return nil
}
// Check if stream is closed
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
return nil
}
stream.mu.Unlock()
if _, err := stream.conn.Write(data); err != nil {
p.logger.Debug("Write to client failed", zap.Error(err))
return err
}
return nil
}
// CloseStream closes a stream
func (p *TunnelProxy) CloseStream(streamID string) {
p.streamMu.RLock()
stream, ok := p.streams[streamID]
p.streamMu.RUnlock()
if !ok {
return
}
// Mark as closed first
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
return
}
stream.closed = true
stream.mu.Unlock()
// Now close the connection
stream.conn.Close()
}
func (p *TunnelProxy) Stop() {
p.logger.Info("Stopping TCP proxy",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
_ = netutil.PipeWithCallbacksAndBufferSize(
p.ctx,
conn,
stream,
pool.SizeLarge,
func(n int64) {
if p.stats != nil {
p.stats.AddBytesIn(n)
}
},
func(n int64) {
if p.stats != nil {
p.stats.AddBytesOut(n)
}
},
)
close(p.stopCh)
if p.listener != nil {
p.listener.Close()
}
p.streamMu.Lock()
for _, stream := range p.streams {
stream.mu.Lock()
stream.closed = true
stream.mu.Unlock()
stream.conn.Close()
}
p.streams = make(map[string]*proxyStream)
p.streamMu.Unlock()
p.wg.Wait()
if p.frameWriter != nil {
p.frameWriter.Close()
}
p.logger.Info("TCP proxy stopped", zap.Int("port", p.port))
}

View File

@@ -0,0 +1,98 @@
package tcp
import (
"bufio"
"fmt"
"io"
"net"
"github.com/hashicorp/yamux"
"drip/internal/shared/constants"
)
type bufferedConn struct {
net.Conn
reader *bufio.Reader
}
func (c *bufferedConn) Read(p []byte) (int, error) {
return c.reader.Read(p)
}
func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
// Public server acts as yamux Client, client connector acts as yamux Server.
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Client(bc, cfg)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
openStream := session.Open
if c.tunnelID != "" && c.groupManager != nil {
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
group.AddSession("primary", session)
openStream = group.OpenStream
}
}
c.proxy = NewProxy(c.ctx, c.port, c.subdomain, openStream, c.tunnelConn, c.logger)
if err := c.proxy.Start(); err != nil {
return fmt.Errorf("failed to start tcp proxy: %w", err)
}
select {
case <-c.stopCh:
return nil
case <-session.CloseChan():
return nil
}
}
func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
// Public server acts as yamux Client, client connector acts as yamux Server.
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Client(bc, cfg)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
openStream := session.Open
if c.tunnelID != "" && c.groupManager != nil {
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
group.AddSession("primary", session)
openStream = group.OpenStream
}
}
if c.tunnelConn != nil {
c.tunnelConn.SetOpenStream(openStream)
}
select {
case <-c.stopCh:
return nil
case <-session.CloseChan():
return nil
}
}

View File

@@ -1,7 +1,9 @@
package tunnel
import (
"net"
"sync"
"sync/atomic"
"time"
"drip/internal/shared/protocol"
@@ -9,13 +11,6 @@ import (
"go.uber.org/zap"
)
// Transport represents the control channel to the client.
// It is implemented by the TCP control connection so the HTTP proxy
// can push frames directly to the client without depending on WebSockets.
type Transport interface {
SendFrame(frame *protocol.Frame) error
}
// Connection represents a tunnel connection from a client
type Connection struct {
Subdomain string
@@ -26,8 +21,12 @@ type Connection struct {
mu sync.RWMutex
logger *zap.Logger
closed bool
transport Transport
tunnelType protocol.TunnelType
openStream func() (net.Conn, error)
bytesIn atomic.Int64
bytesOut atomic.Int64
activeConnections atomic.Int64
}
// NewConnection creates a new tunnel connection
@@ -106,21 +105,6 @@ func (c *Connection) IsClosed() bool {
return c.closed
}
// SetTransport attaches the control transport and tunnel type.
func (c *Connection) SetTransport(t Transport, tType protocol.TunnelType) {
c.mu.Lock()
c.transport = t
c.tunnelType = tType
c.mu.Unlock()
}
// GetTransport returns the attached transport (if any).
func (c *Connection) GetTransport() Transport {
c.mu.RLock()
defer c.mu.RUnlock()
return c.transport
}
// SetTunnelType sets the tunnel type.
func (c *Connection) SetTunnelType(tType protocol.TunnelType) {
c.mu.Lock()
@@ -135,6 +119,63 @@ func (c *Connection) GetTunnelType() protocol.TunnelType {
return c.tunnelType
}
// SetOpenStream registers a yamux stream opener for this tunnel.
// It is used by the HTTP proxy to forward each request over a mux stream.
func (c *Connection) SetOpenStream(open func() (net.Conn, error)) {
c.mu.Lock()
c.openStream = open
c.mu.Unlock()
}
// OpenStream opens a new mux stream to the tunnel client.
func (c *Connection) OpenStream() (net.Conn, error) {
c.mu.RLock()
open := c.openStream
closed := c.closed
c.mu.RUnlock()
if closed || open == nil {
return nil, ErrConnectionClosed
}
return open()
}
func (c *Connection) AddBytesIn(n int64) {
if n <= 0 {
return
}
c.bytesIn.Add(n)
}
func (c *Connection) AddBytesOut(n int64) {
if n <= 0 {
return
}
c.bytesOut.Add(n)
}
func (c *Connection) GetBytesIn() int64 {
return c.bytesIn.Load()
}
func (c *Connection) GetBytesOut() int64 {
return c.bytesOut.Load()
}
func (c *Connection) IncActiveConnections() {
c.activeConnections.Add(1)
}
func (c *Connection) DecActiveConnections() {
if v := c.activeConnections.Add(-1); v < 0 {
c.activeConnections.Store(0)
}
}
func (c *Connection) GetActiveConnections() int64 {
return c.activeConnections.Load()
}
// StartWritePump starts the write pump for sending messages
func (c *Connection) StartWritePump() {
// Skip write pump for TCP-only connections (no WebSocket)