fix(tcp): Fixed a connection reading issue during WebSocket upgrade processing.

When processing HTTP streams, support for buffered readers has been added for WebSocket upgrade requests.
This ensures that data not fully read before connection switching is not lost. The forwarding logic for the Host header has also been optimized.
Add the X-Forwarded-Host header to preserve the original host information.
This commit is contained in:
Gouryella
2025-12-19 17:48:15 +08:00
parent b1393e5e0f
commit bad099d0f3
2 changed files with 41 additions and 11 deletions

View File

@@ -78,7 +78,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
_ = stream.SetReadDeadline(time.Time{})
if httputil.IsWebSocketUpgrade(req) {
c.handleWebSocketUpgrade(cc, req)
c.handleWebSocketUpgrade(&bufferedConn{Conn: cc, reader: br}, req)
return
}
@@ -96,6 +96,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
httputil.WriteProxyError(cc, http.StatusBadGateway, "Bad Gateway")
return
}
outReq.ContentLength = req.ContentLength
origHost := req.Host
httputil.CopyHeaders(outReq.Header, req.Header)
@@ -153,11 +154,6 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
}
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
scheme := "ws"
if c.tunnelType == protocol.TunnelTypeHTTPS {
scheme = "wss"
}
targetAddr := net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort))
localConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second)
if err != nil {
@@ -175,8 +171,11 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
localConn = tlsConn
}
req.URL.Scheme = scheme
req.URL.Host = targetAddr
origHost := req.Host
req.Host = targetAddr
if origHost != "" {
req.Header.Set("X-Forwarded-Host", origHost)
}
if err := req.Write(localConn); err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to forward upgrade request")
return
@@ -194,10 +193,14 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
}
if resp.StatusCode == http.StatusSwitchingProtocols {
localRW := net.Conn(localConn)
if localBr.Buffered() > 0 {
localRW = &bufferedConn{Conn: localConn, reader: localBr}
}
_ = netutil.PipeWithCallbacksAndBufferSize(
c.ctx,
cc,
localConn,
localRW,
pool.SizeLarge,
func(n int64) { c.stats.AddBytesIn(n) },
func(n int64) { c.stats.AddBytesOut(n) },
@@ -205,6 +208,15 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
}
}
type bufferedConn struct {
net.Conn
reader *bufio.Reader
}
func (c *bufferedConn) Read(p []byte) (int, error) {
return c.reader.Read(p)
}
func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client {
var tlsConfig *tls.Config
if tunnelType == protocol.TunnelTypeHTTPS {

View File

@@ -2,6 +2,7 @@ package proxy
import (
"bufio"
"context"
"fmt"
"io"
"net"
@@ -188,7 +189,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
return
}
clientConn, _, err := hj.Hijack()
clientConn, clientBuf, err := hj.Hijack()
if err != nil {
stream.Close()
tconn.DecActiveConnections()
@@ -208,7 +209,15 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
defer clientConn.Close()
defer tconn.DecActiveConnections()
_ = netutil.PipeWithCallbacks(r.Context(), stream, clientConn,
var clientRW io.ReadWriteCloser = clientConn
if clientBuf != nil && clientBuf.Reader.Buffered() > 0 {
clientRW = &bufferedReadWriteCloser{
Reader: clientBuf.Reader,
Conn: clientConn,
}
}
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
func(n int64) { tconn.AddBytesOut(n) },
func(n int64) { tconn.AddBytesIn(n) },
)
@@ -386,3 +395,12 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
type bufferedReadWriteCloser struct {
*bufio.Reader
net.Conn
}
func (b *bufferedReadWriteCloser) Read(p []byte) (int, error) {
return b.Reader.Read(p)
}