mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-24 05:10:43 +00:00
- Added handling for the HTTP CONNECT method, supporting HTTPS tunneling proxies. - Introducing connQueueListener to hand over HTTP connections to standard http.Server handling. - Optimized Connection struct fields and lifecycle management logic - Remove redundant comments and streamline some response writing logic - Upgrade the golang.org/x/net dependency version to support new features. - Enhanced HTTP request parsing stability and improved error logging methods. - Adjusted the TCP listener startup process to integrate HTTP/2 configuration support. - Improve the connection closing mechanism to avoid resource leakage issues.
251 lines
6.1 KiB
Go
251 lines
6.1 KiB
Go
package tcp
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"time"
|
|
|
|
"drip/internal/shared/httputil"
|
|
"drip/internal/shared/netutil"
|
|
"drip/internal/shared/pool"
|
|
"drip/internal/shared/protocol"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
func (c *PoolClient) handleStream(h *sessionHandle, stream net.Conn) {
|
|
defer c.wg.Done()
|
|
defer func() {
|
|
h.active.Add(-1)
|
|
c.stats.DecActiveConnections()
|
|
}()
|
|
defer stream.Close()
|
|
|
|
switch c.tunnelType {
|
|
case protocol.TunnelTypeHTTP, protocol.TunnelTypeHTTPS:
|
|
c.handleHTTPStream(stream)
|
|
default:
|
|
c.handleTCPStream(stream)
|
|
}
|
|
}
|
|
|
|
func (c *PoolClient) handleTCPStream(stream net.Conn) {
|
|
localConn, err := net.DialTimeout("tcp", net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort)), 10*time.Second)
|
|
if err != nil {
|
|
c.logger.Debug("Dial local failed", zap.Error(err))
|
|
return
|
|
}
|
|
defer localConn.Close()
|
|
|
|
if tcpConn, ok := localConn.(*net.TCPConn); ok {
|
|
_ = tcpConn.SetNoDelay(true)
|
|
_ = tcpConn.SetKeepAlive(true)
|
|
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
|
_ = tcpConn.SetReadBuffer(256 * 1024)
|
|
_ = tcpConn.SetWriteBuffer(256 * 1024)
|
|
}
|
|
|
|
_ = netutil.PipeWithCallbacksAndBufferSize(
|
|
c.ctx,
|
|
stream,
|
|
localConn,
|
|
pool.SizeLarge,
|
|
func(n int64) { c.stats.AddBytesIn(n) },
|
|
func(n int64) { c.stats.AddBytesOut(n) },
|
|
)
|
|
}
|
|
|
|
func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
|
_ = stream.SetReadDeadline(time.Now().Add(30 * time.Second))
|
|
|
|
cc := netutil.NewCountingConn(stream,
|
|
func(n int64) { c.stats.AddBytesIn(n) },
|
|
func(n int64) { c.stats.AddBytesOut(n) },
|
|
)
|
|
|
|
br := bufio.NewReaderSize(cc, 32*1024)
|
|
req, err := http.ReadRequest(br)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer req.Body.Close()
|
|
|
|
_ = stream.SetReadDeadline(time.Time{})
|
|
|
|
if httputil.IsWebSocketUpgrade(req) {
|
|
c.handleWebSocketUpgrade(cc, req)
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(c.ctx)
|
|
defer cancel()
|
|
|
|
scheme := "http"
|
|
if c.tunnelType == protocol.TunnelTypeHTTPS {
|
|
scheme = "https"
|
|
}
|
|
|
|
targetURL := fmt.Sprintf("%s://%s:%d%s", scheme, c.localHost, c.localPort, req.URL.RequestURI())
|
|
outReq, err := http.NewRequestWithContext(ctx, req.Method, targetURL, req.Body)
|
|
if err != nil {
|
|
httputil.WriteProxyError(cc, http.StatusBadGateway, "Bad Gateway")
|
|
return
|
|
}
|
|
|
|
origHost := req.Host
|
|
httputil.CopyHeaders(outReq.Header, req.Header)
|
|
httputil.CleanHopByHopHeaders(outReq.Header)
|
|
|
|
outReq.Header.Del("Accept-Encoding")
|
|
|
|
targetHost := c.localHost
|
|
if c.localPort != 80 && c.localPort != 443 {
|
|
targetHost = fmt.Sprintf("%s:%d", c.localHost, c.localPort)
|
|
}
|
|
outReq.Host = targetHost
|
|
outReq.Header.Set("Host", targetHost)
|
|
if origHost != "" {
|
|
outReq.Header.Set("X-Forwarded-Host", origHost)
|
|
}
|
|
outReq.Header.Set("X-Forwarded-Proto", "https")
|
|
|
|
resp, err := c.httpClient.Do(outReq)
|
|
if err != nil {
|
|
httputil.WriteProxyError(cc, http.StatusBadGateway, "Local service unavailable")
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
_ = stream.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
|
if err := writeResponseHeader(cc, resp); err != nil {
|
|
return
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
stream.Close()
|
|
case <-done:
|
|
}
|
|
}()
|
|
|
|
buf := make([]byte, 32*1024)
|
|
for {
|
|
nr, er := resp.Body.Read(buf)
|
|
if nr > 0 {
|
|
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
nw, ew := cc.Write(buf[:nr])
|
|
if ew != nil || nr != nw {
|
|
break
|
|
}
|
|
}
|
|
if er != nil {
|
|
break
|
|
}
|
|
}
|
|
close(done)
|
|
}
|
|
|
|
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
|
scheme := "ws"
|
|
if c.tunnelType == protocol.TunnelTypeHTTPS {
|
|
scheme = "wss"
|
|
}
|
|
|
|
targetAddr := net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort))
|
|
localConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second)
|
|
if err != nil {
|
|
httputil.WriteProxyError(cc, http.StatusBadGateway, "WebSocket backend unavailable")
|
|
return
|
|
}
|
|
defer localConn.Close()
|
|
|
|
if c.tunnelType == protocol.TunnelTypeHTTPS {
|
|
tlsConn := tls.Client(localConn, &tls.Config{InsecureSkipVerify: true})
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
httputil.WriteProxyError(cc, http.StatusBadGateway, "TLS handshake failed")
|
|
return
|
|
}
|
|
localConn = tlsConn
|
|
}
|
|
|
|
req.URL.Scheme = scheme
|
|
req.URL.Host = targetAddr
|
|
if err := req.Write(localConn); err != nil {
|
|
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to forward upgrade request")
|
|
return
|
|
}
|
|
|
|
localBr := bufio.NewReader(localConn)
|
|
resp, err := http.ReadResponse(localBr, req)
|
|
if err != nil {
|
|
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to read upgrade response")
|
|
return
|
|
}
|
|
|
|
if err := resp.Write(cc); err != nil {
|
|
return
|
|
}
|
|
|
|
if resp.StatusCode == http.StatusSwitchingProtocols {
|
|
_ = netutil.PipeWithCallbacksAndBufferSize(
|
|
c.ctx,
|
|
cc,
|
|
localConn,
|
|
pool.SizeLarge,
|
|
func(n int64) { c.stats.AddBytesIn(n) },
|
|
func(n int64) { c.stats.AddBytesOut(n) },
|
|
)
|
|
}
|
|
}
|
|
|
|
func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client {
|
|
var tlsConfig *tls.Config
|
|
if tunnelType == protocol.TunnelTypeHTTPS {
|
|
tlsConfig = &tls.Config{InsecureSkipVerify: true}
|
|
}
|
|
return &http.Client{
|
|
Transport: &http.Transport{
|
|
MaxIdleConns: 2000,
|
|
MaxIdleConnsPerHost: 1000,
|
|
MaxConnsPerHost: 0,
|
|
IdleConnTimeout: 180 * time.Second,
|
|
DisableCompression: true,
|
|
DisableKeepAlives: false,
|
|
TLSHandshakeTimeout: 5 * time.Second,
|
|
TLSClientConfig: tlsConfig,
|
|
ResponseHeaderTimeout: 15 * time.Second,
|
|
ExpectContinueTimeout: 500 * time.Millisecond,
|
|
WriteBufferSize: 32 * 1024,
|
|
ReadBufferSize: 32 * 1024,
|
|
DialContext: (&net.Dialer{
|
|
Timeout: 3 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}).DialContext,
|
|
},
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
},
|
|
}
|
|
}
|
|
|
|
func writeResponseHeader(w io.Writer, resp *http.Response) error {
|
|
statusLine := fmt.Sprintf("HTTP/%d.%d %d %s\r\n",
|
|
resp.ProtoMajor, resp.ProtoMinor,
|
|
resp.StatusCode, http.StatusText(resp.StatusCode))
|
|
if _, err := io.WriteString(w, statusLine); err != nil {
|
|
return err
|
|
}
|
|
if err := resp.Header.Write(w); err != nil {
|
|
return err
|
|
}
|
|
_, err := io.WriteString(w, "\r\n")
|
|
return err
|
|
}
|