Files
drip/internal/client/tcp/frame_handler.go
Gouryella 35e6c86e1f feat(client): Added the --short option to the version command to support plain text output.
Added the `--short` flag to the `version` command for printing version information without styles.

In this mode, only the version, Git commit hash, and build time in plain text format will be output, facilitating script parsing.

Optimized Windows process detection logic to improve runtime accuracy.

Removed redundant comments and simplified signal checking methods, making the code clearer and easier to maintain.

refactor(protocol): Replaced string matching of data frame types with enumeration types.

Unified the representation of data frame types in the protocol, using the `DataType` enumeration to improve performance and readability.

Introduced a pooled buffer mechanism to improve memory efficiency in high-load scenarios.

refactor(ui): Adjusted style definitions, removing hard-coded color values.

Removed fixed color settings from some lipgloss styles, providing flexibility for future theme customization.

``` docs(install): Improved the version extraction function in the installation script.

Added the `get_version_from_binary` function to enhance version identification capabilities, prioritizing plain mode output, ensuring accurate version number acquisition for the drip client or server across different terminal environments.

perf(tcp): Improved TCP processing performance and connection management capabilities.

Adjusted HTTP client transmission parameter configuration, increasing the maximum number of idle connections to accommodate higher concurrent requests.

Improved error handling logic, adding special checks for common cases such as closing network connections to avoid log pollution.

chore(writer): Expanded the FrameWriter queue length to improve batch write stability.

Increased the FrameWriter queue size from 1024 to 2048, and released pooled resources after flushing, better handling sudden traffic spikes and reducing memory usage fluctuations.
2025-12-03 18:11:37 +08:00

441 lines
11 KiB
Go

package tcp
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// FrameHandler handles data frames and forwards to local service
type FrameHandler struct {
conn net.Conn
frameWriter *protocol.FrameWriter
localHost string
localPort int
logger *zap.Logger
streams map[string]*Stream
streamMu sync.RWMutex
tunnelType protocol.TunnelType
httpClient *http.Client
stats *TrafficStats
isClosedCheck func() bool
bufferPool *pool.BufferPool
headerPool *pool.HeaderPool
}
// Stream represents a single request/response stream
type Stream struct {
ID string
LocalConn net.Conn
ResponseCh chan []byte
Done chan struct{}
}
func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost string, localPort int, tunnelType protocol.TunnelType, logger *zap.Logger, isClosedCheck func() bool, bufferPool *pool.BufferPool) *FrameHandler {
var tlsConfig *tls.Config
if tunnelType == protocol.TunnelTypeHTTPS {
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
return &FrameHandler{
conn: conn,
frameWriter: frameWriter,
localHost: localHost,
localPort: localPort,
logger: logger,
streams: make(map[string]*Stream),
tunnelType: tunnelType,
stats: NewTrafficStats(),
isClosedCheck: isClosedCheck,
bufferPool: bufferPool,
headerPool: pool.NewHeaderPool(),
httpClient: &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 1000, // Optimized for both mid and high load scenarios
MaxIdleConnsPerHost: 500, // Sufficient for 400+ concurrent connections
MaxConnsPerHost: 0, // Unlimited
IdleConnTimeout: 180 * time.Second,
DisableCompression: true,
DisableKeepAlives: false,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig,
ResponseHeaderTimeout: 15 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
},
}
}
func (h *FrameHandler) HandleDataFrame(frame *protocol.Frame) error {
h.stats.AddBytesIn(int64(len(frame.Payload)))
h.stats.AddRequest()
header, data, err := protocol.DecodeDataPayload(frame.Payload)
if err != nil {
return fmt.Errorf("failed to decode data payload: %w", err)
}
if header.Type == protocol.DataTypeHTTPRequest {
return h.handleHTTPFrame(header, data)
}
if header.Type == protocol.DataTypeClose {
h.closeStream(header.StreamID)
return nil
}
stream, err := h.getOrCreateStream(header.StreamID)
if err != nil {
return fmt.Errorf("failed to get stream: %w", err)
}
h.forwardToLocal(stream, data)
return nil
}
func (h *FrameHandler) getOrCreateStream(streamID string) (*Stream, error) {
h.streamMu.Lock()
defer h.streamMu.Unlock()
if stream, ok := h.streams[streamID]; ok {
return stream, nil
}
localAddr := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second)
if err != nil {
return nil, fmt.Errorf("failed to connect to local service: %w", err)
}
stream := &Stream{
ID: streamID,
LocalConn: localConn,
ResponseCh: make(chan []byte, 10),
Done: make(chan struct{}),
}
h.streams[streamID] = stream
go h.handleLocalResponse(stream)
return stream, nil
}
func (h *FrameHandler) forwardToLocal(stream *Stream, data []byte) {
if _, err := stream.LocalConn.Write(data); err != nil {
h.logger.Error("Failed to write to local service",
zap.String("stream_id", stream.ID),
zap.Error(err),
)
h.closeStream(stream.ID)
}
}
func (h *FrameHandler) handleLocalResponse(stream *Stream) {
defer h.closeStream(stream.ID)
bufPtr := h.bufferPool.Get(pool.SizeMedium)
defer h.bufferPool.Put(bufPtr)
buf := (*bufPtr)[:pool.SizeMedium]
for {
n, err := stream.LocalConn.Read(buf)
if err != nil {
break
}
if n > 0 {
if h.isClosedCheck != nil && h.isClosedCheck() {
break
}
header := protocol.DataHeader{
StreamID: stream.ID,
Type: protocol.DataTypeResponse,
IsLast: false,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, buf[:n])
if err != nil {
h.logger.Error("Encode payload failed", zap.Error(err))
break
}
dataFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
err = h.frameWriter.WriteFrame(dataFrame)
if err != nil {
h.logger.Error("Send frame failed", zap.Error(err))
break
}
h.stats.AddBytesOut(int64(len(payload)))
}
}
}
func (h *FrameHandler) handleHTTPFrame(header protocol.DataHeader, payload []byte) error {
if h.tunnelType != protocol.TunnelTypeHTTP && h.tunnelType != protocol.TunnelTypeHTTPS {
return nil
}
httpReq, err := protocol.DecodeHTTPRequest(payload)
if err != nil {
return fmt.Errorf("failed to decode HTTP request: %w", err)
}
targetURL := httpReq.URL
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
scheme := "http"
if h.tunnelType == protocol.TunnelTypeHTTPS {
scheme = "https"
}
targetURL = fmt.Sprintf("%s://%s:%d%s", scheme, h.localHost, h.localPort, targetURL)
}
req, err := http.NewRequest(httpReq.Method, targetURL, bytes.NewReader(httpReq.Body))
if err != nil {
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("build request: %v", err))
}
origHost := ""
for key, values := range httpReq.Headers {
for _, value := range values {
req.Header.Add(key, value)
}
}
if host := req.Header.Get("Host"); host != "" {
origHost = host
}
isLocalTarget := h.isLocalAddress(h.localHost)
if isLocalTarget {
if origHost != "" {
req.Host = origHost
req.Header.Set("Host", origHost)
} else {
localHostPort := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
req.Host = localHostPort
req.Header.Set("Host", localHostPort)
}
if origHost != "" {
req.Header.Set("X-Forwarded-Host", origHost)
}
} else {
targetHost := h.localHost
if h.localPort != 443 && h.localPort != 80 {
targetHost = fmt.Sprintf("%s:%d", h.localHost, h.localPort)
}
req.Host = targetHost
req.Header.Set("Host", targetHost)
if origHost != "" {
req.Header.Set("X-Forwarded-Host", origHost)
}
}
req.Header.Set("X-Forwarded-Proto", "https")
resp, err := h.httpClient.Do(req)
if err != nil {
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("local request failed: %v", err))
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("read response: %v", err))
}
httpResp := protocol.HTTPResponse{
StatusCode: resp.StatusCode,
Status: resp.Status,
Headers: resp.Header,
Body: body,
}
return h.sendHTTPResponse(header.StreamID, header.RequestID, &httpResp)
}
func (h *FrameHandler) sendHTTPError(streamID, requestID string, status int, message string) error {
headers := h.headerPool.Get()
headers.Set("Content-Type", "text/plain")
httpResp := protocol.HTTPResponse{
StatusCode: status,
Status: http.StatusText(status),
Headers: headers,
Body: []byte(message),
}
err := h.sendHTTPResponse(streamID, requestID, &httpResp)
h.headerPool.Put(headers)
return err
}
func (h *FrameHandler) sendHTTPResponse(streamID, requestID string, resp *protocol.HTTPResponse) error {
if h.isClosedCheck != nil && h.isClosedCheck() {
return nil
}
header := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPResponse,
IsLast: true,
}
respBytes, err := protocol.EncodeHTTPResponse(resp)
if err != nil {
return fmt.Errorf("encode http response: %w", err)
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, respBytes)
if err != nil {
return fmt.Errorf("encode payload: %w", err)
}
dataFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
h.stats.AddBytesOut(int64(len(payload)))
return h.frameWriter.WriteFrame(dataFrame)
}
func (h *FrameHandler) closeStream(streamID string) {
h.streamMu.Lock()
defer h.streamMu.Unlock()
stream, ok := h.streams[streamID]
if !ok {
return
}
if stream.LocalConn != nil {
stream.LocalConn.Close()
}
close(stream.Done)
delete(h.streams, streamID)
if h.isClosedCheck != nil && h.isClosedCheck() {
return
}
header := protocol.DataHeader{
StreamID: streamID,
RequestID: streamID,
Type: protocol.DataTypeClose,
IsLast: true,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
if err != nil {
return
}
closeFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
h.frameWriter.WriteFrame(closeFrame)
}
// Close closes all streams
func (h *FrameHandler) Close() {
h.streamMu.Lock()
defer h.streamMu.Unlock()
for streamID, stream := range h.streams {
if stream.LocalConn != nil {
stream.LocalConn.Close()
}
close(stream.Done)
delete(h.streams, streamID)
}
}
// GetStats returns the traffic stats tracker
func (h *FrameHandler) GetStats() *TrafficStats {
return h.stats
}
func (h *FrameHandler) WarmupConnectionPool(numConnections int) {
if h.tunnelType != protocol.TunnelTypeHTTP {
return
}
targetURL := fmt.Sprintf("http://%s:%d/", h.localHost, h.localPort)
var wg sync.WaitGroup
for i := 0; i < numConnections; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req, err := http.NewRequest("HEAD", targetURL, nil)
if err != nil {
return
}
resp, err := h.httpClient.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
io.Copy(io.Discard, resp.Body)
}()
}
wg.Wait()
}
func (h *FrameHandler) isLocalAddress(addr string) bool {
if addr == "localhost" || addr == "127.0.0.1" || addr == "::1" {
return true
}
if strings.HasPrefix(addr, "192.168.") ||
strings.HasPrefix(addr, "10.") ||
strings.HasPrefix(addr, "172.16.") ||
strings.HasPrefix(addr, "172.17.") ||
strings.HasPrefix(addr, "172.18.") ||
strings.HasPrefix(addr, "172.19.") ||
strings.HasPrefix(addr, "172.20.") ||
strings.HasPrefix(addr, "172.21.") ||
strings.HasPrefix(addr, "172.22.") ||
strings.HasPrefix(addr, "172.23.") ||
strings.HasPrefix(addr, "172.24.") ||
strings.HasPrefix(addr, "172.25.") ||
strings.HasPrefix(addr, "172.26.") ||
strings.HasPrefix(addr, "172.27.") ||
strings.HasPrefix(addr, "172.28.") ||
strings.HasPrefix(addr, "172.29.") ||
strings.HasPrefix(addr, "172.30.") ||
strings.HasPrefix(addr, "172.31.") {
return true
}
return false
}