Merge pull request #9 from Gouryella/fix/http-proxy

Fix/http proxy
This commit is contained in:
Gouryella
2025-12-19 10:34:34 +08:00
committed by GitHub
9 changed files with 262 additions and 413 deletions

2
go.mod
View File

@@ -10,6 +10,7 @@ require (
github.com/spf13/cobra v1.10.1
go.uber.org/zap v1.27.1
golang.org/x/crypto v0.45.0
golang.org/x/net v0.47.0
golang.org/x/sys v0.38.0
gopkg.in/yaml.v3 v3.0.1
)
@@ -30,6 +31,5 @@ require (
github.com/stretchr/testify v1.11.1 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/text v0.31.0 // indirect
)

View File

@@ -18,7 +18,6 @@ import (
"go.uber.org/zap"
)
// handleStream routes incoming stream to appropriate handler.
func (c *PoolClient) handleStream(h *sessionHandle, stream net.Conn) {
defer c.wg.Done()
defer func() {
@@ -35,7 +34,6 @@ func (c *PoolClient) handleStream(h *sessionHandle, stream net.Conn) {
}
}
// handleTCPStream handles raw TCP tunneling.
func (c *PoolClient) handleTCPStream(stream net.Conn) {
localConn, err := net.DialTimeout("tcp", net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort)), 10*time.Second)
if err != nil {
@@ -62,7 +60,6 @@ func (c *PoolClient) handleTCPStream(stream net.Conn) {
)
}
// handleHTTPStream handles HTTP/HTTPS proxy requests.
func (c *PoolClient) handleHTTPStream(stream net.Conn) {
_ = stream.SetReadDeadline(time.Now().Add(30 * time.Second))
@@ -104,6 +101,8 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
httputil.CopyHeaders(outReq.Header, req.Header)
httputil.CleanHopByHopHeaders(outReq.Header)
outReq.Header.Del("Accept-Encoding")
targetHost := c.localHost
if c.localPort != 80 && c.localPort != 443 {
targetHost = fmt.Sprintf("%s:%d", c.localHost, c.localPort)
@@ -153,7 +152,6 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
close(done)
}
// handleWebSocketUpgrade handles WebSocket upgrade requests.
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
scheme := "ws"
if c.tunnelType == protocol.TunnelTypeHTTPS {
@@ -207,7 +205,6 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
}
}
// newLocalHTTPClient creates an HTTP client for local service requests.
func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client {
var tlsConfig *tls.Config
if tunnelType == protocol.TunnelTypeHTTPS {

View File

@@ -39,7 +39,6 @@ func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, auth
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Always handle /health and /stats directly, regardless of subdomain.
if r.URL.Path == "/health" {
h.serveHealth(w, r)
return
@@ -71,13 +70,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Check for WebSocket upgrade
if r.Method == http.MethodConnect {
http.Error(w, "CONNECT not supported for HTTP tunnels", http.StatusMethodNotAllowed)
return
}
if httputil.IsWebSocketUpgrade(r) {
h.handleWebSocket(w, r, tconn)
return
}
// Open stream with timeout
stream, err := h.openStreamWithTimeout(tconn)
if err != nil {
w.Header().Set("Connection", "close")
@@ -86,17 +88,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
defer stream.Close()
// Track active connections
tconn.IncActiveConnections()
defer tconn.DecActiveConnections()
// Wrap stream with counting for traffic stats
countingStream := netutil.NewCountingConn(stream,
tconn.AddBytesOut, // Data read from stream = bytes out to client
tconn.AddBytesIn, // Data written to stream = bytes in from client
tconn.AddBytesOut,
tconn.AddBytesIn,
)
// 1) Write request over the stream (net/http handles large bodies correctly).
if err := r.Write(countingStream); err != nil {
w.Header().Set("Connection", "close")
_ = r.Body.Close()
@@ -104,7 +103,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// 2) Read response from stream.
resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r)
if err != nil {
w.Header().Set("Connection", "close")
@@ -113,7 +111,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
defer resp.Body.Close()
// 3) Copy headers (strip hop-by-hop).
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
statusCode := resp.StatusCode
@@ -121,9 +118,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
statusCode = http.StatusOK
}
// Ensure message delimiting works with our custom ResponseWriter:
// - If Content-Length is known, send it.
// - Otherwise, re-chunk the decoded body ourselves.
if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified {
if resp.ContentLength >= 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
@@ -136,28 +130,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if resp.ContentLength >= 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
w.WriteHeader(statusCode)
ctx := r.Context()
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
stream.Close()
case <-done:
}
}()
_, _ = io.Copy(w, resp.Body)
close(done)
stream.Close()
return
} else {
w.Header().Del("Content-Length")
}
w.Header().Del("Content-Length")
w.Header().Set("Transfer-Encoding", "chunked")
if len(resp.Trailer) > 0 {
w.Header().Set("Trailer", trailerKeys(resp.Trailer))
}
w.WriteHeader(statusCode)
ctx := r.Context()
@@ -170,9 +146,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}()
if err := writeChunked(w, resp.Body, resp.Trailer); err != nil {
h.logger.Debug("Write chunked response failed", zap.Error(err))
}
_, _ = io.Copy(w, resp.Body)
close(done)
stream.Close()
}
@@ -267,53 +241,6 @@ func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHos
}
}
func trailerKeys(hdr http.Header) string {
keys := make([]string, 0, len(hdr))
for k := range hdr {
keys = append(keys, k)
}
// Deterministic order is nicer for debugging; no semantic impact.
sortStrings(keys)
return strings.Join(keys, ", ")
}
func writeChunked(w io.Writer, r io.Reader, trailer http.Header) error {
buf := make([]byte, 32*1024)
for {
n, err := r.Read(buf)
if n > 0 {
if _, werr := fmt.Fprintf(w, "%x\r\n", n); werr != nil {
return werr
}
if _, werr := w.Write(buf[:n]); werr != nil {
return werr
}
if _, werr := io.WriteString(w, "\r\n"); werr != nil {
return werr
}
}
if err == io.EOF {
break
}
if err != nil {
return err
}
}
if _, err := io.WriteString(w, "0\r\n"); err != nil {
return err
}
for k, vv := range trailer {
for _, v := range vv {
if _, err := io.WriteString(w, fmt.Sprintf("%s: %s\r\n", k, v)); err != nil {
return err
}
}
}
_, err := io.WriteString(w, "\r\n")
return err
}
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
if !strings.HasPrefix(location, "http://") && !strings.HasPrefix(location, "https://") {
return location
@@ -459,13 +386,3 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
func sortStrings(s []string) {
for i := 0; i < len(s); i++ {
for j := i + 1; j < len(s); j++ {
if s[j] < s[i] {
s[i], s[j] = s[j], s[i]
}
}
}
}

View File

@@ -22,7 +22,6 @@ import (
"go.uber.org/zap"
)
// Connection represents a client TCP connection
type Connection struct {
conn net.Conn
authToken string
@@ -40,21 +39,19 @@ type Connection struct {
mu sync.RWMutex
frameWriter *protocol.FrameWriter
httpHandler http.Handler
tunnelType protocol.TunnelType // Track tunnel type
tunnelType protocol.TunnelType
ctx context.Context
cancel context.CancelFunc
// gost-like TCP tunnel (yamux)
session *yamux.Session
proxy *Proxy
// Multi-connection support
tunnelID string
groupManager *ConnectionGroupManager
session *yamux.Session
proxy *Proxy
tunnelID string
groupManager *ConnectionGroupManager
httpListener *connQueueListener
handedOff bool
}
// NewConnection creates a new connection handler
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager) *Connection {
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection {
ctx, cancel := context.WithCancel(context.Background())
c := &Connection{
conn: conn,
@@ -70,25 +67,18 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
ctx: ctx,
cancel: cancel,
groupManager: groupManager,
httpListener: httpListener,
}
return c
}
// Handle handles the connection lifecycle
func (c *Connection) Handle() error {
// Register connection for adaptive load tracking
protocol.RegisterConnection()
// Ensure cleanup of control connection, proxy, port, and registry on exit.
defer c.Close()
// Set initial read timeout for protocol detection
c.conn.SetReadDeadline(time.Now().Add(30 * time.Second))
// Use buffered reader to support peeking
reader := bufio.NewReader(c.conn)
// Peek first 4 bytes to detect protocol (HTTP methods are 4 bytes).
peek, err := reader.Peek(4)
if err != nil {
return fmt.Errorf("failed to peek connection: %w", err)
@@ -109,8 +99,6 @@ func (c *Connection) Handle() error {
return c.handleHTTPRequest(reader)
}
// Continue with drip protocol
// Wait for registration frame
frame, err := protocol.ReadFrame(reader)
if err != nil {
return fmt.Errorf("failed to read registration frame: %w", err)
@@ -118,7 +106,6 @@ func (c *Connection) Handle() error {
sf := protocol.WithFrame(frame)
defer sf.Close()
// Handle data connection request (for multi-connection pool)
if sf.Frame.Type == protocol.FrameTypeDataConnect {
return c.handleDataConnect(sf.Frame, reader)
}
@@ -139,7 +126,6 @@ func (c *Connection) Handle() error {
return fmt.Errorf("authentication failed")
}
// Allocate TCP port only for TCP tunnels
if req.TunnelType == protocol.TunnelTypeTCP {
if c.portAlloc == nil {
return fmt.Errorf("port allocator not configured")
@@ -152,7 +138,6 @@ func (c *Connection) Handle() error {
}
c.port = port
// For TCP tunnels, prefer deterministic subdomain tied to port when not provided by client.
if req.CustomSubdomain == "" {
req.CustomSubdomain = fmt.Sprintf("tcp-%d", port)
}
@@ -174,8 +159,7 @@ func (c *Connection) Handle() error {
}
c.tunnelConn = tunnelConn
// Store TCP connection reference and metadata for HTTP proxy routing
c.tunnelConn.Conn = nil // We're using TCP, not WebSocket
c.tunnelConn.Conn = nil
c.tunnelConn.SetTunnelType(req.TunnelType)
c.tunnelType = req.TunnelType
@@ -186,35 +170,27 @@ func (c *Connection) Handle() error {
zap.Int("remote_port", c.port),
)
// Send registration acknowledgment
// Generate appropriate URL based on tunnel type
var tunnelURL string
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
// HTTP/HTTPS tunnels use HTTPS with subdomain
// Use publicPort for URL generation (configured via --public-port flag)
if c.publicPort == 443 {
tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.domain)
} else {
tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.domain, c.publicPort)
}
} else {
// TCP tunnels use tcp:// with port
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
}
// Generate TunnelID for multi-connection support if client supports it
var tunnelID string
var supportsDataConn bool
recommendedConns := 0
if req.PoolCapabilities != nil && req.ConnectionType == "primary" && c.groupManager != nil {
// Client supports connection pooling
group := c.groupManager.CreateGroup(subdomain, req.Token, c, req.TunnelType)
tunnelID = group.TunnelID
c.tunnelID = tunnelID
supportsDataConn = true
recommendedConns = 4 // Recommend 4 data connections
recommendedConns = 4
c.logger.Info("Created connection group for multi-connection support",
zap.String("tunnel_id", tunnelID),
@@ -235,16 +211,13 @@ func (c *Connection) Handle() error {
respData, _ := json.Marshal(resp)
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
// Send registration ack (sync write before frameWriter is created)
err = protocol.WriteFrame(c.conn, ackFrame)
if err != nil {
return fmt.Errorf("failed to send registration ack: %w", err)
}
// Clear deadline for tunnel data-plane.
c.conn.SetReadDeadline(time.Time{})
// gost-like tunnels: switch to yamux after RegisterAck.
if req.TunnelType == protocol.TunnelTypeTCP {
return c.handleTCPTunnel(reader)
}
@@ -265,6 +238,44 @@ func (c *Connection) Handle() error {
}
func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
if c.httpListener == nil {
return c.handleHTTPRequestLegacy(reader)
}
c.conn.SetReadDeadline(time.Time{})
wrappedConn := &bufferedConn{
Conn: c.conn,
reader: reader,
}
if !c.httpListener.Enqueue(wrappedConn) {
c.logger.Warn("HTTP listener queue full, rejecting connection")
response := "HTTP/1.1 503 Service Unavailable\r\n" +
"Content-Type: text/plain\r\n" +
"Content-Length: 32\r\n" +
"Connection: close\r\n" +
"\r\n" +
"Server busy, please retry later\r\n"
c.conn.Write([]byte(response))
return fmt.Errorf("http listener queue full")
}
c.mu.Lock()
c.conn = nil
c.handedOff = true
c.mu.Unlock()
return nil
}
func (c *Connection) IsHandedOff() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.handedOff
}
func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
if c.httpHandler == nil {
c.logger.Warn("HTTP request received but no HTTP handler configured")
response := "HTTP/1.1 503 Service Unavailable\r\n" +
@@ -276,18 +287,13 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
return fmt.Errorf("HTTP handler not configured")
}
// Clear read deadline for HTTP processing
c.conn.SetReadDeadline(time.Time{})
// Handle multiple HTTP requests on the same connection (HTTP/1.1 keep-alive)
for {
// Set a read deadline for each request to avoid hanging forever
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
// Parse HTTP request
req, err := http.ReadRequest(reader)
if err != nil {
// EOF or timeout is normal when client closes connection or no more requests
if err == io.EOF || err == io.ErrUnexpectedEOF {
c.logger.Debug("Client closed HTTP connection")
return nil
@@ -296,7 +302,6 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
c.logger.Debug("HTTP keep-alive timeout")
return nil
}
// Connection reset by peer is normal - client closed connection abruptly
errStr := err.Error()
if errors.Is(err, net.ErrClosed) || strings.Contains(errStr, "use of closed network connection") {
c.logger.Debug("HTTP connection closed during read", zap.Error(err))
@@ -308,13 +313,11 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
c.logger.Debug("Client disconnected abruptly", zap.Error(err))
return nil
}
// Check if it looks like garbage data (not a valid HTTP request)
if strings.Contains(errStr, "malformed HTTP") {
c.logger.Warn("Received malformed HTTP request, possibly due to pipelined requests or protocol error",
c.logger.Warn("Received malformed HTTP request",
zap.Error(err),
zap.String("error_snippet", errStr[:min(len(errStr), 100)]),
)
// Close connection on malformed request to prevent further errors
return nil
}
c.logger.Error("Failed to parse HTTP request", zap.Error(err))
@@ -370,8 +373,6 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
c.logger.Debug("Closing connection as requested by client or server")
return nil
}
// Continue to next request on the same connection
}
}
@@ -382,7 +383,6 @@ func min(a, b int) int {
return b
}
// handleFrames handles incoming frames
func (c *Connection) handleFrames(reader *bufio.Reader) error {
for {
select {
@@ -391,7 +391,6 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
default:
}
// Read frame with timeout
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(reader)
if err != nil {
@@ -399,15 +398,12 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
c.logger.Warn("Read timeout, connection may be dead")
return fmt.Errorf("read timeout")
}
// EOF is normal when client closes connection gracefully
if err.Error() == "failed to read frame header: EOF" || err.Error() == "EOF" {
c.logger.Info("Client disconnected")
return nil
}
// Check if connection was closed (during shutdown)
select {
case <-c.stopCh:
// Connection was closed intentionally, don't log as error
c.logger.Debug("Connection closed during shutdown")
return nil
default:
@@ -415,7 +411,6 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
}
}
// Handle frame based on type
sf := protocol.WithFrame(frame)
switch sf.Frame.Type {
@@ -437,22 +432,18 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
}
}
// handleHeartbeat handles heartbeat frame
func (c *Connection) handleHeartbeat() {
c.mu.Lock()
c.lastHeartbeat = time.Now()
c.mu.Unlock()
// Send heartbeat ack
ackFrame := protocol.NewFrame(protocol.FrameTypeHeartbeatAck, nil)
err := c.frameWriter.WriteControl(ackFrame)
if err != nil {
c.logger.Error("Failed to send heartbeat ack", zap.Error(err))
}
}
// heartbeatChecker checks for heartbeat timeout
func (c *Connection) heartbeatChecker() {
ticker := time.NewTicker(constants.HeartbeatInterval)
defer ticker.Stop()
@@ -496,16 +487,19 @@ func (c *Connection) sendError(code, message string) {
func (c *Connection) Close() {
c.once.Do(func() {
protocol.UnregisterConnection()
close(c.stopCh)
if c.cancel != nil {
c.cancel()
}
// Ensure any in-flight writes return quickly on shutdown to avoid hanging.
if c.conn != nil {
_ = c.conn.SetDeadline(time.Now())
// Prevent race with handleHTTPRequest setting c.conn = nil
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn != nil {
_ = conn.SetDeadline(time.Now())
}
if c.frameWriter != nil {
@@ -520,8 +514,8 @@ func (c *Connection) Close() {
_ = c.session.Close()
}
if c.conn != nil {
c.conn.Close()
if conn != nil {
conn.Close()
}
if c.port > 0 && c.portAlloc != nil {
@@ -530,9 +524,6 @@ func (c *Connection) Close() {
if c.subdomain != "" {
c.manager.Unregister(c.subdomain)
// Clean up connection group when PRIMARY connection closes
// (only primary connections have subdomain set)
if c.tunnelID != "" && c.groupManager != nil {
c.groupManager.RemoveGroup(c.tunnelID)
}
@@ -544,10 +535,9 @@ func (c *Connection) Close() {
})
}
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
type httpResponseWriter struct {
conn net.Conn
writer *bufio.Writer // Buffered writer for efficient I/O
writer *bufio.Writer
header http.Header
statusCode int
headerWritten bool
@@ -594,7 +584,6 @@ func (w *httpResponseWriter) Write(data []byte) (int, error) {
return w.writer.Write(data)
}
// handleDataConnect handles a data connection join request
func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Reader) error {
var req protocol.DataConnectRequest
if err := json.Unmarshal(frame.Payload, &req); err != nil {
@@ -607,13 +596,11 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
zap.String("connection_id", req.ConnectionID),
)
// Validate the request
if c.groupManager == nil {
c.sendDataConnectError("not_supported", "Multi-connection not supported")
return fmt.Errorf("group manager not available")
}
// Validate auth token
if c.authToken != "" && req.Token != c.authToken {
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed for data connection")
@@ -625,75 +612,13 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
return fmt.Errorf("tunnel not found: %s", req.TunnelID)
}
// Validate token against the primary registration token.
if group.Token != "" && req.Token != group.Token {
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed for data connection")
}
// Store tunnelID for cleanup
c.tunnelID = req.TunnelID
// For TCP tunnels, the data connection is upgraded to a yamux session and used for
// stream forwarding, not framed request/response routing.
if group.TunnelType == protocol.TunnelTypeTCP {
resp := protocol.DataConnectResponse{
Accepted: true,
ConnectionID: req.ConnectionID,
Message: "Data connection accepted",
}
respData, _ := json.Marshal(resp)
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
return fmt.Errorf("failed to send data connect ack: %w", err)
}
c.logger.Info("TCP data connection established",
zap.String("tunnel_id", req.TunnelID),
zap.String("connection_id", req.ConnectionID),
)
// Clear deadline for yamux data-plane.
_ = c.conn.SetReadDeadline(time.Time{})
// Public server acts as yamux Client, client connector acts as yamux Server.
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
session, err := yamux.Client(bc, cfg)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
group.AddSession(req.ConnectionID, session)
defer group.RemoveSession(req.ConnectionID)
select {
case <-c.stopCh:
return nil
case <-session.CloseChan():
return nil
}
}
// Add data connection to group
dataConn, err := c.groupManager.AddDataConnection(&req, c.conn)
if err != nil {
c.sendDataConnectError("join_failed", err.Error())
return fmt.Errorf("failed to join connection group: %w", err)
}
// Send success response
resp := protocol.DataConnectResponse{
Accepted: true,
ConnectionID: req.ConnectionID,
@@ -712,56 +637,33 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
zap.String("connection_id", req.ConnectionID),
)
// Handle data frames on this connection
return c.handleDataConnectionFrames(dataConn, reader)
}
_ = c.conn.SetReadDeadline(time.Time{})
// handleDataConnectionFrames handles frames on a data connection
func (c *Connection) handleDataConnectionFrames(dataConn *DataConnection, reader *bufio.Reader) error {
defer func() {
// Get the group and remove this data connection
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok {
group.RemoveDataConnection(dataConn.ID)
}
}()
// Server acts as yamux Client, client connector acts as yamux Server
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
}
for {
select {
case <-dataConn.stopCh:
return nil
default:
}
cfg := yamux.DefaultConfig()
cfg.EnableKeepAlive = false
cfg.LogOutput = io.Discard
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(reader)
if err != nil {
// Timeout is OK, continue
if isTimeoutError(err) {
continue
}
return err
}
session, err := yamux.Client(bc, cfg)
if err != nil {
return fmt.Errorf("failed to init yamux session: %w", err)
}
c.session = session
dataConn.mu.Lock()
dataConn.LastActive = time.Now()
dataConn.mu.Unlock()
group.AddSession(req.ConnectionID, session)
defer group.RemoveSession(req.ConnectionID)
sf := protocol.WithFrame(frame)
switch sf.Frame.Type {
case protocol.FrameTypeClose:
sf.Close()
c.logger.Info("Data connection closed by client",
zap.String("connection_id", dataConn.ID))
return nil
default:
sf.Close()
c.logger.Warn("Unexpected frame type on data connection",
zap.String("type", sf.Frame.Type.String()),
zap.String("connection_id", dataConn.ID),
)
}
select {
case <-c.stopCh:
return nil
case <-session.CloseChan():
return nil
}
}
@@ -773,11 +675,9 @@ func isTimeoutError(err error) bool {
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
// Fallback for wrapped errors without net.Error
return strings.Contains(err.Error(), "i/o timeout")
}
// sendDataConnectError sends a data connect error response
func (c *Connection) sendDataConnectError(code, message string) {
resp := protocol.DataConnectResponse{
Accepted: false,

View File

@@ -15,23 +15,11 @@ import (
"go.uber.org/zap"
)
type DataConnection struct {
ID string
Conn net.Conn
LastActive time.Time
closed bool
closedMu sync.RWMutex
stopCh chan struct{}
mu sync.RWMutex
}
type ConnectionGroup struct {
TunnelID string
Subdomain string
Token string
PrimaryConn *Connection
DataConns map[string]*DataConnection
Sessions map[string]*yamux.Session
TunnelType protocol.TunnelType
RegisteredAt time.Time
@@ -50,7 +38,6 @@ func NewConnectionGroup(tunnelID, subdomain, token string, primaryConn *Connecti
Subdomain: subdomain,
Token: token,
PrimaryConn: primaryConn,
DataConns: make(map[string]*DataConnection),
Sessions: make(map[string]*yamux.Session),
TunnelType: tunnelType,
RegisteredAt: time.Now(),
@@ -146,46 +133,6 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
}
}
func (g *ConnectionGroup) AddDataConnection(connID string, conn net.Conn) *DataConnection {
g.mu.Lock()
defer g.mu.Unlock()
dataConn := &DataConnection{
ID: connID,
Conn: conn,
LastActive: time.Now(),
stopCh: make(chan struct{}),
}
g.DataConns[connID] = dataConn
g.LastActivity = time.Now()
return dataConn
}
func (g *ConnectionGroup) RemoveDataConnection(connID string) {
g.mu.Lock()
defer g.mu.Unlock()
if dataConn, ok := g.DataConns[connID]; ok {
dataConn.closedMu.Lock()
if !dataConn.closed {
dataConn.closed = true
close(dataConn.stopCh)
if dataConn.Conn != nil {
_ = dataConn.Conn.SetDeadline(time.Now())
dataConn.Conn.Close()
}
}
dataConn.closedMu.Unlock()
delete(g.DataConns, connID)
}
}
func (g *ConnectionGroup) DataConnectionCount() int {
g.mu.RLock()
defer g.mu.RUnlock()
return len(g.DataConns)
}
func (g *ConnectionGroup) Close() {
g.mu.Lock()
@@ -197,12 +144,6 @@ func (g *ConnectionGroup) Close() {
close(g.stopCh)
}
dataConns := make([]*DataConnection, 0, len(g.DataConns))
for _, dataConn := range g.DataConns {
dataConns = append(dataConns, dataConn)
}
g.DataConns = make(map[string]*DataConnection)
sessions := make([]*yamux.Session, 0, len(g.Sessions))
for _, session := range g.Sessions {
if session != nil {
@@ -213,19 +154,6 @@ func (g *ConnectionGroup) Close() {
g.mu.Unlock()
for _, dataConn := range dataConns {
dataConn.closedMu.Lock()
if !dataConn.closed {
dataConn.closed = true
close(dataConn.stopCh)
if dataConn.Conn != nil {
_ = dataConn.Conn.SetDeadline(time.Now())
_ = dataConn.Conn.Close()
}
}
dataConn.closedMu.Unlock()
}
for _, session := range sessions {
_ = session.Close()
}
@@ -302,7 +230,13 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
default:
}
sessions := g.sessionsSnapshot()
// Prefer data sessions for data-plane traffic; keep the primary session
// as control-plane (client ping/latency), and only fall back to primary
// when no data session exists.
sessions := g.sessionsSnapshot(false)
if len(sessions) == 0 {
sessions = g.sessionsSnapshot(true)
}
if len(sessions) == 0 {
return nil, net.ErrClosed
}
@@ -380,7 +314,10 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
}
func (g *ConnectionGroup) selectSession() *yamux.Session {
sessions := g.sessionsSnapshot()
sessions := g.sessionsSnapshot(false)
if len(sessions) == 0 {
sessions = g.sessionsSnapshot(true)
}
if len(sessions) == 0 {
return nil
}
@@ -403,7 +340,7 @@ func (g *ConnectionGroup) selectSession() *yamux.Session {
return best
}
func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
func (g *ConnectionGroup) sessionsSnapshot(includePrimary bool) []*yamux.Session {
g.mu.Lock()
defer g.mu.Unlock()
@@ -417,6 +354,9 @@ func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
delete(g.Sessions, id)
continue
}
if id == "primary" && !includePrimary {
continue
}
sessions = append(sessions, session)
}

View File

@@ -3,8 +3,6 @@ package tcp
import (
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"sync"
"time"
@@ -84,26 +82,6 @@ func (m *ConnectionGroupManager) RemoveGroup(tunnelID string) {
}
}
// AddDataConnection adds a data connection to a group
func (m *ConnectionGroupManager) AddDataConnection(req *protocol.DataConnectRequest, conn net.Conn) (*DataConnection, error) {
m.mu.RLock()
group, ok := m.groups[req.TunnelID]
m.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("tunnel not found: %s", req.TunnelID)
}
// Validate token
if group.Token != "" && req.Token != group.Token {
return nil, fmt.Errorf("invalid token")
}
dataConn := group.AddDataConnection(req.ConnectionID, conn)
return dataConn, nil
}
// cleanupLoop periodically cleans up stale groups
func (m *ConnectionGroupManager) cleanupLoop() {
ticker := time.NewTicker(m.cleanupInterval)

View File

@@ -0,0 +1,83 @@
package tcp
import (
"net"
"sync"
"sync/atomic"
"time"
)
// connQueueListener is a net.Listener backed by a channel of pre-accepted conns.
// It lets the TCP/TLS multiplexer hand off HTTP connections to a standard http.Server.
type connQueueListener struct {
addr net.Addr
conns chan net.Conn
done chan struct{}
once sync.Once
closed atomic.Bool
}
func newConnQueueListener(addr net.Addr, buffer int) *connQueueListener {
if buffer <= 0 {
buffer = 1024
}
return &connQueueListener{
addr: addr,
conns: make(chan net.Conn, buffer),
done: make(chan struct{}),
}
}
func (l *connQueueListener) Accept() (net.Conn, error) {
select {
case <-l.done:
return nil, net.ErrClosed
case conn := <-l.conns:
if conn == nil {
return nil, net.ErrClosed
}
return conn, nil
}
}
func (l *connQueueListener) Close() error {
l.once.Do(func() {
l.closed.Store(true)
close(l.done)
l.drain()
})
return nil
}
func (l *connQueueListener) Addr() net.Addr { return l.addr }
func (l *connQueueListener) Enqueue(conn net.Conn) bool {
if conn == nil {
return false
}
if l.closed.Load() {
return false
}
select {
case l.conns <- conn:
return true
default:
return false
}
}
func (l *connQueueListener) drain() {
for {
select {
case conn := <-l.conns:
if conn == nil {
continue
}
_ = conn.SetDeadline(time.Now())
_ = conn.Close()
default:
return
}
}
}

View File

@@ -1,6 +1,7 @@
package tcp
import (
"context"
"crypto/tls"
"fmt"
"net"
@@ -14,29 +15,30 @@ import (
"drip/internal/shared/recovery"
"go.uber.org/zap"
"golang.org/x/net/http2"
)
// Listener handles TCP connections with TLS 1.3
type Listener struct {
address string
tlsConfig *tls.Config
authToken string
manager *tunnel.Manager
portAlloc *PortAllocator
logger *zap.Logger
domain string
publicPort int
httpHandler http.Handler
listener net.Listener
stopCh chan struct{}
wg sync.WaitGroup
connections map[string]*Connection
connMu sync.RWMutex
workerPool *pool.WorkerPool // Worker pool for connection handling
recoverer *recovery.Recoverer
panicMetrics *recovery.PanicMetrics
address string
tlsConfig *tls.Config
authToken string
manager *tunnel.Manager
portAlloc *PortAllocator
logger *zap.Logger
domain string
publicPort int
httpHandler http.Handler
listener net.Listener
stopCh chan struct{}
wg sync.WaitGroup
connections map[string]*Connection
connMu sync.RWMutex
workerPool *pool.WorkerPool
recoverer *recovery.Recoverer
panicMetrics *recovery.PanicMetrics
groupManager *ConnectionGroupManager
httpServer *http.Server
httpListener *connQueueListener
}
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler) *Listener {
@@ -73,7 +75,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
}
}
// Start starts the TCP listener
func (l *Listener) Start() error {
var err error
@@ -87,13 +88,39 @@ func (l *Listener) Start() error {
zap.String("tls_version", "TLS 1.3"),
)
l.httpListener = newConnQueueListener(l.listener.Addr(), 4096)
l.httpServer = &http.Server{
Handler: l.httpHandler,
ReadHeaderTimeout: 30 * time.Second,
ReadTimeout: 0,
WriteTimeout: 0,
IdleTimeout: 120 * time.Second,
MaxHeaderBytes: 1 << 20,
}
if err := http2.ConfigureServer(l.httpServer, &http2.Server{
MaxConcurrentStreams: 1000,
IdleTimeout: 120 * time.Second,
}); err != nil {
l.logger.Warn("Failed to configure HTTP/2", zap.Error(err))
}
l.wg.Add(1)
go func() {
defer l.wg.Done()
l.logger.Info("HTTP server started (with context cancellation support)")
if err := l.httpServer.Serve(l.httpListener); err != nil && err != http.ErrServerClosed {
l.logger.Error("HTTP server error", zap.Error(err))
}
}()
l.wg.Add(1)
go l.acceptLoop()
return nil
}
// acceptLoop accepts incoming connections
func (l *Listener) acceptLoop() {
defer l.wg.Done()
defer l.recoverer.Recover("acceptLoop")
@@ -105,7 +132,6 @@ func (l *Listener) acceptLoop() {
default:
}
// Set accept deadline to allow checking stopCh
if tcpListener, ok := l.listener.(*net.TCPListener); ok {
tcpListener.SetDeadline(time.Now().Add(1 * time.Second))
}
@@ -113,7 +139,7 @@ func (l *Listener) acceptLoop() {
conn, err := l.listener.Accept()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue // Timeout is expected due to deadline
continue
}
select {
case <-l.stopCh:
@@ -143,10 +169,8 @@ func (l *Listener) acceptLoop() {
}
}
// handleConnection handles a single client connection
func (l *Listener) handleConnection(netConn net.Conn) {
defer l.wg.Done()
defer netConn.Close()
defer l.recoverer.RecoverWithCallback("handleConnection", func(p interface{}) {
connID := netConn.RemoteAddr().String()
l.connMu.Lock()
@@ -160,7 +184,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
return
}
// Set read deadline before handshake to prevent slow handshake attacks
if err := tlsConn.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
l.logger.Warn("Failed to set read deadline",
zap.String("remote_addr", netConn.RemoteAddr().String()),
@@ -177,7 +200,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
return
}
// Clear the read deadline after successful handshake
if err := tlsConn.SetReadDeadline(time.Time{}); err != nil {
l.logger.Warn("Failed to clear read deadline",
zap.String("remote_addr", netConn.RemoteAddr().String()),
@@ -208,7 +230,7 @@ func (l *Listener) handleConnection(netConn net.Conn) {
return
}
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.groupManager)
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
connID := netConn.RemoteAddr().String()
l.connMu.Lock()
@@ -219,12 +241,15 @@ func (l *Listener) handleConnection(netConn net.Conn) {
l.connMu.Lock()
delete(l.connections, connID)
l.connMu.Unlock()
if !conn.IsHandedOff() {
netConn.Close()
}
}()
if err := conn.Handle(); err != nil {
errStr := err.Error()
// Client disconnection errors - normal network behavior, ignore
if strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
@@ -232,7 +257,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
return
}
// Protocol errors (invalid clients, scanners) are expected - log as WARN
if strings.Contains(errStr, "payload too large") ||
strings.Contains(errStr, "failed to read registration frame") ||
strings.Contains(errStr, "expected register frame") ||
@@ -243,7 +267,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
zap.Error(err),
)
} else {
// Legitimate errors (auth failures, registration failures, etc.)
l.logger.Error("Connection handling failed",
zap.String("remote_addr", connID),
zap.Error(err),
@@ -252,12 +275,24 @@ func (l *Listener) handleConnection(netConn net.Conn) {
}
}
// Stop stops the listener and closes all connections
func (l *Listener) Stop() error {
l.logger.Info("Stopping TCP listener")
close(l.stopCh)
if l.httpServer != nil {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := l.httpServer.Shutdown(shutdownCtx); err != nil {
l.logger.Warn("HTTP server shutdown error", zap.Error(err))
}
l.logger.Info("HTTP server shutdown complete")
}
if l.httpListener != nil {
l.httpListener.Close()
}
if l.listener != nil {
if err := l.listener.Close(); err != nil {
l.logger.Error("Failed to close listener", zap.Error(err))
@@ -284,7 +319,6 @@ func (l *Listener) Stop() error {
return nil
}
// GetActiveConnections returns the number of active connections
func (l *Listener) GetActiveConnections() int {
l.connMu.RLock()
defer l.connMu.RUnlock()

View File

@@ -39,7 +39,7 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
c.session = session
openStream := session.Open
if c.tunnelID != "" && c.groupManager != nil {
if c.groupManager != nil {
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
group.AddSession("primary", session)
openStream = group.OpenStream
@@ -78,7 +78,7 @@ func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
c.session = session
openStream := session.Open
if c.tunnelID != "" && c.groupManager != nil {
if c.groupManager != nil {
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
group.AddSession("primary", session)
openStream = group.OpenStream