mirror of
https://github.com/BEDOLAGA-DEV/remnawave-bedolaga-telegram-bot.git
synced 2026-05-01 18:27:25 +00:00
Merge pull request #2561 from BEDOLAGA-DEV/fix/traffic-429-rate-limit
fix: resolve 429 rate limiting on traffic page
This commit is contained in:
@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix='/admin/traffic', tags=['Admin Traffic'])
|
||||
|
||||
_ALLOWED_PERIODS = frozenset({1, 3, 7, 14, 30})
|
||||
_CONCURRENCY_LIMIT = 20 # Max parallel per-user API calls
|
||||
_CONCURRENCY_LIMIT = 5 # Max parallel API calls to avoid rate limiting
|
||||
|
||||
# In-memory cache: {period_days: (timestamp, aggregated_data, nodes_info)}
|
||||
_traffic_cache: dict[int, tuple[float, dict[str, dict[str, int]], list[TrafficNodeInfo]]] = {}
|
||||
@@ -59,9 +59,8 @@ async def _aggregate_traffic(
|
||||
) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo]]:
|
||||
"""Aggregate per-user traffic across all nodes for a given period.
|
||||
|
||||
Uses get_bandwidth_stats_user() per user (same API as the working
|
||||
AdminUserDetail page) instead of get_bandwidth_stats_node_users()
|
||||
which returns UUIDs in a format that may not match the bot DB.
|
||||
Uses per-node endpoint get_bandwidth_stats_node_users() to fetch traffic
|
||||
for all users at once per node — O(nodes) API calls instead of O(users).
|
||||
|
||||
Returns (user_traffic, nodes_info) where:
|
||||
user_traffic = {remnawave_uuid: {node_uuid: total_bytes, ...}}
|
||||
@@ -90,24 +89,29 @@ async def _aggregate_traffic(
|
||||
start_str = start_date.strftime('%Y-%m-%d')
|
||||
end_str = end_date.strftime('%Y-%m-%d')
|
||||
|
||||
user_uuids_set = set(user_uuids)
|
||||
|
||||
async with service.get_api_client() as api:
|
||||
# Get all nodes for column headers
|
||||
nodes = await api.get_all_nodes()
|
||||
|
||||
# Fetch per-user bandwidth stats in parallel with concurrency limit.
|
||||
# Response format: {series: [{uuid: NODE_UUID, total: bytes, ...}, ...]}
|
||||
# Fetch per-node user stats — O(nodes) calls instead of O(users)
|
||||
semaphore = asyncio.Semaphore(_CONCURRENCY_LIMIT)
|
||||
|
||||
async def fetch_user_stats(user_uuid: str):
|
||||
async def fetch_node_users(node):
|
||||
async with semaphore:
|
||||
try:
|
||||
stats = await api.get_bandwidth_stats_user(user_uuid, start_str, end_str)
|
||||
return user_uuid, stats
|
||||
stats = await api.get_bandwidth_stats_node_users(
|
||||
node.uuid,
|
||||
start_str,
|
||||
end_str,
|
||||
top_users_limit=9999,
|
||||
)
|
||||
return node.uuid, stats
|
||||
except Exception:
|
||||
logger.warning('Failed to get traffic for user %s', user_uuid[:8], exc_info=True)
|
||||
return user_uuid, None
|
||||
logger.warning('Failed to get traffic for node %s', node.name, exc_info=True)
|
||||
return node.uuid, None
|
||||
|
||||
results = await asyncio.gather(*(fetch_user_stats(uid) for uid in user_uuids))
|
||||
results = await asyncio.gather(*(fetch_node_users(n) for n in nodes))
|
||||
|
||||
nodes_info: list[TrafficNodeInfo] = [
|
||||
TrafficNodeInfo(node_uuid=node.uuid, node_name=node.name, country_code=node.country_code) for node in nodes
|
||||
@@ -115,17 +119,18 @@ async def _aggregate_traffic(
|
||||
nodes_info.sort(key=lambda n: n.node_name)
|
||||
|
||||
user_traffic: dict[str, dict[str, int]] = {}
|
||||
for user_uuid, stats in results:
|
||||
for node_uuid, stats in results:
|
||||
if not isinstance(stats, dict):
|
||||
continue
|
||||
node_traffic: dict[str, int] = {}
|
||||
for series_item in stats.get('series', []):
|
||||
node_uuid = series_item.get('uuid', '')
|
||||
total = int(series_item.get('total', 0))
|
||||
if node_uuid and total > 0:
|
||||
node_traffic[node_uuid] = node_traffic.get(node_uuid, 0) + total
|
||||
if node_traffic:
|
||||
user_traffic[user_uuid] = node_traffic
|
||||
# The node-users endpoint may return user data under 'users' or 'series'
|
||||
entries = stats.get('users', []) or stats.get('series', [])
|
||||
if not entries and stats:
|
||||
logger.debug('Unexpected node-users response keys: %s', list(stats.keys()))
|
||||
for user_entry in entries:
|
||||
uid = user_entry.get('uuid', '')
|
||||
total = int(user_entry.get('total', 0))
|
||||
if uid and total > 0 and uid in user_uuids_set:
|
||||
user_traffic.setdefault(uid, {})[node_uuid] = user_traffic.get(uid, {}).get(node_uuid, 0) + total
|
||||
|
||||
_traffic_cache[period_days] = (now, user_traffic, nodes_info)
|
||||
return user_traffic, nodes_info
|
||||
|
||||
70
app/external/remnawave_api.py
vendored
70
app/external/remnawave_api.py
vendored
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@@ -366,32 +367,63 @@ class RemnaWaveAPI:
|
||||
raise RemnaWaveAPIError('Session not initialized. Use async context manager.')
|
||||
|
||||
url = f'{self.base_url}{endpoint}'
|
||||
max_retries = 3
|
||||
base_delay = 1.0
|
||||
|
||||
try:
|
||||
kwargs = {'url': url, 'params': params}
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
kwargs = {'url': url, 'params': params}
|
||||
|
||||
if data:
|
||||
kwargs['json'] = data
|
||||
if data:
|
||||
kwargs['json'] = data
|
||||
|
||||
async with self.session.request(method, **kwargs) as response:
|
||||
response_text = await response.text()
|
||||
async with self.session.request(method, **kwargs) as response:
|
||||
response_text = await response.text()
|
||||
|
||||
try:
|
||||
response_data = json.loads(response_text) if response_text else {}
|
||||
except json.JSONDecodeError:
|
||||
response_data = {'raw_response': response_text}
|
||||
try:
|
||||
response_data = json.loads(response_text) if response_text else {}
|
||||
except json.JSONDecodeError:
|
||||
response_data = {'raw_response': response_text}
|
||||
|
||||
if response.status >= 400:
|
||||
error_message = response_data.get('message', f'HTTP {response.status}')
|
||||
logger.error(f'API Error {response.status}: {error_message}')
|
||||
logger.error(f'Response: {response_text[:500]}')
|
||||
raise RemnaWaveAPIError(error_message, response.status, response_data)
|
||||
if response.status == 429 and attempt < max_retries:
|
||||
retry_after = float(response.headers.get('Retry-After', base_delay * (2**attempt)))
|
||||
logger.warning(
|
||||
'Rate limited (429) on %s %s, retry %d/%d after %.1fs',
|
||||
method,
|
||||
endpoint,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
retry_after,
|
||||
)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
|
||||
return response_data
|
||||
if response.status >= 400:
|
||||
error_message = response_data.get('message', f'HTTP {response.status}')
|
||||
logger.error(f'API Error {response.status}: {error_message}')
|
||||
logger.error(f'Response: {response_text[:500]}')
|
||||
raise RemnaWaveAPIError(error_message, response.status, response_data)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f'Request failed: {e}')
|
||||
raise RemnaWaveAPIError(f'Request failed: {e!s}')
|
||||
return response_data
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
if attempt < max_retries:
|
||||
delay = base_delay * (2**attempt)
|
||||
logger.warning(
|
||||
'Request failed on %s %s: %s, retry %d/%d after %.1fs',
|
||||
method,
|
||||
endpoint,
|
||||
e,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
logger.error(f'Request failed: {e}')
|
||||
raise RemnaWaveAPIError(f'Request failed: {e!s}')
|
||||
|
||||
raise RemnaWaveAPIError(f'Max retries exceeded for {method} {endpoint}')
|
||||
|
||||
async def create_user(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user