mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
feat (tunnel/manager): Optimized concurrency security and resource management for tunnel registration.
A CAS loop is used to implement atomic operations on the global tunnel counter, avoiding race conditions. Add a rollback mechanism to ensure that the occupied counter resources are properly released when registration fails. Concurrency safety for IP rate limiting is achieved by using atomic operations and locks in combination. Add appropriate resource rollback logic at each faulty branch to prevent resource leaks.
This commit is contained in:
@@ -168,20 +168,33 @@ func (m *Manager) Register(conn *websocket.Conn, customSubdomain string) (string
|
||||
|
||||
// RegisterWithIP registers a new tunnel with IP tracking
|
||||
func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, remoteIP string) (string, error) {
|
||||
// Check global limits first (lock-free read)
|
||||
if m.tunnelCount.Load() >= int64(m.maxTunnels) {
|
||||
m.logger.Warn("Maximum tunnel limit reached",
|
||||
zap.Int64("current", m.tunnelCount.Load()),
|
||||
zap.Int("max", m.maxTunnels),
|
||||
)
|
||||
return "", ErrTooManyTunnels
|
||||
// Reserve a global slot atomically using CAS loop
|
||||
for {
|
||||
current := m.tunnelCount.Load()
|
||||
if current >= int64(m.maxTunnels) {
|
||||
m.logger.Warn("Maximum tunnel limit reached",
|
||||
zap.Int64("current", current),
|
||||
zap.Int("max", m.maxTunnels),
|
||||
)
|
||||
return "", ErrTooManyTunnels
|
||||
}
|
||||
if m.tunnelCount.CompareAndSwap(current, current+1) {
|
||||
break
|
||||
}
|
||||
// CAS failed, another goroutine modified the counter, retry
|
||||
}
|
||||
|
||||
// Check per-IP limits
|
||||
// Rollback helper for global counter
|
||||
rollbackGlobal := func() {
|
||||
m.tunnelCount.Add(-1)
|
||||
}
|
||||
|
||||
// Check per-IP limits and reserve slot atomically
|
||||
if remoteIP != "" {
|
||||
m.ipMu.Lock()
|
||||
if !m.checkRateLimitLocked(remoteIP) {
|
||||
m.ipMu.Unlock()
|
||||
rollbackGlobal()
|
||||
m.logger.Warn("Rate limit exceeded",
|
||||
zap.String("ip", remoteIP),
|
||||
zap.Int("limit", m.rateLimit),
|
||||
@@ -190,25 +203,48 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
|
||||
}
|
||||
|
||||
if m.tunnelsByIP[remoteIP] >= m.maxTunnelsPerIP {
|
||||
currentPerIP := m.tunnelsByIP[remoteIP]
|
||||
m.ipMu.Unlock()
|
||||
rollbackGlobal()
|
||||
m.logger.Warn("Per-IP tunnel limit reached",
|
||||
zap.String("ip", remoteIP),
|
||||
zap.Int("current", m.tunnelsByIP[remoteIP]),
|
||||
zap.Int("current", currentPerIP),
|
||||
zap.Int("max", m.maxTunnelsPerIP),
|
||||
)
|
||||
return "", ErrTooManyPerIP
|
||||
}
|
||||
|
||||
// Reserve per-IP slot while still holding the lock
|
||||
m.tunnelsByIP[remoteIP]++
|
||||
m.ipMu.Unlock()
|
||||
}
|
||||
|
||||
// Rollback helper for per-IP counter
|
||||
rollbackPerIP := func() {
|
||||
if remoteIP != "" {
|
||||
m.ipMu.Lock()
|
||||
if m.tunnelsByIP[remoteIP] > 0 {
|
||||
m.tunnelsByIP[remoteIP]--
|
||||
if m.tunnelsByIP[remoteIP] == 0 {
|
||||
delete(m.tunnelsByIP, remoteIP)
|
||||
}
|
||||
}
|
||||
m.ipMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
var subdomain string
|
||||
|
||||
if customSubdomain != "" {
|
||||
// Validate custom subdomain
|
||||
if !utils.ValidateSubdomain(customSubdomain) {
|
||||
rollbackPerIP()
|
||||
rollbackGlobal()
|
||||
return "", ErrInvalidSubdomain
|
||||
}
|
||||
if utils.IsReserved(customSubdomain) {
|
||||
rollbackPerIP()
|
||||
rollbackGlobal()
|
||||
return "", ErrReservedSubdomain
|
||||
}
|
||||
|
||||
@@ -217,6 +253,8 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
|
||||
s.mu.Lock()
|
||||
if s.used[customSubdomain] {
|
||||
s.mu.Unlock()
|
||||
rollbackPerIP()
|
||||
rollbackGlobal()
|
||||
return "", ErrSubdomainTaken
|
||||
}
|
||||
subdomain = customSubdomain
|
||||
@@ -240,14 +278,6 @@ func (m *Manager) RegisterWithIP(conn *websocket.Conn, customSubdomain string, r
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Update counters
|
||||
m.tunnelCount.Add(1)
|
||||
if remoteIP != "" {
|
||||
m.ipMu.Lock()
|
||||
m.tunnelsByIP[remoteIP]++
|
||||
m.ipMu.Unlock()
|
||||
}
|
||||
|
||||
// Get connection and start write pump
|
||||
s := m.getShard(subdomain)
|
||||
s.mu.RLock()
|
||||
|
||||
Reference in New Issue
Block a user