mirror of
https://github.com/Gouryella/drip.git
synced 2026-04-28 05:10:55 +00:00
Added the `--short` flag to the `version` command for printing version information without styles. In this mode, only the version, Git commit hash, and build time in plain text format will be output, facilitating script parsing. Optimized Windows process detection logic to improve runtime accuracy. Removed redundant comments and simplified signal checking methods, making the code clearer and easier to maintain. refactor(protocol): Replaced string matching of data frame types with enumeration types. Unified the representation of data frame types in the protocol, using the `DataType` enumeration to improve performance and readability. Introduced a pooled buffer mechanism to improve memory efficiency in high-load scenarios. refactor(ui): Adjusted style definitions, removing hard-coded color values. Removed fixed color settings from some lipgloss styles, providing flexibility for future theme customization. ``` docs(install): Improved the version extraction function in the installation script. Added the `get_version_from_binary` function to enhance version identification capabilities, prioritizing plain mode output, ensuring accurate version number acquisition for the drip client or server across different terminal environments. perf(tcp): Improved TCP processing performance and connection management capabilities. Adjusted HTTP client transmission parameter configuration, increasing the maximum number of idle connections to accommodate higher concurrent requests. Improved error handling logic, adding special checks for common cases such as closing network connections to avoid log pollution. chore(writer): Expanded the FrameWriter queue length to improve batch write stability. Increased the FrameWriter queue size from 1024 to 2048, and released pooled resources after flushing, better handling sudden traffic spikes and reducing memory usage fluctuations.
441 lines
11 KiB
Go
441 lines
11 KiB
Go
package tcp
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"drip/internal/shared/pool"
|
|
"drip/internal/shared/protocol"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// FrameHandler handles data frames and forwards to local service
|
|
type FrameHandler struct {
|
|
conn net.Conn
|
|
frameWriter *protocol.FrameWriter
|
|
localHost string
|
|
localPort int
|
|
logger *zap.Logger
|
|
streams map[string]*Stream
|
|
streamMu sync.RWMutex
|
|
tunnelType protocol.TunnelType
|
|
httpClient *http.Client
|
|
stats *TrafficStats
|
|
isClosedCheck func() bool
|
|
bufferPool *pool.BufferPool
|
|
headerPool *pool.HeaderPool
|
|
}
|
|
|
|
// Stream represents a single request/response stream
|
|
type Stream struct {
|
|
ID string
|
|
LocalConn net.Conn
|
|
ResponseCh chan []byte
|
|
Done chan struct{}
|
|
}
|
|
|
|
func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost string, localPort int, tunnelType protocol.TunnelType, logger *zap.Logger, isClosedCheck func() bool, bufferPool *pool.BufferPool) *FrameHandler {
|
|
var tlsConfig *tls.Config
|
|
if tunnelType == protocol.TunnelTypeHTTPS {
|
|
tlsConfig = &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
}
|
|
}
|
|
|
|
return &FrameHandler{
|
|
conn: conn,
|
|
frameWriter: frameWriter,
|
|
localHost: localHost,
|
|
localPort: localPort,
|
|
logger: logger,
|
|
streams: make(map[string]*Stream),
|
|
tunnelType: tunnelType,
|
|
stats: NewTrafficStats(),
|
|
isClosedCheck: isClosedCheck,
|
|
bufferPool: bufferPool,
|
|
headerPool: pool.NewHeaderPool(),
|
|
httpClient: &http.Client{
|
|
Timeout: 30 * time.Second,
|
|
Transport: &http.Transport{
|
|
MaxIdleConns: 1000, // Optimized for both mid and high load scenarios
|
|
MaxIdleConnsPerHost: 500, // Sufficient for 400+ concurrent connections
|
|
MaxConnsPerHost: 0, // Unlimited
|
|
IdleConnTimeout: 180 * time.Second,
|
|
DisableCompression: true,
|
|
DisableKeepAlives: false,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
TLSClientConfig: tlsConfig,
|
|
ResponseHeaderTimeout: 15 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
DialContext: (&net.Dialer{
|
|
Timeout: 5 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}).DialContext,
|
|
},
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (h *FrameHandler) HandleDataFrame(frame *protocol.Frame) error {
|
|
h.stats.AddBytesIn(int64(len(frame.Payload)))
|
|
h.stats.AddRequest()
|
|
|
|
header, data, err := protocol.DecodeDataPayload(frame.Payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decode data payload: %w", err)
|
|
}
|
|
|
|
if header.Type == protocol.DataTypeHTTPRequest {
|
|
return h.handleHTTPFrame(header, data)
|
|
}
|
|
|
|
if header.Type == protocol.DataTypeClose {
|
|
h.closeStream(header.StreamID)
|
|
return nil
|
|
}
|
|
|
|
stream, err := h.getOrCreateStream(header.StreamID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get stream: %w", err)
|
|
}
|
|
|
|
h.forwardToLocal(stream, data)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *FrameHandler) getOrCreateStream(streamID string) (*Stream, error) {
|
|
h.streamMu.Lock()
|
|
defer h.streamMu.Unlock()
|
|
|
|
if stream, ok := h.streams[streamID]; ok {
|
|
return stream, nil
|
|
}
|
|
|
|
localAddr := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
|
|
localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to local service: %w", err)
|
|
}
|
|
|
|
stream := &Stream{
|
|
ID: streamID,
|
|
LocalConn: localConn,
|
|
ResponseCh: make(chan []byte, 10),
|
|
Done: make(chan struct{}),
|
|
}
|
|
|
|
h.streams[streamID] = stream
|
|
|
|
go h.handleLocalResponse(stream)
|
|
|
|
return stream, nil
|
|
}
|
|
|
|
func (h *FrameHandler) forwardToLocal(stream *Stream, data []byte) {
|
|
if _, err := stream.LocalConn.Write(data); err != nil {
|
|
h.logger.Error("Failed to write to local service",
|
|
zap.String("stream_id", stream.ID),
|
|
zap.Error(err),
|
|
)
|
|
h.closeStream(stream.ID)
|
|
}
|
|
}
|
|
|
|
func (h *FrameHandler) handleLocalResponse(stream *Stream) {
|
|
defer h.closeStream(stream.ID)
|
|
|
|
bufPtr := h.bufferPool.Get(pool.SizeMedium)
|
|
defer h.bufferPool.Put(bufPtr)
|
|
buf := (*bufPtr)[:pool.SizeMedium]
|
|
|
|
for {
|
|
n, err := stream.LocalConn.Read(buf)
|
|
if err != nil {
|
|
break
|
|
}
|
|
|
|
if n > 0 {
|
|
if h.isClosedCheck != nil && h.isClosedCheck() {
|
|
break
|
|
}
|
|
|
|
header := protocol.DataHeader{
|
|
StreamID: stream.ID,
|
|
Type: protocol.DataTypeResponse,
|
|
IsLast: false,
|
|
}
|
|
|
|
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, buf[:n])
|
|
if err != nil {
|
|
h.logger.Error("Encode payload failed", zap.Error(err))
|
|
break
|
|
}
|
|
|
|
dataFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
|
err = h.frameWriter.WriteFrame(dataFrame)
|
|
if err != nil {
|
|
h.logger.Error("Send frame failed", zap.Error(err))
|
|
break
|
|
}
|
|
|
|
h.stats.AddBytesOut(int64(len(payload)))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *FrameHandler) handleHTTPFrame(header protocol.DataHeader, payload []byte) error {
|
|
if h.tunnelType != protocol.TunnelTypeHTTP && h.tunnelType != protocol.TunnelTypeHTTPS {
|
|
return nil
|
|
}
|
|
|
|
httpReq, err := protocol.DecodeHTTPRequest(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decode HTTP request: %w", err)
|
|
}
|
|
|
|
targetURL := httpReq.URL
|
|
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
|
|
scheme := "http"
|
|
if h.tunnelType == protocol.TunnelTypeHTTPS {
|
|
scheme = "https"
|
|
}
|
|
targetURL = fmt.Sprintf("%s://%s:%d%s", scheme, h.localHost, h.localPort, targetURL)
|
|
}
|
|
|
|
req, err := http.NewRequest(httpReq.Method, targetURL, bytes.NewReader(httpReq.Body))
|
|
if err != nil {
|
|
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("build request: %v", err))
|
|
}
|
|
|
|
origHost := ""
|
|
for key, values := range httpReq.Headers {
|
|
for _, value := range values {
|
|
req.Header.Add(key, value)
|
|
}
|
|
}
|
|
if host := req.Header.Get("Host"); host != "" {
|
|
origHost = host
|
|
}
|
|
|
|
isLocalTarget := h.isLocalAddress(h.localHost)
|
|
|
|
if isLocalTarget {
|
|
if origHost != "" {
|
|
req.Host = origHost
|
|
req.Header.Set("Host", origHost)
|
|
} else {
|
|
localHostPort := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
|
|
req.Host = localHostPort
|
|
req.Header.Set("Host", localHostPort)
|
|
}
|
|
if origHost != "" {
|
|
req.Header.Set("X-Forwarded-Host", origHost)
|
|
}
|
|
} else {
|
|
targetHost := h.localHost
|
|
if h.localPort != 443 && h.localPort != 80 {
|
|
targetHost = fmt.Sprintf("%s:%d", h.localHost, h.localPort)
|
|
}
|
|
req.Host = targetHost
|
|
req.Header.Set("Host", targetHost)
|
|
if origHost != "" {
|
|
req.Header.Set("X-Forwarded-Host", origHost)
|
|
}
|
|
}
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
|
|
resp, err := h.httpClient.Do(req)
|
|
if err != nil {
|
|
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("local request failed: %v", err))
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("read response: %v", err))
|
|
}
|
|
|
|
httpResp := protocol.HTTPResponse{
|
|
StatusCode: resp.StatusCode,
|
|
Status: resp.Status,
|
|
Headers: resp.Header,
|
|
Body: body,
|
|
}
|
|
|
|
return h.sendHTTPResponse(header.StreamID, header.RequestID, &httpResp)
|
|
}
|
|
|
|
func (h *FrameHandler) sendHTTPError(streamID, requestID string, status int, message string) error {
|
|
headers := h.headerPool.Get()
|
|
headers.Set("Content-Type", "text/plain")
|
|
|
|
httpResp := protocol.HTTPResponse{
|
|
StatusCode: status,
|
|
Status: http.StatusText(status),
|
|
Headers: headers,
|
|
Body: []byte(message),
|
|
}
|
|
|
|
err := h.sendHTTPResponse(streamID, requestID, &httpResp)
|
|
|
|
h.headerPool.Put(headers)
|
|
|
|
return err
|
|
}
|
|
|
|
func (h *FrameHandler) sendHTTPResponse(streamID, requestID string, resp *protocol.HTTPResponse) error {
|
|
if h.isClosedCheck != nil && h.isClosedCheck() {
|
|
return nil
|
|
}
|
|
|
|
header := protocol.DataHeader{
|
|
StreamID: streamID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPResponse,
|
|
IsLast: true,
|
|
}
|
|
|
|
respBytes, err := protocol.EncodeHTTPResponse(resp)
|
|
if err != nil {
|
|
return fmt.Errorf("encode http response: %w", err)
|
|
}
|
|
|
|
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, respBytes)
|
|
if err != nil {
|
|
return fmt.Errorf("encode payload: %w", err)
|
|
}
|
|
|
|
dataFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
|
|
|
h.stats.AddBytesOut(int64(len(payload)))
|
|
|
|
return h.frameWriter.WriteFrame(dataFrame)
|
|
}
|
|
|
|
func (h *FrameHandler) closeStream(streamID string) {
|
|
h.streamMu.Lock()
|
|
defer h.streamMu.Unlock()
|
|
|
|
stream, ok := h.streams[streamID]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if stream.LocalConn != nil {
|
|
stream.LocalConn.Close()
|
|
}
|
|
|
|
close(stream.Done)
|
|
|
|
delete(h.streams, streamID)
|
|
|
|
if h.isClosedCheck != nil && h.isClosedCheck() {
|
|
return
|
|
}
|
|
|
|
header := protocol.DataHeader{
|
|
StreamID: streamID,
|
|
RequestID: streamID,
|
|
Type: protocol.DataTypeClose,
|
|
IsLast: true,
|
|
}
|
|
|
|
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
closeFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
|
|
|
h.frameWriter.WriteFrame(closeFrame)
|
|
}
|
|
|
|
// Close closes all streams
|
|
func (h *FrameHandler) Close() {
|
|
h.streamMu.Lock()
|
|
defer h.streamMu.Unlock()
|
|
|
|
for streamID, stream := range h.streams {
|
|
if stream.LocalConn != nil {
|
|
stream.LocalConn.Close()
|
|
}
|
|
close(stream.Done)
|
|
delete(h.streams, streamID)
|
|
}
|
|
}
|
|
|
|
// GetStats returns the traffic stats tracker
|
|
func (h *FrameHandler) GetStats() *TrafficStats {
|
|
return h.stats
|
|
}
|
|
|
|
func (h *FrameHandler) WarmupConnectionPool(numConnections int) {
|
|
if h.tunnelType != protocol.TunnelTypeHTTP {
|
|
return
|
|
}
|
|
|
|
targetURL := fmt.Sprintf("http://%s:%d/", h.localHost, h.localPort)
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < numConnections; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
req, err := http.NewRequest("HEAD", targetURL, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
resp, err := h.httpClient.Do(req)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
io.Copy(io.Discard, resp.Body)
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func (h *FrameHandler) isLocalAddress(addr string) bool {
|
|
if addr == "localhost" || addr == "127.0.0.1" || addr == "::1" {
|
|
return true
|
|
}
|
|
|
|
if strings.HasPrefix(addr, "192.168.") ||
|
|
strings.HasPrefix(addr, "10.") ||
|
|
strings.HasPrefix(addr, "172.16.") ||
|
|
strings.HasPrefix(addr, "172.17.") ||
|
|
strings.HasPrefix(addr, "172.18.") ||
|
|
strings.HasPrefix(addr, "172.19.") ||
|
|
strings.HasPrefix(addr, "172.20.") ||
|
|
strings.HasPrefix(addr, "172.21.") ||
|
|
strings.HasPrefix(addr, "172.22.") ||
|
|
strings.HasPrefix(addr, "172.23.") ||
|
|
strings.HasPrefix(addr, "172.24.") ||
|
|
strings.HasPrefix(addr, "172.25.") ||
|
|
strings.HasPrefix(addr, "172.26.") ||
|
|
strings.HasPrefix(addr, "172.27.") ||
|
|
strings.HasPrefix(addr, "172.28.") ||
|
|
strings.HasPrefix(addr, "172.29.") ||
|
|
strings.HasPrefix(addr, "172.30.") ||
|
|
strings.HasPrefix(addr, "172.31.") {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|