mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-24 05:10:43 +00:00
enhancements - Add adaptive HTTP response handling with automatic streaming for large responses (>1MB) - Implement zero-copy streaming using buffer pools for better performance - Add compression module for reduced bandwidth usage - Add GitHub Container Registry workflow for automated Docker builds - Add production-optimized Dockerfile and docker-compose configuration - Simplify background mode with -d flag and improved daemon management - Update documentation with new command syntax and deployment guides - Clean up unused code and improve error handling - Fix lipgloss style usage (remove unnecessary .Copy() calls)
1079 lines
27 KiB
Go
1079 lines
27 KiB
Go
package tcp
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"drip/internal/shared/pool"
|
|
"drip/internal/shared/protocol"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// FrameHandler handles data frames and forwards to local service
|
|
type FrameHandler struct {
|
|
conn net.Conn
|
|
frameWriter *protocol.FrameWriter
|
|
localHost string
|
|
localPort int
|
|
logger *zap.Logger
|
|
streams map[string]*Stream
|
|
streamMu sync.RWMutex
|
|
streamingRequests map[string]*StreamingRequest
|
|
streamingReqMu sync.RWMutex
|
|
tunnelType protocol.TunnelType
|
|
httpClient *http.Client
|
|
stats *TrafficStats
|
|
isClosedCheck func() bool
|
|
bufferPool *pool.BufferPool
|
|
headerPool *pool.HeaderPool
|
|
}
|
|
|
|
// Stream represents a single request/response stream
|
|
type Stream struct {
|
|
ID string
|
|
LocalConn net.Conn
|
|
ResponseCh chan []byte
|
|
Done chan struct{}
|
|
closed bool
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// StreamingRequest represents a streaming upload request in progress
|
|
type StreamingRequest struct {
|
|
RequestID string
|
|
Writer *io.PipeWriter
|
|
Done chan struct{}
|
|
chunkQueue chan *chunkData
|
|
closed bool
|
|
mu sync.Mutex
|
|
}
|
|
|
|
type chunkData struct {
|
|
data []byte
|
|
isLast bool
|
|
}
|
|
|
|
func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost string, localPort int, tunnelType protocol.TunnelType, logger *zap.Logger, isClosedCheck func() bool, bufferPool *pool.BufferPool) *FrameHandler {
|
|
var tlsConfig *tls.Config
|
|
if tunnelType == protocol.TunnelTypeHTTPS {
|
|
tlsConfig = &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
}
|
|
}
|
|
|
|
return &FrameHandler{
|
|
conn: conn,
|
|
frameWriter: frameWriter,
|
|
localHost: localHost,
|
|
localPort: localPort,
|
|
logger: logger,
|
|
streams: make(map[string]*Stream),
|
|
streamingRequests: make(map[string]*StreamingRequest),
|
|
tunnelType: tunnelType,
|
|
stats: NewTrafficStats(),
|
|
isClosedCheck: isClosedCheck,
|
|
bufferPool: bufferPool,
|
|
headerPool: pool.NewHeaderPool(),
|
|
httpClient: &http.Client{
|
|
// No overall timeout - streaming responses can take arbitrary time
|
|
Transport: &http.Transport{
|
|
MaxIdleConns: 1000,
|
|
MaxIdleConnsPerHost: 500,
|
|
MaxConnsPerHost: 0,
|
|
IdleConnTimeout: 180 * time.Second,
|
|
DisableCompression: true,
|
|
DisableKeepAlives: false,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
TLSClientConfig: tlsConfig,
|
|
ResponseHeaderTimeout: 30 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
DialContext: (&net.Dialer{
|
|
Timeout: 5 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}).DialContext,
|
|
},
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (h *FrameHandler) HandleDataFrame(frame *protocol.Frame) error {
|
|
h.stats.AddBytesIn(int64(len(frame.Payload)))
|
|
h.stats.AddRequest()
|
|
|
|
header, data, err := protocol.DecodeDataPayload(frame.Payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decode data payload: %w", err)
|
|
}
|
|
|
|
if header.Type == protocol.DataTypeHTTPRequest {
|
|
return h.handleHTTPFrame(header, data)
|
|
}
|
|
|
|
if header.Type == protocol.DataTypeHTTPRequestHead || header.Type == protocol.DataTypeHTTPHead {
|
|
return h.handleHTTPRequestHead(header, data)
|
|
}
|
|
|
|
if header.Type == protocol.DataTypeHTTPRequestBodyChunk || header.Type == protocol.DataTypeHTTPBodyChunk {
|
|
return h.handleHTTPRequestBodyChunk(header, data)
|
|
}
|
|
|
|
if header.Type == protocol.DataTypeClose {
|
|
h.closeStream(header.StreamID)
|
|
return nil
|
|
}
|
|
|
|
stream, err := h.getOrCreateStream(header.StreamID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get stream: %w", err)
|
|
}
|
|
|
|
h.forwardToLocal(stream, data)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *FrameHandler) getOrCreateStream(streamID string) (*Stream, error) {
|
|
h.streamMu.Lock()
|
|
defer h.streamMu.Unlock()
|
|
|
|
if stream, ok := h.streams[streamID]; ok {
|
|
return stream, nil
|
|
}
|
|
|
|
localAddr := net.JoinHostPort(h.localHost, fmt.Sprintf("%d", h.localPort))
|
|
localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to local service: %w", err)
|
|
}
|
|
|
|
stream := &Stream{
|
|
ID: streamID,
|
|
LocalConn: localConn,
|
|
ResponseCh: make(chan []byte, 10),
|
|
Done: make(chan struct{}),
|
|
}
|
|
|
|
h.streams[streamID] = stream
|
|
|
|
go h.handleLocalResponse(stream)
|
|
|
|
return stream, nil
|
|
}
|
|
|
|
func (h *FrameHandler) forwardToLocal(stream *Stream, data []byte) {
|
|
// Check if stream is closed using mutex
|
|
stream.mu.Lock()
|
|
if stream.closed {
|
|
stream.mu.Unlock()
|
|
return
|
|
}
|
|
stream.mu.Unlock()
|
|
|
|
// Double check with Done channel
|
|
select {
|
|
case <-stream.Done:
|
|
// Stream already closed, ignore data
|
|
return
|
|
default:
|
|
}
|
|
|
|
if _, err := stream.LocalConn.Write(data); err != nil {
|
|
// Only log at debug level since connection close is often expected
|
|
h.logger.Debug("Failed to write to local service",
|
|
zap.String("stream_id", stream.ID),
|
|
zap.Error(err),
|
|
)
|
|
h.closeStream(stream.ID)
|
|
}
|
|
}
|
|
|
|
func (h *FrameHandler) handleLocalResponse(stream *Stream) {
|
|
defer h.closeStream(stream.ID)
|
|
|
|
bufPtr := h.bufferPool.Get(pool.SizeMedium)
|
|
defer h.bufferPool.Put(bufPtr)
|
|
buf := (*bufPtr)[:pool.SizeMedium]
|
|
|
|
for {
|
|
// Check if stream is closed before reading
|
|
stream.mu.Lock()
|
|
closed := stream.closed
|
|
stream.mu.Unlock()
|
|
if closed {
|
|
break
|
|
}
|
|
|
|
n, err := stream.LocalConn.Read(buf)
|
|
if err != nil {
|
|
break
|
|
}
|
|
|
|
if n > 0 {
|
|
if h.isClosedCheck != nil && h.isClosedCheck() {
|
|
break
|
|
}
|
|
|
|
// Check again after read
|
|
stream.mu.Lock()
|
|
closed = stream.closed
|
|
stream.mu.Unlock()
|
|
if closed {
|
|
break
|
|
}
|
|
|
|
header := protocol.DataHeader{
|
|
StreamID: stream.ID,
|
|
Type: protocol.DataTypeResponse,
|
|
IsLast: false,
|
|
}
|
|
|
|
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, buf[:n])
|
|
if err != nil {
|
|
h.logger.Debug("Encode payload failed", zap.Error(err))
|
|
break
|
|
}
|
|
|
|
dataFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
|
err = h.frameWriter.WriteFrame(dataFrame)
|
|
if err != nil {
|
|
h.logger.Debug("Send frame failed", zap.Error(err))
|
|
break
|
|
}
|
|
|
|
h.stats.AddBytesOut(int64(len(payload)))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *FrameHandler) handleHTTPFrame(header protocol.DataHeader, payload []byte) error {
|
|
if h.tunnelType != protocol.TunnelTypeHTTP && h.tunnelType != protocol.TunnelTypeHTTPS {
|
|
return nil
|
|
}
|
|
|
|
httpReq, err := protocol.DecodeHTTPRequest(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decode HTTP request: %w", err)
|
|
}
|
|
|
|
targetURL := httpReq.URL
|
|
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
|
|
scheme := "http"
|
|
if h.tunnelType == protocol.TunnelTypeHTTPS {
|
|
scheme = "https"
|
|
}
|
|
targetURL = fmt.Sprintf("%s://%s:%d%s", scheme, h.localHost, h.localPort, targetURL)
|
|
}
|
|
|
|
req, err := http.NewRequest(httpReq.Method, targetURL, bytes.NewReader(httpReq.Body))
|
|
if err != nil {
|
|
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("build request: %v", err))
|
|
}
|
|
|
|
origHost := ""
|
|
for key, values := range httpReq.Headers {
|
|
for _, value := range values {
|
|
req.Header.Add(key, value)
|
|
}
|
|
}
|
|
if host := req.Header.Get("Host"); host != "" {
|
|
origHost = host
|
|
}
|
|
|
|
isLocalTarget := h.isLocalAddress(h.localHost)
|
|
|
|
if isLocalTarget {
|
|
if origHost != "" {
|
|
req.Host = origHost
|
|
req.Header.Set("Host", origHost)
|
|
} else {
|
|
localHostPort := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
|
|
req.Host = localHostPort
|
|
req.Header.Set("Host", localHostPort)
|
|
}
|
|
if origHost != "" {
|
|
req.Header.Set("X-Forwarded-Host", origHost)
|
|
}
|
|
} else {
|
|
targetHost := h.localHost
|
|
if h.localPort != 443 && h.localPort != 80 {
|
|
targetHost = fmt.Sprintf("%s:%d", h.localHost, h.localPort)
|
|
}
|
|
req.Host = targetHost
|
|
req.Header.Set("Host", targetHost)
|
|
if origHost != "" {
|
|
req.Header.Set("X-Forwarded-Host", origHost)
|
|
}
|
|
}
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
|
|
resp, err := h.httpClient.Do(req)
|
|
if err != nil {
|
|
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("local request failed: %v", err))
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Threshold for switching from buffered to streaming mode
|
|
const bufferThreshold int64 = 1 * 1024 * 1024 // 1MB
|
|
|
|
// If Content-Length is known and large, use streaming directly
|
|
if resp.ContentLength > bufferThreshold {
|
|
return h.streamHTTPResponse(header.StreamID, header.RequestID, resp)
|
|
}
|
|
|
|
// For small or unknown size: try buffered first, switch to streaming if too large
|
|
return h.adaptiveHTTPResponse(header.StreamID, header.RequestID, resp, bufferThreshold)
|
|
}
|
|
|
|
// adaptiveHTTPResponse tries buffered mode first, switches to streaming if data exceeds threshold
|
|
func (h *FrameHandler) adaptiveHTTPResponse(streamID, requestID string, resp *http.Response, threshold int64) error {
|
|
if h.isClosedCheck != nil && h.isClosedCheck() {
|
|
return nil
|
|
}
|
|
|
|
// Buffer for initial read
|
|
buffer := make([]byte, 0, threshold)
|
|
tempBuf := make([]byte, 32*1024) // 32KB read chunks
|
|
|
|
var totalRead int64
|
|
var hitThreshold bool
|
|
|
|
// Try to read up to threshold
|
|
for totalRead < threshold {
|
|
n, err := resp.Body.Read(tempBuf)
|
|
if n > 0 {
|
|
buffer = append(buffer, tempBuf[:n]...)
|
|
totalRead += int64(n)
|
|
}
|
|
if err == io.EOF {
|
|
// Response completed within threshold - use buffered mode
|
|
break
|
|
}
|
|
if err != nil {
|
|
return h.sendHTTPError(streamID, requestID, http.StatusBadGateway, fmt.Sprintf("read response: %v", err))
|
|
}
|
|
if totalRead >= threshold {
|
|
hitThreshold = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !hitThreshold {
|
|
// Small response - send as buffered
|
|
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
|
|
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
|
|
|
|
httpResp := protocol.HTTPResponse{
|
|
StatusCode: resp.StatusCode,
|
|
Status: resp.Status,
|
|
Headers: cleanedHeaders,
|
|
Body: buffer,
|
|
}
|
|
return h.sendHTTPResponse(streamID, requestID, &httpResp)
|
|
}
|
|
|
|
// Large response - switch to streaming mode
|
|
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
|
|
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
|
|
|
|
// First send headers
|
|
httpHead := protocol.HTTPResponseHead{
|
|
StatusCode: resp.StatusCode,
|
|
Status: resp.Status,
|
|
Headers: cleanedHeaders,
|
|
ContentLength: resp.ContentLength, // -1 if unknown
|
|
}
|
|
|
|
headBytes, err := protocol.EncodeHTTPResponseHead(&httpHead)
|
|
if err != nil {
|
|
return fmt.Errorf("encode http head: %w", err)
|
|
}
|
|
|
|
headHeader := protocol.DataHeader{
|
|
StreamID: streamID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPHead,
|
|
IsLast: false,
|
|
}
|
|
|
|
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
|
|
if err != nil {
|
|
return fmt.Errorf("encode head payload: %w", err)
|
|
}
|
|
|
|
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
|
|
if err := h.frameWriter.WriteFrame(headFrame); err != nil {
|
|
return err
|
|
}
|
|
h.frameWriter.Flush()
|
|
h.stats.AddBytesOut(int64(len(headPayload)))
|
|
|
|
// Send buffered data as first chunk
|
|
if len(buffer) > 0 {
|
|
chunkHeader := protocol.DataHeader{
|
|
StreamID: streamID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPBodyChunk,
|
|
IsLast: false,
|
|
}
|
|
|
|
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer)
|
|
if err != nil {
|
|
return fmt.Errorf("encode chunk payload: %w", err)
|
|
}
|
|
|
|
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
|
if err := h.frameWriter.WriteFrame(chunkFrame); err != nil {
|
|
return err
|
|
}
|
|
h.stats.AddBytesOut(int64(len(chunkPayload)))
|
|
}
|
|
|
|
// Clear buffer to free memory
|
|
buffer = nil
|
|
|
|
// Continue streaming remaining data
|
|
bufPtr := h.bufferPool.Get(pool.SizeMedium)
|
|
defer h.bufferPool.Put(bufPtr)
|
|
buf := (*bufPtr)[:pool.SizeMedium]
|
|
|
|
for {
|
|
if h.isClosedCheck != nil && h.isClosedCheck() {
|
|
return nil
|
|
}
|
|
|
|
n, readErr := resp.Body.Read(buf)
|
|
if n > 0 {
|
|
isLast := readErr == io.EOF
|
|
|
|
chunkHeader := protocol.DataHeader{
|
|
StreamID: streamID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPBodyChunk,
|
|
IsLast: isLast,
|
|
}
|
|
|
|
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buf[:n])
|
|
if err != nil {
|
|
return fmt.Errorf("encode chunk payload: %w", err)
|
|
}
|
|
|
|
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
|
if err := h.frameWriter.WriteFrame(chunkFrame); err != nil {
|
|
return err
|
|
}
|
|
h.stats.AddBytesOut(int64(len(chunkPayload)))
|
|
}
|
|
|
|
if readErr == io.EOF {
|
|
if n == 0 {
|
|
// Send final empty chunk
|
|
finalHeader := protocol.DataHeader{
|
|
StreamID: streamID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPBodyChunk,
|
|
IsLast: true,
|
|
}
|
|
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("encode final payload: %w", err)
|
|
}
|
|
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
|
if err := h.frameWriter.WriteFrame(finalFrame); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
h.frameWriter.Flush()
|
|
break
|
|
}
|
|
if readErr != nil {
|
|
return fmt.Errorf("read response body: %w", readErr)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *FrameHandler) sendHTTPError(streamID, requestID string, status int, message string) error {
|
|
headers := h.headerPool.Get()
|
|
headers.Set("Content-Type", "text/plain")
|
|
|
|
httpResp := protocol.HTTPResponse{
|
|
StatusCode: status,
|
|
Status: http.StatusText(status),
|
|
Headers: headers,
|
|
Body: []byte(message),
|
|
}
|
|
|
|
err := h.sendHTTPResponse(streamID, requestID, &httpResp)
|
|
|
|
h.headerPool.Put(headers)
|
|
|
|
return err
|
|
}
|
|
|
|
// streamHTTPResponse streams HTTP response using zero-copy approach
|
|
// First sends headers, then streams body chunks
|
|
func (h *FrameHandler) streamHTTPResponse(streamID, requestID string, resp *http.Response) error {
|
|
if h.isClosedCheck != nil && h.isClosedCheck() {
|
|
return nil
|
|
}
|
|
|
|
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
|
|
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
|
|
|
|
// Send HTTP headers first
|
|
contentLength := resp.ContentLength // -1 if unknown
|
|
httpHead := protocol.HTTPResponseHead{
|
|
StatusCode: resp.StatusCode,
|
|
Status: resp.Status,
|
|
Headers: cleanedHeaders,
|
|
ContentLength: contentLength,
|
|
}
|
|
|
|
headBytes, err := protocol.EncodeHTTPResponseHead(&httpHead)
|
|
if err != nil {
|
|
return fmt.Errorf("encode http head: %w", err)
|
|
}
|
|
|
|
headHeader := protocol.DataHeader{
|
|
StreamID: streamID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPHead,
|
|
IsLast: false,
|
|
}
|
|
|
|
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
|
|
if err != nil {
|
|
return fmt.Errorf("encode head payload: %w", err)
|
|
}
|
|
|
|
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
|
|
if err := h.frameWriter.WriteFrame(headFrame); err != nil {
|
|
return err
|
|
}
|
|
h.frameWriter.Flush()
|
|
h.stats.AddBytesOut(int64(len(headPayload)))
|
|
|
|
// Stream body chunks - zero copy using buffer pool
|
|
bufPtr := h.bufferPool.Get(pool.SizeMedium)
|
|
defer h.bufferPool.Put(bufPtr)
|
|
buf := (*bufPtr)[:pool.SizeMedium]
|
|
|
|
for {
|
|
if h.isClosedCheck != nil && h.isClosedCheck() {
|
|
return nil
|
|
}
|
|
|
|
n, readErr := resp.Body.Read(buf)
|
|
if n > 0 {
|
|
isLast := readErr == io.EOF
|
|
|
|
chunkHeader := protocol.DataHeader{
|
|
StreamID: streamID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPBodyChunk,
|
|
IsLast: isLast,
|
|
}
|
|
|
|
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buf[:n])
|
|
if err != nil {
|
|
return fmt.Errorf("encode chunk payload: %w", err)
|
|
}
|
|
|
|
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
|
if err := h.frameWriter.WriteFrame(chunkFrame); err != nil {
|
|
return err
|
|
}
|
|
h.stats.AddBytesOut(int64(len(chunkPayload)))
|
|
}
|
|
|
|
if readErr == io.EOF {
|
|
// Send final empty chunk with IsLast=true if we haven't already
|
|
if n == 0 {
|
|
finalHeader := protocol.DataHeader{
|
|
StreamID: streamID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPBodyChunk,
|
|
IsLast: true,
|
|
}
|
|
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("encode final payload: %w", err)
|
|
}
|
|
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
|
if err := h.frameWriter.WriteFrame(finalFrame); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
h.frameWriter.Flush()
|
|
break
|
|
}
|
|
if readErr != nil {
|
|
return fmt.Errorf("read response body: %w", readErr)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *FrameHandler) sendHTTPResponse(streamID, requestID string, resp *protocol.HTTPResponse) error {
|
|
if h.isClosedCheck != nil && h.isClosedCheck() {
|
|
return nil
|
|
}
|
|
|
|
header := protocol.DataHeader{
|
|
StreamID: streamID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPResponse,
|
|
IsLast: true,
|
|
}
|
|
|
|
respBytes, err := protocol.EncodeHTTPResponse(resp)
|
|
if err != nil {
|
|
return fmt.Errorf("encode http response: %w", err)
|
|
}
|
|
|
|
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, respBytes)
|
|
if err != nil {
|
|
return fmt.Errorf("encode payload: %w", err)
|
|
}
|
|
|
|
dataFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
|
|
|
h.stats.AddBytesOut(int64(len(payload)))
|
|
|
|
if err := h.frameWriter.WriteFrame(dataFrame); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Flush immediately to ensure the response is sent without batching delay
|
|
h.frameWriter.Flush()
|
|
return nil
|
|
}
|
|
|
|
func (h *FrameHandler) closeStream(streamID string) {
|
|
h.streamMu.Lock()
|
|
stream, ok := h.streams[streamID]
|
|
if !ok {
|
|
h.streamMu.Unlock()
|
|
return
|
|
}
|
|
|
|
// Use stream-level mutex to prevent race conditions
|
|
stream.mu.Lock()
|
|
if stream.closed {
|
|
stream.mu.Unlock()
|
|
h.streamMu.Unlock()
|
|
return
|
|
}
|
|
stream.closed = true
|
|
stream.mu.Unlock()
|
|
|
|
// Remove from map first to prevent concurrent access
|
|
delete(h.streams, streamID)
|
|
h.streamMu.Unlock()
|
|
|
|
// Now safe to close resources without holding the main lock
|
|
if stream.LocalConn != nil {
|
|
stream.LocalConn.Close()
|
|
}
|
|
|
|
close(stream.Done)
|
|
|
|
if h.isClosedCheck != nil && h.isClosedCheck() {
|
|
return
|
|
}
|
|
|
|
header := protocol.DataHeader{
|
|
StreamID: streamID,
|
|
RequestID: streamID,
|
|
Type: protocol.DataTypeClose,
|
|
IsLast: true,
|
|
}
|
|
|
|
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
closeFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
|
|
|
h.frameWriter.WriteFrame(closeFrame)
|
|
}
|
|
|
|
// Close closes all streams
|
|
func (h *FrameHandler) Close() {
|
|
h.streamMu.Lock()
|
|
for streamID, stream := range h.streams {
|
|
stream.mu.Lock()
|
|
if !stream.closed {
|
|
stream.closed = true
|
|
if stream.LocalConn != nil {
|
|
stream.LocalConn.Close()
|
|
}
|
|
close(stream.Done)
|
|
}
|
|
stream.mu.Unlock()
|
|
delete(h.streams, streamID)
|
|
}
|
|
h.streamMu.Unlock()
|
|
|
|
h.streamingReqMu.Lock()
|
|
for requestID, streamingReq := range h.streamingRequests {
|
|
h.closeStreamingRequest(requestID, streamingReq)
|
|
if streamingReq.Writer != nil {
|
|
streamingReq.Writer.CloseWithError(fmt.Errorf("tunnel connection closed"))
|
|
}
|
|
delete(h.streamingRequests, requestID)
|
|
}
|
|
h.streamingReqMu.Unlock()
|
|
}
|
|
|
|
// GetStats returns the traffic stats tracker
|
|
func (h *FrameHandler) GetStats() *TrafficStats {
|
|
return h.stats
|
|
}
|
|
|
|
func (h *FrameHandler) WarmupConnectionPool(numConnections int) {
|
|
if h.tunnelType != protocol.TunnelTypeHTTP {
|
|
return
|
|
}
|
|
|
|
targetURL := fmt.Sprintf("http://%s:%d/", h.localHost, h.localPort)
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < numConnections; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
req, err := http.NewRequest("HEAD", targetURL, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
resp, err := h.httpClient.Do(req)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
io.Copy(io.Discard, resp.Body)
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func (h *FrameHandler) isLocalAddress(addr string) bool {
|
|
if addr == "localhost" || addr == "127.0.0.1" || addr == "::1" {
|
|
return true
|
|
}
|
|
|
|
if strings.HasPrefix(addr, "192.168.") ||
|
|
strings.HasPrefix(addr, "10.") ||
|
|
strings.HasPrefix(addr, "172.16.") ||
|
|
strings.HasPrefix(addr, "172.17.") ||
|
|
strings.HasPrefix(addr, "172.18.") ||
|
|
strings.HasPrefix(addr, "172.19.") ||
|
|
strings.HasPrefix(addr, "172.20.") ||
|
|
strings.HasPrefix(addr, "172.21.") ||
|
|
strings.HasPrefix(addr, "172.22.") ||
|
|
strings.HasPrefix(addr, "172.23.") ||
|
|
strings.HasPrefix(addr, "172.24.") ||
|
|
strings.HasPrefix(addr, "172.25.") ||
|
|
strings.HasPrefix(addr, "172.26.") ||
|
|
strings.HasPrefix(addr, "172.27.") ||
|
|
strings.HasPrefix(addr, "172.28.") ||
|
|
strings.HasPrefix(addr, "172.29.") ||
|
|
strings.HasPrefix(addr, "172.30.") ||
|
|
strings.HasPrefix(addr, "172.31.") {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// cleanResponseHeaders removes hop-by-hop headers that should not be forwarded
|
|
// Go's http.Client automatically handles chunked encoding, so we need to remove
|
|
// the Transfer-Encoding header to avoid sending decoded body with chunked header
|
|
func (h *FrameHandler) cleanResponseHeaders(headers http.Header) http.Header {
|
|
cleaned := make(http.Header)
|
|
|
|
// List of hop-by-hop headers to remove (RFC 2616)
|
|
hopByHopHeaders := map[string]bool{
|
|
"Connection": true,
|
|
"Keep-Alive": true,
|
|
"Proxy-Authenticate": true,
|
|
"Proxy-Authorization": true,
|
|
"Te": true,
|
|
"Trailers": true,
|
|
"Transfer-Encoding": true,
|
|
"Upgrade": true,
|
|
"Proxy-Connection": true,
|
|
}
|
|
|
|
for key, values := range headers {
|
|
canonicalKey := http.CanonicalHeaderKey(key)
|
|
|
|
if hopByHopHeaders[canonicalKey] {
|
|
continue
|
|
}
|
|
|
|
// Also check if this header is listed in Connection header
|
|
connectionHeaders := headers.Get("Connection")
|
|
if connectionHeaders != "" {
|
|
tokens := strings.Split(connectionHeaders, ",")
|
|
skip := false
|
|
for _, token := range tokens {
|
|
if strings.TrimSpace(token) == key {
|
|
skip = true
|
|
break
|
|
}
|
|
}
|
|
if skip {
|
|
continue
|
|
}
|
|
}
|
|
|
|
for _, value := range values {
|
|
cleaned.Add(key, value)
|
|
}
|
|
}
|
|
|
|
return cleaned
|
|
}
|
|
|
|
func (h *FrameHandler) handleHTTPRequestHead(header protocol.DataHeader, payload []byte) error {
|
|
httpReqHead, err := protocol.DecodeHTTPRequestHead(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decode HTTP request head: %w", err)
|
|
}
|
|
|
|
requestID := header.RequestID
|
|
if requestID == "" {
|
|
requestID = header.StreamID
|
|
}
|
|
|
|
targetURL := httpReqHead.URL
|
|
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
|
|
scheme := "http"
|
|
if h.tunnelType == protocol.TunnelTypeHTTPS {
|
|
scheme = "https"
|
|
}
|
|
targetURL = fmt.Sprintf("%s://%s:%d%s", scheme, h.localHost, h.localPort, targetURL)
|
|
}
|
|
|
|
pipeReader, pipeWriter := io.Pipe()
|
|
|
|
req, err := http.NewRequest(httpReqHead.Method, targetURL, pipeReader)
|
|
if err != nil {
|
|
pipeWriter.Close()
|
|
return h.sendHTTPError(header.StreamID, requestID, http.StatusBadGateway, fmt.Sprintf("build request: %v", err))
|
|
}
|
|
|
|
origHost := ""
|
|
for key, values := range httpReqHead.Headers {
|
|
if key == "Content-Length" {
|
|
continue
|
|
}
|
|
for _, value := range values {
|
|
req.Header.Add(key, value)
|
|
}
|
|
}
|
|
if host := req.Header.Get("Host"); host != "" {
|
|
origHost = host
|
|
}
|
|
|
|
req.ContentLength = -1
|
|
|
|
isLocalTarget := h.isLocalAddress(h.localHost)
|
|
if isLocalTarget {
|
|
if origHost != "" {
|
|
req.Host = origHost
|
|
req.Header.Set("Host", origHost)
|
|
} else {
|
|
localHostPort := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
|
|
req.Host = localHostPort
|
|
req.Header.Set("Host", localHostPort)
|
|
}
|
|
if origHost != "" {
|
|
req.Header.Set("X-Forwarded-Host", origHost)
|
|
}
|
|
} else {
|
|
targetHost := h.localHost
|
|
if h.localPort != 443 && h.localPort != 80 {
|
|
targetHost = fmt.Sprintf("%s:%d", h.localHost, h.localPort)
|
|
}
|
|
req.Host = targetHost
|
|
req.Header.Set("Host", targetHost)
|
|
if origHost != "" {
|
|
req.Header.Set("X-Forwarded-Host", origHost)
|
|
}
|
|
}
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
|
|
streamingReq := &StreamingRequest{
|
|
RequestID: requestID,
|
|
Writer: pipeWriter,
|
|
Done: make(chan struct{}),
|
|
chunkQueue: make(chan *chunkData, 512), // deeper buffer for bursty body chunks
|
|
}
|
|
|
|
h.streamingReqMu.Lock()
|
|
h.streamingRequests[requestID] = streamingReq
|
|
h.streamingReqMu.Unlock()
|
|
|
|
go func() {
|
|
defer pipeWriter.Close()
|
|
|
|
timeout := time.NewTimer(5 * time.Minute) // Timeout for receiving body chunks
|
|
defer timeout.Stop()
|
|
|
|
for {
|
|
select {
|
|
case chunk, ok := <-streamingReq.chunkQueue:
|
|
if !ok || chunk == nil {
|
|
return
|
|
}
|
|
|
|
// Reset timeout on each chunk
|
|
if !timeout.Stop() {
|
|
select {
|
|
case <-timeout.C:
|
|
default:
|
|
}
|
|
}
|
|
timeout.Reset(5 * time.Minute)
|
|
|
|
if len(chunk.data) > 0 {
|
|
if _, err := pipeWriter.Write(chunk.data); err != nil {
|
|
h.logger.Error("Failed to write to pipe",
|
|
zap.String("request_id", requestID),
|
|
zap.Error(err),
|
|
)
|
|
pipeWriter.CloseWithError(err)
|
|
return
|
|
}
|
|
}
|
|
|
|
if chunk.isLast {
|
|
return
|
|
}
|
|
case <-streamingReq.Done:
|
|
return
|
|
case <-timeout.C:
|
|
h.logger.Warn("Timeout waiting for request body chunks",
|
|
zap.String("request_id", requestID),
|
|
)
|
|
pipeWriter.CloseWithError(fmt.Errorf("timeout waiting for body chunks"))
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
defer func() {
|
|
h.closeStreamingRequest(requestID, streamingReq)
|
|
h.streamingReqMu.Lock()
|
|
delete(h.streamingRequests, requestID)
|
|
h.streamingReqMu.Unlock()
|
|
}()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
|
defer cancel()
|
|
|
|
reqWithCtx := req.WithContext(ctx)
|
|
|
|
resp, err := h.httpClient.Do(reqWithCtx)
|
|
if err != nil {
|
|
h.sendHTTPError(header.StreamID, requestID, http.StatusBadGateway, fmt.Sprintf("local request failed: %v", err))
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
const bufferThreshold int64 = 1 * 1024 * 1024
|
|
if resp.ContentLength > bufferThreshold {
|
|
h.streamHTTPResponse(header.StreamID, requestID, resp)
|
|
} else {
|
|
h.adaptiveHTTPResponse(header.StreamID, requestID, resp, bufferThreshold)
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *FrameHandler) handleHTTPRequestBodyChunk(header protocol.DataHeader, data []byte) error {
|
|
requestID := header.RequestID
|
|
if requestID == "" {
|
|
requestID = header.StreamID
|
|
}
|
|
|
|
h.streamingReqMu.RLock()
|
|
streamingReq, exists := h.streamingRequests[requestID]
|
|
h.streamingReqMu.RUnlock()
|
|
|
|
if !exists {
|
|
h.logger.Warn("Streaming request not found for body chunk",
|
|
zap.String("request_id", requestID),
|
|
)
|
|
return nil
|
|
}
|
|
|
|
streamingReq.mu.Lock()
|
|
if streamingReq.closed {
|
|
streamingReq.mu.Unlock()
|
|
h.logger.Debug("Streaming request already closed",
|
|
zap.String("request_id", requestID),
|
|
)
|
|
return nil
|
|
}
|
|
streamingReq.mu.Unlock()
|
|
|
|
chunk := &chunkData{
|
|
data: make([]byte, len(data)),
|
|
isLast: header.IsLast,
|
|
}
|
|
copy(chunk.data, data)
|
|
|
|
select {
|
|
case streamingReq.chunkQueue <- chunk:
|
|
case <-streamingReq.Done:
|
|
h.logger.Debug("Streaming request already closed",
|
|
zap.String("request_id", requestID),
|
|
)
|
|
return nil
|
|
}
|
|
|
|
if header.IsLast {
|
|
h.closeStreamingRequest(requestID, streamingReq)
|
|
h.streamingReqMu.Lock()
|
|
delete(h.streamingRequests, requestID)
|
|
h.streamingReqMu.Unlock()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// closeStreamingRequest marks a streaming request closed and signals goroutines.
|
|
func (h *FrameHandler) closeStreamingRequest(requestID string, streamingReq *StreamingRequest) {
|
|
streamingReq.mu.Lock()
|
|
if streamingReq.closed {
|
|
streamingReq.mu.Unlock()
|
|
return
|
|
}
|
|
streamingReq.closed = true
|
|
close(streamingReq.Done)
|
|
streamingReq.mu.Unlock()
|
|
}
|