mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
fix(tcp): Fixed a connection reading issue during WebSocket upgrade processing.
When processing HTTP streams, support for buffered readers has been added for WebSocket upgrade requests. This ensures that data not fully read before connection switching is not lost. The forwarding logic for the Host header has also been optimized. Add the X-Forwarded-Host header to preserve the original host information.
This commit is contained in:
@@ -78,7 +78,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
||||
_ = stream.SetReadDeadline(time.Time{})
|
||||
|
||||
if httputil.IsWebSocketUpgrade(req) {
|
||||
c.handleWebSocketUpgrade(cc, req)
|
||||
c.handleWebSocketUpgrade(&bufferedConn{Conn: cc, reader: br}, req)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -96,6 +96,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "Bad Gateway")
|
||||
return
|
||||
}
|
||||
outReq.ContentLength = req.ContentLength
|
||||
|
||||
origHost := req.Host
|
||||
httputil.CopyHeaders(outReq.Header, req.Header)
|
||||
@@ -153,11 +154,6 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
||||
}
|
||||
|
||||
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
||||
scheme := "ws"
|
||||
if c.tunnelType == protocol.TunnelTypeHTTPS {
|
||||
scheme = "wss"
|
||||
}
|
||||
|
||||
targetAddr := net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort))
|
||||
localConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second)
|
||||
if err != nil {
|
||||
@@ -175,8 +171,11 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
||||
localConn = tlsConn
|
||||
}
|
||||
|
||||
req.URL.Scheme = scheme
|
||||
req.URL.Host = targetAddr
|
||||
origHost := req.Host
|
||||
req.Host = targetAddr
|
||||
if origHost != "" {
|
||||
req.Header.Set("X-Forwarded-Host", origHost)
|
||||
}
|
||||
if err := req.Write(localConn); err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to forward upgrade request")
|
||||
return
|
||||
@@ -194,10 +193,14 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusSwitchingProtocols {
|
||||
localRW := net.Conn(localConn)
|
||||
if localBr.Buffered() > 0 {
|
||||
localRW = &bufferedConn{Conn: localConn, reader: localBr}
|
||||
}
|
||||
_ = netutil.PipeWithCallbacksAndBufferSize(
|
||||
c.ctx,
|
||||
cc,
|
||||
localConn,
|
||||
localRW,
|
||||
pool.SizeLarge,
|
||||
func(n int64) { c.stats.AddBytesIn(n) },
|
||||
func(n int64) { c.stats.AddBytesOut(n) },
|
||||
@@ -205,6 +208,15 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client {
|
||||
var tlsConfig *tls.Config
|
||||
if tunnelType == protocol.TunnelTypeHTTPS {
|
||||
|
||||
@@ -2,6 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -188,7 +189,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, _, err := hj.Hijack()
|
||||
clientConn, clientBuf, err := hj.Hijack()
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
tconn.DecActiveConnections()
|
||||
@@ -208,7 +209,15 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
|
||||
defer clientConn.Close()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
_ = netutil.PipeWithCallbacks(r.Context(), stream, clientConn,
|
||||
var clientRW io.ReadWriteCloser = clientConn
|
||||
if clientBuf != nil && clientBuf.Reader.Buffered() > 0 {
|
||||
clientRW = &bufferedReadWriteCloser{
|
||||
Reader: clientBuf.Reader,
|
||||
Conn: clientConn,
|
||||
}
|
||||
}
|
||||
|
||||
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
|
||||
func(n int64) { tconn.AddBytesOut(n) },
|
||||
func(n int64) { tconn.AddBytesIn(n) },
|
||||
)
|
||||
@@ -386,3 +395,12 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
type bufferedReadWriteCloser struct {
|
||||
*bufio.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (b *bufferedReadWriteCloser) Read(p []byte) (int, error) {
|
||||
return b.Reader.Read(p)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user