mirror of
https://github.com/Gouryella/drip.git
synced 2026-03-04 12:55:53 +00:00
feat: Add HTTP streaming, compression support, and Docker deployment
enhancements - Add adaptive HTTP response handling with automatic streaming for large responses (>1MB) - Implement zero-copy streaming using buffer pools for better performance - Add compression module for reduced bandwidth usage - Add GitHub Container Registry workflow for automated Docker builds - Add production-optimized Dockerfile and docker-compose configuration - Simplify background mode with -d flag and improved daemon management - Update documentation with new command syntax and deployment guides - Clean up unused code and improve error handling - Fix lipgloss style usage (remove unnecessary .Copy() calls)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -12,7 +11,6 @@ import (
|
||||
json "github.com/goccy/go-json"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/utils"
|
||||
@@ -71,17 +69,53 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
limitedReader := io.LimitReader(r.Body, constants.MaxRequestBodySize)
|
||||
body, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
h.logger.Error("Read request body failed", zap.Error(err))
|
||||
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
requestID := utils.GenerateID()
|
||||
|
||||
h.handleAdaptiveRequest(w, r, transport, requestID, subdomain)
|
||||
}
|
||||
|
||||
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
|
||||
const streamingThreshold int64 = 1 * 1024 * 1024
|
||||
|
||||
buffer := make([]byte, 0, streamingThreshold)
|
||||
tempBuf := make([]byte, 32*1024)
|
||||
|
||||
var totalRead int64
|
||||
var hitThreshold bool
|
||||
|
||||
for totalRead < streamingThreshold {
|
||||
n, err := r.Body.Read(tempBuf)
|
||||
if n > 0 {
|
||||
buffer = append(buffer, tempBuf[:n]...)
|
||||
totalRead += int64(n)
|
||||
}
|
||||
if err == io.EOF {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
r.Body.Close()
|
||||
h.logger.Error("Read request body failed", zap.Error(err))
|
||||
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if totalRead >= streamingThreshold {
|
||||
hitThreshold = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hitThreshold {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
|
||||
return
|
||||
}
|
||||
|
||||
h.streamLargeRequest(w, r, transport, requestID, subdomain, buffer)
|
||||
}
|
||||
|
||||
func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, body []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
@@ -93,7 +127,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
reqBytes, err := protocol.EncodeHTTPRequest(&httpReq)
|
||||
|
||||
h.headerPool.Put(headers)
|
||||
|
||||
if err != nil {
|
||||
@@ -119,7 +152,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
|
||||
respChan := h.responses.CreateResponseChan(requestID)
|
||||
defer h.responses.CleanupResponseChan(requestID)
|
||||
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
|
||||
defer func() {
|
||||
h.responses.CleanupResponseChan(requestID)
|
||||
h.responses.CleanupStreamingResponse(requestID)
|
||||
}()
|
||||
|
||||
if err := transport.SendFrame(frame); err != nil {
|
||||
h.logger.Error("Send frame to tunnel failed", zap.Error(err))
|
||||
@@ -127,14 +164,220 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), constants.RequestTimeout)
|
||||
defer cancel()
|
||||
select {
|
||||
case respMsg := <-respChan:
|
||||
if respMsg == nil {
|
||||
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("url", r.URL.String()),
|
||||
)
|
||||
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, bufferedData []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
httpReqHead := protocol.HTTPRequestHead{
|
||||
Method: r.Method,
|
||||
URL: r.URL.String(),
|
||||
Headers: headers,
|
||||
ContentLength: r.ContentLength,
|
||||
}
|
||||
|
||||
headBytes, err := protocol.EncodeHTTPRequestHead(&httpReqHead)
|
||||
h.headerPool.Put(headers)
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("Encode HTTP request head failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
headHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPHead, // shared streaming head type
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode head payload failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
|
||||
|
||||
respChan := h.responses.CreateResponseChan(requestID)
|
||||
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
|
||||
defer func() {
|
||||
h.responses.CleanupResponseChan(requestID)
|
||||
h.responses.CleanupStreamingResponse(requestID)
|
||||
}()
|
||||
|
||||
if err := transport.SendFrame(headFrame); err != nil {
|
||||
h.logger.Error("Send head frame failed", zap.Error(err))
|
||||
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
if len(bufferedData) > 0 {
|
||||
chunkHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, bufferedData)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode buffered chunk failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
||||
if err := transport.SendFrame(chunkFrame); err != nil {
|
||||
h.logger.Error("Send buffered chunk failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
n, readErr := r.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
isLast := readErr == io.EOF
|
||||
|
||||
chunkHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
|
||||
IsLast: isLast,
|
||||
}
|
||||
|
||||
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer[:n])
|
||||
if err != nil {
|
||||
h.logger.Error("Encode chunk payload failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
||||
if err := transport.SendFrame(chunkFrame); err != nil {
|
||||
h.logger.Error("Send chunk frame failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if readErr == io.EOF {
|
||||
if n == 0 {
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if err == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if readErr != nil {
|
||||
h.logger.Error("Read request body failed", zap.Error(readErr))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if err == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
r.Body.Close()
|
||||
|
||||
select {
|
||||
case respMsg := <-respChan:
|
||||
if respMsg == nil {
|
||||
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
|
||||
case <-ctx.Done():
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Streaming request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("url", r.URL.String()),
|
||||
)
|
||||
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
|
||||
}
|
||||
}
|
||||
@@ -145,12 +388,23 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
|
||||
return
|
||||
}
|
||||
|
||||
// For buffered responses, we have the complete body, so we can set Content-Length
|
||||
// Skip ALL hop-by-hop headers - client should have already cleaned them
|
||||
for key, values := range resp.Headers {
|
||||
if key == "Connection" || key == "Keep-Alive" || key == "Transfer-Encoding" || key == "Upgrade" {
|
||||
canonicalKey := http.CanonicalHeaderKey(key)
|
||||
|
||||
// Skip hop-by-hop headers completely using canonical key comparison
|
||||
if canonicalKey == "Connection" ||
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
continue
|
||||
}
|
||||
|
||||
if key == "Location" && len(values) > 0 {
|
||||
if canonicalKey == "Location" && len(values) > 0 {
|
||||
rewrittenLocation := h.rewriteLocationHeader(values[0], r.Host)
|
||||
w.Header().Set("Location", rewrittenLocation)
|
||||
continue
|
||||
@@ -161,9 +415,8 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
|
||||
}
|
||||
}
|
||||
|
||||
if w.Header().Get("Content-Length") == "" && len(resp.Body) > 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
|
||||
}
|
||||
// For buffered mode, always set Content-Length with the actual body size
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
|
||||
|
||||
statusCode := resp.StatusCode
|
||||
if statusCode == 0 {
|
||||
@@ -171,6 +424,7 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
|
||||
}
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
if len(resp.Body) > 0 {
|
||||
w.Write(resp.Body)
|
||||
}
|
||||
@@ -284,19 +538,6 @@ func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(health)
|
||||
}
|
||||
|
||||
func cloneHeadersWithHost(src http.Header, host string) http.Header {
|
||||
dst := make(http.Header, len(src)+1)
|
||||
for k, v := range src {
|
||||
copied := make([]string, len(v))
|
||||
copy(copied, v)
|
||||
dst[k] = copied
|
||||
}
|
||||
if host != "" {
|
||||
dst.Set("Host", host)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
if h.authToken != "" {
|
||||
token := r.URL.Query().Get("token")
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -14,20 +18,32 @@ type responseChanEntry struct {
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
// streamingResponseEntry holds a streaming response writer
|
||||
type streamingResponseEntry struct {
|
||||
w http.ResponseWriter
|
||||
flusher http.Flusher
|
||||
createdAt time.Time
|
||||
headersSent bool
|
||||
done chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// ResponseHandler manages response channels for HTTP requests over TCP/Frame protocol
|
||||
type ResponseHandler struct {
|
||||
channels map[string]*responseChanEntry
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
channels map[string]*responseChanEntry
|
||||
streamingChannels map[string]*streamingResponseEntry
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewResponseHandler creates a new response handler
|
||||
func NewResponseHandler(logger *zap.Logger) *ResponseHandler {
|
||||
h := &ResponseHandler{
|
||||
channels: make(map[string]*responseChanEntry),
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
channels: make(map[string]*responseChanEntry),
|
||||
streamingChannels: make(map[string]*streamingResponseEntry),
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start single cleanup goroutine instead of one per request
|
||||
@@ -50,6 +66,23 @@ func (h *ResponseHandler) CreateResponseChan(requestID string) chan *protocol.HT
|
||||
return ch
|
||||
}
|
||||
|
||||
// CreateStreamingResponse creates a streaming response entry for a request ID
|
||||
func (h *ResponseHandler) CreateStreamingResponse(requestID string, w http.ResponseWriter) chan struct{} {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
flusher, _ := w.(http.Flusher)
|
||||
done := make(chan struct{})
|
||||
h.streamingChannels[requestID] = &streamingResponseEntry{
|
||||
w: w,
|
||||
flusher: flusher,
|
||||
createdAt: time.Now(),
|
||||
done: done,
|
||||
}
|
||||
|
||||
return done
|
||||
}
|
||||
|
||||
// GetResponseChan gets the response channel for a request ID
|
||||
func (h *ResponseHandler) GetResponseChan(requestID string) <-chan *protocol.HTTPResponse {
|
||||
h.mu.RLock()
|
||||
@@ -67,25 +100,165 @@ func (h *ResponseHandler) SendResponse(requestID string, resp *protocol.HTTPResp
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists || entry == nil {
|
||||
h.logger.Warn("Response channel not found",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case entry.ch <- resp:
|
||||
h.logger.Debug("Response sent to channel",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
case <-time.After(5 * time.Second):
|
||||
h.logger.Warn("Timeout sending response to channel",
|
||||
case <-time.After(30 * time.Second):
|
||||
h.logger.Error("Timeout sending response to channel - handler may have abandoned",
|
||||
zap.String("request_id", requestID),
|
||||
zap.Int("status_code", resp.StatusCode),
|
||||
zap.Int("body_size", len(resp.Body)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupResponseChan removes and closes a response channel
|
||||
func (h *ResponseHandler) SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error {
|
||||
h.mu.RLock()
|
||||
entry, exists := h.streamingChannels[requestID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists || entry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.mu.Lock()
|
||||
defer entry.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-entry.done:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
if entry.headersSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy headers, removing hop-by-hop headers that were already handled by client
|
||||
// Client's cleanResponseHeaders already removed Transfer-Encoding, Connection, etc.
|
||||
// But we need to check again in case they slipped through
|
||||
hasContentLength := false
|
||||
|
||||
for key, values := range head.Headers {
|
||||
canonicalKey := http.CanonicalHeaderKey(key)
|
||||
|
||||
// Skip ALL hop-by-hop headers
|
||||
if canonicalKey == "Connection" ||
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
continue
|
||||
}
|
||||
|
||||
if canonicalKey == "Content-Length" {
|
||||
hasContentLength = true
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
entry.w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// For streaming responses, decide how to indicate message length
|
||||
if head.ContentLength >= 0 && !hasContentLength {
|
||||
entry.w.Header().Set("Content-Length", fmt.Sprintf("%d", head.ContentLength))
|
||||
}
|
||||
|
||||
statusCode := head.StatusCode
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusOK
|
||||
}
|
||||
|
||||
entry.w.WriteHeader(statusCode)
|
||||
entry.headersSent = true
|
||||
|
||||
if entry.flusher != nil {
|
||||
entry.flusher.Flush()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) SendStreamingChunk(requestID string, chunk []byte, isLast bool) error {
|
||||
h.mu.RLock()
|
||||
entry, exists := h.streamingChannels[requestID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists || entry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.mu.Lock()
|
||||
defer entry.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-entry.done:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
if len(chunk) > 0 {
|
||||
_, err := entry.w.Write(chunk)
|
||||
if err != nil {
|
||||
if isClientDisconnectError(err) {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if entry.flusher != nil {
|
||||
entry.flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if isLast {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isClientDisconnectError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if netErr, ok := err.(*net.OpError); ok {
|
||||
if netErr.Err != nil {
|
||||
errStr := netErr.Err.Error()
|
||||
if strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection reset") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection reset") ||
|
||||
strings.Contains(errStr, "use of closed network connection")
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) CleanupResponseChan(requestID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
@@ -96,15 +269,26 @@ func (h *ResponseHandler) CleanupResponseChan(requestID string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetPendingCount returns the number of pending responses
|
||||
func (h *ResponseHandler) CleanupStreamingResponse(requestID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if entry, exists := h.streamingChannels[requestID]; exists {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
delete(h.streamingChannels, requestID)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) GetPendingCount() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return len(h.channels)
|
||||
return len(h.channels) + len(h.streamingChannels)
|
||||
}
|
||||
|
||||
// cleanupLoop periodically cleans up expired response channels
|
||||
// This replaces the per-request goroutine approach with a single cleanup goroutine
|
||||
func (h *ResponseHandler) cleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -119,10 +303,10 @@ func (h *ResponseHandler) cleanupLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpiredChannels removes channels older than 30 seconds
|
||||
func (h *ResponseHandler) cleanupExpiredChannels() {
|
||||
now := time.Now()
|
||||
timeout := 30 * time.Second
|
||||
streamingTimeout := 5 * time.Minute
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
@@ -136,24 +320,43 @@ func (h *ResponseHandler) cleanupExpiredChannels() {
|
||||
}
|
||||
}
|
||||
|
||||
for requestID, entry := range h.streamingChannels {
|
||||
if now.Sub(entry.createdAt) > streamingTimeout {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
delete(h.streamingChannels, requestID)
|
||||
expiredCount++
|
||||
}
|
||||
}
|
||||
|
||||
if expiredCount > 0 {
|
||||
h.logger.Debug("Cleaned up expired response channels",
|
||||
zap.Int("count", expiredCount),
|
||||
zap.Int("remaining", len(h.channels)),
|
||||
zap.Int("remaining", len(h.channels)+len(h.streamingChannels)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup loop
|
||||
func (h *ResponseHandler) Close() {
|
||||
close(h.stopCh)
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Close all remaining channels
|
||||
for _, entry := range h.channels {
|
||||
close(entry.ch)
|
||||
}
|
||||
h.channels = make(map[string]*responseChanEntry)
|
||||
|
||||
for _, entry := range h.streamingChannels {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
}
|
||||
h.streamingChannels = make(map[string]*streamingResponseEntry)
|
||||
}
|
||||
|
||||
@@ -49,6 +49,9 @@ type HTTPResponseHandler interface {
|
||||
GetResponseChan(requestID string) <-chan *protocol.HTTPResponse
|
||||
CleanupResponseChan(requestID string)
|
||||
SendResponse(requestID string, resp *protocol.HTTPResponse)
|
||||
// Streaming response methods
|
||||
SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error
|
||||
SendStreamingChunk(requestID string, chunk []byte, isLast bool) error
|
||||
}
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
@@ -273,6 +276,15 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
c.logger.Debug("Client disconnected abruptly", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
// Check if it looks like garbage data (not a valid HTTP request)
|
||||
if strings.Contains(errStr, "malformed HTTP") {
|
||||
c.logger.Warn("Received malformed HTTP request, possibly due to pipelined requests or protocol error",
|
||||
zap.Error(err),
|
||||
zap.String("error_snippet", errStr[:min(len(errStr), 100)]),
|
||||
)
|
||||
// Close connection on malformed request to prevent further errors
|
||||
return nil
|
||||
}
|
||||
c.logger.Error("Failed to parse HTTP request", zap.Error(err))
|
||||
return fmt.Errorf("failed to parse HTTP request: %w", err)
|
||||
}
|
||||
@@ -289,9 +301,21 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
header: make(http.Header),
|
||||
}
|
||||
|
||||
// Handle the request
|
||||
// Handle the request - this blocks until response is complete
|
||||
c.httpHandler.ServeHTTP(respWriter, req)
|
||||
|
||||
// Ensure response is flushed to client
|
||||
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
|
||||
// Force flush TCP buffers
|
||||
tcpConn.SetNoDelay(true)
|
||||
tcpConn.SetNoDelay(false)
|
||||
}
|
||||
|
||||
c.logger.Debug("HTTP request processing completed",
|
||||
zap.String("method", req.Method),
|
||||
zap.String("url", req.URL.String()),
|
||||
)
|
||||
|
||||
// Check if we should close the connection
|
||||
// Close if: Connection: close header, or HTTP/1.0 without Connection: keep-alive
|
||||
shouldClose := false
|
||||
@@ -304,8 +328,13 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Also check if response indicated connection should close
|
||||
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
|
||||
shouldClose = true
|
||||
}
|
||||
|
||||
if shouldClose {
|
||||
c.logger.Debug("Closing connection as requested by client")
|
||||
c.logger.Debug("Closing connection as requested by client or server")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -313,6 +342,13 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
}
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// handleFrames handles incoming frames
|
||||
func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
for {
|
||||
@@ -439,10 +475,55 @@ func (c *Connection) handleDataFrame(frame *protocol.Frame) {
|
||||
}
|
||||
|
||||
c.responseChans.SendResponse(reqID, httpResp)
|
||||
case protocol.DataTypeHTTPHead:
|
||||
// Streaming HTTP response headers
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response handler for streaming HTTP head",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Debug("Routed HTTP response to channel",
|
||||
zap.String("request_id", reqID),
|
||||
)
|
||||
httpHead, err := protocol.DecodeHTTPResponseHead(data)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode HTTP response head",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
if err := c.responseChans.SendStreamingHead(reqID, httpHead); err != nil {
|
||||
c.logger.Error("Failed to send streaming head",
|
||||
zap.String("request_id", reqID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
case protocol.DataTypeHTTPBodyChunk:
|
||||
// Streaming HTTP response body chunk
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response handler for streaming HTTP chunk",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
if err := c.responseChans.SendStreamingChunk(reqID, data, header.IsLast); err != nil {
|
||||
c.logger.Error("Failed to send streaming chunk",
|
||||
zap.String("request_id", reqID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
case protocol.DataTypeClose:
|
||||
// Client is closing the stream
|
||||
if c.proxy != nil {
|
||||
@@ -487,7 +568,12 @@ func (c *Connection) SendFrame(frame *protocol.Frame) error {
|
||||
if c.frameWriter == nil {
|
||||
return protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
if err := c.frameWriter.WriteFrame(frame); err != nil {
|
||||
return err
|
||||
}
|
||||
// Flush immediately to ensure the frame is sent without batching delay
|
||||
c.frameWriter.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendError sends an error frame to the client
|
||||
|
||||
@@ -21,12 +21,19 @@ type TunnelProxy struct {
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
clientAddr string
|
||||
streams map[string]net.Conn // streamID -> external connection
|
||||
streams map[string]*proxyStream // streamID -> stream info
|
||||
streamMu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
bufferPool *pool.BufferPool
|
||||
}
|
||||
|
||||
// proxyStream holds connection info with close state
|
||||
type proxyStream struct {
|
||||
conn net.Conn
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewTunnelProxy creates a new TCP tunnel proxy
|
||||
func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Logger) *TunnelProxy {
|
||||
return &TunnelProxy{
|
||||
@@ -36,7 +43,7 @@ func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Lo
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
clientAddr: tcpConn.RemoteAddr().String(),
|
||||
streams: make(map[string]net.Conn),
|
||||
streams: make(map[string]*proxyStream),
|
||||
bufferPool: pool.NewBufferPool(),
|
||||
frameWriter: protocol.NewFrameWriter(tcpConn),
|
||||
}
|
||||
@@ -101,8 +108,13 @@ func (p *TunnelProxy) handleConnection(conn net.Conn) {
|
||||
|
||||
streamID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), p.port)
|
||||
|
||||
stream := &proxyStream{
|
||||
conn: conn,
|
||||
closed: false,
|
||||
}
|
||||
|
||||
p.streamMu.Lock()
|
||||
p.streams[streamID] = conn
|
||||
p.streams[streamID] = stream
|
||||
p.streamMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
@@ -117,6 +129,14 @@ func (p *TunnelProxy) handleConnection(conn net.Conn) {
|
||||
buffer := (*bufPtr)[:pool.SizeMedium]
|
||||
|
||||
for {
|
||||
// Check if stream is closed
|
||||
stream.mu.Lock()
|
||||
closed := stream.closed
|
||||
stream.mu.Unlock()
|
||||
if closed {
|
||||
break
|
||||
}
|
||||
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
break
|
||||
@@ -124,7 +144,7 @@ func (p *TunnelProxy) handleConnection(conn net.Conn) {
|
||||
|
||||
if n > 0 {
|
||||
if err := p.sendDataToTunnel(streamID, buffer[:n]); err != nil {
|
||||
p.logger.Error("Send to tunnel failed", zap.Error(err))
|
||||
p.logger.Debug("Send to tunnel failed", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -185,15 +205,24 @@ func (p *TunnelProxy) sendCloseToTunnel(streamID string) {
|
||||
|
||||
func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
|
||||
p.streamMu.RLock()
|
||||
conn, ok := p.streams[streamID]
|
||||
stream, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("stream not found: %s", streamID)
|
||||
// Stream may have been closed by client, this is normal
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
p.logger.Error("Write to client failed", zap.Error(err))
|
||||
// Check if stream is closed
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
stream.mu.Unlock()
|
||||
|
||||
if _, err := stream.conn.Write(data); err != nil {
|
||||
p.logger.Debug("Write to client failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -203,12 +232,24 @@ func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
|
||||
// CloseStream closes a stream
|
||||
func (p *TunnelProxy) CloseStream(streamID string) {
|
||||
p.streamMu.RLock()
|
||||
conn, ok := p.streams[streamID]
|
||||
stream, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
conn.Close()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as closed first
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
return
|
||||
}
|
||||
stream.closed = true
|
||||
stream.mu.Unlock()
|
||||
|
||||
// Now close the connection
|
||||
stream.conn.Close()
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) Stop() {
|
||||
@@ -224,10 +265,13 @@ func (p *TunnelProxy) Stop() {
|
||||
}
|
||||
|
||||
p.streamMu.Lock()
|
||||
for _, conn := range p.streams {
|
||||
conn.Close()
|
||||
for _, stream := range p.streams {
|
||||
stream.mu.Lock()
|
||||
stream.closed = true
|
||||
stream.mu.Unlock()
|
||||
stream.conn.Close()
|
||||
}
|
||||
p.streams = make(map[string]net.Conn)
|
||||
p.streams = make(map[string]*proxyStream)
|
||||
p.streamMu.Unlock()
|
||||
|
||||
p.wg.Wait()
|
||||
|
||||
Reference in New Issue
Block a user