feat(tunnel): switch to yamux stream proxying and connection pooling

- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server
- Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly
- Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
This commit is contained in:
Gouryella
2025-12-13 18:03:44 +08:00
parent 3c93789266
commit 0c19c3300c
55 changed files with 3380 additions and 4849 deletions

View File

@@ -1,490 +1,251 @@
package proxy
import (
"context"
"bufio"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
json "github.com/goccy/go-json"
"drip/internal/server/tunnel"
"drip/internal/shared/pool"
"drip/internal/shared/httputil"
"drip/internal/shared/netutil"
"drip/internal/shared/protocol"
"drip/internal/shared/utils"
"go.uber.org/zap"
)
const openStreamTimeout = 10 * time.Second
type Handler struct {
manager *tunnel.Manager
logger *zap.Logger
responses *ResponseHandler
domain string
authToken string
headerPool *pool.HeaderPool
bufferPool *pool.AdaptiveBufferPool
manager *tunnel.Manager
logger *zap.Logger
domain string
authToken string
}
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, responses *ResponseHandler, domain string, authToken string) *Handler {
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, authToken string) *Handler {
return &Handler{
manager: manager,
logger: logger,
responses: responses,
domain: domain,
authToken: authToken,
headerPool: pool.NewHeaderPool(),
bufferPool: pool.NewAdaptiveBufferPool(),
manager: manager,
logger: logger,
domain: domain,
authToken: authToken,
}
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Always handle /health and /stats directly, regardless of subdomain
// Always handle /health and /stats directly, regardless of subdomain.
if r.URL.Path == "/health" {
h.serveHealth(w, r)
return
}
if r.URL.Path == "/stats" {
h.serveStats(w, r)
return
}
subdomain := h.extractSubdomain(r.Host)
if subdomain == "" {
h.serveHomePage(w, r)
return
}
conn, ok := h.manager.Get(subdomain)
if !ok {
tconn, ok := h.manager.Get(subdomain)
if !ok || tconn == nil {
http.Error(w, "Tunnel not found. The tunnel may have been closed.", http.StatusNotFound)
return
}
if conn.IsClosed() {
if tconn.IsClosed() {
http.Error(w, "Tunnel connection closed", http.StatusBadGateway)
return
}
transport := conn.GetTransport()
if transport == nil {
http.Error(w, "Tunnel control channel not ready", http.StatusBadGateway)
return
}
tType := conn.GetTunnelType()
tType := tconn.GetTunnelType()
if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS {
http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway)
return
}
requestID := utils.GenerateID()
// Check for WebSocket upgrade
if httputil.IsWebSocketUpgrade(r) {
h.handleWebSocket(w, r, tconn)
return
}
h.handleAdaptiveRequest(w, r, transport, requestID, subdomain)
}
// Open stream with timeout
stream, err := h.openStreamWithTimeout(tconn)
if err != nil {
w.Header().Set("Connection", "close")
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
return
}
defer stream.Close()
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
const streamingThreshold int64 = 1 * 1024 * 1024
// Track active connections
tconn.IncActiveConnections()
defer tconn.DecActiveConnections()
// Wrap stream with counting for traffic stats
countingStream := netutil.NewCountingConn(stream,
tconn.AddBytesOut, // Data read from stream = bytes out to client
tconn.AddBytesIn, // Data written to stream = bytes in from client
)
// 1) Write request over the stream (net/http handles large bodies correctly).
if err := r.Write(countingStream); err != nil {
w.Header().Set("Connection", "close")
_ = r.Body.Close()
http.Error(w, "Forward failed", http.StatusBadGateway)
return
}
// 2) Read response from stream.
resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r)
if err != nil {
w.Header().Set("Connection", "close")
http.Error(w, "Read response failed", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 3) Copy headers (strip hop-by-hop).
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
statusCode := resp.StatusCode
if statusCode == 0 {
statusCode = http.StatusOK
}
// Ensure message delimiting works with our custom ResponseWriter:
// - If Content-Length is known, send it.
// - Otherwise, re-chunk the decoded body ourselves.
if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified {
if resp.ContentLength >= 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
} else {
w.Header().Del("Content-Length")
}
w.WriteHeader(statusCode)
return
}
if resp.ContentLength >= 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
w.WriteHeader(statusCode)
ctx := r.Context()
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
stream.Close()
case <-done:
}
}()
_, _ = io.Copy(w, resp.Body)
close(done)
stream.Close()
return
}
w.Header().Del("Content-Length")
w.Header().Set("Transfer-Encoding", "chunked")
if len(resp.Trailer) > 0 {
w.Header().Set("Trailer", trailerKeys(resp.Trailer))
}
w.WriteHeader(statusCode)
ctx := r.Context()
var cancelTransport func()
if transport != nil {
cancelOnce := sync.Once{}
cancelFunc := func() {
header := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeClose,
IsLast: true,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
if err != nil {
return
}
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
if err := transport.SendFrame(frame); err != nil {
h.logger.Debug("Failed to send cancel frame to client",
zap.String("request_id", requestID),
zap.Error(err),
)
}
}
cancelTransport = func() {
cancelOnce.Do(cancelFunc)
}
h.responses.RegisterCancelFunc(requestID, cancelTransport)
defer h.responses.CleanupCancelFunc(requestID)
}
largeBufferPtr := h.bufferPool.GetLarge()
tempBufPtr := h.bufferPool.GetMedium()
defer func() {
h.bufferPool.PutLarge(largeBufferPtr)
h.bufferPool.PutMedium(tempBufPtr)
}()
buffer := (*largeBufferPtr)[:0]
tempBuf := (*tempBufPtr)[:pool.MediumBufferSize]
var totalRead int64
var hitThreshold bool
for totalRead < streamingThreshold {
n, err := r.Body.Read(tempBuf)
if n > 0 {
buffer = append(buffer, tempBuf[:n]...)
totalRead += int64(n)
}
if err == io.EOF {
r.Body.Close()
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
return
}
if err != nil {
r.Body.Close()
h.logger.Error("Read request body failed", zap.Error(err))
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
if totalRead >= streamingThreshold {
hitThreshold = true
break
}
}
if !hitThreshold {
r.Body.Close()
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
return
}
h.streamLargeRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
}
func (h *Handler) sendBufferedRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), body []byte) {
headers := h.headerPool.Get()
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
httpReq := protocol.HTTPRequest{
Method: r.Method,
URL: r.URL.String(),
Headers: headers,
Body: body,
}
reqBytes, err := protocol.EncodeHTTPRequest(&httpReq)
h.headerPool.Put(headers)
if err != nil {
h.logger.Error("Encode HTTP request failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
header := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequest,
IsLast: true,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, reqBytes)
if err != nil {
h.logger.Error("Encode data payload failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
respChan := h.responses.CreateResponseChan(requestID)
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
defer func() {
h.responses.CleanupResponseChan(requestID)
h.responses.CleanupStreamingResponse(requestID)
}()
if err := transport.SendFrame(frame); err != nil {
h.logger.Error("Send frame to tunnel failed", zap.Error(err))
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
return
}
select {
case respMsg := <-respChan:
if respMsg == nil {
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
return
}
h.writeHTTPResponse(w, respMsg, subdomain, r)
case <-streamingDone:
// Streaming response has been fully written by SendStreamingChunk
case <-ctx.Done():
if cancelTransport != nil {
cancelTransport()
}
h.logger.Debug("HTTP request context cancelled",
zap.String("request_id", requestID),
zap.String("subdomain", subdomain),
)
return
case <-time.After(5 * time.Minute):
h.logger.Error("Request timeout",
zap.String("request_id", requestID),
zap.String("url", r.URL.String()),
)
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
}
}
func (h *Handler) streamLargeRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), bufferedData []byte) {
headers := h.headerPool.Get()
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
httpReqHead := protocol.HTTPRequestHead{
Method: r.Method,
URL: r.URL.String(),
Headers: headers,
ContentLength: r.ContentLength,
}
headBytes, err := protocol.EncodeHTTPRequestHead(&httpReqHead)
h.headerPool.Put(headers)
if err != nil {
h.logger.Error("Encode HTTP request head failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
headHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPHead, // shared streaming head type
IsLast: false,
}
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
if err != nil {
h.logger.Error("Encode head payload failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
respChan := h.responses.CreateResponseChan(requestID)
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
defer func() {
h.responses.CleanupResponseChan(requestID)
h.responses.CleanupStreamingResponse(requestID)
}()
if err := transport.SendFrame(headFrame); err != nil {
h.logger.Error("Send head frame failed", zap.Error(err))
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
return
}
if len(bufferedData) > 0 {
chunkHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
IsLast: false,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, bufferedData)
if err != nil {
h.logger.Error("Encode buffered chunk failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := transport.SendFrame(chunkFrame); err != nil {
h.logger.Error("Send buffered chunk failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
}
streamBufPtr := h.bufferPool.GetMedium()
defer h.bufferPool.PutMedium(streamBufPtr)
buffer := (*streamBufPtr)[:pool.MediumBufferSize]
for {
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
if cancelTransport != nil {
cancelTransport()
}
h.logger.Debug("Streaming request cancelled via context",
zap.String("request_id", requestID),
zap.String("subdomain", subdomain),
)
return
default:
stream.Close()
case <-done:
}
}()
n, readErr := r.Body.Read(buffer)
if n > 0 {
isLast := readErr == io.EOF
chunkHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
IsLast: isLast,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer[:n])
if err != nil {
h.logger.Error("Encode chunk payload failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := transport.SendFrame(chunkFrame); err != nil {
h.logger.Error("Send chunk frame failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
}
if readErr == io.EOF {
if n == 0 {
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
}
break
}
if readErr != nil {
h.logger.Error("Read request body failed", zap.Error(readErr))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
if err := writeChunked(w, resp.Body, resp.Trailer); err != nil {
h.logger.Debug("Write chunked response failed", zap.Error(err))
}
close(done)
stream.Close()
}
r.Body.Close()
func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, error) {
type result struct {
stream net.Conn
err error
}
ch := make(chan result, 1)
go func() {
s, err := tconn.OpenStream()
ch <- result{s, err}
}()
select {
case respMsg := <-respChan:
if respMsg == nil {
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
return
}
h.writeHTTPResponse(w, respMsg, subdomain, r)
case <-streamingDone:
// Streaming response has been fully written by SendStreamingChunk
case <-ctx.Done():
if cancelTransport != nil {
cancelTransport()
}
h.logger.Debug("Streaming HTTP request context cancelled",
zap.String("request_id", requestID),
zap.String("subdomain", subdomain),
)
return
case <-time.After(5 * time.Minute):
h.logger.Error("Streaming request timeout",
zap.String("request_id", requestID),
zap.String("url", r.URL.String()),
)
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
case r := <-ch:
return r.stream, r.err
case <-time.After(openStreamTimeout):
return nil, fmt.Errorf("open stream timeout")
}
}
func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPResponse, subdomain string, r *http.Request) {
if resp == nil {
http.Error(w, "Invalid response from tunnel", http.StatusBadGateway)
func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection) {
stream, err := h.openStreamWithTimeout(tconn)
if err != nil {
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
return
}
// For buffered responses, we have the complete body, so we can set Content-Length
// Skip ALL hop-by-hop headers - client should have already cleaned them
for key, values := range resp.Headers {
tconn.IncActiveConnections()
hj, ok := w.(http.Hijacker)
if !ok {
stream.Close()
tconn.DecActiveConnections()
http.Error(w, "WebSocket not supported", http.StatusInternalServerError)
return
}
clientConn, _, err := hj.Hijack()
if err != nil {
stream.Close()
tconn.DecActiveConnections()
http.Error(w, "Failed to hijack connection", http.StatusInternalServerError)
return
}
if err := r.Write(stream); err != nil {
stream.Close()
clientConn.Close()
tconn.DecActiveConnections()
return
}
go func() {
defer stream.Close()
defer clientConn.Close()
defer tconn.DecActiveConnections()
_ = netutil.PipeWithCallbacks(r.Context(), stream, clientConn,
func(n int64) { tconn.AddBytesOut(n) },
func(n int64) { tconn.AddBytesIn(n) },
)
}()
}
func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHost string) {
for key, values := range src {
canonicalKey := http.CanonicalHeaderKey(key)
// Skip hop-by-hop headers completely using canonical key comparison
// Hop-by-hop headers must not be forwarded.
if canonicalKey == "Connection" ||
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
@@ -496,29 +257,61 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
}
if canonicalKey == "Location" && len(values) > 0 {
rewrittenLocation := h.rewriteLocationHeader(values[0], r.Host)
w.Header().Set("Location", rewrittenLocation)
dst.Set("Location", h.rewriteLocationHeader(values[0], proxyHost))
continue
}
for _, value := range values {
w.Header().Add(key, value)
dst.Add(key, value)
}
}
}
func trailerKeys(hdr http.Header) string {
keys := make([]string, 0, len(hdr))
for k := range hdr {
keys = append(keys, k)
}
// Deterministic order is nicer for debugging; no semantic impact.
sortStrings(keys)
return strings.Join(keys, ", ")
}
func writeChunked(w io.Writer, r io.Reader, trailer http.Header) error {
buf := make([]byte, 32*1024)
for {
n, err := r.Read(buf)
if n > 0 {
if _, werr := fmt.Fprintf(w, "%x\r\n", n); werr != nil {
return werr
}
if _, werr := w.Write(buf[:n]); werr != nil {
return werr
}
if _, werr := io.WriteString(w, "\r\n"); werr != nil {
return werr
}
}
if err == io.EOF {
break
}
if err != nil {
return err
}
}
// For buffered mode, always set Content-Length with the actual body size
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
statusCode := resp.StatusCode
if statusCode == 0 {
statusCode = http.StatusOK
if _, err := io.WriteString(w, "0\r\n"); err != nil {
return err
}
w.WriteHeader(statusCode)
if len(resp.Body) > 0 {
w.Write(resp.Body)
for k, vv := range trailer {
for _, v := range vv {
if _, err := io.WriteString(w, fmt.Sprintf("%s: %s\r\n", k, v)); err != nil {
return err
}
}
}
_, err := io.WriteString(w, "\r\n")
return err
}
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
@@ -535,22 +328,13 @@ func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
strings.HasPrefix(locationURL.Host, "localhost:") ||
locationURL.Host == "127.0.0.1" ||
strings.HasPrefix(locationURL.Host, "127.0.0.1:") {
scheme := "https"
if strings.Contains(proxyHost, ":") && !strings.Contains(proxyHost, "https") {
parts := strings.Split(proxyHost, ":")
if len(parts) == 2 && parts[1] != "443" {
scheme = "https"
}
}
rewritten := fmt.Sprintf("%s://%s%s", scheme, proxyHost, locationURL.Path)
rewritten := fmt.Sprintf("https://%s%s", proxyHost, locationURL.Path)
if locationURL.RawQuery != "" {
rewritten += "?" + locationURL.RawQuery
}
if locationURL.Fragment != "" {
rewritten += "#" + locationURL.Fragment
}
return rewritten
}
@@ -568,8 +352,7 @@ func (h *Handler) extractSubdomain(host string) string {
suffix := "." + h.domain
if strings.HasSuffix(host, suffix) {
subdomain := strings.TrimSuffix(host, suffix)
return subdomain
return strings.TrimSuffix(host, suffix)
}
return ""
@@ -652,9 +435,17 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
}
for _, conn := range connections {
if conn == nil {
continue
}
stats["tunnels"] = append(stats["tunnels"].([]map[string]interface{}), map[string]interface{}{
"subdomain": conn.Subdomain,
"last_active": conn.LastActive.Unix(),
"subdomain": conn.Subdomain,
"tunnel_type": string(conn.GetTunnelType()),
"last_active": conn.LastActive.Unix(),
"bytes_in": conn.GetBytesIn(),
"bytes_out": conn.GetBytesOut(),
"active_connections": conn.GetActiveConnections(),
"total_bytes": conn.GetBytesIn() + conn.GetBytesOut(),
})
}
@@ -668,3 +459,13 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
func sortStrings(s []string) {
for i := 0; i < len(s); i++ {
for j := i + 1; j < len(s); j++ {
if s[j] < s[i] {
s[i], s[j] = s[j], s[i]
}
}
}
}

View File

@@ -1,421 +0,0 @@
package proxy
import (
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// responseChanEntry holds a response channel and its creation time
type responseChanEntry struct {
ch chan *protocol.HTTPResponse
createdAt time.Time
}
// streamingResponseEntry holds a streaming response writer
type streamingResponseEntry struct {
w http.ResponseWriter
flusher http.Flusher
createdAt time.Time
lastActivityAt time.Time
headersSent bool
done chan struct{}
mu sync.Mutex
}
// ResponseHandler manages response channels for HTTP requests over TCP/Frame protocol
type ResponseHandler struct {
channels map[string]*responseChanEntry
streamingChannels map[string]*streamingResponseEntry
cancelFuncs map[string]func()
mu sync.RWMutex
logger *zap.Logger
stopCh chan struct{}
}
// NewResponseHandler creates a new response handler
func NewResponseHandler(logger *zap.Logger) *ResponseHandler {
h := &ResponseHandler{
channels: make(map[string]*responseChanEntry),
streamingChannels: make(map[string]*streamingResponseEntry),
cancelFuncs: make(map[string]func()),
logger: logger,
stopCh: make(chan struct{}),
}
// Start single cleanup goroutine instead of one per request
go h.cleanupLoop()
return h
}
// CreateResponseChan creates a response channel for a request ID
func (h *ResponseHandler) CreateResponseChan(requestID string) chan *protocol.HTTPResponse {
h.mu.Lock()
defer h.mu.Unlock()
ch := make(chan *protocol.HTTPResponse, 1)
h.channels[requestID] = &responseChanEntry{
ch: ch,
createdAt: time.Now(),
}
return ch
}
// CreateStreamingResponse creates a streaming response entry for a request ID
func (h *ResponseHandler) CreateStreamingResponse(requestID string, w http.ResponseWriter) chan struct{} {
h.mu.Lock()
defer h.mu.Unlock()
flusher, _ := w.(http.Flusher)
done := make(chan struct{})
now := time.Now()
h.streamingChannels[requestID] = &streamingResponseEntry{
w: w,
flusher: flusher,
createdAt: now,
lastActivityAt: now,
done: done,
}
return done
}
// RegisterCancelFunc registers a callback to be invoked when the downstream disconnects.
func (h *ResponseHandler) RegisterCancelFunc(requestID string, cancel func()) {
if cancel == nil {
return
}
h.mu.Lock()
h.cancelFuncs[requestID] = cancel
h.mu.Unlock()
}
// GetResponseChan gets the response channel for a request ID
func (h *ResponseHandler) GetResponseChan(requestID string) <-chan *protocol.HTTPResponse {
h.mu.RLock()
defer h.mu.RUnlock()
if entry := h.channels[requestID]; entry != nil {
return entry.ch
}
return nil
}
// SendResponse sends a response to the waiting channel
func (h *ResponseHandler) SendResponse(requestID string, resp *protocol.HTTPResponse) {
h.mu.RLock()
entry, exists := h.channels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
return
}
select {
case entry.ch <- resp:
case <-time.After(30 * time.Second):
h.logger.Error("Timeout sending response to channel - handler may have abandoned",
zap.String("request_id", requestID),
zap.Int("status_code", resp.StatusCode),
zap.Int("body_size", len(resp.Body)),
)
}
}
func (h *ResponseHandler) SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error {
h.mu.RLock()
entry, exists := h.streamingChannels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
return nil
}
entry.mu.Lock()
defer entry.mu.Unlock()
select {
case <-entry.done:
return nil
default:
}
if entry.headersSent {
return nil
}
// Copy headers, removing hop-by-hop headers that were already handled by client
// Client's cleanResponseHeaders already removed Transfer-Encoding, Connection, etc.
// But we need to check again in case they slipped through
hasContentLength := false
for key, values := range head.Headers {
canonicalKey := http.CanonicalHeaderKey(key)
// Skip ALL hop-by-hop headers
if canonicalKey == "Connection" ||
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
canonicalKey == "Upgrade" ||
canonicalKey == "Proxy-Connection" ||
canonicalKey == "Te" ||
canonicalKey == "Trailer" {
continue
}
if canonicalKey == "Content-Length" {
hasContentLength = true
}
for _, value := range values {
entry.w.Header().Add(key, value)
}
}
// For streaming responses, decide how to indicate message length
if head.ContentLength >= 0 && !hasContentLength {
entry.w.Header().Set("Content-Length", fmt.Sprintf("%d", head.ContentLength))
}
statusCode := head.StatusCode
if statusCode == 0 {
statusCode = http.StatusOK
}
entry.w.WriteHeader(statusCode)
entry.headersSent = true
entry.lastActivityAt = time.Now()
if entry.flusher != nil {
entry.flusher.Flush()
}
return nil
}
func (h *ResponseHandler) SendStreamingChunk(requestID string, chunk []byte, isLast bool) error {
h.mu.RLock()
entry, exists := h.streamingChannels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
return nil
}
entry.mu.Lock()
defer entry.mu.Unlock()
select {
case <-entry.done:
return nil
default:
}
if len(chunk) > 0 {
_, err := entry.w.Write(chunk)
if err != nil {
if isClientDisconnectError(err) {
select {
case <-entry.done:
default:
close(entry.done)
}
h.triggerCancel(requestID)
return nil
}
select {
case <-entry.done:
default:
close(entry.done)
}
h.triggerCancel(requestID)
return nil
}
entry.lastActivityAt = time.Now()
if entry.flusher != nil {
entry.flusher.Flush()
}
}
if isLast {
select {
case <-entry.done:
default:
close(entry.done)
}
}
return nil
}
func isClientDisconnectError(err error) bool {
if err == nil {
return false
}
if netErr, ok := err.(*net.OpError); ok {
if netErr.Err != nil {
errStr := netErr.Err.Error()
if strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "connection refused") {
return true
}
}
}
errStr := err.Error()
return strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "use of closed network connection")
}
// triggerCancel invokes and removes the cancel callback for a request.
func (h *ResponseHandler) triggerCancel(requestID string) {
h.mu.Lock()
cancel := h.cancelFuncs[requestID]
if cancel != nil {
delete(h.cancelFuncs, requestID)
}
h.mu.Unlock()
if cancel != nil {
go func() {
cancel()
}()
}
}
func (h *ResponseHandler) CleanupResponseChan(requestID string) {
h.mu.Lock()
defer h.mu.Unlock()
if entry, exists := h.channels[requestID]; exists {
close(entry.ch)
delete(h.channels, requestID)
}
}
func (h *ResponseHandler) CleanupStreamingResponse(requestID string) {
h.mu.Lock()
defer h.mu.Unlock()
if entry, exists := h.streamingChannels[requestID]; exists {
select {
case <-entry.done:
default:
close(entry.done)
}
delete(h.streamingChannels, requestID)
}
}
// CleanupCancelFunc removes a registered cancel callback.
func (h *ResponseHandler) CleanupCancelFunc(requestID string) {
h.mu.Lock()
delete(h.cancelFuncs, requestID)
h.mu.Unlock()
}
func (h *ResponseHandler) GetPendingCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.channels) + len(h.streamingChannels)
}
func (h *ResponseHandler) cleanupLoop() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h.cleanupExpiredChannels()
case <-h.stopCh:
return
}
}
}
func (h *ResponseHandler) cleanupExpiredChannels() {
now := time.Now()
timeout := 5 * time.Minute
streamingTimeout := 5 * time.Minute
h.mu.Lock()
defer h.mu.Unlock()
expiredCount := 0
cancelList := make([]string, 0)
for requestID, entry := range h.channels {
if now.Sub(entry.createdAt) > timeout {
close(entry.ch)
delete(h.channels, requestID)
expiredCount++
}
}
for requestID, entry := range h.streamingChannels {
if now.Sub(entry.lastActivityAt) > streamingTimeout {
select {
case <-entry.done:
default:
close(entry.done)
}
delete(h.streamingChannels, requestID)
cancelList = append(cancelList, requestID)
expiredCount++
}
}
for _, requestID := range cancelList {
if cancel := h.cancelFuncs[requestID]; cancel != nil {
delete(h.cancelFuncs, requestID)
go cancel()
}
}
if expiredCount > 0 {
h.logger.Debug("Cleaned up expired response channels",
zap.Int("count", expiredCount),
zap.Int("remaining", len(h.channels)+len(h.streamingChannels)),
)
}
}
func (h *ResponseHandler) Close() {
close(h.stopCh)
h.mu.Lock()
defer h.mu.Unlock()
for _, entry := range h.channels {
close(entry.ch)
}
h.channels = make(map[string]*responseChanEntry)
for _, entry := range h.streamingChannels {
select {
case <-entry.done:
default:
close(entry.done)
}
}
h.streamingChannels = make(map[string]*streamingResponseEntry)
for _, cancel := range h.cancelFuncs {
cancel()
}
h.cancelFuncs = make(map[string]func())
}