diff --git a/internal/server/tunnel/manager.go b/internal/server/tunnel/manager.go index 9632ff5..44ffcee 100644 --- a/internal/server/tunnel/manager.go +++ b/internal/server/tunnel/manager.go @@ -168,20 +168,33 @@ func (m *Manager) Register(conn *websocket.Conn, customSubdomain string) (string // RegisterWithIP registers a new tunnel with IP tracking func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, remoteIP string) (string, error) { - // Check global limits first (lock-free read) - if m.tunnelCount.Load() >= int64(m.maxTunnels) { - m.logger.Warn("Maximum tunnel limit reached", - zap.Int64("current", m.tunnelCount.Load()), - zap.Int("max", m.maxTunnels), - ) - return "", ErrTooManyTunnels + // Reserve a global slot atomically using CAS loop + for { + current := m.tunnelCount.Load() + if current >= int64(m.maxTunnels) { + m.logger.Warn("Maximum tunnel limit reached", + zap.Int64("current", current), + zap.Int("max", m.maxTunnels), + ) + return "", ErrTooManyTunnels + } + if m.tunnelCount.CompareAndSwap(current, current+1) { + break + } + // CAS failed, another goroutine modified the counter, retry } - // Check per-IP limits + // Rollback helper for global counter + rollbackGlobal := func() { + m.tunnelCount.Add(-1) + } + + // Check per-IP limits and reserve slot atomically if remoteIP != "" { m.ipMu.Lock() if !m.checkRateLimitLocked(remoteIP) { m.ipMu.Unlock() + rollbackGlobal() m.logger.Warn("Rate limit exceeded", zap.String("ip", remoteIP), zap.Int("limit", m.rateLimit), @@ -190,25 +203,48 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r } if m.tunnelsByIP[remoteIP] >= m.maxTunnelsPerIP { + currentPerIP := m.tunnelsByIP[remoteIP] m.ipMu.Unlock() + rollbackGlobal() m.logger.Warn("Per-IP tunnel limit reached", zap.String("ip", remoteIP), - zap.Int("current", m.tunnelsByIP[remoteIP]), + zap.Int("current", currentPerIP), zap.Int("max", m.maxTunnelsPerIP), ) return "", ErrTooManyPerIP } + + // Reserve per-IP slot while still holding the lock + m.tunnelsByIP[remoteIP]++ m.ipMu.Unlock() } + // Rollback helper for per-IP counter + rollbackPerIP := func() { + if remoteIP != "" { + m.ipMu.Lock() + if m.tunnelsByIP[remoteIP] > 0 { + m.tunnelsByIP[remoteIP]-- + if m.tunnelsByIP[remoteIP] == 0 { + delete(m.tunnelsByIP, remoteIP) + } + } + m.ipMu.Unlock() + } + } + var subdomain string if customSubdomain != "" { // Validate custom subdomain if !utils.ValidateSubdomain(customSubdomain) { + rollbackPerIP() + rollbackGlobal() return "", ErrInvalidSubdomain } if utils.IsReserved(customSubdomain) { + rollbackPerIP() + rollbackGlobal() return "", ErrReservedSubdomain } @@ -217,6 +253,8 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r s.mu.Lock() if s.used[customSubdomain] { s.mu.Unlock() + rollbackPerIP() + rollbackGlobal() return "", ErrSubdomainTaken } subdomain = customSubdomain @@ -240,14 +278,6 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r s.mu.Unlock() } - // Update counters - m.tunnelCount.Add(1) - if remoteIP != "" { - m.ipMu.Lock() - m.tunnelsByIP[remoteIP]++ - m.ipMu.Unlock() - } - // Get connection and start write pump s := m.getShard(subdomain) s.mu.RLock()