Merge pull request #2566 from BEDOLAGA-DEV/feat/traffic-filters-daterange

feat: node/status filters + custom date range for traffic page
This commit is contained in:
Egor
2026-02-07 11:54:54 +03:00
committed by GitHub
2 changed files with 163 additions and 29 deletions

View File

@@ -37,8 +37,8 @@ router = APIRouter(prefix='/admin/traffic', tags=['Admin Traffic'])
_ALLOWED_PERIODS = frozenset({1, 3, 7, 14, 30})
_CONCURRENCY_LIMIT = 5 # Max parallel API calls to avoid rate limiting
# In-memory cache: {period_days: (timestamp, aggregated_data, nodes_info)}
_traffic_cache: dict[int, tuple[float, dict[str, dict[str, int]], list[TrafficNodeInfo]]] = {}
# In-memory cache: {(start_str, end_str): (timestamp, aggregated_data, nodes_info)}
_traffic_cache: dict[tuple[str, str], tuple[float, dict[str, dict[str, int]], list[TrafficNodeInfo]]] = {}
_CACHE_TTL = 300 # 5 minutes
_cache_lock = asyncio.Lock()
@@ -46,6 +46,11 @@ _cache_lock = asyncio.Lock()
_SORT_FIELDS = frozenset({'total_bytes', 'full_name', 'tariff_name', 'device_limit', 'traffic_limit_gb'})
def _get_status(sub) -> str | None:
"""Get subscription status via actual_status property."""
return sub.actual_status
def _validate_period(period: int) -> None:
if period not in _ALLOWED_PERIODS:
raise HTTPException(
@@ -55,9 +60,9 @@ def _validate_period(period: int) -> None:
async def _aggregate_traffic(
period_days: int, user_uuids: list[str]
start_str: str, end_str: str, user_uuids: list[str]
) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo]]:
"""Aggregate per-user traffic across all nodes for a given period.
"""Aggregate per-user traffic across all nodes for a given date range.
Uses legacy per-node endpoint to fetch all users' traffic per node —
O(nodes) API calls instead of O(users). The legacy endpoint returns
@@ -68,9 +73,11 @@ async def _aggregate_traffic(
user_traffic = {remnawave_uuid: {node_uuid: total_bytes, ...}}
nodes_info = [TrafficNodeInfo, ...]
"""
cache_key = (start_str, end_str)
# Quick check without lock
now = time.time()
cached = _traffic_cache.get(period_days)
cached = _traffic_cache.get(cache_key)
if cached and (now - cached[0]) < _CACHE_TTL:
return cached[1], cached[2]
@@ -78,7 +85,7 @@ async def _aggregate_traffic(
async with _cache_lock:
# Re-check after acquiring lock
now = time.time()
cached = _traffic_cache.get(period_days)
cached = _traffic_cache.get(cache_key)
if cached and (now - cached[0]) < _CACHE_TTL:
return cached[1], cached[2]
@@ -86,12 +93,6 @@ async def _aggregate_traffic(
if not service.is_configured:
return {}, []
end_date = datetime.now(UTC)
start_date = end_date - timedelta(days=period_days)
# Legacy endpoint expects date-time format
start_str = start_date.strftime('%Y-%m-%dT%H:%M:%SZ')
end_str = end_date.strftime('%Y-%m-%dT%H:%M:%SZ')
user_uuids_set = set(user_uuids)
async with service.get_api_client() as api:
@@ -127,10 +128,27 @@ async def _aggregate_traffic(
if uid and total > 0 and uid in user_uuids_set:
user_traffic.setdefault(uid, {})[node_uuid] = user_traffic.get(uid, {}).get(node_uuid, 0) + total
_traffic_cache[period_days] = (now, user_traffic, nodes_info)
_traffic_cache[cache_key] = (now, user_traffic, nodes_info)
# Evict expired entries to prevent unbounded growth
expired = [k for k, (ts, _, _) in _traffic_cache.items() if (now - ts) >= _CACHE_TTL]
for k in expired:
del _traffic_cache[k]
return user_traffic, nodes_info
def _compute_date_range(period_days: int) -> tuple[str, str]:
"""Compute ISO date-time range from period days.
Truncates to 5-minute intervals for stable cache keys.
"""
end_dt = datetime.now(UTC).replace(second=0, microsecond=0)
end_dt = end_dt.replace(minute=(end_dt.minute // 5) * 5)
start_dt = end_dt - timedelta(days=period_days)
return start_dt.strftime('%Y-%m-%dT%H:%M:%SZ'), end_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
async def _load_user_map(db: AsyncSession) -> dict[str, User]:
"""Load all users with remnawave_uuid, eagerly loading subscription + tariff."""
stmt = (
@@ -151,8 +169,10 @@ def _build_traffic_items(
sort_by: str = 'total_bytes',
sort_desc: bool = True,
tariff_filter: set[str] | None = None,
status_filter: set[str] | None = None,
node_filter: set[str] | None = None,
) -> list[UserTrafficItem]:
"""Merge traffic data with user data, apply search/tariff filters, return sorted list."""
"""Merge traffic data with user data, apply search/tariff/status/node filters, return sorted list."""
items: list[UserTrafficItem] = []
search_lower = search.lower().strip()
@@ -163,7 +183,6 @@ def _build_traffic_items(
continue
traffic = user_traffic.get(uuid, {})
total_bytes = sum(traffic.values())
full_name = user.full_name
username = user.username
@@ -179,7 +198,7 @@ def _build_traffic_items(
device_limit = 1
if sub:
subscription_status = sub.actual_status if hasattr(sub, 'actual_status') else sub.status
subscription_status = _get_status(sub)
traffic_limit_gb = float(sub.traffic_limit_gb or 0)
device_limit = sub.device_limit or 1
if sub.tariff:
@@ -189,6 +208,16 @@ def _build_traffic_items(
if (tariff_name or '') not in tariff_filter:
continue
if status_filter is not None:
if (subscription_status or '') not in status_filter:
continue
# Apply node filter: keep only selected nodes, recalculate total
if node_filter is not None:
traffic = {k: v for k, v in traffic.items() if k in node_filter}
total_bytes = sum(traffic.values())
items.append(
UserTrafficItem(
user_id=user.id,
@@ -227,12 +256,39 @@ async def get_traffic_usage(
sort_by: str = Query('total_bytes', max_length=100),
sort_desc: bool = Query(True),
tariffs: str = Query('', max_length=500),
statuses: str = Query('', max_length=500),
nodes: str = Query('', max_length=2000),
start_date: str = Query('', max_length=10),
end_date: str = Query('', max_length=10),
):
"""Get paginated per-user traffic usage by node."""
_validate_period(period)
# Determine date range: custom dates or period-based
if start_date.strip() and end_date.strip():
try:
start_dt = datetime.strptime(start_date.strip(), '%Y-%m-%d').replace(tzinfo=UTC)
end_dt = datetime.strptime(end_date.strip(), '%Y-%m-%d').replace(tzinfo=UTC, hour=23, minute=59, second=59)
except ValueError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Invalid date format. Use YYYY-MM-DD.')
now = datetime.now(UTC)
end_dt = min(end_dt, now)
if start_dt > end_dt:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='start_date must be before end_date.')
if (end_dt - start_dt).days > 31:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Date range cannot exceed 31 days.')
start_str = start_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
end_str = end_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
effective_period = (end_dt - start_dt).days or 1
else:
_validate_period(period)
start_str, end_str = _compute_date_range(period)
effective_period = period
user_map = await _load_user_map(db)
user_traffic, nodes_info = await _aggregate_traffic(period, list(user_map.keys()))
user_traffic, nodes_info = await _aggregate_traffic(start_str, end_str, list(user_map.keys()))
# Collect all available tariff names (before filtering)
available_tariffs = sorted(
@@ -243,18 +299,37 @@ async def get_traffic_usage(
}
)
# Collect all available statuses (before filtering)
available_statuses = sorted(
{_get_status(sub) for u in user_map.values() if (sub := u.subscription) and _get_status(sub)}
)
# Parse tariff filter
tariff_filter: set[str] | None = None
if tariffs.strip():
tariff_filter = {t.strip() for t in tariffs.split(',') if t.strip()}
# Parse status filter
status_filter: set[str] | None = None
if statuses.strip():
status_filter = {s.strip() for s in statuses.split(',') if s.strip()}
# Parse node filter
node_filter: set[str] | None = None
all_node_uuids = {n.node_uuid for n in nodes_info}
if nodes.strip():
node_filter = {n.strip() for n in nodes.split(',') if n.strip()} & all_node_uuids
if not node_filter:
node_filter = None # No valid nodes matched, treat as "all nodes"
# Validate sort_by: allow known fields + 'node_<uuid>' for dynamic node columns
node_uuids = {n.node_uuid for n in nodes_info}
is_node_sort = sort_by.startswith('node_') and sort_by[5:] in node_uuids
is_node_sort = sort_by.startswith('node_') and sort_by[5:] in all_node_uuids
if sort_by not in _SORT_FIELDS and not is_node_sort:
sort_by = 'total_bytes'
items = _build_traffic_items(user_traffic, user_map, nodes_info, search, sort_by, sort_desc, tariff_filter)
items = _build_traffic_items(
user_traffic, user_map, nodes_info, search, sort_by, sort_desc, tariff_filter, status_filter, node_filter
)
total = len(items)
paginated = items[offset : offset + limit]
@@ -265,8 +340,9 @@ async def get_traffic_usage(
total=total,
offset=offset,
limit=limit,
period_days=period,
period_days=effective_period,
available_tariffs=available_tariffs,
available_statuses=available_statuses,
)
@@ -277,17 +353,69 @@ async def export_traffic_csv(
db: AsyncSession = Depends(get_cabinet_db),
):
"""Generate CSV with traffic usage and send to admin's Telegram DM."""
_validate_period(request.period)
if not admin.telegram_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Admin has no Telegram ID configured',
)
# Determine date range: custom dates or period-based
if request.start_date and request.end_date:
try:
start_dt = datetime.strptime(request.start_date.strip(), '%Y-%m-%d').replace(tzinfo=UTC)
end_dt = datetime.strptime(request.end_date.strip(), '%Y-%m-%d').replace(
tzinfo=UTC, hour=23, minute=59, second=59
)
except ValueError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Invalid date format. Use YYYY-MM-DD.')
now = datetime.now(UTC)
end_dt = min(end_dt, now)
if start_dt > end_dt:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='start_date must be before end_date.')
if (end_dt - start_dt).days > 31:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Date range cannot exceed 31 days.')
start_str = start_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
end_str = end_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
period_label = f'{request.start_date}_{request.end_date}'
else:
_validate_period(request.period)
start_str, end_str = _compute_date_range(request.period)
period_label = f'{request.period}d'
user_map = await _load_user_map(db)
user_traffic, nodes_info = await _aggregate_traffic(request.period, list(user_map.keys()))
items = _build_traffic_items(user_traffic, user_map, nodes_info)
user_traffic, nodes_info = await _aggregate_traffic(start_str, end_str, list(user_map.keys()))
# Parse filters
tariff_filter: set[str] | None = None
if request.tariffs and request.tariffs.strip():
tariff_filter = {t.strip() for t in request.tariffs.split(',') if t.strip()}
status_filter: set[str] | None = None
if request.statuses and request.statuses.strip():
status_filter = {s.strip() for s in request.statuses.split(',') if s.strip()}
node_filter: set[str] | None = None
all_node_uuids = {n.node_uuid for n in nodes_info}
if request.nodes and request.nodes.strip():
node_filter = {n.strip() for n in request.nodes.split(',') if n.strip()} & all_node_uuids
if not node_filter:
node_filter = None
items = _build_traffic_items(
user_traffic,
user_map,
nodes_info,
sort_by='total_bytes',
sort_desc=True,
tariff_filter=tariff_filter,
status_filter=status_filter,
node_filter=node_filter,
)
# Determine which nodes to include in CSV columns
csv_nodes = [n for n in nodes_info if n.node_uuid in node_filter] if node_filter else nodes_info
# Build CSV rows
rows: list[dict] = []
@@ -302,7 +430,7 @@ async def export_traffic_csv(
'Traffic Limit (GB)': item.traffic_limit_gb,
'Devices': item.device_limit,
}
for node in nodes_info:
for node in csv_nodes:
row[f'{node.node_name} (bytes)'] = item.node_traffic.get(node.node_uuid, 0)
row['Total (bytes)'] = item.total_bytes
row['Total (GB)'] = round(item.total_bytes / (1024**3), 2) if item.total_bytes else 0
@@ -317,7 +445,7 @@ async def export_traffic_csv(
csv_bytes = output.getvalue().encode('utf-8-sig')
timestamp = datetime.now(UTC).strftime('%Y%m%d_%H%M%S')
filename = f'traffic_usage_{request.period}d_{timestamp}.csv'
filename = f'traffic_usage_{period_label}_{timestamp}.csv'
try:
bot = Bot(
@@ -328,7 +456,7 @@ async def export_traffic_csv(
await bot.send_document(
chat_id=admin.telegram_id,
document=BufferedInputFile(csv_bytes, filename=filename),
caption=f'Traffic usage report ({request.period}d)\nUsers: {len(rows)}',
caption=f'Traffic usage report ({period_label})\nUsers: {len(rows)}',
)
except Exception:
logger.error('Failed to send CSV to admin %s', admin.telegram_id, exc_info=True)

View File

@@ -30,10 +30,16 @@ class TrafficUsageResponse(BaseModel):
limit: int
period_days: int
available_tariffs: list[str]
available_statuses: list[str]
class ExportCsvRequest(BaseModel):
period: int = Field(30, ge=1, le=30)
start_date: str | None = None
end_date: str | None = None
tariffs: str | None = None
statuses: str | None = None
nodes: str | None = None
class ExportCsvResponse(BaseModel):