mirror of
https://github.com/BEDOLAGA-DEV/remnawave-bedolaga-telegram-bot.git
synced 2026-02-28 23:35:59 +00:00
Revert "Merge pull request #2565 from BEDOLAGA-DEV/feat/traffic-filters-devices"
This reverts commitad6522f547, reversing changes made to61bb8fcafd.
This commit is contained in:
@@ -36,18 +36,14 @@ router = APIRouter(prefix='/admin/traffic', tags=['Admin Traffic'])
|
||||
|
||||
_ALLOWED_PERIODS = frozenset({1, 3, 7, 14, 30})
|
||||
_CONCURRENCY_LIMIT = 5 # Max parallel API calls to avoid rate limiting
|
||||
_DEVICE_CONCURRENCY_LIMIT = 10
|
||||
|
||||
# In-memory cache: {(start_str, end_str): (timestamp, aggregated_data, nodes_info, devices_map)}
|
||||
_traffic_cache: dict[
|
||||
tuple[str, str], tuple[float, dict[str, dict[str, int]], list[TrafficNodeInfo], dict[str, int]]
|
||||
] = {}
|
||||
# In-memory cache: {period_days: (timestamp, aggregated_data, nodes_info)}
|
||||
_traffic_cache: dict[int, tuple[float, dict[str, dict[str, int]], list[TrafficNodeInfo]]] = {}
|
||||
_CACHE_TTL = 300 # 5 minutes
|
||||
_cache_lock = asyncio.Lock()
|
||||
|
||||
# Valid sort fields for the GET endpoint
|
||||
_SORT_FIELDS = frozenset({'total_bytes', 'full_name', 'tariff_name', 'device_limit', 'traffic_limit_gb'})
|
||||
_MAX_DATE_RANGE_DAYS = 31
|
||||
|
||||
|
||||
def _validate_period(period: int) -> None:
|
||||
@@ -58,78 +54,9 @@ def _validate_period(period: int) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _resolve_date_range(period: int, start_date: str, end_date: str) -> tuple[str, str, int]:
|
||||
"""Resolve date range from either custom dates or period.
|
||||
|
||||
Returns (start_str, end_str, period_days) in ISO datetime format.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
if start_date and end_date:
|
||||
try:
|
||||
start_dt = datetime.strptime(start_date, '%Y-%m-%d').replace(tzinfo=UTC)
|
||||
end_dt = datetime.strptime(end_date, '%Y-%m-%d').replace(tzinfo=UTC, hour=23, minute=59, second=59)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid date format. Use YYYY-MM-DD.',
|
||||
)
|
||||
|
||||
if start_dt > end_dt:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='start_date must be before end_date.',
|
||||
)
|
||||
|
||||
end_dt = min(end_dt, now)
|
||||
|
||||
if start_dt > end_dt:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='start_date cannot be in the future.',
|
||||
)
|
||||
|
||||
delta = (end_dt - start_dt).days
|
||||
if delta > _MAX_DATE_RANGE_DAYS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f'Date range must not exceed {_MAX_DATE_RANGE_DAYS} days.',
|
||||
)
|
||||
|
||||
period_days = max(delta, 1)
|
||||
start_str = start_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
end_str = end_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
return start_str, end_str, period_days
|
||||
|
||||
_validate_period(period)
|
||||
end_dt = now
|
||||
start_dt = end_dt - timedelta(days=period)
|
||||
start_str = start_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
end_str = end_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
return start_str, end_str, period
|
||||
|
||||
|
||||
async def _fetch_devices(api, user_uuids: list[str]) -> dict[str, int]:
|
||||
"""Fetch connected device count for each user UUID. Returns {uuid: count}."""
|
||||
semaphore = asyncio.Semaphore(_DEVICE_CONCURRENCY_LIMIT)
|
||||
devices_map: dict[str, int] = {}
|
||||
|
||||
async def fetch_one(uuid: str):
|
||||
async with semaphore:
|
||||
try:
|
||||
result = await api.get_user_devices(uuid)
|
||||
devices_map[uuid] = result.get('total', 0)
|
||||
except Exception:
|
||||
logger.debug('Failed to get devices for user %s', uuid, exc_info=True)
|
||||
devices_map[uuid] = 0
|
||||
|
||||
await asyncio.gather(*(fetch_one(uid) for uid in user_uuids))
|
||||
return devices_map
|
||||
|
||||
|
||||
async def _aggregate_traffic(
|
||||
start_str: str, end_str: str, user_uuids: list[str]
|
||||
) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo], dict[str, int]]:
|
||||
period_days: int, user_uuids: list[str]
|
||||
) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo]]:
|
||||
"""Aggregate per-user traffic across all nodes for a given period.
|
||||
|
||||
Uses legacy per-node endpoint to fetch all users' traffic per node —
|
||||
@@ -137,30 +64,33 @@ async def _aggregate_traffic(
|
||||
{userUuid, nodeUuid, total} per entry (non-legacy only returns topUsers
|
||||
without userUuid).
|
||||
|
||||
Returns (user_traffic, nodes_info, devices_map) where:
|
||||
Returns (user_traffic, nodes_info) where:
|
||||
user_traffic = {remnawave_uuid: {node_uuid: total_bytes, ...}}
|
||||
nodes_info = [TrafficNodeInfo, ...]
|
||||
devices_map = {remnawave_uuid: connected_device_count}
|
||||
"""
|
||||
cache_key = (start_str, end_str)
|
||||
|
||||
# Quick check without lock
|
||||
now = time.time()
|
||||
cached = _traffic_cache.get(cache_key)
|
||||
cached = _traffic_cache.get(period_days)
|
||||
if cached and (now - cached[0]) < _CACHE_TTL:
|
||||
return cached[1], cached[2], cached[3]
|
||||
return cached[1], cached[2]
|
||||
|
||||
# Acquire lock for the slow path
|
||||
async with _cache_lock:
|
||||
# Re-check after acquiring lock
|
||||
now = time.time()
|
||||
cached = _traffic_cache.get(cache_key)
|
||||
cached = _traffic_cache.get(period_days)
|
||||
if cached and (now - cached[0]) < _CACHE_TTL:
|
||||
return cached[1], cached[2], cached[3]
|
||||
return cached[1], cached[2]
|
||||
|
||||
service = RemnaWaveService()
|
||||
if not service.is_configured:
|
||||
return {}, [], {}
|
||||
return {}, []
|
||||
|
||||
end_date = datetime.now(UTC)
|
||||
start_date = end_date - timedelta(days=period_days)
|
||||
# Legacy endpoint expects date-time format
|
||||
start_str = start_date.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
end_str = end_date.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
|
||||
user_uuids_set = set(user_uuids)
|
||||
|
||||
@@ -181,31 +111,24 @@ async def _aggregate_traffic(
|
||||
|
||||
results = await asyncio.gather(*(fetch_node_users(n) for n in nodes))
|
||||
|
||||
nodes_info: list[TrafficNodeInfo] = [
|
||||
TrafficNodeInfo(node_uuid=node.uuid, node_name=node.name, country_code=node.country_code)
|
||||
for node in nodes
|
||||
]
|
||||
nodes_info.sort(key=lambda n: n.node_name)
|
||||
nodes_info: list[TrafficNodeInfo] = [
|
||||
TrafficNodeInfo(node_uuid=node.uuid, node_name=node.name, country_code=node.country_code) for node in nodes
|
||||
]
|
||||
nodes_info.sort(key=lambda n: n.node_name)
|
||||
|
||||
# Legacy response: [{userUuid, username, nodeUuid, total, date}, ...]
|
||||
user_traffic: dict[str, dict[str, int]] = {}
|
||||
for node_uuid, entries in results:
|
||||
if not isinstance(entries, list):
|
||||
continue
|
||||
for entry in entries:
|
||||
uid = entry.get('userUuid', '')
|
||||
total = int(entry.get('total', 0))
|
||||
if uid and total > 0 and uid in user_uuids_set:
|
||||
user_traffic.setdefault(uid, {})[node_uuid] = (
|
||||
user_traffic.get(uid, {}).get(node_uuid, 0) + total
|
||||
)
|
||||
# Legacy response: [{userUuid, username, nodeUuid, total, date}, ...]
|
||||
user_traffic: dict[str, dict[str, int]] = {}
|
||||
for node_uuid, entries in results:
|
||||
if not isinstance(entries, list):
|
||||
continue
|
||||
for entry in entries:
|
||||
uid = entry.get('userUuid', '')
|
||||
total = int(entry.get('total', 0))
|
||||
if uid and total > 0 and uid in user_uuids_set:
|
||||
user_traffic.setdefault(uid, {})[node_uuid] = user_traffic.get(uid, {}).get(node_uuid, 0) + total
|
||||
|
||||
# Fetch devices for users that have traffic
|
||||
uuids_with_traffic = list(user_traffic.keys())
|
||||
devices_map = await _fetch_devices(api, uuids_with_traffic) if uuids_with_traffic else {}
|
||||
|
||||
_traffic_cache[cache_key] = (now, user_traffic, nodes_info, devices_map)
|
||||
return user_traffic, nodes_info, devices_map
|
||||
_traffic_cache[period_days] = (now, user_traffic, nodes_info)
|
||||
return user_traffic, nodes_info
|
||||
|
||||
|
||||
async def _load_user_map(db: AsyncSession) -> dict[str, User]:
|
||||
@@ -224,15 +147,12 @@ def _build_traffic_items(
|
||||
user_traffic: dict[str, dict[str, int]],
|
||||
user_map: dict[str, User],
|
||||
nodes_info: list[TrafficNodeInfo],
|
||||
devices_map: dict[str, int],
|
||||
search: str = '',
|
||||
sort_by: str = 'total_bytes',
|
||||
sort_desc: bool = True,
|
||||
tariff_filter: set[str] | None = None,
|
||||
node_filter: set[str] | None = None,
|
||||
status_filter: set[str] | None = None,
|
||||
) -> list[UserTrafficItem]:
|
||||
"""Merge traffic data with user data, apply search/tariff/node/status filters, return sorted list."""
|
||||
"""Merge traffic data with user data, apply search/tariff filters, return sorted list."""
|
||||
items: list[UserTrafficItem] = []
|
||||
search_lower = search.lower().strip()
|
||||
|
||||
@@ -243,11 +163,6 @@ def _build_traffic_items(
|
||||
continue
|
||||
|
||||
traffic = user_traffic.get(uuid, {})
|
||||
|
||||
# Apply node filter: keep only selected nodes, recalculate total
|
||||
if node_filter is not None:
|
||||
traffic = {nid: val for nid, val in traffic.items() if nid in node_filter}
|
||||
|
||||
total_bytes = sum(traffic.values())
|
||||
|
||||
full_name = user.full_name
|
||||
@@ -274,12 +189,6 @@ def _build_traffic_items(
|
||||
if (tariff_name or '') not in tariff_filter:
|
||||
continue
|
||||
|
||||
if status_filter is not None:
|
||||
if (subscription_status or '') not in status_filter:
|
||||
continue
|
||||
|
||||
connected_devices = devices_map.get(uuid, 0)
|
||||
|
||||
items.append(
|
||||
UserTrafficItem(
|
||||
user_id=user.id,
|
||||
@@ -290,7 +199,6 @@ def _build_traffic_items(
|
||||
subscription_status=subscription_status,
|
||||
traffic_limit_gb=traffic_limit_gb,
|
||||
device_limit=device_limit,
|
||||
connected_devices=connected_devices,
|
||||
node_traffic=traffic,
|
||||
total_bytes=total_bytes,
|
||||
)
|
||||
@@ -319,16 +227,12 @@ async def get_traffic_usage(
|
||||
sort_by: str = Query('total_bytes', max_length=100),
|
||||
sort_desc: bool = Query(True),
|
||||
tariffs: str = Query('', max_length=500),
|
||||
nodes: str = Query('', max_length=2000),
|
||||
statuses: str = Query('', max_length=200),
|
||||
start_date: str = Query('', max_length=10),
|
||||
end_date: str = Query('', max_length=10),
|
||||
):
|
||||
"""Get paginated per-user traffic usage by node."""
|
||||
start_str, end_str, period_days = _resolve_date_range(period, start_date, end_date)
|
||||
_validate_period(period)
|
||||
|
||||
user_map = await _load_user_map(db)
|
||||
user_traffic, nodes_info, devices_map = await _aggregate_traffic(start_str, end_str, list(user_map.keys()))
|
||||
user_traffic, nodes_info = await _aggregate_traffic(period, list(user_map.keys()))
|
||||
|
||||
# Collect all available tariff names (before filtering)
|
||||
available_tariffs = sorted(
|
||||
@@ -339,66 +243,30 @@ async def get_traffic_usage(
|
||||
}
|
||||
)
|
||||
|
||||
# Collect all available statuses (before filtering)
|
||||
available_statuses = sorted(
|
||||
{
|
||||
(u.subscription.actual_status if hasattr(u.subscription, 'actual_status') else u.subscription.status)
|
||||
for u in user_map.values()
|
||||
if u.subscription
|
||||
}
|
||||
)
|
||||
|
||||
# Parse tariff filter
|
||||
tariff_filter: set[str] | None = None
|
||||
if tariffs.strip():
|
||||
tariff_filter = {t.strip() for t in tariffs.split(',') if t.strip()}
|
||||
|
||||
# Parse node filter
|
||||
node_filter: set[str] | None = None
|
||||
if nodes.strip():
|
||||
node_filter = {n.strip() for n in nodes.split(',') if n.strip()}
|
||||
|
||||
# Parse status filter
|
||||
status_filter: set[str] | None = None
|
||||
if statuses.strip():
|
||||
status_filter = {s.strip() for s in statuses.split(',') if s.strip()}
|
||||
|
||||
# Validate sort_by: allow known fields + 'node_<uuid>' for dynamic node columns
|
||||
node_uuids = {n.node_uuid for n in nodes_info}
|
||||
is_node_sort = sort_by.startswith('node_') and sort_by[5:] in node_uuids
|
||||
if sort_by not in _SORT_FIELDS and not is_node_sort:
|
||||
sort_by = 'total_bytes'
|
||||
|
||||
items = _build_traffic_items(
|
||||
user_traffic,
|
||||
user_map,
|
||||
nodes_info,
|
||||
devices_map,
|
||||
search,
|
||||
sort_by,
|
||||
sort_desc,
|
||||
tariff_filter,
|
||||
node_filter,
|
||||
status_filter,
|
||||
)
|
||||
items = _build_traffic_items(user_traffic, user_map, nodes_info, search, sort_by, sort_desc, tariff_filter)
|
||||
|
||||
total = len(items)
|
||||
paginated = items[offset : offset + limit]
|
||||
|
||||
# Filter nodes_info to only selected nodes for frontend column display
|
||||
filtered_nodes = nodes_info
|
||||
if node_filter is not None:
|
||||
filtered_nodes = [n for n in nodes_info if n.node_uuid in node_filter]
|
||||
|
||||
return TrafficUsageResponse(
|
||||
items=paginated,
|
||||
nodes=filtered_nodes,
|
||||
nodes=nodes_info,
|
||||
total=total,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
period_days=period_days,
|
||||
period_days=period,
|
||||
available_tariffs=available_tariffs,
|
||||
available_statuses=available_statuses,
|
||||
)
|
||||
|
||||
|
||||
@@ -409,45 +277,17 @@ async def export_traffic_csv(
|
||||
db: AsyncSession = Depends(get_cabinet_db),
|
||||
):
|
||||
"""Generate CSV with traffic usage and send to admin's Telegram DM."""
|
||||
_validate_period(request.period)
|
||||
|
||||
if not admin.telegram_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Admin has no Telegram ID configured',
|
||||
)
|
||||
|
||||
start_str, end_str, period_days = _resolve_date_range(request.period, request.start_date, request.end_date)
|
||||
|
||||
user_map = await _load_user_map(db)
|
||||
user_traffic, nodes_info, devices_map = await _aggregate_traffic(start_str, end_str, list(user_map.keys()))
|
||||
|
||||
# Parse filters
|
||||
tariff_filter: set[str] | None = None
|
||||
if request.tariffs.strip():
|
||||
tariff_filter = {t.strip() for t in request.tariffs.split(',') if t.strip()}
|
||||
|
||||
node_filter: set[str] | None = None
|
||||
if request.nodes.strip():
|
||||
node_filter = {n.strip() for n in request.nodes.split(',') if n.strip()}
|
||||
|
||||
status_filter: set[str] | None = None
|
||||
if request.statuses.strip():
|
||||
status_filter = {s.strip() for s in request.statuses.split(',') if s.strip()}
|
||||
|
||||
items = _build_traffic_items(
|
||||
user_traffic,
|
||||
user_map,
|
||||
nodes_info,
|
||||
devices_map,
|
||||
search=request.search,
|
||||
tariff_filter=tariff_filter,
|
||||
node_filter=node_filter,
|
||||
status_filter=status_filter,
|
||||
)
|
||||
|
||||
# Filter node columns for CSV if node filter active
|
||||
csv_nodes = nodes_info
|
||||
if node_filter is not None:
|
||||
csv_nodes = [n for n in nodes_info if n.node_uuid in node_filter]
|
||||
user_traffic, nodes_info = await _aggregate_traffic(request.period, list(user_map.keys()))
|
||||
items = _build_traffic_items(user_traffic, user_map, nodes_info)
|
||||
|
||||
# Build CSV rows
|
||||
rows: list[dict] = []
|
||||
@@ -460,9 +300,9 @@ async def export_traffic_csv(
|
||||
'Tariff': item.tariff_name or '',
|
||||
'Status': item.subscription_status or '',
|
||||
'Traffic Limit (GB)': item.traffic_limit_gb,
|
||||
'Devices': f'{item.connected_devices}/{item.device_limit}',
|
||||
'Devices': item.device_limit,
|
||||
}
|
||||
for node in csv_nodes:
|
||||
for node in nodes_info:
|
||||
row[f'{node.node_name} (bytes)'] = item.node_traffic.get(node.node_uuid, 0)
|
||||
row['Total (bytes)'] = item.total_bytes
|
||||
row['Total (GB)'] = round(item.total_bytes / (1024**3), 2) if item.total_bytes else 0
|
||||
@@ -477,7 +317,7 @@ async def export_traffic_csv(
|
||||
csv_bytes = output.getvalue().encode('utf-8-sig')
|
||||
|
||||
timestamp = datetime.now(UTC).strftime('%Y%m%d_%H%M%S')
|
||||
filename = f'traffic_usage_{period_days}d_{timestamp}.csv'
|
||||
filename = f'traffic_usage_{request.period}d_{timestamp}.csv'
|
||||
|
||||
try:
|
||||
bot = Bot(
|
||||
@@ -488,7 +328,7 @@ async def export_traffic_csv(
|
||||
await bot.send_document(
|
||||
chat_id=admin.telegram_id,
|
||||
document=BufferedInputFile(csv_bytes, filename=filename),
|
||||
caption=f'Traffic usage report ({period_days}d)\nUsers: {len(rows)}',
|
||||
caption=f'Traffic usage report ({request.period}d)\nUsers: {len(rows)}',
|
||||
)
|
||||
except Exception:
|
||||
logger.error('Failed to send CSV to admin %s', admin.telegram_id, exc_info=True)
|
||||
|
||||
@@ -18,7 +18,6 @@ class UserTrafficItem(BaseModel):
|
||||
subscription_status: str | None
|
||||
traffic_limit_gb: float
|
||||
device_limit: int
|
||||
connected_devices: int = 0
|
||||
node_traffic: dict[str, int] # {node_uuid: total_bytes}
|
||||
total_bytes: int
|
||||
|
||||
@@ -31,17 +30,10 @@ class TrafficUsageResponse(BaseModel):
|
||||
limit: int
|
||||
period_days: int
|
||||
available_tariffs: list[str]
|
||||
available_statuses: list[str]
|
||||
|
||||
|
||||
class ExportCsvRequest(BaseModel):
|
||||
period: int = Field(30, ge=1, le=30)
|
||||
start_date: str = ''
|
||||
end_date: str = ''
|
||||
tariffs: str = ''
|
||||
nodes: str = ''
|
||||
statuses: str = ''
|
||||
search: str = ''
|
||||
|
||||
|
||||
class ExportCsvResponse(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user