Merge pull request #10 from Gouryella/fix/websocket-error

fix(tcp): Fixed a connection reading issue during WebSocket upgrade
This commit is contained in:
Gouryella
2025-12-19 17:56:39 +08:00
committed by GitHub
2 changed files with 41 additions and 11 deletions

View File

@@ -78,7 +78,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
_ = stream.SetReadDeadline(time.Time{})
if httputil.IsWebSocketUpgrade(req) {
c.handleWebSocketUpgrade(cc, req)
c.handleWebSocketUpgrade(&bufferedConn{Conn: cc, reader: br}, req)
return
}
@@ -96,6 +96,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
httputil.WriteProxyError(cc, http.StatusBadGateway, "Bad Gateway")
return
}
outReq.ContentLength = req.ContentLength
origHost := req.Host
httputil.CopyHeaders(outReq.Header, req.Header)
@@ -153,11 +154,6 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
}
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
scheme := "ws"
if c.tunnelType == protocol.TunnelTypeHTTPS {
scheme = "wss"
}
targetAddr := net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort))
localConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second)
if err != nil {
@@ -175,8 +171,11 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
localConn = tlsConn
}
req.URL.Scheme = scheme
req.URL.Host = targetAddr
origHost := req.Host
req.Host = targetAddr
if origHost != "" {
req.Header.Set("X-Forwarded-Host", origHost)
}
if err := req.Write(localConn); err != nil {
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to forward upgrade request")
return
@@ -194,10 +193,14 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
}
if resp.StatusCode == http.StatusSwitchingProtocols {
localRW := net.Conn(localConn)
if localBr.Buffered() > 0 {
localRW = &bufferedConn{Conn: localConn, reader: localBr}
}
_ = netutil.PipeWithCallbacksAndBufferSize(
c.ctx,
cc,
localConn,
localRW,
pool.SizeLarge,
func(n int64) { c.stats.AddBytesIn(n) },
func(n int64) { c.stats.AddBytesOut(n) },
@@ -205,6 +208,15 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
}
}
type bufferedConn struct {
net.Conn
reader *bufio.Reader
}
func (c *bufferedConn) Read(p []byte) (int, error) {
return c.reader.Read(p)
}
func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client {
var tlsConfig *tls.Config
if tunnelType == protocol.TunnelTypeHTTPS {

View File

@@ -2,6 +2,7 @@ package proxy
import (
"bufio"
"context"
"fmt"
"io"
"net"
@@ -188,7 +189,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
return
}
clientConn, _, err := hj.Hijack()
clientConn, clientBuf, err := hj.Hijack()
if err != nil {
stream.Close()
tconn.DecActiveConnections()
@@ -208,7 +209,15 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
defer clientConn.Close()
defer tconn.DecActiveConnections()
_ = netutil.PipeWithCallbacks(r.Context(), stream, clientConn,
var clientRW io.ReadWriteCloser = clientConn
if clientBuf != nil && clientBuf.Reader.Buffered() > 0 {
clientRW = &bufferedReadWriteCloser{
Reader: clientBuf.Reader,
Conn: clientConn,
}
}
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
func(n int64) { tconn.AddBytesOut(n) },
func(n int64) { tconn.AddBytesIn(n) },
)
@@ -386,3 +395,12 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
type bufferedReadWriteCloser struct {
*bufio.Reader
net.Conn
}
func (b *bufferedReadWriteCloser) Read(p []byte) (int, error) {
return b.Reader.Read(p)
}