feat (tunnel/manager): Optimized concurrency security and resource management for tunnel registration.

A CAS loop is used to implement atomic operations on the global tunnel counter, avoiding race conditions.
Add a rollback mechanism to ensure that the occupied counter resources are properly released when registration fails.
Concurrency safety for IP rate limiting is achieved by using atomic operations and locks in combination.
Add appropriate resource rollback logic at each faulty branch to prevent resource leaks.
This commit is contained in:
Gouryella
2025-12-24 10:13:30 +08:00
parent 88e4525bf6
commit e05f128a9c

View File

@@ -168,20 +168,33 @@ func (m *Manager) Register(conn *websocket.Conn, customSubdomain string) (string
// RegisterWithIP registers a new tunnel with IP tracking
func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, remoteIP string) (string, error) {
// Check global limits first (lock-free read)
if m.tunnelCount.Load() >= int64(m.maxTunnels) {
m.logger.Warn("Maximum tunnel limit reached",
zap.Int64("current", m.tunnelCount.Load()),
zap.Int("max", m.maxTunnels),
)
return "", ErrTooManyTunnels
// Reserve a global slot atomically using CAS loop
for {
current := m.tunnelCount.Load()
if current >= int64(m.maxTunnels) {
m.logger.Warn("Maximum tunnel limit reached",
zap.Int64("current", current),
zap.Int("max", m.maxTunnels),
)
return "", ErrTooManyTunnels
}
if m.tunnelCount.CompareAndSwap(current, current+1) {
break
}
// CAS failed, another goroutine modified the counter, retry
}
// Check per-IP limits
// Rollback helper for global counter
rollbackGlobal := func() {
m.tunnelCount.Add(-1)
}
// Check per-IP limits and reserve slot atomically
if remoteIP != "" {
m.ipMu.Lock()
if !m.checkRateLimitLocked(remoteIP) {
m.ipMu.Unlock()
rollbackGlobal()
m.logger.Warn("Rate limit exceeded",
zap.String("ip", remoteIP),
zap.Int("limit", m.rateLimit),
@@ -190,25 +203,48 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
}
if m.tunnelsByIP[remoteIP] >= m.maxTunnelsPerIP {
currentPerIP := m.tunnelsByIP[remoteIP]
m.ipMu.Unlock()
rollbackGlobal()
m.logger.Warn("Per-IP tunnel limit reached",
zap.String("ip", remoteIP),
zap.Int("current", m.tunnelsByIP[remoteIP]),
zap.Int("current", currentPerIP),
zap.Int("max", m.maxTunnelsPerIP),
)
return "", ErrTooManyPerIP
}
// Reserve per-IP slot while still holding the lock
m.tunnelsByIP[remoteIP]++
m.ipMu.Unlock()
}
// Rollback helper for per-IP counter
rollbackPerIP := func() {
if remoteIP != "" {
m.ipMu.Lock()
if m.tunnelsByIP[remoteIP] > 0 {
m.tunnelsByIP[remoteIP]--
if m.tunnelsByIP[remoteIP] == 0 {
delete(m.tunnelsByIP, remoteIP)
}
}
m.ipMu.Unlock()
}
}
var subdomain string
if customSubdomain != "" {
// Validate custom subdomain
if !utils.ValidateSubdomain(customSubdomain) {
rollbackPerIP()
rollbackGlobal()
return "", ErrInvalidSubdomain
}
if utils.IsReserved(customSubdomain) {
rollbackPerIP()
rollbackGlobal()
return "", ErrReservedSubdomain
}
@@ -217,6 +253,8 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
s.mu.Lock()
if s.used[customSubdomain] {
s.mu.Unlock()
rollbackPerIP()
rollbackGlobal()
return "", ErrSubdomainTaken
}
subdomain = customSubdomain
@@ -240,14 +278,6 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
s.mu.Unlock()
}
// Update counters
m.tunnelCount.Add(1)
if remoteIP != "" {
m.ipMu.Lock()
m.tunnelsByIP[remoteIP]++
m.ipMu.Unlock()
}
// Get connection and start write pump
s := m.getShard(subdomain)
s.mu.RLock()