diff --git a/internal/client/tcp/pool_handler.go b/internal/client/tcp/pool_handler.go index 1f60354..e9ecc0a 100644 --- a/internal/client/tcp/pool_handler.go +++ b/internal/client/tcp/pool_handler.go @@ -78,7 +78,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) { _ = stream.SetReadDeadline(time.Time{}) if httputil.IsWebSocketUpgrade(req) { - c.handleWebSocketUpgrade(cc, req) + c.handleWebSocketUpgrade(&bufferedConn{Conn: cc, reader: br}, req) return } @@ -96,6 +96,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) { httputil.WriteProxyError(cc, http.StatusBadGateway, "Bad Gateway") return } + outReq.ContentLength = req.ContentLength origHost := req.Host httputil.CopyHeaders(outReq.Header, req.Header) @@ -153,11 +154,6 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) { } func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) { - scheme := "ws" - if c.tunnelType == protocol.TunnelTypeHTTPS { - scheme = "wss" - } - targetAddr := net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort)) localConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second) if err != nil { @@ -175,8 +171,11 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) { localConn = tlsConn } - req.URL.Scheme = scheme - req.URL.Host = targetAddr + origHost := req.Host + req.Host = targetAddr + if origHost != "" { + req.Header.Set("X-Forwarded-Host", origHost) + } if err := req.Write(localConn); err != nil { httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to forward upgrade request") return @@ -194,10 +193,14 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) { } if resp.StatusCode == http.StatusSwitchingProtocols { + localRW := net.Conn(localConn) + if localBr.Buffered() > 0 { + localRW = &bufferedConn{Conn: localConn, reader: localBr} + } _ = netutil.PipeWithCallbacksAndBufferSize( c.ctx, cc, - localConn, + localRW, pool.SizeLarge, func(n int64) { c.stats.AddBytesIn(n) }, func(n int64) { c.stats.AddBytesOut(n) }, @@ -205,6 +208,15 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) { } } +type bufferedConn struct { + net.Conn + reader *bufio.Reader +} + +func (c *bufferedConn) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client { var tlsConfig *tls.Config if tunnelType == protocol.TunnelTypeHTTPS { diff --git a/internal/server/proxy/handler.go b/internal/server/proxy/handler.go index 0d8fa42..98b7b0b 100644 --- a/internal/server/proxy/handler.go +++ b/internal/server/proxy/handler.go @@ -2,6 +2,7 @@ package proxy import ( "bufio" + "context" "fmt" "io" "net" @@ -188,7 +189,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn return } - clientConn, _, err := hj.Hijack() + clientConn, clientBuf, err := hj.Hijack() if err != nil { stream.Close() tconn.DecActiveConnections() @@ -208,7 +209,15 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn defer clientConn.Close() defer tconn.DecActiveConnections() - _ = netutil.PipeWithCallbacks(r.Context(), stream, clientConn, + var clientRW io.ReadWriteCloser = clientConn + if clientBuf != nil && clientBuf.Reader.Buffered() > 0 { + clientRW = &bufferedReadWriteCloser{ + Reader: clientBuf.Reader, + Conn: clientConn, + } + } + + _ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW, func(n int64) { tconn.AddBytesOut(n) }, func(n int64) { tconn.AddBytesIn(n) }, ) @@ -386,3 +395,12 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) w.Write(data) } + +type bufferedReadWriteCloser struct { + *bufio.Reader + net.Conn +} + +func (b *bufferedReadWriteCloser) Read(p []byte) (int, error) { + return b.Reader.Read(p) +}