feat: add tariff filter, fix traffic data aggregation

- Switch from get_bandwidth_stats_node_users (broken UUID matching) to
  get_bandwidth_stats_user per user (same API as working detail page)
- Add tariff filter with available_tariffs in response
- Add concurrency-limited parallel per-user bandwidth stats fetching
This commit is contained in:
Fringg
2026-02-07 09:30:20 +03:00
parent eeed2d6369
commit fa01819674
2 changed files with 60 additions and 32 deletions

View File

@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix='/admin/traffic', tags=['Admin Traffic'])
_ALLOWED_PERIODS = frozenset({1, 3, 7, 14, 30})
_MAX_USERS_PER_NODE = 100_000
_CONCURRENCY_LIMIT = 20 # Max parallel per-user API calls
# In-memory cache: {period_days: (timestamp, aggregated_data, nodes_info)}
_traffic_cache: dict[int, tuple[float, dict[str, dict[str, int]], list[TrafficNodeInfo]]] = {}
@@ -54,9 +54,15 @@ def _validate_period(period: int) -> None:
)
async def _aggregate_traffic(period_days: int) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo]]:
async def _aggregate_traffic(
period_days: int, user_uuids: list[str]
) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo]]:
"""Aggregate per-user traffic across all nodes for a given period.
Uses get_bandwidth_stats_user() per user (same API as the working
AdminUserDetail page) instead of get_bandwidth_stats_node_users()
which returns UUIDs in a format that may not match the bot DB.
Returns (user_traffic, nodes_info) where:
user_traffic = {remnawave_uuid: {node_uuid: total_bytes, ...}}
nodes_info = [TrafficNodeInfo, ...]
@@ -85,42 +91,42 @@ async def _aggregate_traffic(period_days: int) -> tuple[dict[str, dict[str, int]
end_str = end_date.strftime('%Y-%m-%d')
async with service.get_api_client() as api:
# Get all nodes for column headers
nodes = await api.get_all_nodes()
async def fetch_node_users(node):
try:
return node, await api.get_bandwidth_stats_node_users(
node.uuid, start_str, end_str, top_users_limit=_MAX_USERS_PER_NODE
)
except Exception:
logger.warning('Failed to get traffic for node %s', node.uuid, exc_info=True)
return node, None
# Fetch per-user bandwidth stats in parallel with concurrency limit.
# Response format: {series: [{uuid: NODE_UUID, total: bytes, ...}, ...]}
semaphore = asyncio.Semaphore(_CONCURRENCY_LIMIT)
results = await asyncio.gather(*(fetch_node_users(n) for n in nodes))
async def fetch_user_stats(user_uuid: str):
async with semaphore:
try:
stats = await api.get_bandwidth_stats_user(user_uuid, start_str, end_str)
return user_uuid, stats
except Exception:
logger.warning('Failed to get traffic for user %s', user_uuid[:8], exc_info=True)
return user_uuid, None
results = await asyncio.gather(*(fetch_user_stats(uid) for uid in user_uuids))
nodes_info: list[TrafficNodeInfo] = [
TrafficNodeInfo(node_uuid=node.uuid, node_name=node.name, country_code=node.country_code) for node in nodes
]
nodes_info.sort(key=lambda n: n.node_name)
nodes_info: list[TrafficNodeInfo] = []
user_traffic: dict[str, dict[str, int]] = {}
for node, stats in results:
nodes_info.append(
TrafficNodeInfo(
node_uuid=node.uuid,
node_name=node.name,
country_code=node.country_code,
)
)
for user_uuid, stats in results:
if not isinstance(stats, dict):
continue
node_traffic: dict[str, int] = {}
for series_item in stats.get('series', []):
user_uuid = series_item.get('uuid', '')
node_uuid = series_item.get('uuid', '')
total = int(series_item.get('total', 0))
if not user_uuid or total == 0:
continue
if user_uuid not in user_traffic:
user_traffic[user_uuid] = {}
user_traffic[user_uuid][node.uuid] = user_traffic[user_uuid].get(node.uuid, 0) + total
if node_uuid and total > 0:
node_traffic[node_uuid] = node_traffic.get(node_uuid, 0) + total
if node_traffic:
user_traffic[user_uuid] = node_traffic
nodes_info.sort(key=lambda n: n.node_name)
_traffic_cache[period_days] = (now, user_traffic, nodes_info)
return user_traffic, nodes_info
@@ -144,8 +150,9 @@ def _build_traffic_items(
search: str = '',
sort_by: str = 'total_bytes',
sort_desc: bool = True,
tariff_filter: set[str] | None = None,
) -> list[UserTrafficItem]:
"""Merge traffic data with user data, apply search filter, return sorted list."""
"""Merge traffic data with user data, apply search/tariff filters, return sorted list."""
items: list[UserTrafficItem] = []
search_lower = search.lower().strip()
@@ -178,6 +185,10 @@ def _build_traffic_items(
if sub.tariff:
tariff_name = sub.tariff.name
if tariff_filter is not None:
if (tariff_name or '') not in tariff_filter:
continue
items.append(
UserTrafficItem(
user_id=user.id,
@@ -215,12 +226,27 @@ async def get_traffic_usage(
search: str = Query('', max_length=100),
sort_by: str = Query('total_bytes', max_length=100),
sort_desc: bool = Query(True),
tariffs: str = Query('', max_length=500),
):
"""Get paginated per-user traffic usage by node."""
_validate_period(period)
user_traffic, nodes_info = await _aggregate_traffic(period)
user_map = await _load_user_map(db)
user_traffic, nodes_info = await _aggregate_traffic(period, list(user_map.keys()))
# Collect all available tariff names (before filtering)
available_tariffs = sorted(
{
u.subscription.tariff.name
for u in user_map.values()
if u.subscription and u.subscription.tariff and u.subscription.tariff.name
}
)
# Parse tariff filter
tariff_filter: set[str] | None = None
if tariffs.strip():
tariff_filter = {t.strip() for t in tariffs.split(',') if t.strip()}
# Validate sort_by: allow known fields + 'node_<uuid>' for dynamic node columns
node_uuids = {n.node_uuid for n in nodes_info}
@@ -228,7 +254,7 @@ async def get_traffic_usage(
if sort_by not in _SORT_FIELDS and not is_node_sort:
sort_by = 'total_bytes'
items = _build_traffic_items(user_traffic, user_map, nodes_info, search, sort_by, sort_desc)
items = _build_traffic_items(user_traffic, user_map, nodes_info, search, sort_by, sort_desc, tariff_filter)
total = len(items)
paginated = items[offset : offset + limit]
@@ -240,6 +266,7 @@ async def get_traffic_usage(
offset=offset,
limit=limit,
period_days=period,
available_tariffs=available_tariffs,
)
@@ -258,8 +285,8 @@ async def export_traffic_csv(
detail='Admin has no Telegram ID configured',
)
user_traffic, nodes_info = await _aggregate_traffic(request.period)
user_map = await _load_user_map(db)
user_traffic, nodes_info = await _aggregate_traffic(request.period, list(user_map.keys()))
items = _build_traffic_items(user_traffic, user_map, nodes_info)
# Build CSV rows

View File

@@ -29,6 +29,7 @@ class TrafficUsageResponse(BaseModel):
offset: int
limit: int
period_days: int
available_tariffs: list[str]
class ExportCsvRequest(BaseModel):