mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
Merge pull request #10 from Gouryella/fix/websocket-error
fix(tcp): Fixed a connection reading issue during WebSocket upgrade
This commit is contained in:
@@ -78,7 +78,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
||||
_ = stream.SetReadDeadline(time.Time{})
|
||||
|
||||
if httputil.IsWebSocketUpgrade(req) {
|
||||
c.handleWebSocketUpgrade(cc, req)
|
||||
c.handleWebSocketUpgrade(&bufferedConn{Conn: cc, reader: br}, req)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -96,6 +96,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "Bad Gateway")
|
||||
return
|
||||
}
|
||||
outReq.ContentLength = req.ContentLength
|
||||
|
||||
origHost := req.Host
|
||||
httputil.CopyHeaders(outReq.Header, req.Header)
|
||||
@@ -153,11 +154,6 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
||||
}
|
||||
|
||||
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
||||
scheme := "ws"
|
||||
if c.tunnelType == protocol.TunnelTypeHTTPS {
|
||||
scheme = "wss"
|
||||
}
|
||||
|
||||
targetAddr := net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort))
|
||||
localConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second)
|
||||
if err != nil {
|
||||
@@ -175,8 +171,11 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
||||
localConn = tlsConn
|
||||
}
|
||||
|
||||
req.URL.Scheme = scheme
|
||||
req.URL.Host = targetAddr
|
||||
origHost := req.Host
|
||||
req.Host = targetAddr
|
||||
if origHost != "" {
|
||||
req.Header.Set("X-Forwarded-Host", origHost)
|
||||
}
|
||||
if err := req.Write(localConn); err != nil {
|
||||
httputil.WriteProxyError(cc, http.StatusBadGateway, "Failed to forward upgrade request")
|
||||
return
|
||||
@@ -194,10 +193,14 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusSwitchingProtocols {
|
||||
localRW := net.Conn(localConn)
|
||||
if localBr.Buffered() > 0 {
|
||||
localRW = &bufferedConn{Conn: localConn, reader: localBr}
|
||||
}
|
||||
_ = netutil.PipeWithCallbacksAndBufferSize(
|
||||
c.ctx,
|
||||
cc,
|
||||
localConn,
|
||||
localRW,
|
||||
pool.SizeLarge,
|
||||
func(n int64) { c.stats.AddBytesIn(n) },
|
||||
func(n int64) { c.stats.AddBytesOut(n) },
|
||||
@@ -205,6 +208,15 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client {
|
||||
var tlsConfig *tls.Config
|
||||
if tunnelType == protocol.TunnelTypeHTTPS {
|
||||
|
||||
@@ -2,6 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -188,7 +189,7 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, _, err := hj.Hijack()
|
||||
clientConn, clientBuf, err := hj.Hijack()
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
tconn.DecActiveConnections()
|
||||
@@ -208,7 +209,15 @@ func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn
|
||||
defer clientConn.Close()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
_ = netutil.PipeWithCallbacks(r.Context(), stream, clientConn,
|
||||
var clientRW io.ReadWriteCloser = clientConn
|
||||
if clientBuf != nil && clientBuf.Reader.Buffered() > 0 {
|
||||
clientRW = &bufferedReadWriteCloser{
|
||||
Reader: clientBuf.Reader,
|
||||
Conn: clientConn,
|
||||
}
|
||||
}
|
||||
|
||||
_ = netutil.PipeWithCallbacks(context.Background(), stream, clientRW,
|
||||
func(n int64) { tconn.AddBytesOut(n) },
|
||||
func(n int64) { tconn.AddBytesIn(n) },
|
||||
)
|
||||
@@ -386,3 +395,12 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
type bufferedReadWriteCloser struct {
|
||||
*bufio.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (b *bufferedReadWriteCloser) Read(p []byte) (int, error) {
|
||||
return b.Reader.Read(p)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user