mirror of
https://github.com/BEDOLAGA-DEV/remnawave-bedolaga-telegram-bot.git
synced 2026-02-28 23:35:59 +00:00
feat: add tariff filter, fix traffic data aggregation
- Switch from get_bandwidth_stats_node_users (broken UUID matching) to get_bandwidth_stats_user per user (same API as working detail page) - Add tariff filter with available_tariffs in response - Add concurrency-limited parallel per-user bandwidth stats fetching
This commit is contained in:
@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix='/admin/traffic', tags=['Admin Traffic'])
|
||||
|
||||
_ALLOWED_PERIODS = frozenset({1, 3, 7, 14, 30})
|
||||
_MAX_USERS_PER_NODE = 100_000
|
||||
_CONCURRENCY_LIMIT = 20 # Max parallel per-user API calls
|
||||
|
||||
# In-memory cache: {period_days: (timestamp, aggregated_data, nodes_info)}
|
||||
_traffic_cache: dict[int, tuple[float, dict[str, dict[str, int]], list[TrafficNodeInfo]]] = {}
|
||||
@@ -54,9 +54,15 @@ def _validate_period(period: int) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def _aggregate_traffic(period_days: int) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo]]:
|
||||
async def _aggregate_traffic(
|
||||
period_days: int, user_uuids: list[str]
|
||||
) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo]]:
|
||||
"""Aggregate per-user traffic across all nodes for a given period.
|
||||
|
||||
Uses get_bandwidth_stats_user() per user (same API as the working
|
||||
AdminUserDetail page) instead of get_bandwidth_stats_node_users()
|
||||
which returns UUIDs in a format that may not match the bot DB.
|
||||
|
||||
Returns (user_traffic, nodes_info) where:
|
||||
user_traffic = {remnawave_uuid: {node_uuid: total_bytes, ...}}
|
||||
nodes_info = [TrafficNodeInfo, ...]
|
||||
@@ -85,42 +91,42 @@ async def _aggregate_traffic(period_days: int) -> tuple[dict[str, dict[str, int]
|
||||
end_str = end_date.strftime('%Y-%m-%d')
|
||||
|
||||
async with service.get_api_client() as api:
|
||||
# Get all nodes for column headers
|
||||
nodes = await api.get_all_nodes()
|
||||
|
||||
async def fetch_node_users(node):
|
||||
try:
|
||||
return node, await api.get_bandwidth_stats_node_users(
|
||||
node.uuid, start_str, end_str, top_users_limit=_MAX_USERS_PER_NODE
|
||||
)
|
||||
except Exception:
|
||||
logger.warning('Failed to get traffic for node %s', node.uuid, exc_info=True)
|
||||
return node, None
|
||||
# Fetch per-user bandwidth stats in parallel with concurrency limit.
|
||||
# Response format: {series: [{uuid: NODE_UUID, total: bytes, ...}, ...]}
|
||||
semaphore = asyncio.Semaphore(_CONCURRENCY_LIMIT)
|
||||
|
||||
results = await asyncio.gather(*(fetch_node_users(n) for n in nodes))
|
||||
async def fetch_user_stats(user_uuid: str):
|
||||
async with semaphore:
|
||||
try:
|
||||
stats = await api.get_bandwidth_stats_user(user_uuid, start_str, end_str)
|
||||
return user_uuid, stats
|
||||
except Exception:
|
||||
logger.warning('Failed to get traffic for user %s', user_uuid[:8], exc_info=True)
|
||||
return user_uuid, None
|
||||
|
||||
results = await asyncio.gather(*(fetch_user_stats(uid) for uid in user_uuids))
|
||||
|
||||
nodes_info: list[TrafficNodeInfo] = [
|
||||
TrafficNodeInfo(node_uuid=node.uuid, node_name=node.name, country_code=node.country_code) for node in nodes
|
||||
]
|
||||
nodes_info.sort(key=lambda n: n.node_name)
|
||||
|
||||
nodes_info: list[TrafficNodeInfo] = []
|
||||
user_traffic: dict[str, dict[str, int]] = {}
|
||||
|
||||
for node, stats in results:
|
||||
nodes_info.append(
|
||||
TrafficNodeInfo(
|
||||
node_uuid=node.uuid,
|
||||
node_name=node.name,
|
||||
country_code=node.country_code,
|
||||
)
|
||||
)
|
||||
for user_uuid, stats in results:
|
||||
if not isinstance(stats, dict):
|
||||
continue
|
||||
node_traffic: dict[str, int] = {}
|
||||
for series_item in stats.get('series', []):
|
||||
user_uuid = series_item.get('uuid', '')
|
||||
node_uuid = series_item.get('uuid', '')
|
||||
total = int(series_item.get('total', 0))
|
||||
if not user_uuid or total == 0:
|
||||
continue
|
||||
if user_uuid not in user_traffic:
|
||||
user_traffic[user_uuid] = {}
|
||||
user_traffic[user_uuid][node.uuid] = user_traffic[user_uuid].get(node.uuid, 0) + total
|
||||
if node_uuid and total > 0:
|
||||
node_traffic[node_uuid] = node_traffic.get(node_uuid, 0) + total
|
||||
if node_traffic:
|
||||
user_traffic[user_uuid] = node_traffic
|
||||
|
||||
nodes_info.sort(key=lambda n: n.node_name)
|
||||
_traffic_cache[period_days] = (now, user_traffic, nodes_info)
|
||||
return user_traffic, nodes_info
|
||||
|
||||
@@ -144,8 +150,9 @@ def _build_traffic_items(
|
||||
search: str = '',
|
||||
sort_by: str = 'total_bytes',
|
||||
sort_desc: bool = True,
|
||||
tariff_filter: set[str] | None = None,
|
||||
) -> list[UserTrafficItem]:
|
||||
"""Merge traffic data with user data, apply search filter, return sorted list."""
|
||||
"""Merge traffic data with user data, apply search/tariff filters, return sorted list."""
|
||||
items: list[UserTrafficItem] = []
|
||||
search_lower = search.lower().strip()
|
||||
|
||||
@@ -178,6 +185,10 @@ def _build_traffic_items(
|
||||
if sub.tariff:
|
||||
tariff_name = sub.tariff.name
|
||||
|
||||
if tariff_filter is not None:
|
||||
if (tariff_name or '') not in tariff_filter:
|
||||
continue
|
||||
|
||||
items.append(
|
||||
UserTrafficItem(
|
||||
user_id=user.id,
|
||||
@@ -215,12 +226,27 @@ async def get_traffic_usage(
|
||||
search: str = Query('', max_length=100),
|
||||
sort_by: str = Query('total_bytes', max_length=100),
|
||||
sort_desc: bool = Query(True),
|
||||
tariffs: str = Query('', max_length=500),
|
||||
):
|
||||
"""Get paginated per-user traffic usage by node."""
|
||||
_validate_period(period)
|
||||
|
||||
user_traffic, nodes_info = await _aggregate_traffic(period)
|
||||
user_map = await _load_user_map(db)
|
||||
user_traffic, nodes_info = await _aggregate_traffic(period, list(user_map.keys()))
|
||||
|
||||
# Collect all available tariff names (before filtering)
|
||||
available_tariffs = sorted(
|
||||
{
|
||||
u.subscription.tariff.name
|
||||
for u in user_map.values()
|
||||
if u.subscription and u.subscription.tariff and u.subscription.tariff.name
|
||||
}
|
||||
)
|
||||
|
||||
# Parse tariff filter
|
||||
tariff_filter: set[str] | None = None
|
||||
if tariffs.strip():
|
||||
tariff_filter = {t.strip() for t in tariffs.split(',') if t.strip()}
|
||||
|
||||
# Validate sort_by: allow known fields + 'node_<uuid>' for dynamic node columns
|
||||
node_uuids = {n.node_uuid for n in nodes_info}
|
||||
@@ -228,7 +254,7 @@ async def get_traffic_usage(
|
||||
if sort_by not in _SORT_FIELDS and not is_node_sort:
|
||||
sort_by = 'total_bytes'
|
||||
|
||||
items = _build_traffic_items(user_traffic, user_map, nodes_info, search, sort_by, sort_desc)
|
||||
items = _build_traffic_items(user_traffic, user_map, nodes_info, search, sort_by, sort_desc, tariff_filter)
|
||||
|
||||
total = len(items)
|
||||
paginated = items[offset : offset + limit]
|
||||
@@ -240,6 +266,7 @@ async def get_traffic_usage(
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
period_days=period,
|
||||
available_tariffs=available_tariffs,
|
||||
)
|
||||
|
||||
|
||||
@@ -258,8 +285,8 @@ async def export_traffic_csv(
|
||||
detail='Admin has no Telegram ID configured',
|
||||
)
|
||||
|
||||
user_traffic, nodes_info = await _aggregate_traffic(request.period)
|
||||
user_map = await _load_user_map(db)
|
||||
user_traffic, nodes_info = await _aggregate_traffic(request.period, list(user_map.keys()))
|
||||
items = _build_traffic_items(user_traffic, user_map, nodes_info)
|
||||
|
||||
# Build CSV rows
|
||||
|
||||
@@ -29,6 +29,7 @@ class TrafficUsageResponse(BaseModel):
|
||||
offset: int
|
||||
limit: int
|
||||
period_days: int
|
||||
available_tariffs: list[str]
|
||||
|
||||
|
||||
class ExportCsvRequest(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user