Files
drip/internal/client/tcp/frame_handler.go
Gouryella aead68bb62 feat: Add HTTP streaming, compression support, and Docker deployment
enhancements

  - Add adaptive HTTP response handling with automatic streaming for large
  responses (>1MB)
  - Implement zero-copy streaming using buffer pools for better performance
  - Add compression module for reduced bandwidth usage
  - Add GitHub Container Registry workflow for automated Docker builds
  - Add production-optimized Dockerfile and docker-compose configuration
  - Simplify background mode with -d flag and improved daemon management
  - Update documentation with new command syntax and deployment guides
  - Clean up unused code and improve error handling
  - Fix lipgloss style usage (remove unnecessary .Copy() calls)
2025-12-05 22:09:07 +08:00

1079 lines
27 KiB
Go

package tcp
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// FrameHandler handles data frames and forwards to local service
type FrameHandler struct {
conn net.Conn
frameWriter *protocol.FrameWriter
localHost string
localPort int
logger *zap.Logger
streams map[string]*Stream
streamMu sync.RWMutex
streamingRequests map[string]*StreamingRequest
streamingReqMu sync.RWMutex
tunnelType protocol.TunnelType
httpClient *http.Client
stats *TrafficStats
isClosedCheck func() bool
bufferPool *pool.BufferPool
headerPool *pool.HeaderPool
}
// Stream represents a single request/response stream
type Stream struct {
ID string
LocalConn net.Conn
ResponseCh chan []byte
Done chan struct{}
closed bool
mu sync.Mutex
}
// StreamingRequest represents a streaming upload request in progress
type StreamingRequest struct {
RequestID string
Writer *io.PipeWriter
Done chan struct{}
chunkQueue chan *chunkData
closed bool
mu sync.Mutex
}
type chunkData struct {
data []byte
isLast bool
}
func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost string, localPort int, tunnelType protocol.TunnelType, logger *zap.Logger, isClosedCheck func() bool, bufferPool *pool.BufferPool) *FrameHandler {
var tlsConfig *tls.Config
if tunnelType == protocol.TunnelTypeHTTPS {
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
return &FrameHandler{
conn: conn,
frameWriter: frameWriter,
localHost: localHost,
localPort: localPort,
logger: logger,
streams: make(map[string]*Stream),
streamingRequests: make(map[string]*StreamingRequest),
tunnelType: tunnelType,
stats: NewTrafficStats(),
isClosedCheck: isClosedCheck,
bufferPool: bufferPool,
headerPool: pool.NewHeaderPool(),
httpClient: &http.Client{
// No overall timeout - streaming responses can take arbitrary time
Transport: &http.Transport{
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 500,
MaxConnsPerHost: 0,
IdleConnTimeout: 180 * time.Second,
DisableCompression: true,
DisableKeepAlives: false,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig,
ResponseHeaderTimeout: 30 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
},
}
}
func (h *FrameHandler) HandleDataFrame(frame *protocol.Frame) error {
h.stats.AddBytesIn(int64(len(frame.Payload)))
h.stats.AddRequest()
header, data, err := protocol.DecodeDataPayload(frame.Payload)
if err != nil {
return fmt.Errorf("failed to decode data payload: %w", err)
}
if header.Type == protocol.DataTypeHTTPRequest {
return h.handleHTTPFrame(header, data)
}
if header.Type == protocol.DataTypeHTTPRequestHead || header.Type == protocol.DataTypeHTTPHead {
return h.handleHTTPRequestHead(header, data)
}
if header.Type == protocol.DataTypeHTTPRequestBodyChunk || header.Type == protocol.DataTypeHTTPBodyChunk {
return h.handleHTTPRequestBodyChunk(header, data)
}
if header.Type == protocol.DataTypeClose {
h.closeStream(header.StreamID)
return nil
}
stream, err := h.getOrCreateStream(header.StreamID)
if err != nil {
return fmt.Errorf("failed to get stream: %w", err)
}
h.forwardToLocal(stream, data)
return nil
}
func (h *FrameHandler) getOrCreateStream(streamID string) (*Stream, error) {
h.streamMu.Lock()
defer h.streamMu.Unlock()
if stream, ok := h.streams[streamID]; ok {
return stream, nil
}
localAddr := net.JoinHostPort(h.localHost, fmt.Sprintf("%d", h.localPort))
localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second)
if err != nil {
return nil, fmt.Errorf("failed to connect to local service: %w", err)
}
stream := &Stream{
ID: streamID,
LocalConn: localConn,
ResponseCh: make(chan []byte, 10),
Done: make(chan struct{}),
}
h.streams[streamID] = stream
go h.handleLocalResponse(stream)
return stream, nil
}
func (h *FrameHandler) forwardToLocal(stream *Stream, data []byte) {
// Check if stream is closed using mutex
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
return
}
stream.mu.Unlock()
// Double check with Done channel
select {
case <-stream.Done:
// Stream already closed, ignore data
return
default:
}
if _, err := stream.LocalConn.Write(data); err != nil {
// Only log at debug level since connection close is often expected
h.logger.Debug("Failed to write to local service",
zap.String("stream_id", stream.ID),
zap.Error(err),
)
h.closeStream(stream.ID)
}
}
func (h *FrameHandler) handleLocalResponse(stream *Stream) {
defer h.closeStream(stream.ID)
bufPtr := h.bufferPool.Get(pool.SizeMedium)
defer h.bufferPool.Put(bufPtr)
buf := (*bufPtr)[:pool.SizeMedium]
for {
// Check if stream is closed before reading
stream.mu.Lock()
closed := stream.closed
stream.mu.Unlock()
if closed {
break
}
n, err := stream.LocalConn.Read(buf)
if err != nil {
break
}
if n > 0 {
if h.isClosedCheck != nil && h.isClosedCheck() {
break
}
// Check again after read
stream.mu.Lock()
closed = stream.closed
stream.mu.Unlock()
if closed {
break
}
header := protocol.DataHeader{
StreamID: stream.ID,
Type: protocol.DataTypeResponse,
IsLast: false,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, buf[:n])
if err != nil {
h.logger.Debug("Encode payload failed", zap.Error(err))
break
}
dataFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
err = h.frameWriter.WriteFrame(dataFrame)
if err != nil {
h.logger.Debug("Send frame failed", zap.Error(err))
break
}
h.stats.AddBytesOut(int64(len(payload)))
}
}
}
func (h *FrameHandler) handleHTTPFrame(header protocol.DataHeader, payload []byte) error {
if h.tunnelType != protocol.TunnelTypeHTTP && h.tunnelType != protocol.TunnelTypeHTTPS {
return nil
}
httpReq, err := protocol.DecodeHTTPRequest(payload)
if err != nil {
return fmt.Errorf("failed to decode HTTP request: %w", err)
}
targetURL := httpReq.URL
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
scheme := "http"
if h.tunnelType == protocol.TunnelTypeHTTPS {
scheme = "https"
}
targetURL = fmt.Sprintf("%s://%s:%d%s", scheme, h.localHost, h.localPort, targetURL)
}
req, err := http.NewRequest(httpReq.Method, targetURL, bytes.NewReader(httpReq.Body))
if err != nil {
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("build request: %v", err))
}
origHost := ""
for key, values := range httpReq.Headers {
for _, value := range values {
req.Header.Add(key, value)
}
}
if host := req.Header.Get("Host"); host != "" {
origHost = host
}
isLocalTarget := h.isLocalAddress(h.localHost)
if isLocalTarget {
if origHost != "" {
req.Host = origHost
req.Header.Set("Host", origHost)
} else {
localHostPort := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
req.Host = localHostPort
req.Header.Set("Host", localHostPort)
}
if origHost != "" {
req.Header.Set("X-Forwarded-Host", origHost)
}
} else {
targetHost := h.localHost
if h.localPort != 443 && h.localPort != 80 {
targetHost = fmt.Sprintf("%s:%d", h.localHost, h.localPort)
}
req.Host = targetHost
req.Header.Set("Host", targetHost)
if origHost != "" {
req.Header.Set("X-Forwarded-Host", origHost)
}
}
req.Header.Set("X-Forwarded-Proto", "https")
resp, err := h.httpClient.Do(req)
if err != nil {
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("local request failed: %v", err))
}
defer resp.Body.Close()
// Threshold for switching from buffered to streaming mode
const bufferThreshold int64 = 1 * 1024 * 1024 // 1MB
// If Content-Length is known and large, use streaming directly
if resp.ContentLength > bufferThreshold {
return h.streamHTTPResponse(header.StreamID, header.RequestID, resp)
}
// For small or unknown size: try buffered first, switch to streaming if too large
return h.adaptiveHTTPResponse(header.StreamID, header.RequestID, resp, bufferThreshold)
}
// adaptiveHTTPResponse tries buffered mode first, switches to streaming if data exceeds threshold
func (h *FrameHandler) adaptiveHTTPResponse(streamID, requestID string, resp *http.Response, threshold int64) error {
if h.isClosedCheck != nil && h.isClosedCheck() {
return nil
}
// Buffer for initial read
buffer := make([]byte, 0, threshold)
tempBuf := make([]byte, 32*1024) // 32KB read chunks
var totalRead int64
var hitThreshold bool
// Try to read up to threshold
for totalRead < threshold {
n, err := resp.Body.Read(tempBuf)
if n > 0 {
buffer = append(buffer, tempBuf[:n]...)
totalRead += int64(n)
}
if err == io.EOF {
// Response completed within threshold - use buffered mode
break
}
if err != nil {
return h.sendHTTPError(streamID, requestID, http.StatusBadGateway, fmt.Sprintf("read response: %v", err))
}
if totalRead >= threshold {
hitThreshold = true
break
}
}
if !hitThreshold {
// Small response - send as buffered
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
httpResp := protocol.HTTPResponse{
StatusCode: resp.StatusCode,
Status: resp.Status,
Headers: cleanedHeaders,
Body: buffer,
}
return h.sendHTTPResponse(streamID, requestID, &httpResp)
}
// Large response - switch to streaming mode
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
// First send headers
httpHead := protocol.HTTPResponseHead{
StatusCode: resp.StatusCode,
Status: resp.Status,
Headers: cleanedHeaders,
ContentLength: resp.ContentLength, // -1 if unknown
}
headBytes, err := protocol.EncodeHTTPResponseHead(&httpHead)
if err != nil {
return fmt.Errorf("encode http head: %w", err)
}
headHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPHead,
IsLast: false,
}
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
if err != nil {
return fmt.Errorf("encode head payload: %w", err)
}
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
if err := h.frameWriter.WriteFrame(headFrame); err != nil {
return err
}
h.frameWriter.Flush()
h.stats.AddBytesOut(int64(len(headPayload)))
// Send buffered data as first chunk
if len(buffer) > 0 {
chunkHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk,
IsLast: false,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer)
if err != nil {
return fmt.Errorf("encode chunk payload: %w", err)
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := h.frameWriter.WriteFrame(chunkFrame); err != nil {
return err
}
h.stats.AddBytesOut(int64(len(chunkPayload)))
}
// Clear buffer to free memory
buffer = nil
// Continue streaming remaining data
bufPtr := h.bufferPool.Get(pool.SizeMedium)
defer h.bufferPool.Put(bufPtr)
buf := (*bufPtr)[:pool.SizeMedium]
for {
if h.isClosedCheck != nil && h.isClosedCheck() {
return nil
}
n, readErr := resp.Body.Read(buf)
if n > 0 {
isLast := readErr == io.EOF
chunkHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk,
IsLast: isLast,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buf[:n])
if err != nil {
return fmt.Errorf("encode chunk payload: %w", err)
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := h.frameWriter.WriteFrame(chunkFrame); err != nil {
return err
}
h.stats.AddBytesOut(int64(len(chunkPayload)))
}
if readErr == io.EOF {
if n == 0 {
// Send final empty chunk
finalHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err != nil {
return fmt.Errorf("encode final payload: %w", err)
}
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
if err := h.frameWriter.WriteFrame(finalFrame); err != nil {
return err
}
}
h.frameWriter.Flush()
break
}
if readErr != nil {
return fmt.Errorf("read response body: %w", readErr)
}
}
return nil
}
func (h *FrameHandler) sendHTTPError(streamID, requestID string, status int, message string) error {
headers := h.headerPool.Get()
headers.Set("Content-Type", "text/plain")
httpResp := protocol.HTTPResponse{
StatusCode: status,
Status: http.StatusText(status),
Headers: headers,
Body: []byte(message),
}
err := h.sendHTTPResponse(streamID, requestID, &httpResp)
h.headerPool.Put(headers)
return err
}
// streamHTTPResponse streams HTTP response using zero-copy approach
// First sends headers, then streams body chunks
func (h *FrameHandler) streamHTTPResponse(streamID, requestID string, resp *http.Response) error {
if h.isClosedCheck != nil && h.isClosedCheck() {
return nil
}
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
// Send HTTP headers first
contentLength := resp.ContentLength // -1 if unknown
httpHead := protocol.HTTPResponseHead{
StatusCode: resp.StatusCode,
Status: resp.Status,
Headers: cleanedHeaders,
ContentLength: contentLength,
}
headBytes, err := protocol.EncodeHTTPResponseHead(&httpHead)
if err != nil {
return fmt.Errorf("encode http head: %w", err)
}
headHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPHead,
IsLast: false,
}
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
if err != nil {
return fmt.Errorf("encode head payload: %w", err)
}
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
if err := h.frameWriter.WriteFrame(headFrame); err != nil {
return err
}
h.frameWriter.Flush()
h.stats.AddBytesOut(int64(len(headPayload)))
// Stream body chunks - zero copy using buffer pool
bufPtr := h.bufferPool.Get(pool.SizeMedium)
defer h.bufferPool.Put(bufPtr)
buf := (*bufPtr)[:pool.SizeMedium]
for {
if h.isClosedCheck != nil && h.isClosedCheck() {
return nil
}
n, readErr := resp.Body.Read(buf)
if n > 0 {
isLast := readErr == io.EOF
chunkHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk,
IsLast: isLast,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buf[:n])
if err != nil {
return fmt.Errorf("encode chunk payload: %w", err)
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := h.frameWriter.WriteFrame(chunkFrame); err != nil {
return err
}
h.stats.AddBytesOut(int64(len(chunkPayload)))
}
if readErr == io.EOF {
// Send final empty chunk with IsLast=true if we haven't already
if n == 0 {
finalHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err != nil {
return fmt.Errorf("encode final payload: %w", err)
}
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
if err := h.frameWriter.WriteFrame(finalFrame); err != nil {
return err
}
}
h.frameWriter.Flush()
break
}
if readErr != nil {
return fmt.Errorf("read response body: %w", readErr)
}
}
return nil
}
func (h *FrameHandler) sendHTTPResponse(streamID, requestID string, resp *protocol.HTTPResponse) error {
if h.isClosedCheck != nil && h.isClosedCheck() {
return nil
}
header := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPResponse,
IsLast: true,
}
respBytes, err := protocol.EncodeHTTPResponse(resp)
if err != nil {
return fmt.Errorf("encode http response: %w", err)
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, respBytes)
if err != nil {
return fmt.Errorf("encode payload: %w", err)
}
dataFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
h.stats.AddBytesOut(int64(len(payload)))
if err := h.frameWriter.WriteFrame(dataFrame); err != nil {
return err
}
// Flush immediately to ensure the response is sent without batching delay
h.frameWriter.Flush()
return nil
}
func (h *FrameHandler) closeStream(streamID string) {
h.streamMu.Lock()
stream, ok := h.streams[streamID]
if !ok {
h.streamMu.Unlock()
return
}
// Use stream-level mutex to prevent race conditions
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
h.streamMu.Unlock()
return
}
stream.closed = true
stream.mu.Unlock()
// Remove from map first to prevent concurrent access
delete(h.streams, streamID)
h.streamMu.Unlock()
// Now safe to close resources without holding the main lock
if stream.LocalConn != nil {
stream.LocalConn.Close()
}
close(stream.Done)
if h.isClosedCheck != nil && h.isClosedCheck() {
return
}
header := protocol.DataHeader{
StreamID: streamID,
RequestID: streamID,
Type: protocol.DataTypeClose,
IsLast: true,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
if err != nil {
return
}
closeFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
h.frameWriter.WriteFrame(closeFrame)
}
// Close closes all streams
func (h *FrameHandler) Close() {
h.streamMu.Lock()
for streamID, stream := range h.streams {
stream.mu.Lock()
if !stream.closed {
stream.closed = true
if stream.LocalConn != nil {
stream.LocalConn.Close()
}
close(stream.Done)
}
stream.mu.Unlock()
delete(h.streams, streamID)
}
h.streamMu.Unlock()
h.streamingReqMu.Lock()
for requestID, streamingReq := range h.streamingRequests {
h.closeStreamingRequest(requestID, streamingReq)
if streamingReq.Writer != nil {
streamingReq.Writer.CloseWithError(fmt.Errorf("tunnel connection closed"))
}
delete(h.streamingRequests, requestID)
}
h.streamingReqMu.Unlock()
}
// GetStats returns the traffic stats tracker
func (h *FrameHandler) GetStats() *TrafficStats {
return h.stats
}
func (h *FrameHandler) WarmupConnectionPool(numConnections int) {
if h.tunnelType != protocol.TunnelTypeHTTP {
return
}
targetURL := fmt.Sprintf("http://%s:%d/", h.localHost, h.localPort)
var wg sync.WaitGroup
for i := 0; i < numConnections; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req, err := http.NewRequest("HEAD", targetURL, nil)
if err != nil {
return
}
resp, err := h.httpClient.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
io.Copy(io.Discard, resp.Body)
}()
}
wg.Wait()
}
func (h *FrameHandler) isLocalAddress(addr string) bool {
if addr == "localhost" || addr == "127.0.0.1" || addr == "::1" {
return true
}
if strings.HasPrefix(addr, "192.168.") ||
strings.HasPrefix(addr, "10.") ||
strings.HasPrefix(addr, "172.16.") ||
strings.HasPrefix(addr, "172.17.") ||
strings.HasPrefix(addr, "172.18.") ||
strings.HasPrefix(addr, "172.19.") ||
strings.HasPrefix(addr, "172.20.") ||
strings.HasPrefix(addr, "172.21.") ||
strings.HasPrefix(addr, "172.22.") ||
strings.HasPrefix(addr, "172.23.") ||
strings.HasPrefix(addr, "172.24.") ||
strings.HasPrefix(addr, "172.25.") ||
strings.HasPrefix(addr, "172.26.") ||
strings.HasPrefix(addr, "172.27.") ||
strings.HasPrefix(addr, "172.28.") ||
strings.HasPrefix(addr, "172.29.") ||
strings.HasPrefix(addr, "172.30.") ||
strings.HasPrefix(addr, "172.31.") {
return true
}
return false
}
// cleanResponseHeaders removes hop-by-hop headers that should not be forwarded
// Go's http.Client automatically handles chunked encoding, so we need to remove
// the Transfer-Encoding header to avoid sending decoded body with chunked header
func (h *FrameHandler) cleanResponseHeaders(headers http.Header) http.Header {
cleaned := make(http.Header)
// List of hop-by-hop headers to remove (RFC 2616)
hopByHopHeaders := map[string]bool{
"Connection": true,
"Keep-Alive": true,
"Proxy-Authenticate": true,
"Proxy-Authorization": true,
"Te": true,
"Trailers": true,
"Transfer-Encoding": true,
"Upgrade": true,
"Proxy-Connection": true,
}
for key, values := range headers {
canonicalKey := http.CanonicalHeaderKey(key)
if hopByHopHeaders[canonicalKey] {
continue
}
// Also check if this header is listed in Connection header
connectionHeaders := headers.Get("Connection")
if connectionHeaders != "" {
tokens := strings.Split(connectionHeaders, ",")
skip := false
for _, token := range tokens {
if strings.TrimSpace(token) == key {
skip = true
break
}
}
if skip {
continue
}
}
for _, value := range values {
cleaned.Add(key, value)
}
}
return cleaned
}
func (h *FrameHandler) handleHTTPRequestHead(header protocol.DataHeader, payload []byte) error {
httpReqHead, err := protocol.DecodeHTTPRequestHead(payload)
if err != nil {
return fmt.Errorf("failed to decode HTTP request head: %w", err)
}
requestID := header.RequestID
if requestID == "" {
requestID = header.StreamID
}
targetURL := httpReqHead.URL
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
scheme := "http"
if h.tunnelType == protocol.TunnelTypeHTTPS {
scheme = "https"
}
targetURL = fmt.Sprintf("%s://%s:%d%s", scheme, h.localHost, h.localPort, targetURL)
}
pipeReader, pipeWriter := io.Pipe()
req, err := http.NewRequest(httpReqHead.Method, targetURL, pipeReader)
if err != nil {
pipeWriter.Close()
return h.sendHTTPError(header.StreamID, requestID, http.StatusBadGateway, fmt.Sprintf("build request: %v", err))
}
origHost := ""
for key, values := range httpReqHead.Headers {
if key == "Content-Length" {
continue
}
for _, value := range values {
req.Header.Add(key, value)
}
}
if host := req.Header.Get("Host"); host != "" {
origHost = host
}
req.ContentLength = -1
isLocalTarget := h.isLocalAddress(h.localHost)
if isLocalTarget {
if origHost != "" {
req.Host = origHost
req.Header.Set("Host", origHost)
} else {
localHostPort := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
req.Host = localHostPort
req.Header.Set("Host", localHostPort)
}
if origHost != "" {
req.Header.Set("X-Forwarded-Host", origHost)
}
} else {
targetHost := h.localHost
if h.localPort != 443 && h.localPort != 80 {
targetHost = fmt.Sprintf("%s:%d", h.localHost, h.localPort)
}
req.Host = targetHost
req.Header.Set("Host", targetHost)
if origHost != "" {
req.Header.Set("X-Forwarded-Host", origHost)
}
}
req.Header.Set("X-Forwarded-Proto", "https")
streamingReq := &StreamingRequest{
RequestID: requestID,
Writer: pipeWriter,
Done: make(chan struct{}),
chunkQueue: make(chan *chunkData, 512), // deeper buffer for bursty body chunks
}
h.streamingReqMu.Lock()
h.streamingRequests[requestID] = streamingReq
h.streamingReqMu.Unlock()
go func() {
defer pipeWriter.Close()
timeout := time.NewTimer(5 * time.Minute) // Timeout for receiving body chunks
defer timeout.Stop()
for {
select {
case chunk, ok := <-streamingReq.chunkQueue:
if !ok || chunk == nil {
return
}
// Reset timeout on each chunk
if !timeout.Stop() {
select {
case <-timeout.C:
default:
}
}
timeout.Reset(5 * time.Minute)
if len(chunk.data) > 0 {
if _, err := pipeWriter.Write(chunk.data); err != nil {
h.logger.Error("Failed to write to pipe",
zap.String("request_id", requestID),
zap.Error(err),
)
pipeWriter.CloseWithError(err)
return
}
}
if chunk.isLast {
return
}
case <-streamingReq.Done:
return
case <-timeout.C:
h.logger.Warn("Timeout waiting for request body chunks",
zap.String("request_id", requestID),
)
pipeWriter.CloseWithError(fmt.Errorf("timeout waiting for body chunks"))
return
}
}
}()
go func() {
defer func() {
h.closeStreamingRequest(requestID, streamingReq)
h.streamingReqMu.Lock()
delete(h.streamingRequests, requestID)
h.streamingReqMu.Unlock()
}()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
reqWithCtx := req.WithContext(ctx)
resp, err := h.httpClient.Do(reqWithCtx)
if err != nil {
h.sendHTTPError(header.StreamID, requestID, http.StatusBadGateway, fmt.Sprintf("local request failed: %v", err))
return
}
defer resp.Body.Close()
const bufferThreshold int64 = 1 * 1024 * 1024
if resp.ContentLength > bufferThreshold {
h.streamHTTPResponse(header.StreamID, requestID, resp)
} else {
h.adaptiveHTTPResponse(header.StreamID, requestID, resp, bufferThreshold)
}
}()
return nil
}
func (h *FrameHandler) handleHTTPRequestBodyChunk(header protocol.DataHeader, data []byte) error {
requestID := header.RequestID
if requestID == "" {
requestID = header.StreamID
}
h.streamingReqMu.RLock()
streamingReq, exists := h.streamingRequests[requestID]
h.streamingReqMu.RUnlock()
if !exists {
h.logger.Warn("Streaming request not found for body chunk",
zap.String("request_id", requestID),
)
return nil
}
streamingReq.mu.Lock()
if streamingReq.closed {
streamingReq.mu.Unlock()
h.logger.Debug("Streaming request already closed",
zap.String("request_id", requestID),
)
return nil
}
streamingReq.mu.Unlock()
chunk := &chunkData{
data: make([]byte, len(data)),
isLast: header.IsLast,
}
copy(chunk.data, data)
select {
case streamingReq.chunkQueue <- chunk:
case <-streamingReq.Done:
h.logger.Debug("Streaming request already closed",
zap.String("request_id", requestID),
)
return nil
}
if header.IsLast {
h.closeStreamingRequest(requestID, streamingReq)
h.streamingReqMu.Lock()
delete(h.streamingRequests, requestID)
h.streamingReqMu.Unlock()
}
return nil
}
// closeStreamingRequest marks a streaming request closed and signals goroutines.
func (h *FrameHandler) closeStreamingRequest(requestID string, streamingReq *StreamingRequest) {
streamingReq.mu.Lock()
if streamingReq.closed {
streamingReq.mu.Unlock()
return
}
streamingReq.closed = true
close(streamingReq.Done)
streamingReq.mu.Unlock()
}