From e05f128a9c89d73af14caae4c27a6f1ebf6cc8f9 Mon Sep 17 00:00:00 2001 From: Gouryella Date: Wed, 24 Dec 2025 10:13:30 +0800 Subject: [PATCH] feat (tunnel/manager): Optimized concurrency security and resource management for tunnel registration. A CAS loop is used to implement atomic operations on the global tunnel counter, avoiding race conditions. Add a rollback mechanism to ensure that the occupied counter resources are properly released when registration fails. Concurrency safety for IP rate limiting is achieved by using atomic operations and locks in combination. Add appropriate resource rollback logic at each faulty branch to prevent resource leaks. --- internal/server/tunnel/manager.go | 64 +++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/internal/server/tunnel/manager.go b/internal/server/tunnel/manager.go index 9632ff5..44ffcee 100644 --- a/internal/server/tunnel/manager.go +++ b/internal/server/tunnel/manager.go @@ -168,20 +168,33 @@ func (m *Manager) Register(conn *websocket.Conn, customSubdomain string) (string // RegisterWithIP registers a new tunnel with IP tracking func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, remoteIP string) (string, error) { - // Check global limits first (lock-free read) - if m.tunnelCount.Load() >= int64(m.maxTunnels) { - m.logger.Warn("Maximum tunnel limit reached", - zap.Int64("current", m.tunnelCount.Load()), - zap.Int("max", m.maxTunnels), - ) - return "", ErrTooManyTunnels + // Reserve a global slot atomically using CAS loop + for { + current := m.tunnelCount.Load() + if current >= int64(m.maxTunnels) { + m.logger.Warn("Maximum tunnel limit reached", + zap.Int64("current", current), + zap.Int("max", m.maxTunnels), + ) + return "", ErrTooManyTunnels + } + if m.tunnelCount.CompareAndSwap(current, current+1) { + break + } + // CAS failed, another goroutine modified the counter, retry } - // Check per-IP limits + // Rollback helper for global counter + rollbackGlobal := func() { + m.tunnelCount.Add(-1) + } + + // Check per-IP limits and reserve slot atomically if remoteIP != "" { m.ipMu.Lock() if !m.checkRateLimitLocked(remoteIP) { m.ipMu.Unlock() + rollbackGlobal() m.logger.Warn("Rate limit exceeded", zap.String("ip", remoteIP), zap.Int("limit", m.rateLimit), @@ -190,25 +203,48 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r } if m.tunnelsByIP[remoteIP] >= m.maxTunnelsPerIP { + currentPerIP := m.tunnelsByIP[remoteIP] m.ipMu.Unlock() + rollbackGlobal() m.logger.Warn("Per-IP tunnel limit reached", zap.String("ip", remoteIP), - zap.Int("current", m.tunnelsByIP[remoteIP]), + zap.Int("current", currentPerIP), zap.Int("max", m.maxTunnelsPerIP), ) return "", ErrTooManyPerIP } + + // Reserve per-IP slot while still holding the lock + m.tunnelsByIP[remoteIP]++ m.ipMu.Unlock() } + // Rollback helper for per-IP counter + rollbackPerIP := func() { + if remoteIP != "" { + m.ipMu.Lock() + if m.tunnelsByIP[remoteIP] > 0 { + m.tunnelsByIP[remoteIP]-- + if m.tunnelsByIP[remoteIP] == 0 { + delete(m.tunnelsByIP, remoteIP) + } + } + m.ipMu.Unlock() + } + } + var subdomain string if customSubdomain != "" { // Validate custom subdomain if !utils.ValidateSubdomain(customSubdomain) { + rollbackPerIP() + rollbackGlobal() return "", ErrInvalidSubdomain } if utils.IsReserved(customSubdomain) { + rollbackPerIP() + rollbackGlobal() return "", ErrReservedSubdomain } @@ -217,6 +253,8 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r s.mu.Lock() if s.used[customSubdomain] { s.mu.Unlock() + rollbackPerIP() + rollbackGlobal() return "", ErrSubdomainTaken } subdomain = customSubdomain @@ -240,14 +278,6 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r s.mu.Unlock() } - // Update counters - m.tunnelCount.Add(1) - if remoteIP != "" { - m.ipMu.Lock() - m.tunnelsByIP[remoteIP]++ - m.ipMu.Unlock() - } - // Get connection and start write pump s := m.getShard(subdomain) s.mu.RLock()