mirror of
https://github.com/Gouryella/drip.git
synced 2026-03-01 15:52:32 +00:00
feat(tunnel): switch to yamux stream proxying and connection pooling
- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server - Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly - Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
This commit is contained in:
@@ -1,490 +1,251 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/httputil"
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/utils"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const openStreamTimeout = 10 * time.Second
|
||||
|
||||
type Handler struct {
|
||||
manager *tunnel.Manager
|
||||
logger *zap.Logger
|
||||
responses *ResponseHandler
|
||||
domain string
|
||||
authToken string
|
||||
headerPool *pool.HeaderPool
|
||||
bufferPool *pool.AdaptiveBufferPool
|
||||
manager *tunnel.Manager
|
||||
logger *zap.Logger
|
||||
domain string
|
||||
authToken string
|
||||
}
|
||||
|
||||
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, responses *ResponseHandler, domain string, authToken string) *Handler {
|
||||
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, authToken string) *Handler {
|
||||
return &Handler{
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
responses: responses,
|
||||
domain: domain,
|
||||
authToken: authToken,
|
||||
headerPool: pool.NewHeaderPool(),
|
||||
bufferPool: pool.NewAdaptiveBufferPool(),
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
domain: domain,
|
||||
authToken: authToken,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Always handle /health and /stats directly, regardless of subdomain
|
||||
// Always handle /health and /stats directly, regardless of subdomain.
|
||||
if r.URL.Path == "/health" {
|
||||
h.serveHealth(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/stats" {
|
||||
h.serveStats(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
subdomain := h.extractSubdomain(r.Host)
|
||||
|
||||
if subdomain == "" {
|
||||
h.serveHomePage(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
conn, ok := h.manager.Get(subdomain)
|
||||
if !ok {
|
||||
tconn, ok := h.manager.Get(subdomain)
|
||||
if !ok || tconn == nil {
|
||||
http.Error(w, "Tunnel not found. The tunnel may have been closed.", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if conn.IsClosed() {
|
||||
if tconn.IsClosed() {
|
||||
http.Error(w, "Tunnel connection closed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
transport := conn.GetTransport()
|
||||
if transport == nil {
|
||||
http.Error(w, "Tunnel control channel not ready", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
tType := conn.GetTunnelType()
|
||||
tType := tconn.GetTunnelType()
|
||||
if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS {
|
||||
http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
requestID := utils.GenerateID()
|
||||
// Check for WebSocket upgrade
|
||||
if httputil.IsWebSocketUpgrade(r) {
|
||||
h.handleWebSocket(w, r, tconn)
|
||||
return
|
||||
}
|
||||
|
||||
h.handleAdaptiveRequest(w, r, transport, requestID, subdomain)
|
||||
}
|
||||
// Open stream with timeout
|
||||
stream, err := h.openStreamWithTimeout(tconn)
|
||||
if err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
|
||||
const streamingThreshold int64 = 1 * 1024 * 1024
|
||||
// Track active connections
|
||||
tconn.IncActiveConnections()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
// Wrap stream with counting for traffic stats
|
||||
countingStream := netutil.NewCountingConn(stream,
|
||||
tconn.AddBytesOut, // Data read from stream = bytes out to client
|
||||
tconn.AddBytesIn, // Data written to stream = bytes in from client
|
||||
)
|
||||
|
||||
// 1) Write request over the stream (net/http handles large bodies correctly).
|
||||
if err := r.Write(countingStream); err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
_ = r.Body.Close()
|
||||
http.Error(w, "Forward failed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// 2) Read response from stream.
|
||||
resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r)
|
||||
if err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
http.Error(w, "Read response failed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 3) Copy headers (strip hop-by-hop).
|
||||
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
|
||||
|
||||
statusCode := resp.StatusCode
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusOK
|
||||
}
|
||||
|
||||
// Ensure message delimiting works with our custom ResponseWriter:
|
||||
// - If Content-Length is known, send it.
|
||||
// - Otherwise, re-chunk the decoded body ourselves.
|
||||
if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified {
|
||||
if resp.ContentLength >= 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
|
||||
} else {
|
||||
w.Header().Del("Content-Length")
|
||||
}
|
||||
w.WriteHeader(statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
if resp.ContentLength >= 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
ctx := r.Context()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stream.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
close(done)
|
||||
stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Del("Content-Length")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
if len(resp.Trailer) > 0 {
|
||||
w.Header().Set("Trailer", trailerKeys(resp.Trailer))
|
||||
}
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
var cancelTransport func()
|
||||
if transport != nil {
|
||||
cancelOnce := sync.Once{}
|
||||
cancelFunc := func() {
|
||||
header := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeClose,
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
if err := transport.SendFrame(frame); err != nil {
|
||||
h.logger.Debug("Failed to send cancel frame to client",
|
||||
zap.String("request_id", requestID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
cancelTransport = func() {
|
||||
cancelOnce.Do(cancelFunc)
|
||||
}
|
||||
|
||||
h.responses.RegisterCancelFunc(requestID, cancelTransport)
|
||||
defer h.responses.CleanupCancelFunc(requestID)
|
||||
}
|
||||
|
||||
largeBufferPtr := h.bufferPool.GetLarge()
|
||||
tempBufPtr := h.bufferPool.GetMedium()
|
||||
|
||||
defer func() {
|
||||
h.bufferPool.PutLarge(largeBufferPtr)
|
||||
h.bufferPool.PutMedium(tempBufPtr)
|
||||
}()
|
||||
|
||||
buffer := (*largeBufferPtr)[:0]
|
||||
tempBuf := (*tempBufPtr)[:pool.MediumBufferSize]
|
||||
|
||||
var totalRead int64
|
||||
var hitThreshold bool
|
||||
|
||||
for totalRead < streamingThreshold {
|
||||
n, err := r.Body.Read(tempBuf)
|
||||
if n > 0 {
|
||||
buffer = append(buffer, tempBuf[:n]...)
|
||||
totalRead += int64(n)
|
||||
}
|
||||
if err == io.EOF {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
r.Body.Close()
|
||||
h.logger.Error("Read request body failed", zap.Error(err))
|
||||
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if totalRead >= streamingThreshold {
|
||||
hitThreshold = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hitThreshold {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
return
|
||||
}
|
||||
|
||||
h.streamLargeRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
}
|
||||
|
||||
func (h *Handler) sendBufferedRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), body []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
httpReq := protocol.HTTPRequest{
|
||||
Method: r.Method,
|
||||
URL: r.URL.String(),
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
}
|
||||
|
||||
reqBytes, err := protocol.EncodeHTTPRequest(&httpReq)
|
||||
h.headerPool.Put(headers)
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("Encode HTTP request failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
header := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequest,
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, reqBytes)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode data payload failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
|
||||
respChan := h.responses.CreateResponseChan(requestID)
|
||||
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
|
||||
defer func() {
|
||||
h.responses.CleanupResponseChan(requestID)
|
||||
h.responses.CleanupStreamingResponse(requestID)
|
||||
}()
|
||||
|
||||
if err := transport.SendFrame(frame); err != nil {
|
||||
h.logger.Error("Send frame to tunnel failed", zap.Error(err))
|
||||
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case respMsg := <-respChan:
|
||||
if respMsg == nil {
|
||||
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("HTTP request context cancelled",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("url", r.URL.String()),
|
||||
)
|
||||
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) streamLargeRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), bufferedData []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
httpReqHead := protocol.HTTPRequestHead{
|
||||
Method: r.Method,
|
||||
URL: r.URL.String(),
|
||||
Headers: headers,
|
||||
ContentLength: r.ContentLength,
|
||||
}
|
||||
|
||||
headBytes, err := protocol.EncodeHTTPRequestHead(&httpReqHead)
|
||||
h.headerPool.Put(headers)
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("Encode HTTP request head failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
headHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPHead, // shared streaming head type
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode head payload failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
|
||||
|
||||
respChan := h.responses.CreateResponseChan(requestID)
|
||||
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
|
||||
defer func() {
|
||||
h.responses.CleanupResponseChan(requestID)
|
||||
h.responses.CleanupStreamingResponse(requestID)
|
||||
}()
|
||||
|
||||
if err := transport.SendFrame(headFrame); err != nil {
|
||||
h.logger.Error("Send head frame failed", zap.Error(err))
|
||||
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
if len(bufferedData) > 0 {
|
||||
chunkHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, bufferedData)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode buffered chunk failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
||||
if err := transport.SendFrame(chunkFrame); err != nil {
|
||||
h.logger.Error("Send buffered chunk failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
streamBufPtr := h.bufferPool.GetMedium()
|
||||
defer h.bufferPool.PutMedium(streamBufPtr)
|
||||
buffer := (*streamBufPtr)[:pool.MediumBufferSize]
|
||||
for {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("Streaming request cancelled via context",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
default:
|
||||
stream.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
n, readErr := r.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
isLast := readErr == io.EOF
|
||||
|
||||
chunkHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
|
||||
IsLast: isLast,
|
||||
}
|
||||
|
||||
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer[:n])
|
||||
if err != nil {
|
||||
h.logger.Error("Encode chunk payload failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
||||
if err := transport.SendFrame(chunkFrame); err != nil {
|
||||
h.logger.Error("Send chunk frame failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if readErr == io.EOF {
|
||||
if n == 0 {
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if err == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if readErr != nil {
|
||||
h.logger.Error("Read request body failed", zap.Error(readErr))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if err == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := writeChunked(w, resp.Body, resp.Trailer); err != nil {
|
||||
h.logger.Debug("Write chunked response failed", zap.Error(err))
|
||||
}
|
||||
close(done)
|
||||
stream.Close()
|
||||
}
|
||||
|
||||
r.Body.Close()
|
||||
func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, error) {
|
||||
type result struct {
|
||||
stream net.Conn
|
||||
err error
|
||||
}
|
||||
ch := make(chan result, 1)
|
||||
|
||||
go func() {
|
||||
s, err := tconn.OpenStream()
|
||||
ch <- result{s, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case respMsg := <-respChan:
|
||||
if respMsg == nil {
|
||||
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("Streaming HTTP request context cancelled",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Streaming request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("url", r.URL.String()),
|
||||
)
|
||||
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
|
||||
case r := <-ch:
|
||||
return r.stream, r.err
|
||||
case <-time.After(openStreamTimeout):
|
||||
return nil, fmt.Errorf("open stream timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPResponse, subdomain string, r *http.Request) {
|
||||
if resp == nil {
|
||||
http.Error(w, "Invalid response from tunnel", http.StatusBadGateway)
|
||||
func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection) {
|
||||
stream, err := h.openStreamWithTimeout(tconn)
|
||||
if err != nil {
|
||||
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// For buffered responses, we have the complete body, so we can set Content-Length
|
||||
// Skip ALL hop-by-hop headers - client should have already cleaned them
|
||||
for key, values := range resp.Headers {
|
||||
tconn.IncActiveConnections()
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
stream.Close()
|
||||
tconn.DecActiveConnections()
|
||||
http.Error(w, "WebSocket not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
tconn.DecActiveConnections()
|
||||
http.Error(w, "Failed to hijack connection", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.Write(stream); err != nil {
|
||||
stream.Close()
|
||||
clientConn.Close()
|
||||
tconn.DecActiveConnections()
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
defer clientConn.Close()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
_ = netutil.PipeWithCallbacks(r.Context(), stream, clientConn,
|
||||
func(n int64) { tconn.AddBytesOut(n) },
|
||||
func(n int64) { tconn.AddBytesIn(n) },
|
||||
)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHost string) {
|
||||
for key, values := range src {
|
||||
canonicalKey := http.CanonicalHeaderKey(key)
|
||||
|
||||
// Skip hop-by-hop headers completely using canonical key comparison
|
||||
// Hop-by-hop headers must not be forwarded.
|
||||
if canonicalKey == "Connection" ||
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
@@ -496,29 +257,61 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
|
||||
}
|
||||
|
||||
if canonicalKey == "Location" && len(values) > 0 {
|
||||
rewrittenLocation := h.rewriteLocationHeader(values[0], r.Host)
|
||||
w.Header().Set("Location", rewrittenLocation)
|
||||
dst.Set("Location", h.rewriteLocationHeader(values[0], proxyHost))
|
||||
continue
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func trailerKeys(hdr http.Header) string {
|
||||
keys := make([]string, 0, len(hdr))
|
||||
for k := range hdr {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
// Deterministic order is nicer for debugging; no semantic impact.
|
||||
sortStrings(keys)
|
||||
return strings.Join(keys, ", ")
|
||||
}
|
||||
|
||||
func writeChunked(w io.Writer, r io.Reader, trailer http.Header) error {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
if _, werr := fmt.Fprintf(w, "%x\r\n", n); werr != nil {
|
||||
return werr
|
||||
}
|
||||
if _, werr := w.Write(buf[:n]); werr != nil {
|
||||
return werr
|
||||
}
|
||||
if _, werr := io.WriteString(w, "\r\n"); werr != nil {
|
||||
return werr
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// For buffered mode, always set Content-Length with the actual body size
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
|
||||
|
||||
statusCode := resp.StatusCode
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusOK
|
||||
if _, err := io.WriteString(w, "0\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
if len(resp.Body) > 0 {
|
||||
w.Write(resp.Body)
|
||||
for k, vv := range trailer {
|
||||
for _, v := range vv {
|
||||
if _, err := io.WriteString(w, fmt.Sprintf("%s: %s\r\n", k, v)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err := io.WriteString(w, "\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
|
||||
@@ -535,22 +328,13 @@ func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
|
||||
strings.HasPrefix(locationURL.Host, "localhost:") ||
|
||||
locationURL.Host == "127.0.0.1" ||
|
||||
strings.HasPrefix(locationURL.Host, "127.0.0.1:") {
|
||||
scheme := "https"
|
||||
if strings.Contains(proxyHost, ":") && !strings.Contains(proxyHost, "https") {
|
||||
parts := strings.Split(proxyHost, ":")
|
||||
if len(parts) == 2 && parts[1] != "443" {
|
||||
scheme = "https"
|
||||
}
|
||||
}
|
||||
|
||||
rewritten := fmt.Sprintf("%s://%s%s", scheme, proxyHost, locationURL.Path)
|
||||
rewritten := fmt.Sprintf("https://%s%s", proxyHost, locationURL.Path)
|
||||
if locationURL.RawQuery != "" {
|
||||
rewritten += "?" + locationURL.RawQuery
|
||||
}
|
||||
if locationURL.Fragment != "" {
|
||||
rewritten += "#" + locationURL.Fragment
|
||||
}
|
||||
|
||||
return rewritten
|
||||
}
|
||||
|
||||
@@ -568,8 +352,7 @@ func (h *Handler) extractSubdomain(host string) string {
|
||||
|
||||
suffix := "." + h.domain
|
||||
if strings.HasSuffix(host, suffix) {
|
||||
subdomain := strings.TrimSuffix(host, suffix)
|
||||
return subdomain
|
||||
return strings.TrimSuffix(host, suffix)
|
||||
}
|
||||
|
||||
return ""
|
||||
@@ -652,9 +435,17 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
for _, conn := range connections {
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
stats["tunnels"] = append(stats["tunnels"].([]map[string]interface{}), map[string]interface{}{
|
||||
"subdomain": conn.Subdomain,
|
||||
"last_active": conn.LastActive.Unix(),
|
||||
"subdomain": conn.Subdomain,
|
||||
"tunnel_type": string(conn.GetTunnelType()),
|
||||
"last_active": conn.LastActive.Unix(),
|
||||
"bytes_in": conn.GetBytesIn(),
|
||||
"bytes_out": conn.GetBytesOut(),
|
||||
"active_connections": conn.GetActiveConnections(),
|
||||
"total_bytes": conn.GetBytesIn() + conn.GetBytesOut(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -668,3 +459,13 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func sortStrings(s []string) {
|
||||
for i := 0; i < len(s); i++ {
|
||||
for j := i + 1; j < len(s); j++ {
|
||||
if s[j] < s[i] {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,421 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// responseChanEntry holds a response channel and its creation time
|
||||
type responseChanEntry struct {
|
||||
ch chan *protocol.HTTPResponse
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
// streamingResponseEntry holds a streaming response writer
|
||||
type streamingResponseEntry struct {
|
||||
w http.ResponseWriter
|
||||
flusher http.Flusher
|
||||
createdAt time.Time
|
||||
lastActivityAt time.Time
|
||||
headersSent bool
|
||||
done chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// ResponseHandler manages response channels for HTTP requests over TCP/Frame protocol
|
||||
type ResponseHandler struct {
|
||||
channels map[string]*responseChanEntry
|
||||
streamingChannels map[string]*streamingResponseEntry
|
||||
cancelFuncs map[string]func()
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewResponseHandler creates a new response handler
|
||||
func NewResponseHandler(logger *zap.Logger) *ResponseHandler {
|
||||
h := &ResponseHandler{
|
||||
channels: make(map[string]*responseChanEntry),
|
||||
streamingChannels: make(map[string]*streamingResponseEntry),
|
||||
cancelFuncs: make(map[string]func()),
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start single cleanup goroutine instead of one per request
|
||||
go h.cleanupLoop()
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// CreateResponseChan creates a response channel for a request ID
|
||||
func (h *ResponseHandler) CreateResponseChan(requestID string) chan *protocol.HTTPResponse {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
ch := make(chan *protocol.HTTPResponse, 1)
|
||||
h.channels[requestID] = &responseChanEntry{
|
||||
ch: ch,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// CreateStreamingResponse creates a streaming response entry for a request ID
|
||||
func (h *ResponseHandler) CreateStreamingResponse(requestID string, w http.ResponseWriter) chan struct{} {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
flusher, _ := w.(http.Flusher)
|
||||
done := make(chan struct{})
|
||||
now := time.Now()
|
||||
h.streamingChannels[requestID] = &streamingResponseEntry{
|
||||
w: w,
|
||||
flusher: flusher,
|
||||
createdAt: now,
|
||||
lastActivityAt: now,
|
||||
done: done,
|
||||
}
|
||||
|
||||
return done
|
||||
}
|
||||
|
||||
// RegisterCancelFunc registers a callback to be invoked when the downstream disconnects.
|
||||
func (h *ResponseHandler) RegisterCancelFunc(requestID string, cancel func()) {
|
||||
if cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.cancelFuncs[requestID] = cancel
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetResponseChan gets the response channel for a request ID
|
||||
func (h *ResponseHandler) GetResponseChan(requestID string) <-chan *protocol.HTTPResponse {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
if entry := h.channels[requestID]; entry != nil {
|
||||
return entry.ch
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendResponse sends a response to the waiting channel
|
||||
func (h *ResponseHandler) SendResponse(requestID string, resp *protocol.HTTPResponse) {
|
||||
h.mu.RLock()
|
||||
entry, exists := h.channels[requestID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists || entry == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case entry.ch <- resp:
|
||||
case <-time.After(30 * time.Second):
|
||||
h.logger.Error("Timeout sending response to channel - handler may have abandoned",
|
||||
zap.String("request_id", requestID),
|
||||
zap.Int("status_code", resp.StatusCode),
|
||||
zap.Int("body_size", len(resp.Body)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error {
|
||||
h.mu.RLock()
|
||||
entry, exists := h.streamingChannels[requestID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists || entry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.mu.Lock()
|
||||
defer entry.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-entry.done:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
if entry.headersSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy headers, removing hop-by-hop headers that were already handled by client
|
||||
// Client's cleanResponseHeaders already removed Transfer-Encoding, Connection, etc.
|
||||
// But we need to check again in case they slipped through
|
||||
hasContentLength := false
|
||||
|
||||
for key, values := range head.Headers {
|
||||
canonicalKey := http.CanonicalHeaderKey(key)
|
||||
|
||||
// Skip ALL hop-by-hop headers
|
||||
if canonicalKey == "Connection" ||
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
continue
|
||||
}
|
||||
|
||||
if canonicalKey == "Content-Length" {
|
||||
hasContentLength = true
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
entry.w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// For streaming responses, decide how to indicate message length
|
||||
if head.ContentLength >= 0 && !hasContentLength {
|
||||
entry.w.Header().Set("Content-Length", fmt.Sprintf("%d", head.ContentLength))
|
||||
}
|
||||
|
||||
statusCode := head.StatusCode
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusOK
|
||||
}
|
||||
|
||||
entry.w.WriteHeader(statusCode)
|
||||
entry.headersSent = true
|
||||
entry.lastActivityAt = time.Now()
|
||||
|
||||
if entry.flusher != nil {
|
||||
entry.flusher.Flush()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) SendStreamingChunk(requestID string, chunk []byte, isLast bool) error {
|
||||
h.mu.RLock()
|
||||
entry, exists := h.streamingChannels[requestID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists || entry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.mu.Lock()
|
||||
defer entry.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-entry.done:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
if len(chunk) > 0 {
|
||||
_, err := entry.w.Write(chunk)
|
||||
if err != nil {
|
||||
if isClientDisconnectError(err) {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
h.triggerCancel(requestID)
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
h.triggerCancel(requestID)
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.lastActivityAt = time.Now()
|
||||
|
||||
if entry.flusher != nil {
|
||||
entry.flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if isLast {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isClientDisconnectError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if netErr, ok := err.(*net.OpError); ok {
|
||||
if netErr.Err != nil {
|
||||
errStr := netErr.Err.Error()
|
||||
if strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection reset") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection reset") ||
|
||||
strings.Contains(errStr, "use of closed network connection")
|
||||
}
|
||||
|
||||
// triggerCancel invokes and removes the cancel callback for a request.
|
||||
func (h *ResponseHandler) triggerCancel(requestID string) {
|
||||
h.mu.Lock()
|
||||
cancel := h.cancelFuncs[requestID]
|
||||
if cancel != nil {
|
||||
delete(h.cancelFuncs, requestID)
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
if cancel != nil {
|
||||
go func() {
|
||||
cancel()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) CleanupResponseChan(requestID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if entry, exists := h.channels[requestID]; exists {
|
||||
close(entry.ch)
|
||||
delete(h.channels, requestID)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) CleanupStreamingResponse(requestID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if entry, exists := h.streamingChannels[requestID]; exists {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
delete(h.streamingChannels, requestID)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupCancelFunc removes a registered cancel callback.
|
||||
func (h *ResponseHandler) CleanupCancelFunc(requestID string) {
|
||||
h.mu.Lock()
|
||||
delete(h.cancelFuncs, requestID)
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) GetPendingCount() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return len(h.channels) + len(h.streamingChannels)
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) cleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
h.cleanupExpiredChannels()
|
||||
case <-h.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) cleanupExpiredChannels() {
|
||||
now := time.Now()
|
||||
timeout := 5 * time.Minute
|
||||
streamingTimeout := 5 * time.Minute
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
expiredCount := 0
|
||||
cancelList := make([]string, 0)
|
||||
for requestID, entry := range h.channels {
|
||||
if now.Sub(entry.createdAt) > timeout {
|
||||
close(entry.ch)
|
||||
delete(h.channels, requestID)
|
||||
expiredCount++
|
||||
}
|
||||
}
|
||||
|
||||
for requestID, entry := range h.streamingChannels {
|
||||
if now.Sub(entry.lastActivityAt) > streamingTimeout {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
delete(h.streamingChannels, requestID)
|
||||
cancelList = append(cancelList, requestID)
|
||||
expiredCount++
|
||||
}
|
||||
}
|
||||
|
||||
for _, requestID := range cancelList {
|
||||
if cancel := h.cancelFuncs[requestID]; cancel != nil {
|
||||
delete(h.cancelFuncs, requestID)
|
||||
go cancel()
|
||||
}
|
||||
}
|
||||
|
||||
if expiredCount > 0 {
|
||||
h.logger.Debug("Cleaned up expired response channels",
|
||||
zap.Int("count", expiredCount),
|
||||
zap.Int("remaining", len(h.channels)+len(h.streamingChannels)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) Close() {
|
||||
close(h.stopCh)
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
for _, entry := range h.channels {
|
||||
close(entry.ch)
|
||||
}
|
||||
h.channels = make(map[string]*responseChanEntry)
|
||||
|
||||
for _, entry := range h.streamingChannels {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
}
|
||||
h.streamingChannels = make(map[string]*streamingResponseEntry)
|
||||
|
||||
for _, cancel := range h.cancelFuncs {
|
||||
cancel()
|
||||
}
|
||||
h.cancelFuncs = make(map[string]func())
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/hashicorp/yamux"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/constants"
|
||||
@@ -33,36 +34,27 @@ type Connection struct {
|
||||
publicPort int
|
||||
portAlloc *PortAllocator
|
||||
tunnelConn *tunnel.Connection
|
||||
proxy *TunnelProxy
|
||||
stopCh chan struct{}
|
||||
once sync.Once
|
||||
lastHeartbeat time.Time
|
||||
mu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
httpHandler http.Handler
|
||||
responseChans HTTPResponseHandler
|
||||
tunnelType protocol.TunnelType // Track tunnel type
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Flow control
|
||||
paused bool
|
||||
pauseCond *sync.Cond
|
||||
}
|
||||
// gost-like TCP tunnel (yamux)
|
||||
session *yamux.Session
|
||||
proxy *Proxy
|
||||
|
||||
// HTTPResponseHandler interface for response channel operations
|
||||
type HTTPResponseHandler interface {
|
||||
CreateResponseChan(requestID string) chan *protocol.HTTPResponse
|
||||
GetResponseChan(requestID string) <-chan *protocol.HTTPResponse
|
||||
CleanupResponseChan(requestID string)
|
||||
SendResponse(requestID string, resp *protocol.HTTPResponse)
|
||||
// Streaming response methods
|
||||
SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error
|
||||
SendStreamingChunk(requestID string, chunk []byte, isLast bool) error
|
||||
// Multi-connection support
|
||||
tunnelID string
|
||||
groupManager *ConnectionGroupManager
|
||||
}
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Connection {
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager) *Connection {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
c := &Connection{
|
||||
conn: conn,
|
||||
@@ -73,13 +65,12 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
|
||||
domain: domain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
responseChans: responseChans,
|
||||
stopCh: make(chan struct{}),
|
||||
lastHeartbeat: time.Now(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
groupManager: groupManager,
|
||||
}
|
||||
c.pauseCond = sync.NewCond(&c.mu)
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -97,8 +88,8 @@ func (c *Connection) Handle() error {
|
||||
// Use buffered reader to support peeking
|
||||
reader := bufio.NewReader(c.conn)
|
||||
|
||||
// Peek first 8 bytes to detect protocol
|
||||
peek, err := reader.Peek(8)
|
||||
// Peek first 4 bytes to detect protocol (HTTP methods are 4 bytes).
|
||||
peek, err := reader.Peek(4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to peek connection: %w", err)
|
||||
}
|
||||
@@ -127,6 +118,11 @@ func (c *Connection) Handle() error {
|
||||
sf := protocol.WithFrame(frame)
|
||||
defer sf.Close()
|
||||
|
||||
// Handle data connection request (for multi-connection pool)
|
||||
if sf.Frame.Type == protocol.FrameTypeDataConnect {
|
||||
return c.handleDataConnect(sf.Frame, reader)
|
||||
}
|
||||
|
||||
if sf.Frame.Type != protocol.FrameTypeRegister {
|
||||
return fmt.Errorf("expected register frame, got %s", sf.Frame.Type)
|
||||
}
|
||||
@@ -180,7 +176,6 @@ func (c *Connection) Handle() error {
|
||||
|
||||
// Store TCP connection reference and metadata for HTTP proxy routing
|
||||
c.tunnelConn.Conn = nil // We're using TCP, not WebSocket
|
||||
c.tunnelConn.SetTransport(c, req.TunnelType)
|
||||
c.tunnelConn.SetTunnelType(req.TunnelType)
|
||||
c.tunnelType = req.TunnelType
|
||||
|
||||
@@ -208,11 +203,33 @@ func (c *Connection) Handle() error {
|
||||
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
|
||||
}
|
||||
|
||||
// Generate TunnelID for multi-connection support if client supports it
|
||||
var tunnelID string
|
||||
var supportsDataConn bool
|
||||
recommendedConns := 0
|
||||
|
||||
if req.PoolCapabilities != nil && req.ConnectionType == "primary" && c.groupManager != nil {
|
||||
// Client supports connection pooling
|
||||
group := c.groupManager.CreateGroup(subdomain, req.Token, c, req.TunnelType)
|
||||
tunnelID = group.TunnelID
|
||||
c.tunnelID = tunnelID
|
||||
supportsDataConn = true
|
||||
recommendedConns = 4 // Recommend 4 data connections
|
||||
|
||||
c.logger.Info("Created connection group for multi-connection support",
|
||||
zap.String("tunnel_id", tunnelID),
|
||||
zap.Int("max_data_conns", req.PoolCapabilities.MaxDataConns),
|
||||
)
|
||||
}
|
||||
|
||||
resp := protocol.RegisterResponse{
|
||||
Subdomain: subdomain,
|
||||
Port: c.port,
|
||||
URL: tunnelURL,
|
||||
Message: "Tunnel registered successfully",
|
||||
Subdomain: subdomain,
|
||||
Port: c.port,
|
||||
URL: tunnelURL,
|
||||
Message: "Tunnel registered successfully",
|
||||
TunnelID: tunnelID,
|
||||
SupportsDataConn: supportsDataConn,
|
||||
RecommendedConns: recommendedConns,
|
||||
}
|
||||
|
||||
respData, _ := json.Marshal(resp)
|
||||
@@ -224,6 +241,17 @@ func (c *Connection) Handle() error {
|
||||
return fmt.Errorf("failed to send registration ack: %w", err)
|
||||
}
|
||||
|
||||
// Clear deadline for tunnel data-plane.
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// gost-like tunnels: switch to yamux after RegisterAck.
|
||||
if req.TunnelType == protocol.TunnelTypeTCP {
|
||||
return c.handleTCPTunnel(reader)
|
||||
}
|
||||
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
|
||||
return c.handleHTTPProxyTunnel(reader)
|
||||
}
|
||||
|
||||
c.frameWriter = protocol.NewFrameWriter(c.conn)
|
||||
|
||||
c.frameWriter.SetWriteErrorHandler(func(err error) {
|
||||
@@ -231,15 +259,6 @@ func (c *Connection) Handle() error {
|
||||
c.Close()
|
||||
})
|
||||
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if req.TunnelType == protocol.TunnelTypeTCP {
|
||||
c.proxy = NewTunnelProxy(c.port, subdomain, c.conn, c.logger)
|
||||
if err := c.proxy.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start TCP proxy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
go c.heartbeatChecker()
|
||||
|
||||
return c.handleFrames(reader)
|
||||
@@ -376,7 +395,7 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(reader)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
if isTimeoutError(err) {
|
||||
c.logger.Warn("Read timeout, connection may be dead")
|
||||
return fmt.Errorf("read timeout")
|
||||
}
|
||||
@@ -404,15 +423,6 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
c.handleHeartbeat()
|
||||
sf.Close()
|
||||
|
||||
case protocol.FrameTypeData:
|
||||
// Data frame from client (response to forwarded request)
|
||||
c.handleDataFrame(sf.Frame)
|
||||
sf.Close()
|
||||
|
||||
case protocol.FrameTypeFlowControl:
|
||||
c.handleFlowControl(sf.Frame)
|
||||
sf.Close()
|
||||
|
||||
case protocol.FrameTypeClose:
|
||||
sf.Close()
|
||||
c.logger.Info("Client requested close")
|
||||
@@ -436,127 +446,12 @@ func (c *Connection) handleHeartbeat() {
|
||||
// Send heartbeat ack
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeHeartbeatAck, nil)
|
||||
|
||||
err := c.frameWriter.WriteFrame(ackFrame)
|
||||
err := c.frameWriter.WriteControl(ackFrame)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to send heartbeat ack", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// handleDataFrame handles data frame (response from client)
|
||||
func (c *Connection) handleDataFrame(frame *protocol.Frame) {
|
||||
// Decode payload (auto-detects protocol version)
|
||||
header, data, err := protocol.DecodeDataPayload(frame.Payload)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode data payload",
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Debug("Received data frame",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.String("type", header.Type.String()),
|
||||
zap.Int("data_size", len(data)),
|
||||
)
|
||||
|
||||
switch header.Type {
|
||||
case protocol.DataTypeResponse:
|
||||
// TCP tunnel response, forward to proxy
|
||||
if c.proxy != nil {
|
||||
if err := c.proxy.HandleResponse(header.StreamID, data); err != nil {
|
||||
c.logger.Error("Failed to handle response",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
case protocol.DataTypeHTTPResponse:
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response channel handler for HTTP response",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Decode HTTP response (auto-detects JSON vs msgpack)
|
||||
httpResp, err := protocol.DecodeHTTPResponse(data)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode HTTP response",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Route by request ID when provided to keep request/response aligned.
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
c.responseChans.SendResponse(reqID, httpResp)
|
||||
case protocol.DataTypeHTTPHead:
|
||||
// Streaming HTTP response headers
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response handler for streaming HTTP head",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
httpHead, err := protocol.DecodeHTTPResponseHead(data)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode HTTP response head",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
if err := c.responseChans.SendStreamingHead(reqID, httpHead); err != nil {
|
||||
c.logger.Error("Failed to send streaming head",
|
||||
zap.String("request_id", reqID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
case protocol.DataTypeHTTPBodyChunk:
|
||||
// Streaming HTTP response body chunk
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response handler for streaming HTTP chunk",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
if err := c.responseChans.SendStreamingChunk(reqID, data, header.IsLast); err != nil {
|
||||
c.logger.Error("Failed to send streaming chunk",
|
||||
zap.String("request_id", reqID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
case protocol.DataTypeClose:
|
||||
// Client is closing the stream
|
||||
if c.proxy != nil {
|
||||
c.proxy.CloseStream(header.StreamID)
|
||||
}
|
||||
default:
|
||||
c.logger.Warn("Unknown data frame type",
|
||||
zap.String("type", header.Type.String()),
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeatChecker checks for heartbeat timeout
|
||||
func (c *Connection) heartbeatChecker() {
|
||||
ticker := time.NewTicker(constants.HeartbeatInterval)
|
||||
@@ -583,16 +478,6 @@ func (c *Connection) heartbeatChecker() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) SendFrame(frame *protocol.Frame) error {
|
||||
if c.frameWriter == nil {
|
||||
return protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
if frame.Type == protocol.FrameTypeData {
|
||||
return c.sendWithBackpressure(frame)
|
||||
}
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
|
||||
func (c *Connection) sendError(code, message string) {
|
||||
errMsg := protocol.ErrorMessage{
|
||||
Code: code,
|
||||
@@ -618,8 +503,12 @@ func (c *Connection) Close() {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
// Ensure any in-flight writes return quickly on shutdown to avoid hanging.
|
||||
if c.conn != nil {
|
||||
_ = c.conn.SetDeadline(time.Now())
|
||||
}
|
||||
|
||||
if c.frameWriter != nil {
|
||||
c.frameWriter.Flush()
|
||||
c.frameWriter.Close()
|
||||
}
|
||||
|
||||
@@ -627,7 +516,13 @@ func (c *Connection) Close() {
|
||||
c.proxy.Stop()
|
||||
}
|
||||
|
||||
c.conn.Close()
|
||||
if c.session != nil {
|
||||
_ = c.session.Close()
|
||||
}
|
||||
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
if c.port > 0 && c.portAlloc != nil {
|
||||
c.portAlloc.Release(c.port)
|
||||
@@ -635,6 +530,12 @@ func (c *Connection) Close() {
|
||||
|
||||
if c.subdomain != "" {
|
||||
c.manager.Unregister(c.subdomain)
|
||||
|
||||
// Clean up connection group when PRIMARY connection closes
|
||||
// (only primary connections have subdomain set)
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
c.groupManager.RemoveGroup(c.tunnelID)
|
||||
}
|
||||
}
|
||||
|
||||
c.logger.Info("Connection closed",
|
||||
@@ -643,11 +544,6 @@ func (c *Connection) Close() {
|
||||
})
|
||||
}
|
||||
|
||||
// GetSubdomain returns the assigned subdomain
|
||||
func (c *Connection) GetSubdomain() string {
|
||||
return c.subdomain
|
||||
}
|
||||
|
||||
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
|
||||
type httpResponseWriter struct {
|
||||
conn net.Conn
|
||||
@@ -698,39 +594,196 @@ func (w *httpResponseWriter) Write(data []byte) (int, error) {
|
||||
return w.writer.Write(data)
|
||||
}
|
||||
|
||||
func (c *Connection) handleFlowControl(frame *protocol.Frame) {
|
||||
msg, err := protocol.DecodeFlowControlMessage(frame.Payload)
|
||||
// handleDataConnect handles a data connection join request
|
||||
func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Reader) error {
|
||||
var req protocol.DataConnectRequest
|
||||
if err := json.Unmarshal(frame.Payload, &req); err != nil {
|
||||
c.sendError("invalid_request", "Failed to parse data connect request")
|
||||
return fmt.Errorf("failed to parse data connect request: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("Data connection request received",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Validate the request
|
||||
if c.groupManager == nil {
|
||||
c.sendDataConnectError("not_supported", "Multi-connection not supported")
|
||||
return fmt.Errorf("group manager not available")
|
||||
}
|
||||
|
||||
// Validate auth token
|
||||
if c.authToken != "" && req.Token != c.authToken {
|
||||
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed for data connection")
|
||||
}
|
||||
|
||||
group, ok := c.groupManager.GetGroup(req.TunnelID)
|
||||
if !ok || group == nil {
|
||||
c.sendDataConnectError("join_failed", "Tunnel not found")
|
||||
return fmt.Errorf("tunnel not found: %s", req.TunnelID)
|
||||
}
|
||||
|
||||
// Validate token against the primary registration token.
|
||||
if group.Token != "" && req.Token != group.Token {
|
||||
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed for data connection")
|
||||
}
|
||||
|
||||
// Store tunnelID for cleanup
|
||||
c.tunnelID = req.TunnelID
|
||||
|
||||
// For TCP tunnels, the data connection is upgraded to a yamux session and used for
|
||||
// stream forwarding, not framed request/response routing.
|
||||
if group.TunnelType == protocol.TunnelTypeTCP {
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: true,
|
||||
ConnectionID: req.ConnectionID,
|
||||
Message: "Data connection accepted",
|
||||
}
|
||||
|
||||
respData, _ := json.Marshal(resp)
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
|
||||
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
|
||||
return fmt.Errorf("failed to send data connect ack: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("TCP data connection established",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Clear deadline for yamux data-plane.
|
||||
_ = c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
group.AddSession(req.ConnectionID, session)
|
||||
defer group.RemoveSession(req.ConnectionID)
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Add data connection to group
|
||||
dataConn, err := c.groupManager.AddDataConnection(&req, c.conn)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode flow control", zap.Error(err))
|
||||
return
|
||||
c.sendDataConnectError("join_failed", err.Error())
|
||||
return fmt.Errorf("failed to join connection group: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
// Send success response
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: true,
|
||||
ConnectionID: req.ConnectionID,
|
||||
Message: "Data connection accepted",
|
||||
}
|
||||
|
||||
switch msg.Action {
|
||||
case protocol.FlowControlPause:
|
||||
c.paused = true
|
||||
c.logger.Warn("Client requested pause",
|
||||
zap.String("stream", msg.StreamID))
|
||||
respData, _ := json.Marshal(resp)
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
|
||||
case protocol.FlowControlResume:
|
||||
c.paused = false
|
||||
c.pauseCond.Broadcast()
|
||||
c.logger.Info("Client requested resume",
|
||||
zap.String("stream", msg.StreamID))
|
||||
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
|
||||
return fmt.Errorf("failed to send data connect ack: %w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
c.logger.Warn("Unknown flow control action",
|
||||
zap.String("action", string(msg.Action)))
|
||||
c.logger.Info("Data connection established",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Handle data frames on this connection
|
||||
return c.handleDataConnectionFrames(dataConn, reader)
|
||||
}
|
||||
|
||||
// handleDataConnectionFrames handles frames on a data connection
|
||||
func (c *Connection) handleDataConnectionFrames(dataConn *DataConnection, reader *bufio.Reader) error {
|
||||
defer func() {
|
||||
// Get the group and remove this data connection
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok {
|
||||
group.RemoveDataConnection(dataConn.ID)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-dataConn.stopCh:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(reader)
|
||||
if err != nil {
|
||||
// Timeout is OK, continue
|
||||
if isTimeoutError(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
dataConn.mu.Lock()
|
||||
dataConn.LastActive = time.Now()
|
||||
dataConn.mu.Unlock()
|
||||
|
||||
sf := protocol.WithFrame(frame)
|
||||
|
||||
switch sf.Frame.Type {
|
||||
case protocol.FrameTypeClose:
|
||||
sf.Close()
|
||||
c.logger.Info("Data connection closed by client",
|
||||
zap.String("connection_id", dataConn.ID))
|
||||
return nil
|
||||
|
||||
default:
|
||||
sf.Close()
|
||||
c.logger.Warn("Unexpected frame type on data connection",
|
||||
zap.String("type", sf.Frame.Type.String()),
|
||||
zap.String("connection_id", dataConn.ID),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) sendWithBackpressure(frame *protocol.Frame) error {
|
||||
c.mu.Lock()
|
||||
for c.paused {
|
||||
c.pauseCond.Wait()
|
||||
func isTimeoutError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
// Fallback for wrapped errors without net.Error
|
||||
return strings.Contains(err.Error(), "i/o timeout")
|
||||
}
|
||||
|
||||
// sendDataConnectError sends a data connect error response
|
||||
func (c *Connection) sendDataConnectError(code, message string) {
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: false,
|
||||
Message: fmt.Sprintf("%s: %s", code, message),
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
|
||||
438
internal/server/tcp/connection_group.go
Normal file
438
internal/server/tcp/connection_group.go
Normal file
@@ -0,0 +1,438 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
|
||||
type DataConnection struct {
|
||||
ID string
|
||||
Conn net.Conn
|
||||
LastActive time.Time
|
||||
closed bool
|
||||
closedMu sync.RWMutex
|
||||
stopCh chan struct{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type ConnectionGroup struct {
|
||||
TunnelID string
|
||||
Subdomain string
|
||||
Token string
|
||||
PrimaryConn *Connection
|
||||
DataConns map[string]*DataConnection
|
||||
Sessions map[string]*yamux.Session
|
||||
TunnelType protocol.TunnelType
|
||||
RegisteredAt time.Time
|
||||
LastActivity time.Time
|
||||
sessionIdx uint32
|
||||
mu sync.RWMutex
|
||||
stopCh chan struct{}
|
||||
logger *zap.Logger
|
||||
|
||||
heartbeatStarted bool
|
||||
}
|
||||
|
||||
func NewConnectionGroup(tunnelID, subdomain, token string, primaryConn *Connection, tunnelType protocol.TunnelType, logger *zap.Logger) *ConnectionGroup {
|
||||
return &ConnectionGroup{
|
||||
TunnelID: tunnelID,
|
||||
Subdomain: subdomain,
|
||||
Token: token,
|
||||
PrimaryConn: primaryConn,
|
||||
DataConns: make(map[string]*DataConnection),
|
||||
Sessions: make(map[string]*yamux.Session),
|
||||
TunnelType: tunnelType,
|
||||
RegisteredAt: time.Now(),
|
||||
LastActivity: time.Now(),
|
||||
stopCh: make(chan struct{}),
|
||||
logger: logger.With(zap.String("tunnel_id", tunnelID)),
|
||||
}
|
||||
}
|
||||
|
||||
// StartHeartbeat starts a goroutine that periodically pings all sessions
|
||||
// and removes dead ones. The caller should ensure this is only called once.
|
||||
func (g *ConnectionGroup) StartHeartbeat(interval, timeout time.Duration) {
|
||||
go g.heartbeatLoop(interval, timeout)
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
const maxConsecutiveFailures = 3
|
||||
failureCount := make(map[string]int)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
g.mu.RLock()
|
||||
sessions := make(map[string]*yamux.Session, len(g.Sessions))
|
||||
for id, s := range g.Sessions {
|
||||
sessions[id] = s
|
||||
}
|
||||
g.mu.RUnlock()
|
||||
|
||||
for id, session := range sessions {
|
||||
if session == nil || session.IsClosed() {
|
||||
g.RemoveSession(id)
|
||||
delete(failureCount, id)
|
||||
continue
|
||||
}
|
||||
|
||||
// Ping with timeout
|
||||
done := make(chan error, 1)
|
||||
go func(s *yamux.Session) {
|
||||
_, err := s.Ping()
|
||||
done <- err
|
||||
}(session)
|
||||
|
||||
var err error
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-time.After(timeout):
|
||||
err = fmt.Errorf("ping timeout")
|
||||
case <-g.stopCh:
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
failureCount[id]++
|
||||
g.logger.Debug("Session ping failed",
|
||||
zap.String("session_id", id),
|
||||
zap.Int("consecutive_failures", failureCount[id]),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
if failureCount[id] >= maxConsecutiveFailures {
|
||||
g.logger.Warn("Session ping failed too many times, removing",
|
||||
zap.String("session_id", id),
|
||||
zap.Int("failures", failureCount[id]),
|
||||
)
|
||||
g.RemoveSession(id)
|
||||
delete(failureCount, id)
|
||||
}
|
||||
} else {
|
||||
// Reset on success
|
||||
failureCount[id] = 0
|
||||
g.mu.Lock()
|
||||
g.LastActivity = time.Now()
|
||||
g.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all sessions are gone
|
||||
g.mu.RLock()
|
||||
sessionCount := len(g.Sessions)
|
||||
g.mu.RUnlock()
|
||||
|
||||
if sessionCount == 0 {
|
||||
g.logger.Info("All sessions closed, tunnel will be cleaned up")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) AddDataConnection(connID string, conn net.Conn) *DataConnection {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
dataConn := &DataConnection{
|
||||
ID: connID,
|
||||
Conn: conn,
|
||||
LastActive: time.Now(),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
g.DataConns[connID] = dataConn
|
||||
g.LastActivity = time.Now()
|
||||
return dataConn
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) RemoveDataConnection(connID string) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
if dataConn, ok := g.DataConns[connID]; ok {
|
||||
dataConn.closedMu.Lock()
|
||||
if !dataConn.closed {
|
||||
dataConn.closed = true
|
||||
close(dataConn.stopCh)
|
||||
if dataConn.Conn != nil {
|
||||
_ = dataConn.Conn.SetDeadline(time.Now())
|
||||
dataConn.Conn.Close()
|
||||
}
|
||||
}
|
||||
dataConn.closedMu.Unlock()
|
||||
delete(g.DataConns, connID)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) DataConnectionCount() int {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
return len(g.DataConns)
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) Close() {
|
||||
g.mu.Lock()
|
||||
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
g.mu.Unlock()
|
||||
return
|
||||
default:
|
||||
close(g.stopCh)
|
||||
}
|
||||
|
||||
dataConns := make([]*DataConnection, 0, len(g.DataConns))
|
||||
for _, dataConn := range g.DataConns {
|
||||
dataConns = append(dataConns, dataConn)
|
||||
}
|
||||
g.DataConns = make(map[string]*DataConnection)
|
||||
|
||||
sessions := make([]*yamux.Session, 0, len(g.Sessions))
|
||||
for _, session := range g.Sessions {
|
||||
if session != nil {
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
}
|
||||
g.Sessions = make(map[string]*yamux.Session)
|
||||
|
||||
g.mu.Unlock()
|
||||
|
||||
for _, dataConn := range dataConns {
|
||||
dataConn.closedMu.Lock()
|
||||
if !dataConn.closed {
|
||||
dataConn.closed = true
|
||||
close(dataConn.stopCh)
|
||||
if dataConn.Conn != nil {
|
||||
_ = dataConn.Conn.SetDeadline(time.Now())
|
||||
_ = dataConn.Conn.Close()
|
||||
}
|
||||
}
|
||||
dataConn.closedMu.Unlock()
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
_ = session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) IsStale(timeout time.Duration) bool {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
return time.Since(g.LastActivity) > timeout
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) AddSession(connID string, session *yamux.Session) {
|
||||
if connID == "" || session == nil {
|
||||
return
|
||||
}
|
||||
|
||||
g.mu.Lock()
|
||||
if g.Sessions == nil {
|
||||
g.Sessions = make(map[string]*yamux.Session)
|
||||
}
|
||||
g.Sessions[connID] = session
|
||||
g.LastActivity = time.Now()
|
||||
|
||||
// Start heartbeat on first session
|
||||
shouldStartHeartbeat := !g.heartbeatStarted
|
||||
if shouldStartHeartbeat {
|
||||
g.heartbeatStarted = true
|
||||
}
|
||||
g.mu.Unlock()
|
||||
|
||||
if shouldStartHeartbeat {
|
||||
g.StartHeartbeat(constants.HeartbeatInterval, constants.HeartbeatTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) RemoveSession(connID string) {
|
||||
if connID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var session *yamux.Session
|
||||
|
||||
g.mu.Lock()
|
||||
if g.Sessions != nil {
|
||||
session = g.Sessions[connID]
|
||||
delete(g.Sessions, connID)
|
||||
}
|
||||
g.mu.Unlock()
|
||||
|
||||
if session != nil {
|
||||
_ = session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) SessionCount() int {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
return len(g.Sessions)
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
|
||||
const (
|
||||
maxStreamsPerSession = 256
|
||||
maxRetries = 3
|
||||
backoffBase = 25 * time.Millisecond
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
return nil, net.ErrClosed
|
||||
default:
|
||||
}
|
||||
|
||||
sessions := g.sessionsSnapshot()
|
||||
if len(sessions) == 0 {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
|
||||
tried := make([]bool, len(sessions))
|
||||
anyUnderCap := false
|
||||
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
|
||||
|
||||
for range sessions {
|
||||
bestIdx := -1
|
||||
minStreams := int(^uint(0) >> 1)
|
||||
|
||||
for i := 0; i < len(sessions); i++ {
|
||||
idx := (start + i) % len(sessions)
|
||||
if tried[idx] {
|
||||
continue
|
||||
}
|
||||
|
||||
session := sessions[idx]
|
||||
if session == nil || session.IsClosed() {
|
||||
tried[idx] = true
|
||||
continue
|
||||
}
|
||||
|
||||
n := session.NumStreams()
|
||||
if n >= maxStreamsPerSession {
|
||||
continue
|
||||
}
|
||||
anyUnderCap = true
|
||||
|
||||
if n < minStreams {
|
||||
minStreams = n
|
||||
bestIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
if bestIdx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
tried[bestIdx] = true
|
||||
session := sessions[bestIdx]
|
||||
if session == nil || session.IsClosed() {
|
||||
continue
|
||||
}
|
||||
|
||||
stream, err := session.Open()
|
||||
if err == nil {
|
||||
return stream, nil
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
if session.IsClosed() {
|
||||
g.deleteClosedSessions()
|
||||
}
|
||||
}
|
||||
|
||||
if !anyUnderCap {
|
||||
lastErr = fmt.Errorf("all sessions are at stream capacity (%d)", maxStreamsPerSession)
|
||||
}
|
||||
|
||||
if attempt < maxRetries-1 {
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
return nil, net.ErrClosed
|
||||
case <-time.After(backoffBase * time.Duration(attempt+1)):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr == nil {
|
||||
lastErr = fmt.Errorf("failed to open stream")
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) selectSession() *yamux.Session {
|
||||
sessions := g.sessionsSnapshot()
|
||||
if len(sessions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := int(atomic.AddUint32(&g.sessionIdx, 1) - 1)
|
||||
minStreams := int(^uint(0) >> 1)
|
||||
var best *yamux.Session
|
||||
|
||||
for i := 0; i < len(sessions); i++ {
|
||||
session := sessions[(start+i)%len(sessions)]
|
||||
if session == nil || session.IsClosed() {
|
||||
continue
|
||||
}
|
||||
if n := session.NumStreams(); n < minStreams {
|
||||
minStreams = n
|
||||
best = session
|
||||
}
|
||||
}
|
||||
|
||||
return best
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
if len(g.Sessions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sessions := make([]*yamux.Session, 0, len(g.Sessions))
|
||||
for id, session := range g.Sessions {
|
||||
if session == nil || session.IsClosed() {
|
||||
delete(g.Sessions, id)
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
|
||||
if len(sessions) > 0 {
|
||||
g.LastActivity = time.Now()
|
||||
}
|
||||
|
||||
return sessions
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) deleteClosedSessions() {
|
||||
g.mu.Lock()
|
||||
for id, session := range g.Sessions {
|
||||
if session == nil || session.IsClosed() {
|
||||
delete(g.Sessions, id)
|
||||
}
|
||||
}
|
||||
g.mu.Unlock()
|
||||
}
|
||||
163
internal/server/tcp/connection_group_manager.go
Normal file
163
internal/server/tcp/connection_group_manager.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ConnectionGroupManager manages all connection groups
|
||||
type ConnectionGroupManager struct {
|
||||
groups map[string]*ConnectionGroup // TunnelID -> ConnectionGroup
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
|
||||
// Cleanup
|
||||
cleanupInterval time.Duration
|
||||
staleTimeout time.Duration
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewConnectionGroupManager creates a new connection group manager
|
||||
func NewConnectionGroupManager(logger *zap.Logger) *ConnectionGroupManager {
|
||||
m := &ConnectionGroupManager{
|
||||
groups: make(map[string]*ConnectionGroup),
|
||||
logger: logger,
|
||||
cleanupInterval: 60 * time.Second,
|
||||
staleTimeout: 5 * time.Minute,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
go m.cleanupLoop()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// GenerateTunnelID generates a unique tunnel ID
|
||||
func GenerateTunnelID() string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// CreateGroup creates a new connection group
|
||||
func (m *ConnectionGroupManager) CreateGroup(subdomain, token string, primaryConn *Connection, tunnelType protocol.TunnelType) *ConnectionGroup {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
tunnelID := GenerateTunnelID()
|
||||
|
||||
group := NewConnectionGroup(tunnelID, subdomain, token, primaryConn, tunnelType, m.logger)
|
||||
|
||||
m.groups[tunnelID] = group
|
||||
|
||||
return group
|
||||
}
|
||||
|
||||
// GetGroup returns a connection group by tunnel ID
|
||||
func (m *ConnectionGroupManager) GetGroup(tunnelID string) (*ConnectionGroup, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
group, ok := m.groups[tunnelID]
|
||||
return group, ok
|
||||
}
|
||||
|
||||
// RemoveGroup removes and closes a connection group
|
||||
func (m *ConnectionGroupManager) RemoveGroup(tunnelID string) {
|
||||
m.mu.Lock()
|
||||
group, ok := m.groups[tunnelID]
|
||||
if ok {
|
||||
delete(m.groups, tunnelID)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if ok && group != nil {
|
||||
group.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// AddDataConnection adds a data connection to a group
|
||||
func (m *ConnectionGroupManager) AddDataConnection(req *protocol.DataConnectRequest, conn net.Conn) (*DataConnection, error) {
|
||||
m.mu.RLock()
|
||||
group, ok := m.groups[req.TunnelID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tunnel not found: %s", req.TunnelID)
|
||||
}
|
||||
|
||||
// Validate token
|
||||
if group.Token != "" && req.Token != group.Token {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
dataConn := group.AddDataConnection(req.ConnectionID, conn)
|
||||
|
||||
return dataConn, nil
|
||||
}
|
||||
|
||||
// cleanupLoop periodically cleans up stale groups
|
||||
func (m *ConnectionGroupManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(m.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.cleanupStaleGroups()
|
||||
case <-m.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ConnectionGroupManager) cleanupStaleGroups() {
|
||||
// Collect stale groups under lock
|
||||
m.mu.Lock()
|
||||
var staleGroups []*ConnectionGroup
|
||||
var staleIDs []string
|
||||
for tunnelID, group := range m.groups {
|
||||
if group.IsStale(m.staleTimeout) {
|
||||
staleIDs = append(staleIDs, tunnelID)
|
||||
staleGroups = append(staleGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from map while holding lock
|
||||
for _, tunnelID := range staleIDs {
|
||||
delete(m.groups, tunnelID)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Close groups without holding lock to avoid blocking other operations
|
||||
for _, group := range staleGroups {
|
||||
group.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the manager
|
||||
func (m *ConnectionGroupManager) Close() {
|
||||
close(m.stopCh)
|
||||
|
||||
// Collect all groups under lock
|
||||
m.mu.Lock()
|
||||
groups := make([]*ConnectionGroup, 0, len(m.groups))
|
||||
for _, group := range m.groups {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
m.groups = make(map[string]*ConnectionGroup)
|
||||
m.mu.Unlock()
|
||||
|
||||
// Close groups without holding lock
|
||||
for _, group := range groups {
|
||||
group.Close()
|
||||
}
|
||||
}
|
||||
@@ -12,32 +12,34 @@ import (
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/recovery"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Listener handles TCP connections with TLS 1.3
|
||||
type Listener struct {
|
||||
address string
|
||||
tlsConfig *tls.Config
|
||||
authToken string
|
||||
manager *tunnel.Manager
|
||||
portAlloc *PortAllocator
|
||||
logger *zap.Logger
|
||||
domain string
|
||||
publicPort int
|
||||
httpHandler http.Handler
|
||||
responseChans HTTPResponseHandler
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
connections map[string]*Connection
|
||||
connMu sync.RWMutex
|
||||
workerPool *pool.WorkerPool // Worker pool for connection handling
|
||||
recoverer *recovery.Recoverer
|
||||
address string
|
||||
tlsConfig *tls.Config
|
||||
authToken string
|
||||
manager *tunnel.Manager
|
||||
portAlloc *PortAllocator
|
||||
logger *zap.Logger
|
||||
domain string
|
||||
publicPort int
|
||||
httpHandler http.Handler
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
connections map[string]*Connection
|
||||
connMu sync.RWMutex
|
||||
workerPool *pool.WorkerPool // Worker pool for connection handling
|
||||
recoverer *recovery.Recoverer
|
||||
panicMetrics *recovery.PanicMetrics
|
||||
|
||||
groupManager *ConnectionGroupManager
|
||||
}
|
||||
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Listener {
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler) *Listener {
|
||||
numCPU := pool.NumCPU()
|
||||
workers := numCPU * 5
|
||||
queueSize := workers * 20
|
||||
@@ -53,21 +55,21 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
recoverer := recovery.NewRecoverer(logger, panicMetrics)
|
||||
|
||||
return &Listener{
|
||||
address: address,
|
||||
tlsConfig: tlsConfig,
|
||||
authToken: authToken,
|
||||
manager: manager,
|
||||
portAlloc: portAlloc,
|
||||
logger: logger,
|
||||
domain: domain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
responseChans: responseChans,
|
||||
stopCh: make(chan struct{}),
|
||||
connections: make(map[string]*Connection),
|
||||
workerPool: workerPool,
|
||||
recoverer: recoverer,
|
||||
panicMetrics: panicMetrics,
|
||||
address: address,
|
||||
tlsConfig: tlsConfig,
|
||||
authToken: authToken,
|
||||
manager: manager,
|
||||
portAlloc: portAlloc,
|
||||
logger: logger,
|
||||
domain: domain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
stopCh: make(chan struct{}),
|
||||
connections: make(map[string]*Connection),
|
||||
workerPool: workerPool,
|
||||
recoverer: recoverer,
|
||||
panicMetrics: panicMetrics,
|
||||
groupManager: NewConnectionGroupManager(logger),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,7 +208,7 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.responseChans)
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.groupManager)
|
||||
|
||||
connID := netConn.RemoteAddr().String()
|
||||
l.connMu.Lock()
|
||||
@@ -222,14 +224,11 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
if err := conn.Handle(); err != nil {
|
||||
errStr := err.Error()
|
||||
|
||||
// Client disconnection errors - normal network behavior, log as DEBUG
|
||||
if strings.Contains(errStr, "connection reset by peer") ||
|
||||
// Client disconnection errors - normal network behavior, ignore
|
||||
if strings.Contains(errStr, "EOF") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
l.logger.Debug("Client disconnected",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -277,6 +276,10 @@ func (l *Listener) Stop() error {
|
||||
l.workerPool.Close()
|
||||
}
|
||||
|
||||
if l.groupManager != nil {
|
||||
l.groupManager.Close()
|
||||
}
|
||||
|
||||
l.logger.Info("TCP listener stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,64 +1,79 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TunnelProxy handles TCP connections for a specific tunnel
|
||||
type TunnelProxy struct {
|
||||
port int
|
||||
subdomain string
|
||||
tcpConn net.Conn // The tunnel control connection
|
||||
listener net.Listener
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
clientAddr string
|
||||
streams map[string]*proxyStream // streamID -> stream info
|
||||
streamMu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
bufferPool *pool.BufferPool
|
||||
// Proxy exposes a public TCP port and forwards each incoming
|
||||
// connection over a dedicated mux stream.
|
||||
type Proxy struct {
|
||||
port int
|
||||
subdomain string
|
||||
logger *zap.Logger
|
||||
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
once sync.Once
|
||||
wg sync.WaitGroup
|
||||
|
||||
openStream func() (net.Conn, error)
|
||||
stats trafficStats
|
||||
sem chan struct{}
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// proxyStream holds connection info with close state
|
||||
type proxyStream struct {
|
||||
conn net.Conn
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
type trafficStats interface {
|
||||
AddBytesIn(n int64)
|
||||
AddBytesOut(n int64)
|
||||
IncActiveConnections()
|
||||
DecActiveConnections()
|
||||
}
|
||||
|
||||
// NewTunnelProxy creates a new TCP tunnel proxy
|
||||
func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Logger) *TunnelProxy {
|
||||
return &TunnelProxy{
|
||||
port: port,
|
||||
subdomain: subdomain,
|
||||
tcpConn: tcpConn,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
clientAddr: tcpConn.RemoteAddr().String(),
|
||||
streams: make(map[string]*proxyStream),
|
||||
bufferPool: pool.NewBufferPool(),
|
||||
frameWriter: protocol.NewFrameWriter(tcpConn),
|
||||
func NewProxy(ctx context.Context, port int, subdomain string, openStream func() (net.Conn, error), stats trafficStats, logger *zap.Logger) *Proxy {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
cctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
const maxConcurrentConnections = 10000
|
||||
var sem chan struct{}
|
||||
if maxConcurrentConnections > 0 {
|
||||
sem = make(chan struct{}, maxConcurrentConnections)
|
||||
}
|
||||
|
||||
return &Proxy{
|
||||
port: port,
|
||||
subdomain: subdomain,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
openStream: openStream,
|
||||
stats: stats,
|
||||
sem: sem,
|
||||
ctx: cctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts listening on the allocated port
|
||||
func (p *TunnelProxy) Start() error {
|
||||
func (p *Proxy) Start() error {
|
||||
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
|
||||
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
|
||||
}
|
||||
|
||||
p.listener = listener
|
||||
p.listener = ln
|
||||
|
||||
p.logger.Info("TCP proxy started",
|
||||
zap.Int("port", p.port),
|
||||
@@ -67,14 +82,47 @@ func (p *TunnelProxy) Start() error {
|
||||
|
||||
p.wg.Add(1)
|
||||
go p.acceptLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop accepts incoming TCP connections
|
||||
func (p *TunnelProxy) acceptLoop() {
|
||||
func (p *Proxy) Stop() {
|
||||
p.once.Do(func() {
|
||||
close(p.stopCh)
|
||||
p.cancel()
|
||||
|
||||
if p.listener != nil {
|
||||
_ = p.listener.Close()
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
p.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
const stopTimeout = 30 * time.Second
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
p.logger.Info("TCP proxy stopped",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
)
|
||||
case <-time.After(stopTimeout):
|
||||
p.logger.Warn("TCP proxy stop timed out",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
zap.Duration("timeout", stopTimeout),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Proxy) acceptLoop() {
|
||||
defer p.wg.Done()
|
||||
|
||||
tcpLn, _ := p.listener.(*net.TCPListener)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
@@ -82,11 +130,13 @@ func (p *TunnelProxy) acceptLoop() {
|
||||
default:
|
||||
}
|
||||
|
||||
p.listener.(*net.TCPListener).SetDeadline(time.Now().Add(1 * time.Second))
|
||||
if tcpLn != nil {
|
||||
_ = tcpLn.SetDeadline(time.Now().Add(1 * time.Second))
|
||||
}
|
||||
|
||||
conn, err := p.listener.Accept()
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
@@ -98,187 +148,86 @@ func (p *TunnelProxy) acceptLoop() {
|
||||
}
|
||||
|
||||
p.wg.Add(1)
|
||||
go p.handleConnection(conn)
|
||||
go p.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) handleConnection(conn net.Conn) {
|
||||
func (p *Proxy) handleConn(conn net.Conn) {
|
||||
defer p.wg.Done()
|
||||
defer conn.Close()
|
||||
|
||||
streamID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), p.port)
|
||||
|
||||
stream := &proxyStream{
|
||||
conn: conn,
|
||||
closed: false,
|
||||
if p.sem != nil {
|
||||
select {
|
||||
case p.sem <- struct{}{}:
|
||||
defer func() { <-p.sem }()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
p.streamMu.Lock()
|
||||
p.streams[streamID] = stream
|
||||
p.streamMu.Unlock()
|
||||
if p.stats != nil {
|
||||
p.stats.IncActiveConnections()
|
||||
defer p.stats.DecActiveConnections()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
p.streamMu.Lock()
|
||||
delete(p.streams, streamID)
|
||||
p.streamMu.Unlock()
|
||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||
_ = tcpConn.SetNoDelay(true)
|
||||
_ = tcpConn.SetKeepAlive(true)
|
||||
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
_ = tcpConn.SetReadBuffer(256 * 1024)
|
||||
_ = tcpConn.SetWriteBuffer(256 * 1024)
|
||||
}
|
||||
|
||||
if p.openStream == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Open stream with timeout to prevent goroutine leak
|
||||
const openStreamTimeout = 10 * time.Second
|
||||
type streamResult struct {
|
||||
stream net.Conn
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan streamResult, 1)
|
||||
|
||||
go func() {
|
||||
s, err := p.openStream()
|
||||
resultCh <- streamResult{s, err}
|
||||
}()
|
||||
|
||||
bufPtr := p.bufferPool.Get(pool.SizeMedium)
|
||||
defer p.bufferPool.Put(bufPtr)
|
||||
|
||||
buffer := (*bufPtr)[:pool.SizeMedium]
|
||||
|
||||
for {
|
||||
// Check if stream is closed
|
||||
stream.mu.Lock()
|
||||
closed := stream.closed
|
||||
stream.mu.Unlock()
|
||||
if closed {
|
||||
break
|
||||
}
|
||||
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
if err := p.sendDataToTunnel(streamID, buffer[:n]); err != nil {
|
||||
p.logger.Debug("Send to tunnel failed", zap.Error(err))
|
||||
break
|
||||
var stream net.Conn
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
if result.err != nil {
|
||||
if !errors.Is(result.err, net.ErrClosed) {
|
||||
p.logger.Debug("Open stream failed", zap.Error(result.err))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
stream = result.stream
|
||||
case <-time.After(openStreamTimeout):
|
||||
p.logger.Debug("Open stream timeout")
|
||||
return
|
||||
case <-p.stopCh:
|
||||
default:
|
||||
p.sendCloseToTunnel(streamID)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) sendDataToTunnel(streamID string, data []byte) error {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return fmt.Errorf("tunnel proxy stopped")
|
||||
default:
|
||||
}
|
||||
|
||||
header := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: streamID,
|
||||
Type: protocol.DataTypeData,
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode payload: %w", err)
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
|
||||
err = p.frameWriter.WriteFrame(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write frame: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) sendCloseToTunnel(streamID string) {
|
||||
header := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: streamID,
|
||||
Type: protocol.DataTypeClose,
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
p.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
|
||||
p.streamMu.RLock()
|
||||
stream, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
// Stream may have been closed by client, this is normal
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if stream is closed
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
stream.mu.Unlock()
|
||||
|
||||
if _, err := stream.conn.Write(data); err != nil {
|
||||
p.logger.Debug("Write to client failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseStream closes a stream
|
||||
func (p *TunnelProxy) CloseStream(streamID string) {
|
||||
p.streamMu.RLock()
|
||||
stream, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as closed first
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
return
|
||||
}
|
||||
stream.closed = true
|
||||
stream.mu.Unlock()
|
||||
|
||||
// Now close the connection
|
||||
stream.conn.Close()
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) Stop() {
|
||||
p.logger.Info("Stopping TCP proxy",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
_ = netutil.PipeWithCallbacksAndBufferSize(
|
||||
p.ctx,
|
||||
conn,
|
||||
stream,
|
||||
pool.SizeLarge,
|
||||
func(n int64) {
|
||||
if p.stats != nil {
|
||||
p.stats.AddBytesIn(n)
|
||||
}
|
||||
},
|
||||
func(n int64) {
|
||||
if p.stats != nil {
|
||||
p.stats.AddBytesOut(n)
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
close(p.stopCh)
|
||||
|
||||
if p.listener != nil {
|
||||
p.listener.Close()
|
||||
}
|
||||
|
||||
p.streamMu.Lock()
|
||||
for _, stream := range p.streams {
|
||||
stream.mu.Lock()
|
||||
stream.closed = true
|
||||
stream.mu.Unlock()
|
||||
stream.conn.Close()
|
||||
}
|
||||
p.streams = make(map[string]*proxyStream)
|
||||
p.streamMu.Unlock()
|
||||
|
||||
p.wg.Wait()
|
||||
|
||||
if p.frameWriter != nil {
|
||||
p.frameWriter.Close()
|
||||
}
|
||||
|
||||
p.logger.Info("TCP proxy stopped", zap.Int("port", p.port))
|
||||
}
|
||||
|
||||
98
internal/server/tcp/tunnel.go
Normal file
98
internal/server/tcp/tunnel.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
)
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
openStream := session.Open
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
group.AddSession("primary", session)
|
||||
openStream = group.OpenStream
|
||||
}
|
||||
}
|
||||
|
||||
c.proxy = NewProxy(c.ctx, c.port, c.subdomain, openStream, c.tunnelConn, c.logger)
|
||||
if err := c.proxy.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start tcp proxy: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
openStream := session.Open
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
group.AddSession("primary", session)
|
||||
openStream = group.OpenStream
|
||||
}
|
||||
}
|
||||
|
||||
if c.tunnelConn != nil {
|
||||
c.tunnelConn.SetOpenStream(openStream)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/protocol"
|
||||
@@ -9,13 +11,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Transport represents the control channel to the client.
|
||||
// It is implemented by the TCP control connection so the HTTP proxy
|
||||
// can push frames directly to the client without depending on WebSockets.
|
||||
type Transport interface {
|
||||
SendFrame(frame *protocol.Frame) error
|
||||
}
|
||||
|
||||
// Connection represents a tunnel connection from a client
|
||||
type Connection struct {
|
||||
Subdomain string
|
||||
@@ -26,8 +21,12 @@ type Connection struct {
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
closed bool
|
||||
transport Transport
|
||||
tunnelType protocol.TunnelType
|
||||
openStream func() (net.Conn, error)
|
||||
|
||||
bytesIn atomic.Int64
|
||||
bytesOut atomic.Int64
|
||||
activeConnections atomic.Int64
|
||||
}
|
||||
|
||||
// NewConnection creates a new tunnel connection
|
||||
@@ -106,21 +105,6 @@ func (c *Connection) IsClosed() bool {
|
||||
return c.closed
|
||||
}
|
||||
|
||||
// SetTransport attaches the control transport and tunnel type.
|
||||
func (c *Connection) SetTransport(t Transport, tType protocol.TunnelType) {
|
||||
c.mu.Lock()
|
||||
c.transport = t
|
||||
c.tunnelType = tType
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetTransport returns the attached transport (if any).
|
||||
func (c *Connection) GetTransport() Transport {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.transport
|
||||
}
|
||||
|
||||
// SetTunnelType sets the tunnel type.
|
||||
func (c *Connection) SetTunnelType(tType protocol.TunnelType) {
|
||||
c.mu.Lock()
|
||||
@@ -135,6 +119,63 @@ func (c *Connection) GetTunnelType() protocol.TunnelType {
|
||||
return c.tunnelType
|
||||
}
|
||||
|
||||
// SetOpenStream registers a yamux stream opener for this tunnel.
|
||||
// It is used by the HTTP proxy to forward each request over a mux stream.
|
||||
func (c *Connection) SetOpenStream(open func() (net.Conn, error)) {
|
||||
c.mu.Lock()
|
||||
c.openStream = open
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// OpenStream opens a new mux stream to the tunnel client.
|
||||
func (c *Connection) OpenStream() (net.Conn, error) {
|
||||
c.mu.RLock()
|
||||
open := c.openStream
|
||||
closed := c.closed
|
||||
c.mu.RUnlock()
|
||||
|
||||
if closed || open == nil {
|
||||
return nil, ErrConnectionClosed
|
||||
}
|
||||
return open()
|
||||
}
|
||||
|
||||
func (c *Connection) AddBytesIn(n int64) {
|
||||
if n <= 0 {
|
||||
return
|
||||
}
|
||||
c.bytesIn.Add(n)
|
||||
}
|
||||
|
||||
func (c *Connection) AddBytesOut(n int64) {
|
||||
if n <= 0 {
|
||||
return
|
||||
}
|
||||
c.bytesOut.Add(n)
|
||||
}
|
||||
|
||||
func (c *Connection) GetBytesIn() int64 {
|
||||
return c.bytesIn.Load()
|
||||
}
|
||||
|
||||
func (c *Connection) GetBytesOut() int64 {
|
||||
return c.bytesOut.Load()
|
||||
}
|
||||
|
||||
func (c *Connection) IncActiveConnections() {
|
||||
c.activeConnections.Add(1)
|
||||
}
|
||||
|
||||
func (c *Connection) DecActiveConnections() {
|
||||
if v := c.activeConnections.Add(-1); v < 0 {
|
||||
c.activeConnections.Store(0)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) GetActiveConnections() int64 {
|
||||
return c.activeConnections.Load()
|
||||
}
|
||||
|
||||
// StartWritePump starts the write pump for sending messages
|
||||
func (c *Connection) StartWritePump() {
|
||||
// Skip write pump for TCP-only connections (no WebSocket)
|
||||
|
||||
Reference in New Issue
Block a user