mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-28 23:36:00 +00:00
feat(server): Supports HTTP CONNECT proxy and connection pooling.
- Added handling for the HTTP CONNECT method, supporting HTTPS tunneling proxies. - Introducing connQueueListener to hand over HTTP connections to standard http.Server handling. - Optimized Connection struct fields and lifecycle management logic - Remove redundant comments and streamline some response writing logic - Upgrade the golang.org/x/net dependency version to support new features. - Enhanced HTTP request parsing stability and improved error logging methods. - Adjusted the TCP listener startup process to integrate HTTP/2 configuration support. - Improve the connection closing mechanism to avoid resource leakage issues.
This commit is contained in:
@@ -39,7 +39,11 @@ func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, auth
|
||||
}
|
||||
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Always handle /health and /stats directly, regardless of subdomain.
|
||||
if r.Method == http.MethodConnect {
|
||||
h.handleConnect(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/health" {
|
||||
h.serveHealth(w, r)
|
||||
return
|
||||
@@ -71,13 +75,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check for WebSocket upgrade
|
||||
if httputil.IsWebSocketUpgrade(r) {
|
||||
h.handleWebSocket(w, r, tconn)
|
||||
return
|
||||
}
|
||||
|
||||
// Open stream with timeout
|
||||
stream, err := h.openStreamWithTimeout(tconn)
|
||||
if err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
@@ -86,17 +88,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
// Track active connections
|
||||
tconn.IncActiveConnections()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
// Wrap stream with counting for traffic stats
|
||||
countingStream := netutil.NewCountingConn(stream,
|
||||
tconn.AddBytesOut, // Data read from stream = bytes out to client
|
||||
tconn.AddBytesIn, // Data written to stream = bytes in from client
|
||||
tconn.AddBytesOut,
|
||||
tconn.AddBytesIn,
|
||||
)
|
||||
|
||||
// 1) Write request over the stream (net/http handles large bodies correctly).
|
||||
if err := r.Write(countingStream); err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
_ = r.Body.Close()
|
||||
@@ -104,7 +103,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// 2) Read response from stream.
|
||||
resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r)
|
||||
if err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
@@ -113,7 +111,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 3) Copy headers (strip hop-by-hop).
|
||||
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
|
||||
|
||||
statusCode := resp.StatusCode
|
||||
@@ -121,9 +118,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
statusCode = http.StatusOK
|
||||
}
|
||||
|
||||
// Ensure message delimiting works with our custom ResponseWriter:
|
||||
// - If Content-Length is known, send it.
|
||||
// - Otherwise, re-chunk the decoded body ourselves.
|
||||
if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified {
|
||||
if resp.ContentLength >= 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
|
||||
@@ -136,28 +130,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if resp.ContentLength >= 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
ctx := r.Context()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stream.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
close(done)
|
||||
stream.Close()
|
||||
return
|
||||
} else {
|
||||
w.Header().Del("Content-Length")
|
||||
}
|
||||
|
||||
w.Header().Del("Content-Length")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
if len(resp.Trailer) > 0 {
|
||||
w.Header().Set("Trailer", trailerKeys(resp.Trailer))
|
||||
}
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
ctx := r.Context()
|
||||
@@ -170,9 +146,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := writeChunked(w, resp.Body, resp.Trailer); err != nil {
|
||||
h.logger.Debug("Write chunked response failed", zap.Error(err))
|
||||
}
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
close(done)
|
||||
stream.Close()
|
||||
}
|
||||
@@ -267,53 +241,6 @@ func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHos
|
||||
}
|
||||
}
|
||||
|
||||
func trailerKeys(hdr http.Header) string {
|
||||
keys := make([]string, 0, len(hdr))
|
||||
for k := range hdr {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
// Deterministic order is nicer for debugging; no semantic impact.
|
||||
sortStrings(keys)
|
||||
return strings.Join(keys, ", ")
|
||||
}
|
||||
|
||||
func writeChunked(w io.Writer, r io.Reader, trailer http.Header) error {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
if _, werr := fmt.Fprintf(w, "%x\r\n", n); werr != nil {
|
||||
return werr
|
||||
}
|
||||
if _, werr := w.Write(buf[:n]); werr != nil {
|
||||
return werr
|
||||
}
|
||||
if _, werr := io.WriteString(w, "\r\n"); werr != nil {
|
||||
return werr
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := io.WriteString(w, "0\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
for k, vv := range trailer {
|
||||
for _, v := range vv {
|
||||
if _, err := io.WriteString(w, fmt.Sprintf("%s: %s\r\n", k, v)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err := io.WriteString(w, "\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
|
||||
if !strings.HasPrefix(location, "http://") && !strings.HasPrefix(location, "https://") {
|
||||
return location
|
||||
@@ -460,12 +387,65 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func sortStrings(s []string) {
|
||||
for i := 0; i < len(s); i++ {
|
||||
for j := i + 1; j < len(s); j++ {
|
||||
if s[j] < s[i] {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
}
|
||||
func (h *Handler) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||
targetAddr := r.Host
|
||||
if targetAddr == "" {
|
||||
targetAddr = r.URL.Host
|
||||
}
|
||||
if targetAddr == "" {
|
||||
http.Error(w, "Bad Request: missing target host", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.Contains(targetAddr, ":") {
|
||||
targetAddr = targetAddr + ":443"
|
||||
}
|
||||
|
||||
h.logger.Info("CONNECT proxy request",
|
||||
zap.String("target", targetAddr),
|
||||
zap.String("remote", r.RemoteAddr),
|
||||
)
|
||||
|
||||
targetConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second)
|
||||
if err != nil {
|
||||
h.logger.Warn("Failed to connect to target",
|
||||
zap.String("target", targetAddr),
|
||||
zap.Error(err),
|
||||
)
|
||||
http.Error(w, "Bad Gateway: failed to connect to target", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
targetConn.Close()
|
||||
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
targetConn.Close()
|
||||
http.Error(w, "Failed to hijack connection", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
|
||||
if err != nil {
|
||||
clientConn.Close()
|
||||
targetConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer targetConn.Close()
|
||||
defer clientConn.Close()
|
||||
_, _ = io.Copy(targetConn, clientConn)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer targetConn.Close()
|
||||
defer clientConn.Close()
|
||||
_, _ = io.Copy(clientConn, targetConn)
|
||||
}()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user