From bad099d0f3b142f94734149292b3ff618e698892 Mon Sep 17 00:00:00 2001 From: Gouryella Date: Fri, 19 Dec 2025 17:48:15 +0800 Subject: [PATCH] fix(tcp): Fixed a connection reading issue during WebSocket upgrade processing. When processing HTTP streams, support for buffered readers has been added for WebSocket upgrade requests. This ensures that data not fully read before connection switching is not lost. The forwarding logic for the Host header has also been optimized. Add the X-Forwarded-Host header to preserve the original host information. --- internal/client/tcp/pool_handler.go | 30 ++++++++++++++++++++--------- internal/server/proxy/handler.go | 22 +++++++++++++++++++-- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/internal/client/tcp/pool_handler.go b/internal/client/tcp/pool_handler.go index 1f60354..e9ecc0a 100644 --- a/internal/client/tcp/pool_handler.go +++ b/internal/client/tcp/pool_handler.go @@ -78,7 +78,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) { _ = stream.SetReadDeadline(time.Time{}) if httputil.IsWebSocketUpgrade(req) { - c.handleWebSocketUpgrade(cc, req) + c.handleWebSocketUpgrade(&bufferedConn{Conn: cc, reader: br}, req) return } @@ -96,6 +96,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) { httputil.WriteProxyError(cc, http.StatusBadGateway, "Bad Gateway") return } + outReq.ContentLength = req.ContentLength origHost := req.Host httputil.CopyHeaders(outReq.Header, req.Header) @@ -153,11 +154,6 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) { } func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) { - scheme := "ws" - if c.tunnelType == protocol.TunnelTypeHTTPS { - scheme = "wss" - } - targetAddr := net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort)) localConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second) if err != nil { @@ -175,8 +171,11 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) { localConn = tlsConn } - req.URL.Scheme = scheme - req.URL.Host = targetAddr + origHost := req.Host + req.Host = targetAddr + if origHost != "" { + req.Header.Set("X-Forwarded-Host", origHost) + } if err := req.Write(localConn); err != nil { httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to forward upgrade request") return @@ -194,10 +193,14 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) { } if resp.StatusCode == http.StatusSwitchingProtocols { + localRW := net.Conn(localConn) + if localBr.Buffered() > 0 { + localRW = &bufferedConn{Conn: localConn, reader: localBr} + } _ = netutil.PipeWithCallbacksAndBufferSize( c.ctx, cc, - localConn, + localRW, pool.SizeLarge, func(n int64) { c.stats.AddBytesIn(n) }, func(n int64) { c.stats.AddBytesOut(n) }, @@ -205,6 +208,15 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) { } } +type bufferedConn struct { + net.Conn + reader *bufio.Reader +} + +func (c *bufferedConn) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client { var tlsConfig *tls.Config if tunnelType == protocol.TunnelTypeHTTPS { diff --git a/internal/server/proxy/handler.go b/internal/server/proxy/handler.go index 0d8fa42..98b7b0b 100644 --- a/internal/server/proxy/handler.go +++ b/internal/server/proxy/handler.go @@ -2,6 +2,7 @@ package proxy import ( "bufio" + "context" "fmt" "io" "net" @@ -188,7 +189,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn return } - clientConn, _, err := hj.Hijack() + clientConn, clientBuf, err := hj.Hijack() if err != nil { stream.Close() tconn.DecActiveConnections() @@ -208,7 +209,15 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn defer clientConn.Close() defer tconn.DecActiveConnections() - _ = netutil.PipeWithCallbacks(r.Context(), stream, clientConn, + var clientRW io.ReadWriteCloser = clientConn + if clientBuf != nil && clientBuf.Reader.Buffered() > 0 { + clientRW = &bufferedReadWriteCloser{ + Reader: clientBuf.Reader, + Conn: clientConn, + } + } + + _ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW, func(n int64) { tconn.AddBytesOut(n) }, func(n int64) { tconn.AddBytesIn(n) }, ) @@ -386,3 +395,12 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) w.Write(data) } + +type bufferedReadWriteCloser struct { + *bufio.Reader + net.Conn +} + +func (b *bufferedReadWriteCloser) Read(p []byte) (int, error) { + return b.Reader.Read(p) +}