mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
2
go.mod
2
go.mod
@@ -10,6 +10,7 @@ require (
|
||||
github.com/spf13/cobra v1.10.1
|
||||
go.uber.org/zap v1.27.1
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/sys v0.38.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
@@ -30,6 +31,5 @@ require (
|
||||
github.com/stretchr/testify v1.11.1 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
)
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// handleStream routes incoming stream to appropriate handler.
|
||||
func (c *PoolClient) handleStream(h *sessionHandle, stream net.Conn) {
|
||||
defer c.wg.Done()
|
||||
defer func() {
|
||||
@@ -35,7 +34,6 @@ func (c *PoolClient) handleStream(h *sessionHandle, stream net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleTCPStream handles raw TCP tunneling.
|
||||
func (c *PoolClient) handleTCPStream(stream net.Conn) {
|
||||
localConn, err := net.DialTimeout("tcp", net.JoinHostPort(c.localHost, fmt.Sprintf("%d", c.localPort)), 10*time.Second)
|
||||
if err != nil {
|
||||
@@ -62,7 +60,6 @@ func (c *PoolClient) handleTCPStream(stream net.Conn) {
|
||||
)
|
||||
}
|
||||
|
||||
// handleHTTPStream handles HTTP/HTTPS proxy requests.
|
||||
func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
||||
_ = stream.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
@@ -104,6 +101,8 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
||||
httputil.CopyHeaders(outReq.Header, req.Header)
|
||||
httputil.CleanHopByHopHeaders(outReq.Header)
|
||||
|
||||
outReq.Header.Del("Accept-Encoding")
|
||||
|
||||
targetHost := c.localHost
|
||||
if c.localPort != 80 && c.localPort != 443 {
|
||||
targetHost = fmt.Sprintf("%s:%d", c.localHost, c.localPort)
|
||||
@@ -153,7 +152,6 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
||||
close(done)
|
||||
}
|
||||
|
||||
// handleWebSocketUpgrade handles WebSocket upgrade requests.
|
||||
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
||||
scheme := "ws"
|
||||
if c.tunnelType == protocol.TunnelTypeHTTPS {
|
||||
@@ -207,7 +205,6 @@ func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// newLocalHTTPClient creates an HTTP client for local service requests.
|
||||
func newLocalHTTPClient(tunnelType protocol.TunnelType) *http.Client {
|
||||
var tlsConfig *tls.Config
|
||||
if tunnelType == protocol.TunnelTypeHTTPS {
|
||||
|
||||
@@ -39,7 +39,6 @@ func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, auth
|
||||
}
|
||||
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Always handle /health and /stats directly, regardless of subdomain.
|
||||
if r.URL.Path == "/health" {
|
||||
h.serveHealth(w, r)
|
||||
return
|
||||
@@ -71,13 +70,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check for WebSocket upgrade
|
||||
if r.Method == http.MethodConnect {
|
||||
http.Error(w, "CONNECT not supported for HTTP tunnels", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if httputil.IsWebSocketUpgrade(r) {
|
||||
h.handleWebSocket(w, r, tconn)
|
||||
return
|
||||
}
|
||||
|
||||
// Open stream with timeout
|
||||
stream, err := h.openStreamWithTimeout(tconn)
|
||||
if err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
@@ -86,17 +88,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
// Track active connections
|
||||
tconn.IncActiveConnections()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
// Wrap stream with counting for traffic stats
|
||||
countingStream := netutil.NewCountingConn(stream,
|
||||
tconn.AddBytesOut, // Data read from stream = bytes out to client
|
||||
tconn.AddBytesIn, // Data written to stream = bytes in from client
|
||||
tconn.AddBytesOut,
|
||||
tconn.AddBytesIn,
|
||||
)
|
||||
|
||||
// 1) Write request over the stream (net/http handles large bodies correctly).
|
||||
if err := r.Write(countingStream); err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
_ = r.Body.Close()
|
||||
@@ -104,7 +103,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// 2) Read response from stream.
|
||||
resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r)
|
||||
if err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
@@ -113,7 +111,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 3) Copy headers (strip hop-by-hop).
|
||||
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
|
||||
|
||||
statusCode := resp.StatusCode
|
||||
@@ -121,9 +118,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
statusCode = http.StatusOK
|
||||
}
|
||||
|
||||
// Ensure message delimiting works with our custom ResponseWriter:
|
||||
// - If Content-Length is known, send it.
|
||||
// - Otherwise, re-chunk the decoded body ourselves.
|
||||
if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified {
|
||||
if resp.ContentLength >= 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
|
||||
@@ -136,28 +130,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if resp.ContentLength >= 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
ctx := r.Context()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stream.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
close(done)
|
||||
stream.Close()
|
||||
return
|
||||
} else {
|
||||
w.Header().Del("Content-Length")
|
||||
}
|
||||
|
||||
w.Header().Del("Content-Length")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
if len(resp.Trailer) > 0 {
|
||||
w.Header().Set("Trailer", trailerKeys(resp.Trailer))
|
||||
}
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
ctx := r.Context()
|
||||
@@ -170,9 +146,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := writeChunked(w, resp.Body, resp.Trailer); err != nil {
|
||||
h.logger.Debug("Write chunked response failed", zap.Error(err))
|
||||
}
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
close(done)
|
||||
stream.Close()
|
||||
}
|
||||
@@ -267,53 +241,6 @@ func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHos
|
||||
}
|
||||
}
|
||||
|
||||
func trailerKeys(hdr http.Header) string {
|
||||
keys := make([]string, 0, len(hdr))
|
||||
for k := range hdr {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
// Deterministic order is nicer for debugging; no semantic impact.
|
||||
sortStrings(keys)
|
||||
return strings.Join(keys, ", ")
|
||||
}
|
||||
|
||||
func writeChunked(w io.Writer, r io.Reader, trailer http.Header) error {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
if _, werr := fmt.Fprintf(w, "%x\r\n", n); werr != nil {
|
||||
return werr
|
||||
}
|
||||
if _, werr := w.Write(buf[:n]); werr != nil {
|
||||
return werr
|
||||
}
|
||||
if _, werr := io.WriteString(w, "\r\n"); werr != nil {
|
||||
return werr
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := io.WriteString(w, "0\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
for k, vv := range trailer {
|
||||
for _, v := range vv {
|
||||
if _, err := io.WriteString(w, fmt.Sprintf("%s: %s\r\n", k, v)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err := io.WriteString(w, "\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
|
||||
if !strings.HasPrefix(location, "http://") && !strings.HasPrefix(location, "https://") {
|
||||
return location
|
||||
@@ -459,13 +386,3 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func sortStrings(s []string) {
|
||||
for i := 0; i < len(s); i++ {
|
||||
for j := i + 1; j < len(s); j++ {
|
||||
if s[j] < s[i] {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Connection represents a client TCP connection
|
||||
type Connection struct {
|
||||
conn net.Conn
|
||||
authToken string
|
||||
@@ -40,21 +39,19 @@ type Connection struct {
|
||||
mu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
httpHandler http.Handler
|
||||
tunnelType protocol.TunnelType // Track tunnel type
|
||||
tunnelType protocol.TunnelType
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// gost-like TCP tunnel (yamux)
|
||||
session *yamux.Session
|
||||
proxy *Proxy
|
||||
|
||||
// Multi-connection support
|
||||
tunnelID string
|
||||
groupManager *ConnectionGroupManager
|
||||
session *yamux.Session
|
||||
proxy *Proxy
|
||||
tunnelID string
|
||||
groupManager *ConnectionGroupManager
|
||||
httpListener *connQueueListener
|
||||
handedOff bool
|
||||
}
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager) *Connection {
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, groupManager *ConnectionGroupManager, httpListener *connQueueListener) *Connection {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
c := &Connection{
|
||||
conn: conn,
|
||||
@@ -70,25 +67,18 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
groupManager: groupManager,
|
||||
httpListener: httpListener,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// Handle handles the connection lifecycle
|
||||
func (c *Connection) Handle() error {
|
||||
// Register connection for adaptive load tracking
|
||||
protocol.RegisterConnection()
|
||||
|
||||
// Ensure cleanup of control connection, proxy, port, and registry on exit.
|
||||
defer c.Close()
|
||||
|
||||
// Set initial read timeout for protocol detection
|
||||
c.conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
// Use buffered reader to support peeking
|
||||
reader := bufio.NewReader(c.conn)
|
||||
|
||||
// Peek first 4 bytes to detect protocol (HTTP methods are 4 bytes).
|
||||
peek, err := reader.Peek(4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to peek connection: %w", err)
|
||||
@@ -109,8 +99,6 @@ func (c *Connection) Handle() error {
|
||||
return c.handleHTTPRequest(reader)
|
||||
}
|
||||
|
||||
// Continue with drip protocol
|
||||
// Wait for registration frame
|
||||
frame, err := protocol.ReadFrame(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read registration frame: %w", err)
|
||||
@@ -118,7 +106,6 @@ func (c *Connection) Handle() error {
|
||||
sf := protocol.WithFrame(frame)
|
||||
defer sf.Close()
|
||||
|
||||
// Handle data connection request (for multi-connection pool)
|
||||
if sf.Frame.Type == protocol.FrameTypeDataConnect {
|
||||
return c.handleDataConnect(sf.Frame, reader)
|
||||
}
|
||||
@@ -139,7 +126,6 @@ func (c *Connection) Handle() error {
|
||||
return fmt.Errorf("authentication failed")
|
||||
}
|
||||
|
||||
// Allocate TCP port only for TCP tunnels
|
||||
if req.TunnelType == protocol.TunnelTypeTCP {
|
||||
if c.portAlloc == nil {
|
||||
return fmt.Errorf("port allocator not configured")
|
||||
@@ -152,7 +138,6 @@ func (c *Connection) Handle() error {
|
||||
}
|
||||
c.port = port
|
||||
|
||||
// For TCP tunnels, prefer deterministic subdomain tied to port when not provided by client.
|
||||
if req.CustomSubdomain == "" {
|
||||
req.CustomSubdomain = fmt.Sprintf("tcp-%d", port)
|
||||
}
|
||||
@@ -174,8 +159,7 @@ func (c *Connection) Handle() error {
|
||||
}
|
||||
c.tunnelConn = tunnelConn
|
||||
|
||||
// Store TCP connection reference and metadata for HTTP proxy routing
|
||||
c.tunnelConn.Conn = nil // We're using TCP, not WebSocket
|
||||
c.tunnelConn.Conn = nil
|
||||
c.tunnelConn.SetTunnelType(req.TunnelType)
|
||||
c.tunnelType = req.TunnelType
|
||||
|
||||
@@ -186,35 +170,27 @@ func (c *Connection) Handle() error {
|
||||
zap.Int("remote_port", c.port),
|
||||
)
|
||||
|
||||
// Send registration acknowledgment
|
||||
// Generate appropriate URL based on tunnel type
|
||||
var tunnelURL string
|
||||
|
||||
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
|
||||
// HTTP/HTTPS tunnels use HTTPS with subdomain
|
||||
// Use publicPort for URL generation (configured via --public-port flag)
|
||||
if c.publicPort == 443 {
|
||||
tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.domain)
|
||||
} else {
|
||||
tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.domain, c.publicPort)
|
||||
}
|
||||
} else {
|
||||
// TCP tunnels use tcp:// with port
|
||||
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
|
||||
}
|
||||
|
||||
// Generate TunnelID for multi-connection support if client supports it
|
||||
var tunnelID string
|
||||
var supportsDataConn bool
|
||||
recommendedConns := 0
|
||||
|
||||
if req.PoolCapabilities != nil && req.ConnectionType == "primary" && c.groupManager != nil {
|
||||
// Client supports connection pooling
|
||||
group := c.groupManager.CreateGroup(subdomain, req.Token, c, req.TunnelType)
|
||||
tunnelID = group.TunnelID
|
||||
c.tunnelID = tunnelID
|
||||
supportsDataConn = true
|
||||
recommendedConns = 4 // Recommend 4 data connections
|
||||
recommendedConns = 4
|
||||
|
||||
c.logger.Info("Created connection group for multi-connection support",
|
||||
zap.String("tunnel_id", tunnelID),
|
||||
@@ -235,16 +211,13 @@ func (c *Connection) Handle() error {
|
||||
respData, _ := json.Marshal(resp)
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
|
||||
|
||||
// Send registration ack (sync write before frameWriter is created)
|
||||
err = protocol.WriteFrame(c.conn, ackFrame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send registration ack: %w", err)
|
||||
}
|
||||
|
||||
// Clear deadline for tunnel data-plane.
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// gost-like tunnels: switch to yamux after RegisterAck.
|
||||
if req.TunnelType == protocol.TunnelTypeTCP {
|
||||
return c.handleTCPTunnel(reader)
|
||||
}
|
||||
@@ -265,6 +238,44 @@ func (c *Connection) Handle() error {
|
||||
}
|
||||
|
||||
func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
if c.httpListener == nil {
|
||||
return c.handleHTTPRequestLegacy(reader)
|
||||
}
|
||||
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
wrappedConn := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
if !c.httpListener.Enqueue(wrappedConn) {
|
||||
c.logger.Warn("HTTP listener queue full, rejecting connection")
|
||||
response := "HTTP/1.1 503 Service Unavailable\r\n" +
|
||||
"Content-Type: text/plain\r\n" +
|
||||
"Content-Length: 32\r\n" +
|
||||
"Connection: close\r\n" +
|
||||
"\r\n" +
|
||||
"Server busy, please retry later\r\n"
|
||||
c.conn.Write([]byte(response))
|
||||
return fmt.Errorf("http listener queue full")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = nil
|
||||
c.handedOff = true
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) IsHandedOff() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.handedOff
|
||||
}
|
||||
|
||||
func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
|
||||
if c.httpHandler == nil {
|
||||
c.logger.Warn("HTTP request received but no HTTP handler configured")
|
||||
response := "HTTP/1.1 503 Service Unavailable\r\n" +
|
||||
@@ -276,18 +287,13 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
return fmt.Errorf("HTTP handler not configured")
|
||||
}
|
||||
|
||||
// Clear read deadline for HTTP processing
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Handle multiple HTTP requests on the same connection (HTTP/1.1 keep-alive)
|
||||
for {
|
||||
// Set a read deadline for each request to avoid hanging forever
|
||||
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
|
||||
// Parse HTTP request
|
||||
req, err := http.ReadRequest(reader)
|
||||
if err != nil {
|
||||
// EOF or timeout is normal when client closes connection or no more requests
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
c.logger.Debug("Client closed HTTP connection")
|
||||
return nil
|
||||
@@ -296,7 +302,6 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
c.logger.Debug("HTTP keep-alive timeout")
|
||||
return nil
|
||||
}
|
||||
// Connection reset by peer is normal - client closed connection abruptly
|
||||
errStr := err.Error()
|
||||
if errors.Is(err, net.ErrClosed) || strings.Contains(errStr, "use of closed network connection") {
|
||||
c.logger.Debug("HTTP connection closed during read", zap.Error(err))
|
||||
@@ -308,13 +313,11 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
c.logger.Debug("Client disconnected abruptly", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
// Check if it looks like garbage data (not a valid HTTP request)
|
||||
if strings.Contains(errStr, "malformed HTTP") {
|
||||
c.logger.Warn("Received malformed HTTP request, possibly due to pipelined requests or protocol error",
|
||||
c.logger.Warn("Received malformed HTTP request",
|
||||
zap.Error(err),
|
||||
zap.String("error_snippet", errStr[:min(len(errStr), 100)]),
|
||||
)
|
||||
// Close connection on malformed request to prevent further errors
|
||||
return nil
|
||||
}
|
||||
c.logger.Error("Failed to parse HTTP request", zap.Error(err))
|
||||
@@ -370,8 +373,6 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
c.logger.Debug("Closing connection as requested by client or server")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Continue to next request on the same connection
|
||||
}
|
||||
}
|
||||
|
||||
@@ -382,7 +383,6 @@ func min(a, b int) int {
|
||||
return b
|
||||
}
|
||||
|
||||
// handleFrames handles incoming frames
|
||||
func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
for {
|
||||
select {
|
||||
@@ -391,7 +391,6 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
default:
|
||||
}
|
||||
|
||||
// Read frame with timeout
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(reader)
|
||||
if err != nil {
|
||||
@@ -399,15 +398,12 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
c.logger.Warn("Read timeout, connection may be dead")
|
||||
return fmt.Errorf("read timeout")
|
||||
}
|
||||
// EOF is normal when client closes connection gracefully
|
||||
if err.Error() == "failed to read frame header: EOF" || err.Error() == "EOF" {
|
||||
c.logger.Info("Client disconnected")
|
||||
return nil
|
||||
}
|
||||
// Check if connection was closed (during shutdown)
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
// Connection was closed intentionally, don't log as error
|
||||
c.logger.Debug("Connection closed during shutdown")
|
||||
return nil
|
||||
default:
|
||||
@@ -415,7 +411,6 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle frame based on type
|
||||
sf := protocol.WithFrame(frame)
|
||||
|
||||
switch sf.Frame.Type {
|
||||
@@ -437,22 +432,18 @@ func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
}
|
||||
}
|
||||
|
||||
// handleHeartbeat handles heartbeat frame
|
||||
func (c *Connection) handleHeartbeat() {
|
||||
c.mu.Lock()
|
||||
c.lastHeartbeat = time.Now()
|
||||
c.mu.Unlock()
|
||||
|
||||
// Send heartbeat ack
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeHeartbeatAck, nil)
|
||||
|
||||
err := c.frameWriter.WriteControl(ackFrame)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to send heartbeat ack", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeatChecker checks for heartbeat timeout
|
||||
func (c *Connection) heartbeatChecker() {
|
||||
ticker := time.NewTicker(constants.HeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
@@ -496,16 +487,19 @@ func (c *Connection) sendError(code, message string) {
|
||||
func (c *Connection) Close() {
|
||||
c.once.Do(func() {
|
||||
protocol.UnregisterConnection()
|
||||
|
||||
close(c.stopCh)
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
// Ensure any in-flight writes return quickly on shutdown to avoid hanging.
|
||||
if c.conn != nil {
|
||||
_ = c.conn.SetDeadline(time.Now())
|
||||
// Prevent race with handleHTTPRequest setting c.conn = nil
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn != nil {
|
||||
_ = conn.SetDeadline(time.Now())
|
||||
}
|
||||
|
||||
if c.frameWriter != nil {
|
||||
@@ -520,8 +514,8 @@ func (c *Connection) Close() {
|
||||
_ = c.session.Close()
|
||||
}
|
||||
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
if c.port > 0 && c.portAlloc != nil {
|
||||
@@ -530,9 +524,6 @@ func (c *Connection) Close() {
|
||||
|
||||
if c.subdomain != "" {
|
||||
c.manager.Unregister(c.subdomain)
|
||||
|
||||
// Clean up connection group when PRIMARY connection closes
|
||||
// (only primary connections have subdomain set)
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
c.groupManager.RemoveGroup(c.tunnelID)
|
||||
}
|
||||
@@ -544,10 +535,9 @@ func (c *Connection) Close() {
|
||||
})
|
||||
}
|
||||
|
||||
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
|
||||
type httpResponseWriter struct {
|
||||
conn net.Conn
|
||||
writer *bufio.Writer // Buffered writer for efficient I/O
|
||||
writer *bufio.Writer
|
||||
header http.Header
|
||||
statusCode int
|
||||
headerWritten bool
|
||||
@@ -594,7 +584,6 @@ func (w *httpResponseWriter) Write(data []byte) (int, error) {
|
||||
return w.writer.Write(data)
|
||||
}
|
||||
|
||||
// handleDataConnect handles a data connection join request
|
||||
func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Reader) error {
|
||||
var req protocol.DataConnectRequest
|
||||
if err := json.Unmarshal(frame.Payload, &req); err != nil {
|
||||
@@ -607,13 +596,11 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Validate the request
|
||||
if c.groupManager == nil {
|
||||
c.sendDataConnectError("not_supported", "Multi-connection not supported")
|
||||
return fmt.Errorf("group manager not available")
|
||||
}
|
||||
|
||||
// Validate auth token
|
||||
if c.authToken != "" && req.Token != c.authToken {
|
||||
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed for data connection")
|
||||
@@ -625,75 +612,13 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
|
||||
return fmt.Errorf("tunnel not found: %s", req.TunnelID)
|
||||
}
|
||||
|
||||
// Validate token against the primary registration token.
|
||||
if group.Token != "" && req.Token != group.Token {
|
||||
c.sendDataConnectError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed for data connection")
|
||||
}
|
||||
|
||||
// Store tunnelID for cleanup
|
||||
c.tunnelID = req.TunnelID
|
||||
|
||||
// For TCP tunnels, the data connection is upgraded to a yamux session and used for
|
||||
// stream forwarding, not framed request/response routing.
|
||||
if group.TunnelType == protocol.TunnelTypeTCP {
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: true,
|
||||
ConnectionID: req.ConnectionID,
|
||||
Message: "Data connection accepted",
|
||||
}
|
||||
|
||||
respData, _ := json.Marshal(resp)
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
|
||||
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
|
||||
return fmt.Errorf("failed to send data connect ack: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("TCP data connection established",
|
||||
zap.String("tunnel_id", req.TunnelID),
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Clear deadline for yamux data-plane.
|
||||
_ = c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Public server acts as yamux Client, client connector acts as yamux Server.
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
group.AddSession(req.ConnectionID, session)
|
||||
defer group.RemoveSession(req.ConnectionID)
|
||||
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Add data connection to group
|
||||
dataConn, err := c.groupManager.AddDataConnection(&req, c.conn)
|
||||
if err != nil {
|
||||
c.sendDataConnectError("join_failed", err.Error())
|
||||
return fmt.Errorf("failed to join connection group: %w", err)
|
||||
}
|
||||
|
||||
// Send success response
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: true,
|
||||
ConnectionID: req.ConnectionID,
|
||||
@@ -712,56 +637,33 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
|
||||
zap.String("connection_id", req.ConnectionID),
|
||||
)
|
||||
|
||||
// Handle data frames on this connection
|
||||
return c.handleDataConnectionFrames(dataConn, reader)
|
||||
}
|
||||
_ = c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// handleDataConnectionFrames handles frames on a data connection
|
||||
func (c *Connection) handleDataConnectionFrames(dataConn *DataConnection, reader *bufio.Reader) error {
|
||||
defer func() {
|
||||
// Get the group and remove this data connection
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok {
|
||||
group.RemoveDataConnection(dataConn.ID)
|
||||
}
|
||||
}()
|
||||
// Server acts as yamux Client, client connector acts as yamux Server
|
||||
bc := &bufferedConn{
|
||||
Conn: c.conn,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-dataConn.stopCh:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
cfg := yamux.DefaultConfig()
|
||||
cfg.EnableKeepAlive = false
|
||||
cfg.LogOutput = io.Discard
|
||||
cfg.AcceptBacklog = constants.YamuxAcceptBacklog
|
||||
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(reader)
|
||||
if err != nil {
|
||||
// Timeout is OK, continue
|
||||
if isTimeoutError(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
session, err := yamux.Client(bc, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init yamux session: %w", err)
|
||||
}
|
||||
c.session = session
|
||||
|
||||
dataConn.mu.Lock()
|
||||
dataConn.LastActive = time.Now()
|
||||
dataConn.mu.Unlock()
|
||||
group.AddSession(req.ConnectionID, session)
|
||||
defer group.RemoveSession(req.ConnectionID)
|
||||
|
||||
sf := protocol.WithFrame(frame)
|
||||
|
||||
switch sf.Frame.Type {
|
||||
case protocol.FrameTypeClose:
|
||||
sf.Close()
|
||||
c.logger.Info("Data connection closed by client",
|
||||
zap.String("connection_id", dataConn.ID))
|
||||
return nil
|
||||
|
||||
default:
|
||||
sf.Close()
|
||||
c.logger.Warn("Unexpected frame type on data connection",
|
||||
zap.String("type", sf.Frame.Type.String()),
|
||||
zap.String("connection_id", dataConn.ID),
|
||||
)
|
||||
}
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
case <-session.CloseChan():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -773,11 +675,9 @@ func isTimeoutError(err error) bool {
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
// Fallback for wrapped errors without net.Error
|
||||
return strings.Contains(err.Error(), "i/o timeout")
|
||||
}
|
||||
|
||||
// sendDataConnectError sends a data connect error response
|
||||
func (c *Connection) sendDataConnectError(code, message string) {
|
||||
resp := protocol.DataConnectResponse{
|
||||
Accepted: false,
|
||||
|
||||
@@ -15,23 +15,11 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
|
||||
type DataConnection struct {
|
||||
ID string
|
||||
Conn net.Conn
|
||||
LastActive time.Time
|
||||
closed bool
|
||||
closedMu sync.RWMutex
|
||||
stopCh chan struct{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type ConnectionGroup struct {
|
||||
TunnelID string
|
||||
Subdomain string
|
||||
Token string
|
||||
PrimaryConn *Connection
|
||||
DataConns map[string]*DataConnection
|
||||
Sessions map[string]*yamux.Session
|
||||
TunnelType protocol.TunnelType
|
||||
RegisteredAt time.Time
|
||||
@@ -50,7 +38,6 @@ func NewConnectionGroup(tunnelID, subdomain, token string, primaryConn *Connecti
|
||||
Subdomain: subdomain,
|
||||
Token: token,
|
||||
PrimaryConn: primaryConn,
|
||||
DataConns: make(map[string]*DataConnection),
|
||||
Sessions: make(map[string]*yamux.Session),
|
||||
TunnelType: tunnelType,
|
||||
RegisteredAt: time.Now(),
|
||||
@@ -146,46 +133,6 @@ func (g *ConnectionGroup) heartbeatLoop(interval, timeout time.Duration) {
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) AddDataConnection(connID string, conn net.Conn) *DataConnection {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
dataConn := &DataConnection{
|
||||
ID: connID,
|
||||
Conn: conn,
|
||||
LastActive: time.Now(),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
g.DataConns[connID] = dataConn
|
||||
g.LastActivity = time.Now()
|
||||
return dataConn
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) RemoveDataConnection(connID string) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
if dataConn, ok := g.DataConns[connID]; ok {
|
||||
dataConn.closedMu.Lock()
|
||||
if !dataConn.closed {
|
||||
dataConn.closed = true
|
||||
close(dataConn.stopCh)
|
||||
if dataConn.Conn != nil {
|
||||
_ = dataConn.Conn.SetDeadline(time.Now())
|
||||
dataConn.Conn.Close()
|
||||
}
|
||||
}
|
||||
dataConn.closedMu.Unlock()
|
||||
delete(g.DataConns, connID)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) DataConnectionCount() int {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
return len(g.DataConns)
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) Close() {
|
||||
g.mu.Lock()
|
||||
|
||||
@@ -197,12 +144,6 @@ func (g *ConnectionGroup) Close() {
|
||||
close(g.stopCh)
|
||||
}
|
||||
|
||||
dataConns := make([]*DataConnection, 0, len(g.DataConns))
|
||||
for _, dataConn := range g.DataConns {
|
||||
dataConns = append(dataConns, dataConn)
|
||||
}
|
||||
g.DataConns = make(map[string]*DataConnection)
|
||||
|
||||
sessions := make([]*yamux.Session, 0, len(g.Sessions))
|
||||
for _, session := range g.Sessions {
|
||||
if session != nil {
|
||||
@@ -213,19 +154,6 @@ func (g *ConnectionGroup) Close() {
|
||||
|
||||
g.mu.Unlock()
|
||||
|
||||
for _, dataConn := range dataConns {
|
||||
dataConn.closedMu.Lock()
|
||||
if !dataConn.closed {
|
||||
dataConn.closed = true
|
||||
close(dataConn.stopCh)
|
||||
if dataConn.Conn != nil {
|
||||
_ = dataConn.Conn.SetDeadline(time.Now())
|
||||
_ = dataConn.Conn.Close()
|
||||
}
|
||||
}
|
||||
dataConn.closedMu.Unlock()
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
_ = session.Close()
|
||||
}
|
||||
@@ -302,7 +230,13 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
|
||||
default:
|
||||
}
|
||||
|
||||
sessions := g.sessionsSnapshot()
|
||||
// Prefer data sessions for data-plane traffic; keep the primary session
|
||||
// as control-plane (client ping/latency), and only fall back to primary
|
||||
// when no data session exists.
|
||||
sessions := g.sessionsSnapshot(false)
|
||||
if len(sessions) == 0 {
|
||||
sessions = g.sessionsSnapshot(true)
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
@@ -380,7 +314,10 @@ func (g *ConnectionGroup) OpenStream() (net.Conn, error) {
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) selectSession() *yamux.Session {
|
||||
sessions := g.sessionsSnapshot()
|
||||
sessions := g.sessionsSnapshot(false)
|
||||
if len(sessions) == 0 {
|
||||
sessions = g.sessionsSnapshot(true)
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -403,7 +340,7 @@ func (g *ConnectionGroup) selectSession() *yamux.Session {
|
||||
return best
|
||||
}
|
||||
|
||||
func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
|
||||
func (g *ConnectionGroup) sessionsSnapshot(includePrimary bool) []*yamux.Session {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
@@ -417,6 +354,9 @@ func (g *ConnectionGroup) sessionsSnapshot() []*yamux.Session {
|
||||
delete(g.Sessions, id)
|
||||
continue
|
||||
}
|
||||
if id == "primary" && !includePrimary {
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,6 @@ package tcp
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -84,26 +82,6 @@ func (m *ConnectionGroupManager) RemoveGroup(tunnelID string) {
|
||||
}
|
||||
}
|
||||
|
||||
// AddDataConnection adds a data connection to a group
|
||||
func (m *ConnectionGroupManager) AddDataConnection(req *protocol.DataConnectRequest, conn net.Conn) (*DataConnection, error) {
|
||||
m.mu.RLock()
|
||||
group, ok := m.groups[req.TunnelID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tunnel not found: %s", req.TunnelID)
|
||||
}
|
||||
|
||||
// Validate token
|
||||
if group.Token != "" && req.Token != group.Token {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
dataConn := group.AddDataConnection(req.ConnectionID, conn)
|
||||
|
||||
return dataConn, nil
|
||||
}
|
||||
|
||||
// cleanupLoop periodically cleans up stale groups
|
||||
func (m *ConnectionGroupManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(m.cleanupInterval)
|
||||
|
||||
83
internal/server/tcp/http_conn_listener.go
Normal file
83
internal/server/tcp/http_conn_listener.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// connQueueListener is a net.Listener backed by a channel of pre-accepted conns.
|
||||
// It lets the TCP/TLS multiplexer hand off HTTP connections to a standard http.Server.
|
||||
type connQueueListener struct {
|
||||
addr net.Addr
|
||||
conns chan net.Conn
|
||||
done chan struct{}
|
||||
once sync.Once
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func newConnQueueListener(addr net.Addr, buffer int) *connQueueListener {
|
||||
if buffer <= 0 {
|
||||
buffer = 1024
|
||||
}
|
||||
return &connQueueListener{
|
||||
addr: addr,
|
||||
conns: make(chan net.Conn, buffer),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *connQueueListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-l.done:
|
||||
return nil, net.ErrClosed
|
||||
case conn := <-l.conns:
|
||||
if conn == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (l *connQueueListener) Close() error {
|
||||
l.once.Do(func() {
|
||||
l.closed.Store(true)
|
||||
close(l.done)
|
||||
l.drain()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *connQueueListener) Addr() net.Addr { return l.addr }
|
||||
|
||||
func (l *connQueueListener) Enqueue(conn net.Conn) bool {
|
||||
if conn == nil {
|
||||
return false
|
||||
}
|
||||
if l.closed.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
select {
|
||||
case l.conns <- conn:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (l *connQueueListener) drain() {
|
||||
for {
|
||||
select {
|
||||
case conn := <-l.conns:
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
_ = conn.SetDeadline(time.Now())
|
||||
_ = conn.Close()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -14,29 +15,30 @@ import (
|
||||
"drip/internal/shared/recovery"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// Listener handles TCP connections with TLS 1.3
|
||||
type Listener struct {
|
||||
address string
|
||||
tlsConfig *tls.Config
|
||||
authToken string
|
||||
manager *tunnel.Manager
|
||||
portAlloc *PortAllocator
|
||||
logger *zap.Logger
|
||||
domain string
|
||||
publicPort int
|
||||
httpHandler http.Handler
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
connections map[string]*Connection
|
||||
connMu sync.RWMutex
|
||||
workerPool *pool.WorkerPool // Worker pool for connection handling
|
||||
recoverer *recovery.Recoverer
|
||||
panicMetrics *recovery.PanicMetrics
|
||||
|
||||
address string
|
||||
tlsConfig *tls.Config
|
||||
authToken string
|
||||
manager *tunnel.Manager
|
||||
portAlloc *PortAllocator
|
||||
logger *zap.Logger
|
||||
domain string
|
||||
publicPort int
|
||||
httpHandler http.Handler
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
connections map[string]*Connection
|
||||
connMu sync.RWMutex
|
||||
workerPool *pool.WorkerPool
|
||||
recoverer *recovery.Recoverer
|
||||
panicMetrics *recovery.PanicMetrics
|
||||
groupManager *ConnectionGroupManager
|
||||
httpServer *http.Server
|
||||
httpListener *connQueueListener
|
||||
}
|
||||
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler) *Listener {
|
||||
@@ -73,7 +75,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the TCP listener
|
||||
func (l *Listener) Start() error {
|
||||
var err error
|
||||
|
||||
@@ -87,13 +88,39 @@ func (l *Listener) Start() error {
|
||||
zap.String("tls_version", "TLS 1.3"),
|
||||
)
|
||||
|
||||
l.httpListener = newConnQueueListener(l.listener.Addr(), 4096)
|
||||
|
||||
l.httpServer = &http.Server{
|
||||
Handler: l.httpHandler,
|
||||
ReadHeaderTimeout: 30 * time.Second,
|
||||
ReadTimeout: 0,
|
||||
WriteTimeout: 0,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
}
|
||||
|
||||
if err := http2.ConfigureServer(l.httpServer, &http2.Server{
|
||||
MaxConcurrentStreams: 1000,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}); err != nil {
|
||||
l.logger.Warn("Failed to configure HTTP/2", zap.Error(err))
|
||||
}
|
||||
|
||||
l.wg.Add(1)
|
||||
go func() {
|
||||
defer l.wg.Done()
|
||||
l.logger.Info("HTTP server started (with context cancellation support)")
|
||||
if err := l.httpServer.Serve(l.httpListener); err != nil && err != http.ErrServerClosed {
|
||||
l.logger.Error("HTTP server error", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
l.wg.Add(1)
|
||||
go l.acceptLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop accepts incoming connections
|
||||
func (l *Listener) acceptLoop() {
|
||||
defer l.wg.Done()
|
||||
defer l.recoverer.Recover("acceptLoop")
|
||||
@@ -105,7 +132,6 @@ func (l *Listener) acceptLoop() {
|
||||
default:
|
||||
}
|
||||
|
||||
// Set accept deadline to allow checking stopCh
|
||||
if tcpListener, ok := l.listener.(*net.TCPListener); ok {
|
||||
tcpListener.SetDeadline(time.Now().Add(1 * time.Second))
|
||||
}
|
||||
@@ -113,7 +139,7 @@ func (l *Listener) acceptLoop() {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue // Timeout is expected due to deadline
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
@@ -143,10 +169,8 @@ func (l *Listener) acceptLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection handles a single client connection
|
||||
func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
defer l.wg.Done()
|
||||
defer netConn.Close()
|
||||
defer l.recoverer.RecoverWithCallback("handleConnection", func(p interface{}) {
|
||||
connID := netConn.RemoteAddr().String()
|
||||
l.connMu.Lock()
|
||||
@@ -160,7 +184,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
// Set read deadline before handshake to prevent slow handshake attacks
|
||||
if err := tlsConn.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||
l.logger.Warn("Failed to set read deadline",
|
||||
zap.String("remote_addr", netConn.RemoteAddr().String()),
|
||||
@@ -177,7 +200,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear the read deadline after successful handshake
|
||||
if err := tlsConn.SetReadDeadline(time.Time{}); err != nil {
|
||||
l.logger.Warn("Failed to clear read deadline",
|
||||
zap.String("remote_addr", netConn.RemoteAddr().String()),
|
||||
@@ -208,7 +230,7 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.groupManager)
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
|
||||
|
||||
connID := netConn.RemoteAddr().String()
|
||||
l.connMu.Lock()
|
||||
@@ -219,12 +241,15 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
l.connMu.Lock()
|
||||
delete(l.connections, connID)
|
||||
l.connMu.Unlock()
|
||||
|
||||
if !conn.IsHandedOff() {
|
||||
netConn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := conn.Handle(); err != nil {
|
||||
errStr := err.Error()
|
||||
|
||||
// Client disconnection errors - normal network behavior, ignore
|
||||
if strings.Contains(errStr, "EOF") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
@@ -232,7 +257,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
// Protocol errors (invalid clients, scanners) are expected - log as WARN
|
||||
if strings.Contains(errStr, "payload too large") ||
|
||||
strings.Contains(errStr, "failed to read registration frame") ||
|
||||
strings.Contains(errStr, "expected register frame") ||
|
||||
@@ -243,7 +267,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
// Legitimate errors (auth failures, registration failures, etc.)
|
||||
l.logger.Error("Connection handling failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
@@ -252,12 +275,24 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the listener and closes all connections
|
||||
func (l *Listener) Stop() error {
|
||||
l.logger.Info("Stopping TCP listener")
|
||||
|
||||
close(l.stopCh)
|
||||
|
||||
if l.httpServer != nil {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := l.httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
l.logger.Warn("HTTP server shutdown error", zap.Error(err))
|
||||
}
|
||||
l.logger.Info("HTTP server shutdown complete")
|
||||
}
|
||||
|
||||
if l.httpListener != nil {
|
||||
l.httpListener.Close()
|
||||
}
|
||||
|
||||
if l.listener != nil {
|
||||
if err := l.listener.Close(); err != nil {
|
||||
l.logger.Error("Failed to close listener", zap.Error(err))
|
||||
@@ -284,7 +319,6 @@ func (l *Listener) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveConnections returns the number of active connections
|
||||
func (l *Listener) GetActiveConnections() int {
|
||||
l.connMu.RLock()
|
||||
defer l.connMu.RUnlock()
|
||||
|
||||
@@ -39,7 +39,7 @@ func (c *Connection) handleTCPTunnel(reader *bufio.Reader) error {
|
||||
c.session = session
|
||||
|
||||
openStream := session.Open
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
if c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
group.AddSession("primary", session)
|
||||
openStream = group.OpenStream
|
||||
@@ -78,7 +78,7 @@ func (c *Connection) handleHTTPProxyTunnel(reader *bufio.Reader) error {
|
||||
c.session = session
|
||||
|
||||
openStream := session.Open
|
||||
if c.tunnelID != "" && c.groupManager != nil {
|
||||
if c.groupManager != nil {
|
||||
if group, ok := c.groupManager.GetGroup(c.tunnelID); ok && group != nil {
|
||||
group.AddSession("primary", session)
|
||||
openStream = group.OpenStream
|
||||
|
||||
Reference in New Issue
Block a user