Merge pull request #2561 from BEDOLAGA-DEV/fix/traffic-429-rate-limit

fix: resolve 429 rate limiting on traffic page
This commit is contained in:
Egor
2026-02-07 09:49:21 +03:00
committed by GitHub
2 changed files with 78 additions and 41 deletions

View File

@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix='/admin/traffic', tags=['Admin Traffic'])
_ALLOWED_PERIODS = frozenset({1, 3, 7, 14, 30})
_CONCURRENCY_LIMIT = 20 # Max parallel per-user API calls
_CONCURRENCY_LIMIT = 5 # Max parallel API calls to avoid rate limiting
# In-memory cache: {period_days: (timestamp, aggregated_data, nodes_info)}
_traffic_cache: dict[int, tuple[float, dict[str, dict[str, int]], list[TrafficNodeInfo]]] = {}
@@ -59,9 +59,8 @@ async def _aggregate_traffic(
) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo]]:
"""Aggregate per-user traffic across all nodes for a given period.
Uses get_bandwidth_stats_user() per user (same API as the working
AdminUserDetail page) instead of get_bandwidth_stats_node_users()
which returns UUIDs in a format that may not match the bot DB.
Uses per-node endpoint get_bandwidth_stats_node_users() to fetch traffic
for all users at once per node — O(nodes) API calls instead of O(users).
Returns (user_traffic, nodes_info) where:
user_traffic = {remnawave_uuid: {node_uuid: total_bytes, ...}}
@@ -90,24 +89,29 @@ async def _aggregate_traffic(
start_str = start_date.strftime('%Y-%m-%d')
end_str = end_date.strftime('%Y-%m-%d')
user_uuids_set = set(user_uuids)
async with service.get_api_client() as api:
# Get all nodes for column headers
nodes = await api.get_all_nodes()
# Fetch per-user bandwidth stats in parallel with concurrency limit.
# Response format: {series: [{uuid: NODE_UUID, total: bytes, ...}, ...]}
# Fetch per-node user stats — O(nodes) calls instead of O(users)
semaphore = asyncio.Semaphore(_CONCURRENCY_LIMIT)
async def fetch_user_stats(user_uuid: str):
async def fetch_node_users(node):
async with semaphore:
try:
stats = await api.get_bandwidth_stats_user(user_uuid, start_str, end_str)
return user_uuid, stats
stats = await api.get_bandwidth_stats_node_users(
node.uuid,
start_str,
end_str,
top_users_limit=9999,
)
return node.uuid, stats
except Exception:
logger.warning('Failed to get traffic for user %s', user_uuid[:8], exc_info=True)
return user_uuid, None
logger.warning('Failed to get traffic for node %s', node.name, exc_info=True)
return node.uuid, None
results = await asyncio.gather(*(fetch_user_stats(uid) for uid in user_uuids))
results = await asyncio.gather(*(fetch_node_users(n) for n in nodes))
nodes_info: list[TrafficNodeInfo] = [
TrafficNodeInfo(node_uuid=node.uuid, node_name=node.name, country_code=node.country_code) for node in nodes
@@ -115,17 +119,18 @@ async def _aggregate_traffic(
nodes_info.sort(key=lambda n: n.node_name)
user_traffic: dict[str, dict[str, int]] = {}
for user_uuid, stats in results:
for node_uuid, stats in results:
if not isinstance(stats, dict):
continue
node_traffic: dict[str, int] = {}
for series_item in stats.get('series', []):
node_uuid = series_item.get('uuid', '')
total = int(series_item.get('total', 0))
if node_uuid and total > 0:
node_traffic[node_uuid] = node_traffic.get(node_uuid, 0) + total
if node_traffic:
user_traffic[user_uuid] = node_traffic
# The node-users endpoint may return user data under 'users' or 'series'
entries = stats.get('users', []) or stats.get('series', [])
if not entries and stats:
logger.debug('Unexpected node-users response keys: %s', list(stats.keys()))
for user_entry in entries:
uid = user_entry.get('uuid', '')
total = int(user_entry.get('total', 0))
if uid and total > 0 and uid in user_uuids_set:
user_traffic.setdefault(uid, {})[node_uuid] = user_traffic.get(uid, {}).get(node_uuid, 0) + total
_traffic_cache[period_days] = (now, user_traffic, nodes_info)
return user_traffic, nodes_info

View File

@@ -1,3 +1,4 @@
import asyncio
import base64
import json
import logging
@@ -366,32 +367,63 @@ class RemnaWaveAPI:
raise RemnaWaveAPIError('Session not initialized. Use async context manager.')
url = f'{self.base_url}{endpoint}'
max_retries = 3
base_delay = 1.0
try:
kwargs = {'url': url, 'params': params}
for attempt in range(max_retries + 1):
try:
kwargs = {'url': url, 'params': params}
if data:
kwargs['json'] = data
if data:
kwargs['json'] = data
async with self.session.request(method, **kwargs) as response:
response_text = await response.text()
async with self.session.request(method, **kwargs) as response:
response_text = await response.text()
try:
response_data = json.loads(response_text) if response_text else {}
except json.JSONDecodeError:
response_data = {'raw_response': response_text}
try:
response_data = json.loads(response_text) if response_text else {}
except json.JSONDecodeError:
response_data = {'raw_response': response_text}
if response.status >= 400:
error_message = response_data.get('message', f'HTTP {response.status}')
logger.error(f'API Error {response.status}: {error_message}')
logger.error(f'Response: {response_text[:500]}')
raise RemnaWaveAPIError(error_message, response.status, response_data)
if response.status == 429 and attempt < max_retries:
retry_after = float(response.headers.get('Retry-After', base_delay * (2**attempt)))
logger.warning(
'Rate limited (429) on %s %s, retry %d/%d after %.1fs',
method,
endpoint,
attempt + 1,
max_retries,
retry_after,
)
await asyncio.sleep(retry_after)
continue
return response_data
if response.status >= 400:
error_message = response_data.get('message', f'HTTP {response.status}')
logger.error(f'API Error {response.status}: {error_message}')
logger.error(f'Response: {response_text[:500]}')
raise RemnaWaveAPIError(error_message, response.status, response_data)
except aiohttp.ClientError as e:
logger.error(f'Request failed: {e}')
raise RemnaWaveAPIError(f'Request failed: {e!s}')
return response_data
except aiohttp.ClientError as e:
if attempt < max_retries:
delay = base_delay * (2**attempt)
logger.warning(
'Request failed on %s %s: %s, retry %d/%d after %.1fs',
method,
endpoint,
e,
attempt + 1,
max_retries,
delay,
)
await asyncio.sleep(delay)
continue
logger.error(f'Request failed: {e}')
raise RemnaWaveAPIError(f'Request failed: {e!s}')
raise RemnaWaveAPIError(f'Max retries exceeded for {method} {endpoint}')
async def create_user(
self,