Revert "Merge pull request #2565 from BEDOLAGA-DEV/feat/traffic-filters-devices"

This reverts commit ad6522f547, reversing
changes made to 61bb8fcafd.
This commit is contained in:
Fringg
2026-02-07 11:29:31 +03:00
parent ad6522f547
commit 3fd3bce2cf
2 changed files with 46 additions and 214 deletions

View File

@@ -36,18 +36,14 @@ router = APIRouter(prefix='/admin/traffic', tags=['Admin Traffic'])
_ALLOWED_PERIODS = frozenset({1, 3, 7, 14, 30})
_CONCURRENCY_LIMIT = 5 # Max parallel API calls to avoid rate limiting
_DEVICE_CONCURRENCY_LIMIT = 10
# In-memory cache: {(start_str, end_str): (timestamp, aggregated_data, nodes_info, devices_map)}
_traffic_cache: dict[
tuple[str, str], tuple[float, dict[str, dict[str, int]], list[TrafficNodeInfo], dict[str, int]]
] = {}
# In-memory cache: {period_days: (timestamp, aggregated_data, nodes_info)}
_traffic_cache: dict[int, tuple[float, dict[str, dict[str, int]], list[TrafficNodeInfo]]] = {}
_CACHE_TTL = 300 # 5 minutes
_cache_lock = asyncio.Lock()
# Valid sort fields for the GET endpoint
_SORT_FIELDS = frozenset({'total_bytes', 'full_name', 'tariff_name', 'device_limit', 'traffic_limit_gb'})
_MAX_DATE_RANGE_DAYS = 31
def _validate_period(period: int) -> None:
@@ -58,78 +54,9 @@ def _validate_period(period: int) -> None:
)
def _resolve_date_range(period: int, start_date: str, end_date: str) -> tuple[str, str, int]:
"""Resolve date range from either custom dates or period.
Returns (start_str, end_str, period_days) in ISO datetime format.
"""
now = datetime.now(UTC)
if start_date and end_date:
try:
start_dt = datetime.strptime(start_date, '%Y-%m-%d').replace(tzinfo=UTC)
end_dt = datetime.strptime(end_date, '%Y-%m-%d').replace(tzinfo=UTC, hour=23, minute=59, second=59)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Invalid date format. Use YYYY-MM-DD.',
)
if start_dt > end_dt:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='start_date must be before end_date.',
)
end_dt = min(end_dt, now)
if start_dt > end_dt:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='start_date cannot be in the future.',
)
delta = (end_dt - start_dt).days
if delta > _MAX_DATE_RANGE_DAYS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f'Date range must not exceed {_MAX_DATE_RANGE_DAYS} days.',
)
period_days = max(delta, 1)
start_str = start_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
end_str = end_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
return start_str, end_str, period_days
_validate_period(period)
end_dt = now
start_dt = end_dt - timedelta(days=period)
start_str = start_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
end_str = end_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
return start_str, end_str, period
async def _fetch_devices(api, user_uuids: list[str]) -> dict[str, int]:
"""Fetch connected device count for each user UUID. Returns {uuid: count}."""
semaphore = asyncio.Semaphore(_DEVICE_CONCURRENCY_LIMIT)
devices_map: dict[str, int] = {}
async def fetch_one(uuid: str):
async with semaphore:
try:
result = await api.get_user_devices(uuid)
devices_map[uuid] = result.get('total', 0)
except Exception:
logger.debug('Failed to get devices for user %s', uuid, exc_info=True)
devices_map[uuid] = 0
await asyncio.gather(*(fetch_one(uid) for uid in user_uuids))
return devices_map
async def _aggregate_traffic(
start_str: str, end_str: str, user_uuids: list[str]
) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo], dict[str, int]]:
period_days: int, user_uuids: list[str]
) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo]]:
"""Aggregate per-user traffic across all nodes for a given period.
Uses legacy per-node endpoint to fetch all users' traffic per node —
@@ -137,30 +64,33 @@ async def _aggregate_traffic(
{userUuid, nodeUuid, total} per entry (non-legacy only returns topUsers
without userUuid).
Returns (user_traffic, nodes_info, devices_map) where:
Returns (user_traffic, nodes_info) where:
user_traffic = {remnawave_uuid: {node_uuid: total_bytes, ...}}
nodes_info = [TrafficNodeInfo, ...]
devices_map = {remnawave_uuid: connected_device_count}
"""
cache_key = (start_str, end_str)
# Quick check without lock
now = time.time()
cached = _traffic_cache.get(cache_key)
cached = _traffic_cache.get(period_days)
if cached and (now - cached[0]) < _CACHE_TTL:
return cached[1], cached[2], cached[3]
return cached[1], cached[2]
# Acquire lock for the slow path
async with _cache_lock:
# Re-check after acquiring lock
now = time.time()
cached = _traffic_cache.get(cache_key)
cached = _traffic_cache.get(period_days)
if cached and (now - cached[0]) < _CACHE_TTL:
return cached[1], cached[2], cached[3]
return cached[1], cached[2]
service = RemnaWaveService()
if not service.is_configured:
return {}, [], {}
return {}, []
end_date = datetime.now(UTC)
start_date = end_date - timedelta(days=period_days)
# Legacy endpoint expects date-time format
start_str = start_date.strftime('%Y-%m-%dT%H:%M:%SZ')
end_str = end_date.strftime('%Y-%m-%dT%H:%M:%SZ')
user_uuids_set = set(user_uuids)
@@ -181,31 +111,24 @@ async def _aggregate_traffic(
results = await asyncio.gather(*(fetch_node_users(n) for n in nodes))
nodes_info: list[TrafficNodeInfo] = [
TrafficNodeInfo(node_uuid=node.uuid, node_name=node.name, country_code=node.country_code)
for node in nodes
]
nodes_info.sort(key=lambda n: n.node_name)
nodes_info: list[TrafficNodeInfo] = [
TrafficNodeInfo(node_uuid=node.uuid, node_name=node.name, country_code=node.country_code) for node in nodes
]
nodes_info.sort(key=lambda n: n.node_name)
# Legacy response: [{userUuid, username, nodeUuid, total, date}, ...]
user_traffic: dict[str, dict[str, int]] = {}
for node_uuid, entries in results:
if not isinstance(entries, list):
continue
for entry in entries:
uid = entry.get('userUuid', '')
total = int(entry.get('total', 0))
if uid and total > 0 and uid in user_uuids_set:
user_traffic.setdefault(uid, {})[node_uuid] = (
user_traffic.get(uid, {}).get(node_uuid, 0) + total
)
# Legacy response: [{userUuid, username, nodeUuid, total, date}, ...]
user_traffic: dict[str, dict[str, int]] = {}
for node_uuid, entries in results:
if not isinstance(entries, list):
continue
for entry in entries:
uid = entry.get('userUuid', '')
total = int(entry.get('total', 0))
if uid and total > 0 and uid in user_uuids_set:
user_traffic.setdefault(uid, {})[node_uuid] = user_traffic.get(uid, {}).get(node_uuid, 0) + total
# Fetch devices for users that have traffic
uuids_with_traffic = list(user_traffic.keys())
devices_map = await _fetch_devices(api, uuids_with_traffic) if uuids_with_traffic else {}
_traffic_cache[cache_key] = (now, user_traffic, nodes_info, devices_map)
return user_traffic, nodes_info, devices_map
_traffic_cache[period_days] = (now, user_traffic, nodes_info)
return user_traffic, nodes_info
async def _load_user_map(db: AsyncSession) -> dict[str, User]:
@@ -224,15 +147,12 @@ def _build_traffic_items(
user_traffic: dict[str, dict[str, int]],
user_map: dict[str, User],
nodes_info: list[TrafficNodeInfo],
devices_map: dict[str, int],
search: str = '',
sort_by: str = 'total_bytes',
sort_desc: bool = True,
tariff_filter: set[str] | None = None,
node_filter: set[str] | None = None,
status_filter: set[str] | None = None,
) -> list[UserTrafficItem]:
"""Merge traffic data with user data, apply search/tariff/node/status filters, return sorted list."""
"""Merge traffic data with user data, apply search/tariff filters, return sorted list."""
items: list[UserTrafficItem] = []
search_lower = search.lower().strip()
@@ -243,11 +163,6 @@ def _build_traffic_items(
continue
traffic = user_traffic.get(uuid, {})
# Apply node filter: keep only selected nodes, recalculate total
if node_filter is not None:
traffic = {nid: val for nid, val in traffic.items() if nid in node_filter}
total_bytes = sum(traffic.values())
full_name = user.full_name
@@ -274,12 +189,6 @@ def _build_traffic_items(
if (tariff_name or '') not in tariff_filter:
continue
if status_filter is not None:
if (subscription_status or '') not in status_filter:
continue
connected_devices = devices_map.get(uuid, 0)
items.append(
UserTrafficItem(
user_id=user.id,
@@ -290,7 +199,6 @@ def _build_traffic_items(
subscription_status=subscription_status,
traffic_limit_gb=traffic_limit_gb,
device_limit=device_limit,
connected_devices=connected_devices,
node_traffic=traffic,
total_bytes=total_bytes,
)
@@ -319,16 +227,12 @@ async def get_traffic_usage(
sort_by: str = Query('total_bytes', max_length=100),
sort_desc: bool = Query(True),
tariffs: str = Query('', max_length=500),
nodes: str = Query('', max_length=2000),
statuses: str = Query('', max_length=200),
start_date: str = Query('', max_length=10),
end_date: str = Query('', max_length=10),
):
"""Get paginated per-user traffic usage by node."""
start_str, end_str, period_days = _resolve_date_range(period, start_date, end_date)
_validate_period(period)
user_map = await _load_user_map(db)
user_traffic, nodes_info, devices_map = await _aggregate_traffic(start_str, end_str, list(user_map.keys()))
user_traffic, nodes_info = await _aggregate_traffic(period, list(user_map.keys()))
# Collect all available tariff names (before filtering)
available_tariffs = sorted(
@@ -339,66 +243,30 @@ async def get_traffic_usage(
}
)
# Collect all available statuses (before filtering)
available_statuses = sorted(
{
(u.subscription.actual_status if hasattr(u.subscription, 'actual_status') else u.subscription.status)
for u in user_map.values()
if u.subscription
}
)
# Parse tariff filter
tariff_filter: set[str] | None = None
if tariffs.strip():
tariff_filter = {t.strip() for t in tariffs.split(',') if t.strip()}
# Parse node filter
node_filter: set[str] | None = None
if nodes.strip():
node_filter = {n.strip() for n in nodes.split(',') if n.strip()}
# Parse status filter
status_filter: set[str] | None = None
if statuses.strip():
status_filter = {s.strip() for s in statuses.split(',') if s.strip()}
# Validate sort_by: allow known fields + 'node_<uuid>' for dynamic node columns
node_uuids = {n.node_uuid for n in nodes_info}
is_node_sort = sort_by.startswith('node_') and sort_by[5:] in node_uuids
if sort_by not in _SORT_FIELDS and not is_node_sort:
sort_by = 'total_bytes'
items = _build_traffic_items(
user_traffic,
user_map,
nodes_info,
devices_map,
search,
sort_by,
sort_desc,
tariff_filter,
node_filter,
status_filter,
)
items = _build_traffic_items(user_traffic, user_map, nodes_info, search, sort_by, sort_desc, tariff_filter)
total = len(items)
paginated = items[offset : offset + limit]
# Filter nodes_info to only selected nodes for frontend column display
filtered_nodes = nodes_info
if node_filter is not None:
filtered_nodes = [n for n in nodes_info if n.node_uuid in node_filter]
return TrafficUsageResponse(
items=paginated,
nodes=filtered_nodes,
nodes=nodes_info,
total=total,
offset=offset,
limit=limit,
period_days=period_days,
period_days=period,
available_tariffs=available_tariffs,
available_statuses=available_statuses,
)
@@ -409,45 +277,17 @@ async def export_traffic_csv(
db: AsyncSession = Depends(get_cabinet_db),
):
"""Generate CSV with traffic usage and send to admin's Telegram DM."""
_validate_period(request.period)
if not admin.telegram_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Admin has no Telegram ID configured',
)
start_str, end_str, period_days = _resolve_date_range(request.period, request.start_date, request.end_date)
user_map = await _load_user_map(db)
user_traffic, nodes_info, devices_map = await _aggregate_traffic(start_str, end_str, list(user_map.keys()))
# Parse filters
tariff_filter: set[str] | None = None
if request.tariffs.strip():
tariff_filter = {t.strip() for t in request.tariffs.split(',') if t.strip()}
node_filter: set[str] | None = None
if request.nodes.strip():
node_filter = {n.strip() for n in request.nodes.split(',') if n.strip()}
status_filter: set[str] | None = None
if request.statuses.strip():
status_filter = {s.strip() for s in request.statuses.split(',') if s.strip()}
items = _build_traffic_items(
user_traffic,
user_map,
nodes_info,
devices_map,
search=request.search,
tariff_filter=tariff_filter,
node_filter=node_filter,
status_filter=status_filter,
)
# Filter node columns for CSV if node filter active
csv_nodes = nodes_info
if node_filter is not None:
csv_nodes = [n for n in nodes_info if n.node_uuid in node_filter]
user_traffic, nodes_info = await _aggregate_traffic(request.period, list(user_map.keys()))
items = _build_traffic_items(user_traffic, user_map, nodes_info)
# Build CSV rows
rows: list[dict] = []
@@ -460,9 +300,9 @@ async def export_traffic_csv(
'Tariff': item.tariff_name or '',
'Status': item.subscription_status or '',
'Traffic Limit (GB)': item.traffic_limit_gb,
'Devices': f'{item.connected_devices}/{item.device_limit}',
'Devices': item.device_limit,
}
for node in csv_nodes:
for node in nodes_info:
row[f'{node.node_name} (bytes)'] = item.node_traffic.get(node.node_uuid, 0)
row['Total (bytes)'] = item.total_bytes
row['Total (GB)'] = round(item.total_bytes / (1024**3), 2) if item.total_bytes else 0
@@ -477,7 +317,7 @@ async def export_traffic_csv(
csv_bytes = output.getvalue().encode('utf-8-sig')
timestamp = datetime.now(UTC).strftime('%Y%m%d_%H%M%S')
filename = f'traffic_usage_{period_days}d_{timestamp}.csv'
filename = f'traffic_usage_{request.period}d_{timestamp}.csv'
try:
bot = Bot(
@@ -488,7 +328,7 @@ async def export_traffic_csv(
await bot.send_document(
chat_id=admin.telegram_id,
document=BufferedInputFile(csv_bytes, filename=filename),
caption=f'Traffic usage report ({period_days}d)\nUsers: {len(rows)}',
caption=f'Traffic usage report ({request.period}d)\nUsers: {len(rows)}',
)
except Exception:
logger.error('Failed to send CSV to admin %s', admin.telegram_id, exc_info=True)

View File

@@ -18,7 +18,6 @@ class UserTrafficItem(BaseModel):
subscription_status: str | None
traffic_limit_gb: float
device_limit: int
connected_devices: int = 0
node_traffic: dict[str, int] # {node_uuid: total_bytes}
total_bytes: int
@@ -31,17 +30,10 @@ class TrafficUsageResponse(BaseModel):
limit: int
period_days: int
available_tariffs: list[str]
available_statuses: list[str]
class ExportCsvRequest(BaseModel):
period: int = Field(30, ge=1, le=30)
start_date: str = ''
end_date: str = ''
tariffs: str = ''
nodes: str = ''
statuses: str = ''
search: str = ''
class ExportCsvResponse(BaseModel):