From fa01819674b2d2abb0d05b470559b09eb43abef8 Mon Sep 17 00:00:00 2001 From: Fringg Date: Sat, 7 Feb 2026 09:30:20 +0300 Subject: [PATCH] feat: add tariff filter, fix traffic data aggregation - Switch from get_bandwidth_stats_node_users (broken UUID matching) to get_bandwidth_stats_user per user (same API as working detail page) - Add tariff filter with available_tariffs in response - Add concurrency-limited parallel per-user bandwidth stats fetching --- app/cabinet/routes/admin_traffic.py | 91 +++++++++++++++++++---------- app/cabinet/schemas/traffic.py | 1 + 2 files changed, 60 insertions(+), 32 deletions(-) diff --git a/app/cabinet/routes/admin_traffic.py b/app/cabinet/routes/admin_traffic.py index ac62aa14..c412f9d3 100644 --- a/app/cabinet/routes/admin_traffic.py +++ b/app/cabinet/routes/admin_traffic.py @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix='/admin/traffic', tags=['Admin Traffic']) _ALLOWED_PERIODS = frozenset({1, 3, 7, 14, 30}) -_MAX_USERS_PER_NODE = 100_000 +_CONCURRENCY_LIMIT = 20 # Max parallel per-user API calls # In-memory cache: {period_days: (timestamp, aggregated_data, nodes_info)} _traffic_cache: dict[int, tuple[float, dict[str, dict[str, int]], list[TrafficNodeInfo]]] = {} @@ -54,9 +54,15 @@ def _validate_period(period: int) -> None: ) -async def _aggregate_traffic(period_days: int) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo]]: +async def _aggregate_traffic( + period_days: int, user_uuids: list[str] +) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo]]: """Aggregate per-user traffic across all nodes for a given period. + Uses get_bandwidth_stats_user() per user (same API as the working + AdminUserDetail page) instead of get_bandwidth_stats_node_users() + which returns UUIDs in a format that may not match the bot DB. + Returns (user_traffic, nodes_info) where: user_traffic = {remnawave_uuid: {node_uuid: total_bytes, ...}} nodes_info = [TrafficNodeInfo, ...] @@ -85,42 +91,42 @@ async def _aggregate_traffic(period_days: int) -> tuple[dict[str, dict[str, int] end_str = end_date.strftime('%Y-%m-%d') async with service.get_api_client() as api: + # Get all nodes for column headers nodes = await api.get_all_nodes() - async def fetch_node_users(node): - try: - return node, await api.get_bandwidth_stats_node_users( - node.uuid, start_str, end_str, top_users_limit=_MAX_USERS_PER_NODE - ) - except Exception: - logger.warning('Failed to get traffic for node %s', node.uuid, exc_info=True) - return node, None + # Fetch per-user bandwidth stats in parallel with concurrency limit. + # Response format: {series: [{uuid: NODE_UUID, total: bytes, ...}, ...]} + semaphore = asyncio.Semaphore(_CONCURRENCY_LIMIT) - results = await asyncio.gather(*(fetch_node_users(n) for n in nodes)) + async def fetch_user_stats(user_uuid: str): + async with semaphore: + try: + stats = await api.get_bandwidth_stats_user(user_uuid, start_str, end_str) + return user_uuid, stats + except Exception: + logger.warning('Failed to get traffic for user %s', user_uuid[:8], exc_info=True) + return user_uuid, None + + results = await asyncio.gather(*(fetch_user_stats(uid) for uid in user_uuids)) + + nodes_info: list[TrafficNodeInfo] = [ + TrafficNodeInfo(node_uuid=node.uuid, node_name=node.name, country_code=node.country_code) for node in nodes + ] + nodes_info.sort(key=lambda n: n.node_name) - nodes_info: list[TrafficNodeInfo] = [] user_traffic: dict[str, dict[str, int]] = {} - - for node, stats in results: - nodes_info.append( - TrafficNodeInfo( - node_uuid=node.uuid, - node_name=node.name, - country_code=node.country_code, - ) - ) + for user_uuid, stats in results: if not isinstance(stats, dict): continue + node_traffic: dict[str, int] = {} for series_item in stats.get('series', []): - user_uuid = series_item.get('uuid', '') + node_uuid = series_item.get('uuid', '') total = int(series_item.get('total', 0)) - if not user_uuid or total == 0: - continue - if user_uuid not in user_traffic: - user_traffic[user_uuid] = {} - user_traffic[user_uuid][node.uuid] = user_traffic[user_uuid].get(node.uuid, 0) + total + if node_uuid and total > 0: + node_traffic[node_uuid] = node_traffic.get(node_uuid, 0) + total + if node_traffic: + user_traffic[user_uuid] = node_traffic - nodes_info.sort(key=lambda n: n.node_name) _traffic_cache[period_days] = (now, user_traffic, nodes_info) return user_traffic, nodes_info @@ -144,8 +150,9 @@ def _build_traffic_items( search: str = '', sort_by: str = 'total_bytes', sort_desc: bool = True, + tariff_filter: set[str] | None = None, ) -> list[UserTrafficItem]: - """Merge traffic data with user data, apply search filter, return sorted list.""" + """Merge traffic data with user data, apply search/tariff filters, return sorted list.""" items: list[UserTrafficItem] = [] search_lower = search.lower().strip() @@ -178,6 +185,10 @@ def _build_traffic_items( if sub.tariff: tariff_name = sub.tariff.name + if tariff_filter is not None: + if (tariff_name or '') not in tariff_filter: + continue + items.append( UserTrafficItem( user_id=user.id, @@ -215,12 +226,27 @@ async def get_traffic_usage( search: str = Query('', max_length=100), sort_by: str = Query('total_bytes', max_length=100), sort_desc: bool = Query(True), + tariffs: str = Query('', max_length=500), ): """Get paginated per-user traffic usage by node.""" _validate_period(period) - user_traffic, nodes_info = await _aggregate_traffic(period) user_map = await _load_user_map(db) + user_traffic, nodes_info = await _aggregate_traffic(period, list(user_map.keys())) + + # Collect all available tariff names (before filtering) + available_tariffs = sorted( + { + u.subscription.tariff.name + for u in user_map.values() + if u.subscription and u.subscription.tariff and u.subscription.tariff.name + } + ) + + # Parse tariff filter + tariff_filter: set[str] | None = None + if tariffs.strip(): + tariff_filter = {t.strip() for t in tariffs.split(',') if t.strip()} # Validate sort_by: allow known fields + 'node_' for dynamic node columns node_uuids = {n.node_uuid for n in nodes_info} @@ -228,7 +254,7 @@ async def get_traffic_usage( if sort_by not in _SORT_FIELDS and not is_node_sort: sort_by = 'total_bytes' - items = _build_traffic_items(user_traffic, user_map, nodes_info, search, sort_by, sort_desc) + items = _build_traffic_items(user_traffic, user_map, nodes_info, search, sort_by, sort_desc, tariff_filter) total = len(items) paginated = items[offset : offset + limit] @@ -240,6 +266,7 @@ async def get_traffic_usage( offset=offset, limit=limit, period_days=period, + available_tariffs=available_tariffs, ) @@ -258,8 +285,8 @@ async def export_traffic_csv( detail='Admin has no Telegram ID configured', ) - user_traffic, nodes_info = await _aggregate_traffic(request.period) user_map = await _load_user_map(db) + user_traffic, nodes_info = await _aggregate_traffic(request.period, list(user_map.keys())) items = _build_traffic_items(user_traffic, user_map, nodes_info) # Build CSV rows diff --git a/app/cabinet/schemas/traffic.py b/app/cabinet/schemas/traffic.py index e68df89b..0c5bb087 100644 --- a/app/cabinet/schemas/traffic.py +++ b/app/cabinet/schemas/traffic.py @@ -29,6 +29,7 @@ class TrafficUsageResponse(BaseModel): offset: int limit: int period_days: int + available_tariffs: list[str] class ExportCsvRequest(BaseModel):