feat: Add HTTP streaming, compression support, and Docker deployment

enhancements

  - Add adaptive HTTP response handling with automatic streaming for large
  responses (>1MB)
  - Implement zero-copy streaming using buffer pools for better performance
  - Add compression module for reduced bandwidth usage
  - Add GitHub Container Registry workflow for automated Docker builds
  - Add production-optimized Dockerfile and docker-compose configuration
  - Simplify background mode with -d flag and improved daemon management
  - Update documentation with new command syntax and deployment guides
  - Clean up unused code and improve error handling
  - Fix lipgloss style usage (remove unnecessary .Copy() calls)
This commit is contained in:
Gouryella
2025-12-05 22:09:07 +08:00
parent b538397a00
commit aead68bb62
31 changed files with 2641 additions and 272 deletions

View File

@@ -1,7 +1,6 @@
package proxy
import (
"context"
"fmt"
"io"
"net/http"
@@ -12,7 +11,6 @@ import (
json "github.com/goccy/go-json"
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/utils"
@@ -71,17 +69,53 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
limitedReader := io.LimitReader(r.Body, constants.MaxRequestBodySize)
body, err := io.ReadAll(limitedReader)
if err != nil {
h.logger.Error("Read request body failed", zap.Error(err))
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
defer r.Body.Close()
requestID := utils.GenerateID()
h.handleAdaptiveRequest(w, r, transport, requestID, subdomain)
}
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
const streamingThreshold int64 = 1 * 1024 * 1024
buffer := make([]byte, 0, streamingThreshold)
tempBuf := make([]byte, 32*1024)
var totalRead int64
var hitThreshold bool
for totalRead < streamingThreshold {
n, err := r.Body.Read(tempBuf)
if n > 0 {
buffer = append(buffer, tempBuf[:n]...)
totalRead += int64(n)
}
if err == io.EOF {
r.Body.Close()
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
return
}
if err != nil {
r.Body.Close()
h.logger.Error("Read request body failed", zap.Error(err))
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
if totalRead >= streamingThreshold {
hitThreshold = true
break
}
}
if !hitThreshold {
r.Body.Close()
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
return
}
h.streamLargeRequest(w, r, transport, requestID, subdomain, buffer)
}
func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, body []byte) {
headers := h.headerPool.Get()
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
@@ -93,7 +127,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
reqBytes, err := protocol.EncodeHTTPRequest(&httpReq)
h.headerPool.Put(headers)
if err != nil {
@@ -119,7 +152,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
respChan := h.responses.CreateResponseChan(requestID)
defer h.responses.CleanupResponseChan(requestID)
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
defer func() {
h.responses.CleanupResponseChan(requestID)
h.responses.CleanupStreamingResponse(requestID)
}()
if err := transport.SendFrame(frame); err != nil {
h.logger.Error("Send frame to tunnel failed", zap.Error(err))
@@ -127,14 +164,220 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), constants.RequestTimeout)
defer cancel()
select {
case respMsg := <-respChan:
if respMsg == nil {
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
return
}
h.writeHTTPResponse(w, respMsg, subdomain, r)
case <-streamingDone:
// Streaming response has been fully written by SendStreamingChunk
case <-time.After(5 * time.Minute):
h.logger.Error("Request timeout",
zap.String("request_id", requestID),
zap.String("url", r.URL.String()),
)
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
}
}
func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, bufferedData []byte) {
headers := h.headerPool.Get()
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
httpReqHead := protocol.HTTPRequestHead{
Method: r.Method,
URL: r.URL.String(),
Headers: headers,
ContentLength: r.ContentLength,
}
headBytes, err := protocol.EncodeHTTPRequestHead(&httpReqHead)
h.headerPool.Put(headers)
if err != nil {
h.logger.Error("Encode HTTP request head failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
headHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPHead, // shared streaming head type
IsLast: false,
}
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
if err != nil {
h.logger.Error("Encode head payload failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
respChan := h.responses.CreateResponseChan(requestID)
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
defer func() {
h.responses.CleanupResponseChan(requestID)
h.responses.CleanupStreamingResponse(requestID)
}()
if err := transport.SendFrame(headFrame); err != nil {
h.logger.Error("Send head frame failed", zap.Error(err))
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
return
}
if len(bufferedData) > 0 {
chunkHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
IsLast: false,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, bufferedData)
if err != nil {
h.logger.Error("Encode buffered chunk failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := transport.SendFrame(chunkFrame); err != nil {
h.logger.Error("Send buffered chunk failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
}
buffer := make([]byte, 32*1024)
for {
n, readErr := r.Body.Read(buffer)
if n > 0 {
isLast := readErr == io.EOF
chunkHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
IsLast: isLast,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer[:n])
if err != nil {
h.logger.Error("Encode chunk payload failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := transport.SendFrame(chunkFrame); err != nil {
h.logger.Error("Send chunk frame failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
}
if readErr == io.EOF {
if n == 0 {
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
}
break
}
if readErr != nil {
h.logger.Error("Read request body failed", zap.Error(readErr))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
}
r.Body.Close()
select {
case respMsg := <-respChan:
if respMsg == nil {
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
return
}
h.writeHTTPResponse(w, respMsg, subdomain, r)
case <-ctx.Done():
case <-streamingDone:
// Streaming response has been fully written by SendStreamingChunk
case <-time.After(5 * time.Minute):
h.logger.Error("Streaming request timeout",
zap.String("request_id", requestID),
zap.String("url", r.URL.String()),
)
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
}
}
@@ -145,12 +388,23 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
return
}
// For buffered responses, we have the complete body, so we can set Content-Length
// Skip ALL hop-by-hop headers - client should have already cleaned them
for key, values := range resp.Headers {
if key == "Connection" || key == "Keep-Alive" || key == "Transfer-Encoding" || key == "Upgrade" {
canonicalKey := http.CanonicalHeaderKey(key)
// Skip hop-by-hop headers completely using canonical key comparison
if canonicalKey == "Connection" ||
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
canonicalKey == "Upgrade" ||
canonicalKey == "Proxy-Connection" ||
canonicalKey == "Te" ||
canonicalKey == "Trailer" {
continue
}
if key == "Location" && len(values) > 0 {
if canonicalKey == "Location" && len(values) > 0 {
rewrittenLocation := h.rewriteLocationHeader(values[0], r.Host)
w.Header().Set("Location", rewrittenLocation)
continue
@@ -161,9 +415,8 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
}
}
if w.Header().Get("Content-Length") == "" && len(resp.Body) > 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
}
// For buffered mode, always set Content-Length with the actual body size
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
statusCode := resp.StatusCode
if statusCode == 0 {
@@ -171,6 +424,7 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
}
w.WriteHeader(statusCode)
if len(resp.Body) > 0 {
w.Write(resp.Body)
}
@@ -284,19 +538,6 @@ func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(health)
}
func cloneHeadersWithHost(src http.Header, host string) http.Header {
dst := make(http.Header, len(src)+1)
for k, v := range src {
copied := make([]string, len(v))
copy(copied, v)
dst[k] = copied
}
if host != "" {
dst.Set("Host", host)
}
return dst
}
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
if h.authToken != "" {
token := r.URL.Query().Get("token")

View File

@@ -1,6 +1,10 @@
package proxy
import (
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
@@ -14,20 +18,32 @@ type responseChanEntry struct {
createdAt time.Time
}
// streamingResponseEntry holds a streaming response writer
type streamingResponseEntry struct {
w http.ResponseWriter
flusher http.Flusher
createdAt time.Time
headersSent bool
done chan struct{}
mu sync.Mutex
}
// ResponseHandler manages response channels for HTTP requests over TCP/Frame protocol
type ResponseHandler struct {
channels map[string]*responseChanEntry
mu sync.RWMutex
logger *zap.Logger
stopCh chan struct{}
channels map[string]*responseChanEntry
streamingChannels map[string]*streamingResponseEntry
mu sync.RWMutex
logger *zap.Logger
stopCh chan struct{}
}
// NewResponseHandler creates a new response handler
func NewResponseHandler(logger *zap.Logger) *ResponseHandler {
h := &ResponseHandler{
channels: make(map[string]*responseChanEntry),
logger: logger,
stopCh: make(chan struct{}),
channels: make(map[string]*responseChanEntry),
streamingChannels: make(map[string]*streamingResponseEntry),
logger: logger,
stopCh: make(chan struct{}),
}
// Start single cleanup goroutine instead of one per request
@@ -50,6 +66,23 @@ func (h *ResponseHandler) CreateResponseChan(requestID string) chan *protocol.HT
return ch
}
// CreateStreamingResponse creates a streaming response entry for a request ID
func (h *ResponseHandler) CreateStreamingResponse(requestID string, w http.ResponseWriter) chan struct{} {
h.mu.Lock()
defer h.mu.Unlock()
flusher, _ := w.(http.Flusher)
done := make(chan struct{})
h.streamingChannels[requestID] = &streamingResponseEntry{
w: w,
flusher: flusher,
createdAt: time.Now(),
done: done,
}
return done
}
// GetResponseChan gets the response channel for a request ID
func (h *ResponseHandler) GetResponseChan(requestID string) <-chan *protocol.HTTPResponse {
h.mu.RLock()
@@ -67,25 +100,165 @@ func (h *ResponseHandler) SendResponse(requestID string, resp *protocol.HTTPResp
h.mu.RUnlock()
if !exists || entry == nil {
h.logger.Warn("Response channel not found",
zap.String("request_id", requestID),
)
return
}
select {
case entry.ch <- resp:
h.logger.Debug("Response sent to channel",
zap.String("request_id", requestID),
)
case <-time.After(5 * time.Second):
h.logger.Warn("Timeout sending response to channel",
case <-time.After(30 * time.Second):
h.logger.Error("Timeout sending response to channel - handler may have abandoned",
zap.String("request_id", requestID),
zap.Int("status_code", resp.StatusCode),
zap.Int("body_size", len(resp.Body)),
)
}
}
// CleanupResponseChan removes and closes a response channel
func (h *ResponseHandler) SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error {
h.mu.RLock()
entry, exists := h.streamingChannels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
return nil
}
entry.mu.Lock()
defer entry.mu.Unlock()
select {
case <-entry.done:
return nil
default:
}
if entry.headersSent {
return nil
}
// Copy headers, removing hop-by-hop headers that were already handled by client
// Client's cleanResponseHeaders already removed Transfer-Encoding, Connection, etc.
// But we need to check again in case they slipped through
hasContentLength := false
for key, values := range head.Headers {
canonicalKey := http.CanonicalHeaderKey(key)
// Skip ALL hop-by-hop headers
if canonicalKey == "Connection" ||
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
canonicalKey == "Upgrade" ||
canonicalKey == "Proxy-Connection" ||
canonicalKey == "Te" ||
canonicalKey == "Trailer" {
continue
}
if canonicalKey == "Content-Length" {
hasContentLength = true
}
for _, value := range values {
entry.w.Header().Add(key, value)
}
}
// For streaming responses, decide how to indicate message length
if head.ContentLength >= 0 && !hasContentLength {
entry.w.Header().Set("Content-Length", fmt.Sprintf("%d", head.ContentLength))
}
statusCode := head.StatusCode
if statusCode == 0 {
statusCode = http.StatusOK
}
entry.w.WriteHeader(statusCode)
entry.headersSent = true
if entry.flusher != nil {
entry.flusher.Flush()
}
return nil
}
func (h *ResponseHandler) SendStreamingChunk(requestID string, chunk []byte, isLast bool) error {
h.mu.RLock()
entry, exists := h.streamingChannels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
return nil
}
entry.mu.Lock()
defer entry.mu.Unlock()
select {
case <-entry.done:
return nil
default:
}
if len(chunk) > 0 {
_, err := entry.w.Write(chunk)
if err != nil {
if isClientDisconnectError(err) {
select {
case <-entry.done:
default:
close(entry.done)
}
return nil
}
select {
case <-entry.done:
default:
close(entry.done)
}
return nil
}
if entry.flusher != nil {
entry.flusher.Flush()
}
}
if isLast {
select {
case <-entry.done:
default:
close(entry.done)
}
}
return nil
}
func isClientDisconnectError(err error) bool {
if err == nil {
return false
}
if netErr, ok := err.(*net.OpError); ok {
if netErr.Err != nil {
errStr := netErr.Err.Error()
if strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "connection refused") {
return true
}
}
}
errStr := err.Error()
return strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "use of closed network connection")
}
func (h *ResponseHandler) CleanupResponseChan(requestID string) {
h.mu.Lock()
defer h.mu.Unlock()
@@ -96,15 +269,26 @@ func (h *ResponseHandler) CleanupResponseChan(requestID string) {
}
}
// GetPendingCount returns the number of pending responses
func (h *ResponseHandler) CleanupStreamingResponse(requestID string) {
h.mu.Lock()
defer h.mu.Unlock()
if entry, exists := h.streamingChannels[requestID]; exists {
select {
case <-entry.done:
default:
close(entry.done)
}
delete(h.streamingChannels, requestID)
}
}
func (h *ResponseHandler) GetPendingCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.channels)
return len(h.channels) + len(h.streamingChannels)
}
// cleanupLoop periodically cleans up expired response channels
// This replaces the per-request goroutine approach with a single cleanup goroutine
func (h *ResponseHandler) cleanupLoop() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
@@ -119,10 +303,10 @@ func (h *ResponseHandler) cleanupLoop() {
}
}
// cleanupExpiredChannels removes channels older than 30 seconds
func (h *ResponseHandler) cleanupExpiredChannels() {
now := time.Now()
timeout := 30 * time.Second
streamingTimeout := 5 * time.Minute
h.mu.Lock()
defer h.mu.Unlock()
@@ -136,24 +320,43 @@ func (h *ResponseHandler) cleanupExpiredChannels() {
}
}
for requestID, entry := range h.streamingChannels {
if now.Sub(entry.createdAt) > streamingTimeout {
select {
case <-entry.done:
default:
close(entry.done)
}
delete(h.streamingChannels, requestID)
expiredCount++
}
}
if expiredCount > 0 {
h.logger.Debug("Cleaned up expired response channels",
zap.Int("count", expiredCount),
zap.Int("remaining", len(h.channels)),
zap.Int("remaining", len(h.channels)+len(h.streamingChannels)),
)
}
}
// Close stops the cleanup loop
func (h *ResponseHandler) Close() {
close(h.stopCh)
h.mu.Lock()
defer h.mu.Unlock()
// Close all remaining channels
for _, entry := range h.channels {
close(entry.ch)
}
h.channels = make(map[string]*responseChanEntry)
for _, entry := range h.streamingChannels {
select {
case <-entry.done:
default:
close(entry.done)
}
}
h.streamingChannels = make(map[string]*streamingResponseEntry)
}

View File

@@ -49,6 +49,9 @@ type HTTPResponseHandler interface {
GetResponseChan(requestID string) <-chan *protocol.HTTPResponse
CleanupResponseChan(requestID string)
SendResponse(requestID string, resp *protocol.HTTPResponse)
// Streaming response methods
SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error
SendStreamingChunk(requestID string, chunk []byte, isLast bool) error
}
// NewConnection creates a new connection handler
@@ -273,6 +276,15 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
c.logger.Debug("Client disconnected abruptly", zap.Error(err))
return nil
}
// Check if it looks like garbage data (not a valid HTTP request)
if strings.Contains(errStr, "malformed HTTP") {
c.logger.Warn("Received malformed HTTP request, possibly due to pipelined requests or protocol error",
zap.Error(err),
zap.String("error_snippet", errStr[:min(len(errStr), 100)]),
)
// Close connection on malformed request to prevent further errors
return nil
}
c.logger.Error("Failed to parse HTTP request", zap.Error(err))
return fmt.Errorf("failed to parse HTTP request: %w", err)
}
@@ -289,9 +301,21 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
header: make(http.Header),
}
// Handle the request
// Handle the request - this blocks until response is complete
c.httpHandler.ServeHTTP(respWriter, req)
// Ensure response is flushed to client
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
// Force flush TCP buffers
tcpConn.SetNoDelay(true)
tcpConn.SetNoDelay(false)
}
c.logger.Debug("HTTP request processing completed",
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
)
// Check if we should close the connection
// Close if: Connection: close header, or HTTP/1.0 without Connection: keep-alive
shouldClose := false
@@ -304,8 +328,13 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
}
}
// Also check if response indicated connection should close
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
shouldClose = true
}
if shouldClose {
c.logger.Debug("Closing connection as requested by client")
c.logger.Debug("Closing connection as requested by client or server")
return nil
}
@@ -313,6 +342,13 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
// handleFrames handles incoming frames
func (c *Connection) handleFrames(reader *bufio.Reader) error {
for {
@@ -439,10 +475,55 @@ func (c *Connection) handleDataFrame(frame *protocol.Frame) {
}
c.responseChans.SendResponse(reqID, httpResp)
case protocol.DataTypeHTTPHead:
// Streaming HTTP response headers
if c.responseChans == nil {
c.logger.Warn("No response handler for streaming HTTP head",
zap.String("stream_id", header.StreamID),
)
return
}
c.logger.Debug("Routed HTTP response to channel",
zap.String("request_id", reqID),
)
httpHead, err := protocol.DecodeHTTPResponseHead(data)
if err != nil {
c.logger.Error("Failed to decode HTTP response head",
zap.String("stream_id", header.StreamID),
zap.Error(err),
)
return
}
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
if err := c.responseChans.SendStreamingHead(reqID, httpHead); err != nil {
c.logger.Error("Failed to send streaming head",
zap.String("request_id", reqID),
zap.Error(err),
)
}
case protocol.DataTypeHTTPBodyChunk:
// Streaming HTTP response body chunk
if c.responseChans == nil {
c.logger.Warn("No response handler for streaming HTTP chunk",
zap.String("stream_id", header.StreamID),
)
return
}
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
if err := c.responseChans.SendStreamingChunk(reqID, data, header.IsLast); err != nil {
c.logger.Error("Failed to send streaming chunk",
zap.String("request_id", reqID),
zap.Error(err),
)
}
case protocol.DataTypeClose:
// Client is closing the stream
if c.proxy != nil {
@@ -487,7 +568,12 @@ func (c *Connection) SendFrame(frame *protocol.Frame) error {
if c.frameWriter == nil {
return protocol.WriteFrame(c.conn, frame)
}
return c.frameWriter.WriteFrame(frame)
if err := c.frameWriter.WriteFrame(frame); err != nil {
return err
}
// Flush immediately to ensure the frame is sent without batching delay
c.frameWriter.Flush()
return nil
}
// sendError sends an error frame to the client

View File

@@ -21,12 +21,19 @@ type TunnelProxy struct {
stopCh chan struct{}
wg sync.WaitGroup
clientAddr string
streams map[string]net.Conn // streamID -> external connection
streams map[string]*proxyStream // streamID -> stream info
streamMu sync.RWMutex
frameWriter *protocol.FrameWriter
bufferPool *pool.BufferPool
}
// proxyStream holds connection info with close state
type proxyStream struct {
conn net.Conn
closed bool
mu sync.Mutex
}
// NewTunnelProxy creates a new TCP tunnel proxy
func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Logger) *TunnelProxy {
return &TunnelProxy{
@@ -36,7 +43,7 @@ func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Lo
logger: logger,
stopCh: make(chan struct{}),
clientAddr: tcpConn.RemoteAddr().String(),
streams: make(map[string]net.Conn),
streams: make(map[string]*proxyStream),
bufferPool: pool.NewBufferPool(),
frameWriter: protocol.NewFrameWriter(tcpConn),
}
@@ -101,8 +108,13 @@ func (p *TunnelProxy) handleConnection(conn net.Conn) {
streamID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), p.port)
stream := &proxyStream{
conn: conn,
closed: false,
}
p.streamMu.Lock()
p.streams[streamID] = conn
p.streams[streamID] = stream
p.streamMu.Unlock()
defer func() {
@@ -117,6 +129,14 @@ func (p *TunnelProxy) handleConnection(conn net.Conn) {
buffer := (*bufPtr)[:pool.SizeMedium]
for {
// Check if stream is closed
stream.mu.Lock()
closed := stream.closed
stream.mu.Unlock()
if closed {
break
}
n, err := conn.Read(buffer)
if err != nil {
break
@@ -124,7 +144,7 @@ func (p *TunnelProxy) handleConnection(conn net.Conn) {
if n > 0 {
if err := p.sendDataToTunnel(streamID, buffer[:n]); err != nil {
p.logger.Error("Send to tunnel failed", zap.Error(err))
p.logger.Debug("Send to tunnel failed", zap.Error(err))
break
}
}
@@ -185,15 +205,24 @@ func (p *TunnelProxy) sendCloseToTunnel(streamID string) {
func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
p.streamMu.RLock()
conn, ok := p.streams[streamID]
stream, ok := p.streams[streamID]
p.streamMu.RUnlock()
if !ok {
return fmt.Errorf("stream not found: %s", streamID)
// Stream may have been closed by client, this is normal
return nil
}
if _, err := conn.Write(data); err != nil {
p.logger.Error("Write to client failed", zap.Error(err))
// Check if stream is closed
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
return nil
}
stream.mu.Unlock()
if _, err := stream.conn.Write(data); err != nil {
p.logger.Debug("Write to client failed", zap.Error(err))
return err
}
@@ -203,12 +232,24 @@ func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
// CloseStream closes a stream
func (p *TunnelProxy) CloseStream(streamID string) {
p.streamMu.RLock()
conn, ok := p.streams[streamID]
stream, ok := p.streams[streamID]
p.streamMu.RUnlock()
if ok {
conn.Close()
if !ok {
return
}
// Mark as closed first
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
return
}
stream.closed = true
stream.mu.Unlock()
// Now close the connection
stream.conn.Close()
}
func (p *TunnelProxy) Stop() {
@@ -224,10 +265,13 @@ func (p *TunnelProxy) Stop() {
}
p.streamMu.Lock()
for _, conn := range p.streams {
conn.Close()
for _, stream := range p.streams {
stream.mu.Lock()
stream.closed = true
stream.mu.Unlock()
stream.conn.Close()
}
p.streams = make(map[string]net.Conn)
p.streams = make(map[string]*proxyStream)
p.streamMu.Unlock()
p.wg.Wait()