feat(server): Supports HTTP CONNECT proxy and connection pooling.

- Added handling for the HTTP CONNECT method, supporting HTTPS tunneling proxies.
- Introducing connQueueListener to hand over HTTP connections to standard http.Server handling.
- Optimized Connection struct fields and lifecycle management logic
- Remove redundant comments and streamline some response writing logic
- Upgrade the golang.org/x/net dependency version to support new features.
- Enhanced HTTP request parsing stability and improved error logging methods.
- Adjusted the TCP listener startup process to integrate HTTP/2 configuration support.
- Improve the connection closing mechanism to avoid resource leakage issues.
This commit is contained in:
Gouryella
2025-12-16 02:24:20 +08:00
parent 7431d821d8
commit 1c733de303
6 changed files with 284 additions and 209 deletions

2
go.mod
View File

@@ -10,6 +10,7 @@ require (
github.com/spf13/cobra v1.10.1
go.uber.org/zap v1.27.1
golang.org/x/crypto v0.45.0
golang.org/x/net v0.47.0
golang.org/x/sys v0.38.0
gopkg.in/yaml.v3 v3.0.1
)
@@ -30,6 +31,5 @@ require (
github.com/stretchr/testify v1.11.1 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/text v0.31.0 // indirect
)

View File

@@ -18,7 +18,6 @@ import (
"go.uber.org/zap"
)
// handleStream routes incoming stream to appropriate handler.
func (c *PoolClient) handleStream(h *sessionHandle, stream net.Conn) {
defer c.wg.Done()
defer func() {
@@ -35,7 +34,6 @@ func (c *PoolClient) handleStream(h *sessionHandle, stream net.Conn) {
}
}
// handleTCPStream handles raw TCP tunneling.
func (c *PoolClient) handleTCPStream(stream net.Conn) {
localConn, err := net.DialTimeout("tcp", net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort)), 10*time.Second)
if err != nil {
@@ -62,7 +60,6 @@ func (c *PoolClient) handleTCPStream(stream net.Conn) {
)
}
// handleHTTPStream handles HTTP/HTTPS proxy requests.
func (c *PoolClient) handleHTTPStream(stream net.Conn) {
_ = stream.SetReadDeadline(time.Now().Add(30 * time.Second))
@@ -104,6 +101,8 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
httputil.CopyHeaders(outReq.Header, req.Header)
httputil.CleanHopByHopHeaders(outReq.Header)
outReq.Header.Del("Accept-Encoding")
targetHost := c.localHost
if c.localPort != 80 && c.localPort != 443 {
targetHost = fmt.Sprintf("%s:%d", c.localHost, c.localPort)
@@ -153,7 +152,6 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
close(done)
}
// handleWebSocketUpgrade handles WebSocket upgrade requests.
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
scheme := "ws"
if c.tunnelType == protocol.TunnelTypeHTTPS {
@@ -207,7 +205,6 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
}
}
// newLocalHTTPClient creates an HTTP client for local service requests.
func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client {
var tlsConfig *tls.Config
if tunnelType == protocol.TunnelTypeHTTPS {

View File

@@ -39,7 +39,11 @@ func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, auth
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Always handle /health and /stats directly, regardless of subdomain.
if r.Method == http.MethodConnect {
h.handleConnect(w, r)
return
}
if r.URL.Path == "/health" {
h.serveHealth(w, r)
return
@@ -71,13 +75,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Check for WebSocket upgrade
if httputil.IsWebSocketUpgrade(r) {
h.handleWebSocket(w, r, tconn)
return
}
// Open stream with timeout
stream, err := h.openStreamWithTimeout(tconn)
if err != nil {
w.Header().Set("Connection", "close")
@@ -86,17 +88,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
defer stream.Close()
// Track active connections
tconn.IncActiveConnections()
defer tconn.DecActiveConnections()
// Wrap stream with counting for traffic stats
countingStream := netutil.NewCountingConn(stream,
tconn.AddBytesOut, // Data read from stream = bytes out to client
tconn.AddBytesIn, // Data written to stream = bytes in from client
tconn.AddBytesOut,
tconn.AddBytesIn,
)
// 1) Write request over the stream (net/http handles large bodies correctly).
if err := r.Write(countingStream); err != nil {
w.Header().Set("Connection", "close")
_ = r.Body.Close()
@@ -104,7 +103,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// 2) Read response from stream.
resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r)
if err != nil {
w.Header().Set("Connection", "close")
@@ -113,7 +111,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
defer resp.Body.Close()
// 3) Copy headers (strip hop-by-hop).
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
statusCode := resp.StatusCode
@@ -121,9 +118,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
statusCode = http.StatusOK
}
// Ensure message delimiting works with our custom ResponseWriter:
// - If Content-Length is known, send it.
// - Otherwise, re-chunk the decoded body ourselves.
if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified {
if resp.ContentLength >= 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
@@ -136,28 +130,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if resp.ContentLength >= 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
w.WriteHeader(statusCode)
ctx := r.Context()
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
stream.Close()
case <-done:
}
}()
_, _ = io.Copy(w, resp.Body)
close(done)
stream.Close()
return
} else {
w.Header().Del("Content-Length")
}
w.Header().Del("Content-Length")
w.Header().Set("Transfer-Encoding", "chunked")
if len(resp.Trailer) > 0 {
w.Header().Set("Trailer", trailerKeys(resp.Trailer))
}
w.WriteHeader(statusCode)
ctx := r.Context()
@@ -170,9 +146,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}()
if err := writeChunked(w, resp.Body, resp.Trailer); err != nil {
h.logger.Debug("Write chunked response failed", zap.Error(err))
}
_, _ = io.Copy(w, resp.Body)
close(done)
stream.Close()
}
@@ -267,53 +241,6 @@ func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHos
}
}
func trailerKeys(hdr http.Header) string {
keys := make([]string, 0, len(hdr))
for k := range hdr {
keys = append(keys, k)
}
// Deterministic order is nicer for debugging; no semantic impact.
sortStrings(keys)
return strings.Join(keys, ", ")
}
func writeChunked(w io.Writer, r io.Reader, trailer http.Header) error {
buf := make([]byte, 32*1024)
for {
n, err := r.Read(buf)
if n > 0 {
if _, werr := fmt.Fprintf(w, "%x\r\n", n); werr != nil {
return werr
}
if _, werr := w.Write(buf[:n]); werr != nil {
return werr
}
if _, werr := io.WriteString(w, "\r\n"); werr != nil {
return werr
}
}
if err == io.EOF {
break
}
if err != nil {
return err
}
}
if _, err := io.WriteString(w, "0\r\n"); err != nil {
return err
}
for k, vv := range trailer {
for _, v := range vv {
if _, err := io.WriteString(w, fmt.Sprintf("%s: %s\r\n", k, v)); err != nil {
return err
}
}
}
_, err := io.WriteString(w, "\r\n")
return err
}
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
if !strings.HasPrefix(location, "http://") && !strings.HasPrefix(location, "https://") {
return location
@@ -460,12 +387,65 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
w.Write(data)
}
func sortStrings(s []string) {
for i := 0; i < len(s); i++ {
for j := i + 1; j < len(s); j++ {
if s[j] < s[i] {
s[i], s[j] = s[j], s[i]
}
}
func (h *Handler) handleConnect(w http.ResponseWriter, r *http.Request) {
targetAddr := r.Host
if targetAddr == "" {
targetAddr = r.URL.Host
}
if targetAddr == "" {
http.Error(w, "Bad Request: missing target host", http.StatusBadRequest)
return
}
if !strings.Contains(targetAddr, ":") {
targetAddr = targetAddr + ":443"
}
h.logger.Info("CONNECT proxy request",
zap.String("target", targetAddr),
zap.String("remote", r.RemoteAddr),
)
targetConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second)
if err != nil {
h.logger.Warn("Failed to connect to target",
zap.String("target", targetAddr),
zap.Error(err),
)
http.Error(w, "Bad Gateway: failed to connect to target", http.StatusBadGateway)
return
}
hj, ok := w.(http.Hijacker)
if !ok {
targetConn.Close()
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
return
}
clientConn, _, err := hj.Hijack()
if err != nil {
targetConn.Close()
http.Error(w, "Failed to hijack connection", http.StatusInternalServerError)
return
}
_, err = clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
if err != nil {
clientConn.Close()
targetConn.Close()
return
}
go func() {
defer targetConn.Close()
defer clientConn.Close()
_, _ = io.Copy(targetConn, clientConn)
}()
go func() {
defer targetConn.Close()
defer clientConn.Close()
_, _ = io.Copy(clientConn, targetConn)
}()
}

View File

@@ -22,7 +22,6 @@ import (
"go.uber.org/zap"
)
// Connection represents a client TCP connection
type Connection struct {
conn net.Conn
authToken string
@@ -40,21 +39,19 @@ type Connection struct {
mu sync.RWMutex
frameWriter *protocol.FrameWriter
httpHandler http.Handler
tunnelType protocol.TunnelType // Track tunnel type
tunnelType protocol.TunnelType
ctx context.Context
cancel context.CancelFunc
// gost-like TCP tunnel (yamux)
session *yamux.Session
proxy *Proxy
// Multi-connection support
tunnelID string
groupManager *ConnectionGroupManager
session *yamux.Session
proxy *Proxy
tunnelID string
groupManager *ConnectionGroupManager
httpListener *connQueueListener
handedOff bool
}
// NewConnection creates a new connection handler
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager) *Connection {
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection {
ctx, cancel := context.WithCancel(context.Background())
c := &Connection{
conn: conn,
@@ -70,25 +67,18 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
ctx: ctx,
cancel: cancel,
groupManager: groupManager,
httpListener: httpListener,
}
return c
}
// Handle handles the connection lifecycle
func (c *Connection) Handle() error {
// Register connection for adaptive load tracking
protocol.RegisterConnection()
// Ensure cleanup of control connection, proxy, port, and registry on exit.
defer c.Close()
// Set initial read timeout for protocol detection
c.conn.SetReadDeadline(time.Now().Add(30 * time.Second))
// Use buffered reader to support peeking
reader := bufio.NewReader(c.conn)
// Peek first 4 bytes to detect protocol (HTTP methods are 4 bytes).
peek, err := reader.Peek(4)
if err != nil {
return fmt.Errorf("failed to peek connection: %w", err)
@@ -109,8 +99,6 @@ func (c *Connection) Handle() error {
return c.handleHTTPRequest(reader)
}
// Continue with drip protocol
// Wait for registration frame
frame, err := protocol.ReadFrame(reader)
if err != nil {
return fmt.Errorf("failed to read registration frame: %w", err)
@@ -118,7 +106,6 @@ func (c *Connection) Handle() error {
sf := protocol.WithFrame(frame)
defer sf.Close()
// Handle data connection request (for multi-connection pool)
if sf.Frame.Type == protocol.FrameTypeDataConnect {
return c.handleDataConnect(sf.Frame, reader)
}
@@ -139,7 +126,6 @@ func (c *Connection) Handle() error {
return fmt.Errorf("authentication failed")
}
// Allocate TCP port only for TCP tunnels
if req.TunnelType == protocol.TunnelTypeTCP {
if c.portAlloc == nil {
return fmt.Errorf("port allocator not configured")
@@ -152,7 +138,6 @@ func (c *Connection) Handle() error {
}
c.port = port
// For TCP tunnels, prefer deterministic subdomain tied to port when not provided by client.
if req.CustomSubdomain == "" {
req.CustomSubdomain = fmt.Sprintf("tcp-%d", port)
}
@@ -174,8 +159,7 @@ func (c *Connection) Handle() error {
}
c.tunnelConn = tunnelConn
// Store TCP connection reference and metadata for HTTP proxy routing
c.tunnelConn.Conn = nil // We're using TCP, not WebSocket
c.tunnelConn.Conn = nil
c.tunnelConn.SetTunnelType(req.TunnelType)
c.tunnelType = req.TunnelType
@@ -186,35 +170,27 @@ func (c *Connection) Handle() error {
zap.Int("remote_port", c.port),
)
// Send registration acknowledgment
// Generate appropriate URL based on tunnel type
var tunnelURL string
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
// HTTP/HTTPS tunnels use HTTPS with subdomain
// Use publicPort for URL generation (configured via --public-port flag)
if c.publicPort == 443 {
tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.domain)
} else {
tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.domain, c.publicPort)
}
} else {
// TCP tunnels use tcp:// with port
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
}
// Generate TunnelID for multi-connection support if client supports it
var tunnelID string
var supportsDataConn bool
recommendedConns := 0
if req.PoolCapabilities != nil && req.ConnectionType == "primary" && c.groupManager != nil {
// Client supports connection pooling
group := c.groupManager.CreateGroup(subdomain, req.Token, c, req.TunnelType)
tunnelID = group.TunnelID
c.tunnelID = tunnelID
supportsDataConn = true
recommendedConns = 4 // Recommend 4 data connections
recommendedConns = 4
c.logger.Info("Created connection group for multi-connection support",
zap.String("tunnel_id", tunnelID),
@@ -235,16 +211,13 @@ func (c *Connection) Handle() error {
respData, _ := json.Marshal(resp)
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
// Send registration ack (sync write before frameWriter is created)
err = protocol.WriteFrame(c.conn, ackFrame)
if err != nil {
return fmt.Errorf("failed to send registration ack: %w", err)
}
// Clear deadline for tunnel data-plane.
c.conn.SetReadDeadline(time.Time{})
// gost-like tunnels: switch to yamux after RegisterAck.
if req.TunnelType == protocol.TunnelTypeTCP {
return c.handleTCPTunnel(reader)
}
@@ -265,6 +238,44 @@ func (c *Connection) Handle() error {
}
func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
if c.httpListener == nil {
return c.handleHTTPRequestLegacy(reader)
}
c.conn.SetReadDeadline(time.Time{})
wrappedConn := &bufferedConn{
Conn: c.conn,
reader: reader,
}
if !c.httpListener.Enqueue(wrappedConn) {
c.logger.Warn("HTTP listener queue full, rejecting connection")
response := "HTTP/1.1 503 Service Unavailable\r\n" +
"Content-Type: text/plain\r\n" +
"Content-Length: 32\r\n" +
"Connection: close\r\n" +
"\r\n" +
"Server busy, please retry later\r\n"
c.conn.Write([]byte(response))
return fmt.Errorf("http listener queue full")
}
c.mu.Lock()
c.conn = nil
c.handedOff = true
c.mu.Unlock()
return nil
}
func (c *Connection) IsHandedOff() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.handedOff
}
func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
if c.httpHandler == nil {
c.logger.Warn("HTTP request received but no HTTP handler configured")
response := "HTTP/1.1 503 Service Unavailable\r\n" +
@@ -276,18 +287,13 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
return fmt.Errorf("HTTP handler not configured")
}
// Clear read deadline for HTTP processing
c.conn.SetReadDeadline(time.Time{})
// Handle multiple HTTP requests on the same connection (HTTP/1.1 keep-alive)
for {
// Set a read deadline for each request to avoid hanging forever
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
// Parse HTTP request
req, err := http.ReadRequest(reader)
if err != nil {
// EOF or timeout is normal when client closes connection or no more requests
if err == io.EOF || err == io.ErrUnexpectedEOF {
c.logger.Debug("Client closed HTTP connection")
return nil
@@ -296,7 +302,6 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
c.logger.Debug("HTTP keep-alive timeout")
return nil
}
// Connection reset by peer is normal - client closed connection abruptly
errStr := err.Error()
if errors.Is(err, net.ErrClosed) || strings.Contains(errStr, "use of closed network connection") {
c.logger.Debug("HTTP connection closed during read", zap.Error(err))
@@ -308,13 +313,11 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
c.logger.Debug("Client disconnected abruptly", zap.Error(err))
return nil
}
// Check if it looks like garbage data (not a valid HTTP request)
if strings.Contains(errStr, "malformed HTTP") {
c.logger.Warn("Received malformed HTTP request, possibly due to pipelined requests or protocol error",
c.logger.Warn("Received malformed HTTP request",
zap.Error(err),
zap.String("error_snippet", errStr[:min(len(errStr), 100)]),
)
// Close connection on malformed request to prevent further errors
return nil
}
c.logger.Error("Failed to parse HTTP request", zap.Error(err))
@@ -370,8 +373,6 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
c.logger.Debug("Closing connection as requested by client or server")
return nil
}
// Continue to next request on the same connection
}
}
@@ -382,7 +383,6 @@ func min(a, b int) int {
return b
}
// handleFrames handles incoming frames
func (c *Connection) handleFrames(reader *bufio.Reader) error {
for {
select {
@@ -391,7 +391,6 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
default:
}
// Read frame with timeout
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(reader)
if err != nil {
@@ -399,15 +398,12 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
c.logger.Warn("Read timeout, connection may be dead")
return fmt.Errorf("read timeout")
}
// EOF is normal when client closes connection gracefully
if err.Error() == "failed to read frame header: EOF" || err.Error() == "EOF" {
c.logger.Info("Client disconnected")
return nil
}
// Check if connection was closed (during shutdown)
select {
case <-c.stopCh:
// Connection was closed intentionally, don't log as error
c.logger.Debug("Connection closed during shutdown")
return nil
default:
@@ -415,7 +411,6 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
}
}
// Handle frame based on type
sf := protocol.WithFrame(frame)
switch sf.Frame.Type {
@@ -437,22 +432,18 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
}
}
// handleHeartbeat handles heartbeat frame
func (c *Connection) handleHeartbeat() {
c.mu.Lock()
c.lastHeartbeat = time.Now()
c.mu.Unlock()
// Send heartbeat ack
ackFrame := protocol.NewFrame(protocol.FrameTypeHeartbeatAck, nil)
err := c.frameWriter.WriteControl(ackFrame)
if err != nil {
c.logger.Error("Failed to send heartbeat ack", zap.Error(err))
}
}
// heartbeatChecker checks for heartbeat timeout
func (c *Connection) heartbeatChecker() {
ticker := time.NewTicker(constants.HeartbeatInterval)
defer ticker.Stop()
@@ -496,16 +487,19 @@ func (c *Connection) sendError(code, message string) {
func (c *Connection) Close() {
c.once.Do(func() {
protocol.UnregisterConnection()
close(c.stopCh)
if c.cancel != nil {
c.cancel()
}
// Ensure any in-flight writes return quickly on shutdown to avoid hanging.
if c.conn != nil {
_ = c.conn.SetDeadline(time.Now())
// Prevent race with handleHTTPRequest setting c.conn = nil
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn != nil {
_ = conn.SetDeadline(time.Now())
}
if c.frameWriter != nil {
@@ -520,8 +514,8 @@ func (c *Connection) Close() {
_ = c.session.Close()
}
if c.conn != nil {
c.conn.Close()
if conn != nil {
conn.Close()
}
if c.port > 0 && c.portAlloc != nil {
@@ -530,9 +524,6 @@ func (c *Connection) Close() {
if c.subdomain != "" {
c.manager.Unregister(c.subdomain)
// Clean up connection group when PRIMARY connection closes
// (only primary connections have subdomain set)
if c.tunnelID != "" && c.groupManager != nil {
c.groupManager.RemoveGroup(c.tunnelID)
}
@@ -544,10 +535,9 @@ func (c *Connection) Close() {
})
}
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
type httpResponseWriter struct {
conn net.Conn
writer *bufio.Writer // Buffered writer for efficient I/O
writer *bufio.Writer
header http.Header
statusCode int
headerWritten bool
@@ -594,7 +584,6 @@ func (w *httpResponseWriter) Write(data []byte) (int, error) {
return w.writer.Write(data)
}
// handleDataConnect handles a data connection join request
func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Reader) error {
var req protocol.DataConnectRequest
if err := json.Unmarshal(frame.Payload, &req); err != nil {
@@ -607,13 +596,11 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
zap.String("connection_id", req.ConnectionID),
)
// Validate the request
if c.groupManager == nil {
c.sendDataConnectError("not_supported", "Multi-connection not supported")
return fmt.Errorf("group manager not available")
}
// Validate auth token
if c.authToken != "" && req.Token != c.authToken {
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed for data connection")
@@ -625,16 +612,13 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
return fmt.Errorf("tunnel not found: %s", req.TunnelID)
}
// Validate token against the primary registration token.
if group.Token != "" && req.Token != group.Token {
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed for data connection")
}
// Store tunnelID for cleanup
c.tunnelID = req.TunnelID
// Send success response before upgrading the connection to yamux.
resp := protocol.DataConnectResponse{
Accepted: true,
ConnectionID: req.ConnectionID,
@@ -653,10 +637,9 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
zap.String("connection_id", req.ConnectionID),
)
// Clear deadline for yamux data-plane.
_ = c.conn.SetReadDeadline(time.Time{})
// Public server acts as yamux Client, client connector acts as yamux Server.
// Server acts as yamux Client, client connector acts as yamux Server
bc := &bufferedConn{
Conn: c.conn,
reader: reader,
@@ -692,11 +675,9 @@ func isTimeoutError(err error) bool {
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
// Fallback for wrapped errors without net.Error
return strings.Contains(err.Error(), "i/o timeout")
}
// sendDataConnectError sends a data connect error response
func (c *Connection) sendDataConnectError(code, message string) {
resp := protocol.DataConnectResponse{
Accepted: false,

View File

@@ -0,0 +1,83 @@
package tcp
import (
"net"
"sync"
"sync/atomic"
"time"
)
// connQueueListener is a net.Listener backed by a channel of pre-accepted conns.
// It lets the TCP/TLS multiplexer hand off HTTP connections to a standard http.Server.
type connQueueListener struct {
addr net.Addr
conns chan net.Conn
done chan struct{}
once sync.Once
closed atomic.Bool
}
func newConnQueueListener(addr net.Addr, buffer int) *connQueueListener {
if buffer <= 0 {
buffer = 1024
}
return &connQueueListener{
addr: addr,
conns: make(chan net.Conn, buffer),
done: make(chan struct{}),
}
}
func (l *connQueueListener) Accept() (net.Conn, error) {
select {
case <-l.done:
return nil, net.ErrClosed
case conn := <-l.conns:
if conn == nil {
return nil, net.ErrClosed
}
return conn, nil
}
}
func (l *connQueueListener) Close() error {
l.once.Do(func() {
l.closed.Store(true)
close(l.done)
l.drain()
})
return nil
}
func (l *connQueueListener) Addr() net.Addr { return l.addr }
func (l *connQueueListener) Enqueue(conn net.Conn) bool {
if conn == nil {
return false
}
if l.closed.Load() {
return false
}
select {
case l.conns <- conn:
return true
default:
return false
}
}
func (l *connQueueListener) drain() {
for {
select {
case conn := <-l.conns:
if conn == nil {
continue
}
_ = conn.SetDeadline(time.Now())
_ = conn.Close()
default:
return
}
}
}

View File

@@ -1,6 +1,7 @@
package tcp
import (
"context"
"crypto/tls"
"fmt"
"net"
@@ -14,29 +15,30 @@ import (
"drip/internal/shared/recovery"
"go.uber.org/zap"
"golang.org/x/net/http2"
)
// Listener handles TCP connections with TLS 1.3
type Listener struct {
address string
tlsConfig *tls.Config
authToken string
manager *tunnel.Manager
portAlloc *PortAllocator
logger *zap.Logger
domain string
publicPort int
httpHandler http.Handler
listener net.Listener
stopCh chan struct{}
wg sync.WaitGroup
connections map[string]*Connection
connMu sync.RWMutex
workerPool *pool.WorkerPool // Worker pool for connection handling
recoverer *recovery.Recoverer
panicMetrics *recovery.PanicMetrics
address string
tlsConfig *tls.Config
authToken string
manager *tunnel.Manager
portAlloc *PortAllocator
logger *zap.Logger
domain string
publicPort int
httpHandler http.Handler
listener net.Listener
stopCh chan struct{}
wg sync.WaitGroup
connections map[string]*Connection
connMu sync.RWMutex
workerPool *pool.WorkerPool
recoverer *recovery.Recoverer
panicMetrics *recovery.PanicMetrics
groupManager *ConnectionGroupManager
httpServer *http.Server
httpListener *connQueueListener
}
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler) *Listener {
@@ -73,7 +75,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
}
}
// Start starts the TCP listener
func (l *Listener) Start() error {
var err error
@@ -87,13 +88,39 @@ func (l *Listener) Start() error {
zap.String("tls_version", "TLS 1.3"),
)
l.httpListener = newConnQueueListener(l.listener.Addr(), 4096)
l.httpServer = &http.Server{
Handler: l.httpHandler,
ReadHeaderTimeout: 30 * time.Second,
ReadTimeout: 0,
WriteTimeout: 0,
IdleTimeout: 120 * time.Second,
MaxHeaderBytes: 1 << 20,
}
if err := http2.ConfigureServer(l.httpServer, &http2.Server{
MaxConcurrentStreams: 1000,
IdleTimeout: 120 * time.Second,
}); err != nil {
l.logger.Warn("Failed to configure HTTP/2", zap.Error(err))
}
l.wg.Add(1)
go func() {
defer l.wg.Done()
l.logger.Info("HTTP server started (with context cancellation support)")
if err := l.httpServer.Serve(l.httpListener); err != nil && err != http.ErrServerClosed {
l.logger.Error("HTTP server error", zap.Error(err))
}
}()
l.wg.Add(1)
go l.acceptLoop()
return nil
}
// acceptLoop accepts incoming connections
func (l *Listener) acceptLoop() {
defer l.wg.Done()
defer l.recoverer.Recover("acceptLoop")
@@ -105,7 +132,6 @@ func (l *Listener) acceptLoop() {
default:
}
// Set accept deadline to allow checking stopCh
if tcpListener, ok := l.listener.(*net.TCPListener); ok {
tcpListener.SetDeadline(time.Now().Add(1 * time.Second))
}
@@ -113,7 +139,7 @@ func (l *Listener) acceptLoop() {
conn, err := l.listener.Accept()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue // Timeout is expected due to deadline
continue
}
select {
case <-l.stopCh:
@@ -143,10 +169,8 @@ func (l *Listener) acceptLoop() {
}
}
// handleConnection handles a single client connection
func (l *Listener) handleConnection(netConn net.Conn) {
defer l.wg.Done()
defer netConn.Close()
defer l.recoverer.RecoverWithCallback("handleConnection", func(p interface{}) {
connID := netConn.RemoteAddr().String()
l.connMu.Lock()
@@ -160,7 +184,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
return
}
// Set read deadline before handshake to prevent slow handshake attacks
if err := tlsConn.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
l.logger.Warn("Failed to set read deadline",
zap.String("remote_addr", netConn.RemoteAddr().String()),
@@ -177,7 +200,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
return
}
// Clear the read deadline after successful handshake
if err := tlsConn.SetReadDeadline(time.Time{}); err != nil {
l.logger.Warn("Failed to clear read deadline",
zap.String("remote_addr", netConn.RemoteAddr().String()),
@@ -208,7 +230,7 @@ func (l *Listener) handleConnection(netConn net.Conn) {
return
}
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.groupManager)
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
connID := netConn.RemoteAddr().String()
l.connMu.Lock()
@@ -219,12 +241,15 @@ func (l *Listener) handleConnection(netConn net.Conn) {
l.connMu.Lock()
delete(l.connections, connID)
l.connMu.Unlock()
if !conn.IsHandedOff() {
netConn.Close()
}
}()
if err := conn.Handle(); err != nil {
errStr := err.Error()
// Client disconnection errors - normal network behavior, ignore
if strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
@@ -232,7 +257,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
return
}
// Protocol errors (invalid clients, scanners) are expected - log as WARN
if strings.Contains(errStr, "payload too large") ||
strings.Contains(errStr, "failed to read registration frame") ||
strings.Contains(errStr, "expected register frame") ||
@@ -243,7 +267,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
zap.Error(err),
)
} else {
// Legitimate errors (auth failures, registration failures, etc.)
l.logger.Error("Connection handling failed",
zap.String("remote_addr", connID),
zap.Error(err),
@@ -252,12 +275,24 @@ func (l *Listener) handleConnection(netConn net.Conn) {
}
}
// Stop stops the listener and closes all connections
func (l *Listener) Stop() error {
l.logger.Info("Stopping TCP listener")
close(l.stopCh)
if l.httpServer != nil {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := l.httpServer.Shutdown(shutdownCtx); err != nil {
l.logger.Warn("HTTP server shutdown error", zap.Error(err))
}
l.logger.Info("HTTP server shutdown complete")
}
if l.httpListener != nil {
l.httpListener.Close()
}
if l.listener != nil {
if err := l.listener.Close(); err != nil {
l.logger.Error("Failed to close listener", zap.Error(err))
@@ -284,7 +319,6 @@ func (l *Listener) Stop() error {
return nil
}
// GetActiveConnections returns the number of active connections
func (l *Listener) GetActiveConnections() int {
l.connMu.RLock()
defer l.connMu.RUnlock()