mirror of
https://github.com/BEDOLAGA-DEV/remnawave-bedolaga-telegram-bot.git
synced 2026-02-28 15:23:35 +00:00
Merge pull request #2566 from BEDOLAGA-DEV/feat/traffic-filters-daterange
feat: node/status filters + custom date range for traffic page
This commit is contained in:
@@ -37,8 +37,8 @@ router = APIRouter(prefix='/admin/traffic', tags=['Admin Traffic'])
|
||||
_ALLOWED_PERIODS = frozenset({1, 3, 7, 14, 30})
|
||||
_CONCURRENCY_LIMIT = 5 # Max parallel API calls to avoid rate limiting
|
||||
|
||||
# In-memory cache: {period_days: (timestamp, aggregated_data, nodes_info)}
|
||||
_traffic_cache: dict[int, tuple[float, dict[str, dict[str, int]], list[TrafficNodeInfo]]] = {}
|
||||
# In-memory cache: {(start_str, end_str): (timestamp, aggregated_data, nodes_info)}
|
||||
_traffic_cache: dict[tuple[str, str], tuple[float, dict[str, dict[str, int]], list[TrafficNodeInfo]]] = {}
|
||||
_CACHE_TTL = 300 # 5 minutes
|
||||
_cache_lock = asyncio.Lock()
|
||||
|
||||
@@ -46,6 +46,11 @@ _cache_lock = asyncio.Lock()
|
||||
_SORT_FIELDS = frozenset({'total_bytes', 'full_name', 'tariff_name', 'device_limit', 'traffic_limit_gb'})
|
||||
|
||||
|
||||
def _get_status(sub) -> str | None:
|
||||
"""Get subscription status via actual_status property."""
|
||||
return sub.actual_status
|
||||
|
||||
|
||||
def _validate_period(period: int) -> None:
|
||||
if period not in _ALLOWED_PERIODS:
|
||||
raise HTTPException(
|
||||
@@ -55,9 +60,9 @@ def _validate_period(period: int) -> None:
|
||||
|
||||
|
||||
async def _aggregate_traffic(
|
||||
period_days: int, user_uuids: list[str]
|
||||
start_str: str, end_str: str, user_uuids: list[str]
|
||||
) -> tuple[dict[str, dict[str, int]], list[TrafficNodeInfo]]:
|
||||
"""Aggregate per-user traffic across all nodes for a given period.
|
||||
"""Aggregate per-user traffic across all nodes for a given date range.
|
||||
|
||||
Uses legacy per-node endpoint to fetch all users' traffic per node —
|
||||
O(nodes) API calls instead of O(users). The legacy endpoint returns
|
||||
@@ -68,9 +73,11 @@ async def _aggregate_traffic(
|
||||
user_traffic = {remnawave_uuid: {node_uuid: total_bytes, ...}}
|
||||
nodes_info = [TrafficNodeInfo, ...]
|
||||
"""
|
||||
cache_key = (start_str, end_str)
|
||||
|
||||
# Quick check without lock
|
||||
now = time.time()
|
||||
cached = _traffic_cache.get(period_days)
|
||||
cached = _traffic_cache.get(cache_key)
|
||||
if cached and (now - cached[0]) < _CACHE_TTL:
|
||||
return cached[1], cached[2]
|
||||
|
||||
@@ -78,7 +85,7 @@ async def _aggregate_traffic(
|
||||
async with _cache_lock:
|
||||
# Re-check after acquiring lock
|
||||
now = time.time()
|
||||
cached = _traffic_cache.get(period_days)
|
||||
cached = _traffic_cache.get(cache_key)
|
||||
if cached and (now - cached[0]) < _CACHE_TTL:
|
||||
return cached[1], cached[2]
|
||||
|
||||
@@ -86,12 +93,6 @@ async def _aggregate_traffic(
|
||||
if not service.is_configured:
|
||||
return {}, []
|
||||
|
||||
end_date = datetime.now(UTC)
|
||||
start_date = end_date - timedelta(days=period_days)
|
||||
# Legacy endpoint expects date-time format
|
||||
start_str = start_date.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
end_str = end_date.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
|
||||
user_uuids_set = set(user_uuids)
|
||||
|
||||
async with service.get_api_client() as api:
|
||||
@@ -127,10 +128,27 @@ async def _aggregate_traffic(
|
||||
if uid and total > 0 and uid in user_uuids_set:
|
||||
user_traffic.setdefault(uid, {})[node_uuid] = user_traffic.get(uid, {}).get(node_uuid, 0) + total
|
||||
|
||||
_traffic_cache[period_days] = (now, user_traffic, nodes_info)
|
||||
_traffic_cache[cache_key] = (now, user_traffic, nodes_info)
|
||||
|
||||
# Evict expired entries to prevent unbounded growth
|
||||
expired = [k for k, (ts, _, _) in _traffic_cache.items() if (now - ts) >= _CACHE_TTL]
|
||||
for k in expired:
|
||||
del _traffic_cache[k]
|
||||
|
||||
return user_traffic, nodes_info
|
||||
|
||||
|
||||
def _compute_date_range(period_days: int) -> tuple[str, str]:
|
||||
"""Compute ISO date-time range from period days.
|
||||
|
||||
Truncates to 5-minute intervals for stable cache keys.
|
||||
"""
|
||||
end_dt = datetime.now(UTC).replace(second=0, microsecond=0)
|
||||
end_dt = end_dt.replace(minute=(end_dt.minute // 5) * 5)
|
||||
start_dt = end_dt - timedelta(days=period_days)
|
||||
return start_dt.strftime('%Y-%m-%dT%H:%M:%SZ'), end_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
|
||||
|
||||
async def _load_user_map(db: AsyncSession) -> dict[str, User]:
|
||||
"""Load all users with remnawave_uuid, eagerly loading subscription + tariff."""
|
||||
stmt = (
|
||||
@@ -151,8 +169,10 @@ def _build_traffic_items(
|
||||
sort_by: str = 'total_bytes',
|
||||
sort_desc: bool = True,
|
||||
tariff_filter: set[str] | None = None,
|
||||
status_filter: set[str] | None = None,
|
||||
node_filter: set[str] | None = None,
|
||||
) -> list[UserTrafficItem]:
|
||||
"""Merge traffic data with user data, apply search/tariff filters, return sorted list."""
|
||||
"""Merge traffic data with user data, apply search/tariff/status/node filters, return sorted list."""
|
||||
items: list[UserTrafficItem] = []
|
||||
search_lower = search.lower().strip()
|
||||
|
||||
@@ -163,7 +183,6 @@ def _build_traffic_items(
|
||||
continue
|
||||
|
||||
traffic = user_traffic.get(uuid, {})
|
||||
total_bytes = sum(traffic.values())
|
||||
|
||||
full_name = user.full_name
|
||||
username = user.username
|
||||
@@ -179,7 +198,7 @@ def _build_traffic_items(
|
||||
device_limit = 1
|
||||
|
||||
if sub:
|
||||
subscription_status = sub.actual_status if hasattr(sub, 'actual_status') else sub.status
|
||||
subscription_status = _get_status(sub)
|
||||
traffic_limit_gb = float(sub.traffic_limit_gb or 0)
|
||||
device_limit = sub.device_limit or 1
|
||||
if sub.tariff:
|
||||
@@ -189,6 +208,16 @@ def _build_traffic_items(
|
||||
if (tariff_name or '') not in tariff_filter:
|
||||
continue
|
||||
|
||||
if status_filter is not None:
|
||||
if (subscription_status or '') not in status_filter:
|
||||
continue
|
||||
|
||||
# Apply node filter: keep only selected nodes, recalculate total
|
||||
if node_filter is not None:
|
||||
traffic = {k: v for k, v in traffic.items() if k in node_filter}
|
||||
|
||||
total_bytes = sum(traffic.values())
|
||||
|
||||
items.append(
|
||||
UserTrafficItem(
|
||||
user_id=user.id,
|
||||
@@ -227,12 +256,39 @@ async def get_traffic_usage(
|
||||
sort_by: str = Query('total_bytes', max_length=100),
|
||||
sort_desc: bool = Query(True),
|
||||
tariffs: str = Query('', max_length=500),
|
||||
statuses: str = Query('', max_length=500),
|
||||
nodes: str = Query('', max_length=2000),
|
||||
start_date: str = Query('', max_length=10),
|
||||
end_date: str = Query('', max_length=10),
|
||||
):
|
||||
"""Get paginated per-user traffic usage by node."""
|
||||
_validate_period(period)
|
||||
# Determine date range: custom dates or period-based
|
||||
if start_date.strip() and end_date.strip():
|
||||
try:
|
||||
start_dt = datetime.strptime(start_date.strip(), '%Y-%m-%d').replace(tzinfo=UTC)
|
||||
end_dt = datetime.strptime(end_date.strip(), '%Y-%m-%d').replace(tzinfo=UTC, hour=23, minute=59, second=59)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Invalid date format. Use YYYY-MM-DD.')
|
||||
|
||||
now = datetime.now(UTC)
|
||||
end_dt = min(end_dt, now)
|
||||
|
||||
if start_dt > end_dt:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='start_date must be before end_date.')
|
||||
|
||||
if (end_dt - start_dt).days > 31:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Date range cannot exceed 31 days.')
|
||||
|
||||
start_str = start_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
end_str = end_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
effective_period = (end_dt - start_dt).days or 1
|
||||
else:
|
||||
_validate_period(period)
|
||||
start_str, end_str = _compute_date_range(period)
|
||||
effective_period = period
|
||||
|
||||
user_map = await _load_user_map(db)
|
||||
user_traffic, nodes_info = await _aggregate_traffic(period, list(user_map.keys()))
|
||||
user_traffic, nodes_info = await _aggregate_traffic(start_str, end_str, list(user_map.keys()))
|
||||
|
||||
# Collect all available tariff names (before filtering)
|
||||
available_tariffs = sorted(
|
||||
@@ -243,18 +299,37 @@ async def get_traffic_usage(
|
||||
}
|
||||
)
|
||||
|
||||
# Collect all available statuses (before filtering)
|
||||
available_statuses = sorted(
|
||||
{_get_status(sub) for u in user_map.values() if (sub := u.subscription) and _get_status(sub)}
|
||||
)
|
||||
|
||||
# Parse tariff filter
|
||||
tariff_filter: set[str] | None = None
|
||||
if tariffs.strip():
|
||||
tariff_filter = {t.strip() for t in tariffs.split(',') if t.strip()}
|
||||
|
||||
# Parse status filter
|
||||
status_filter: set[str] | None = None
|
||||
if statuses.strip():
|
||||
status_filter = {s.strip() for s in statuses.split(',') if s.strip()}
|
||||
|
||||
# Parse node filter
|
||||
node_filter: set[str] | None = None
|
||||
all_node_uuids = {n.node_uuid for n in nodes_info}
|
||||
if nodes.strip():
|
||||
node_filter = {n.strip() for n in nodes.split(',') if n.strip()} & all_node_uuids
|
||||
if not node_filter:
|
||||
node_filter = None # No valid nodes matched, treat as "all nodes"
|
||||
|
||||
# Validate sort_by: allow known fields + 'node_<uuid>' for dynamic node columns
|
||||
node_uuids = {n.node_uuid for n in nodes_info}
|
||||
is_node_sort = sort_by.startswith('node_') and sort_by[5:] in node_uuids
|
||||
is_node_sort = sort_by.startswith('node_') and sort_by[5:] in all_node_uuids
|
||||
if sort_by not in _SORT_FIELDS and not is_node_sort:
|
||||
sort_by = 'total_bytes'
|
||||
|
||||
items = _build_traffic_items(user_traffic, user_map, nodes_info, search, sort_by, sort_desc, tariff_filter)
|
||||
items = _build_traffic_items(
|
||||
user_traffic, user_map, nodes_info, search, sort_by, sort_desc, tariff_filter, status_filter, node_filter
|
||||
)
|
||||
|
||||
total = len(items)
|
||||
paginated = items[offset : offset + limit]
|
||||
@@ -265,8 +340,9 @@ async def get_traffic_usage(
|
||||
total=total,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
period_days=period,
|
||||
period_days=effective_period,
|
||||
available_tariffs=available_tariffs,
|
||||
available_statuses=available_statuses,
|
||||
)
|
||||
|
||||
|
||||
@@ -277,17 +353,69 @@ async def export_traffic_csv(
|
||||
db: AsyncSession = Depends(get_cabinet_db),
|
||||
):
|
||||
"""Generate CSV with traffic usage and send to admin's Telegram DM."""
|
||||
_validate_period(request.period)
|
||||
|
||||
if not admin.telegram_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Admin has no Telegram ID configured',
|
||||
)
|
||||
|
||||
# Determine date range: custom dates or period-based
|
||||
if request.start_date and request.end_date:
|
||||
try:
|
||||
start_dt = datetime.strptime(request.start_date.strip(), '%Y-%m-%d').replace(tzinfo=UTC)
|
||||
end_dt = datetime.strptime(request.end_date.strip(), '%Y-%m-%d').replace(
|
||||
tzinfo=UTC, hour=23, minute=59, second=59
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Invalid date format. Use YYYY-MM-DD.')
|
||||
|
||||
now = datetime.now(UTC)
|
||||
end_dt = min(end_dt, now)
|
||||
if start_dt > end_dt:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='start_date must be before end_date.')
|
||||
if (end_dt - start_dt).days > 31:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Date range cannot exceed 31 days.')
|
||||
|
||||
start_str = start_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
end_str = end_dt.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
period_label = f'{request.start_date}_{request.end_date}'
|
||||
else:
|
||||
_validate_period(request.period)
|
||||
start_str, end_str = _compute_date_range(request.period)
|
||||
period_label = f'{request.period}d'
|
||||
|
||||
user_map = await _load_user_map(db)
|
||||
user_traffic, nodes_info = await _aggregate_traffic(request.period, list(user_map.keys()))
|
||||
items = _build_traffic_items(user_traffic, user_map, nodes_info)
|
||||
user_traffic, nodes_info = await _aggregate_traffic(start_str, end_str, list(user_map.keys()))
|
||||
|
||||
# Parse filters
|
||||
tariff_filter: set[str] | None = None
|
||||
if request.tariffs and request.tariffs.strip():
|
||||
tariff_filter = {t.strip() for t in request.tariffs.split(',') if t.strip()}
|
||||
|
||||
status_filter: set[str] | None = None
|
||||
if request.statuses and request.statuses.strip():
|
||||
status_filter = {s.strip() for s in request.statuses.split(',') if s.strip()}
|
||||
|
||||
node_filter: set[str] | None = None
|
||||
all_node_uuids = {n.node_uuid for n in nodes_info}
|
||||
if request.nodes and request.nodes.strip():
|
||||
node_filter = {n.strip() for n in request.nodes.split(',') if n.strip()} & all_node_uuids
|
||||
if not node_filter:
|
||||
node_filter = None
|
||||
|
||||
items = _build_traffic_items(
|
||||
user_traffic,
|
||||
user_map,
|
||||
nodes_info,
|
||||
sort_by='total_bytes',
|
||||
sort_desc=True,
|
||||
tariff_filter=tariff_filter,
|
||||
status_filter=status_filter,
|
||||
node_filter=node_filter,
|
||||
)
|
||||
|
||||
# Determine which nodes to include in CSV columns
|
||||
csv_nodes = [n for n in nodes_info if n.node_uuid in node_filter] if node_filter else nodes_info
|
||||
|
||||
# Build CSV rows
|
||||
rows: list[dict] = []
|
||||
@@ -302,7 +430,7 @@ async def export_traffic_csv(
|
||||
'Traffic Limit (GB)': item.traffic_limit_gb,
|
||||
'Devices': item.device_limit,
|
||||
}
|
||||
for node in nodes_info:
|
||||
for node in csv_nodes:
|
||||
row[f'{node.node_name} (bytes)'] = item.node_traffic.get(node.node_uuid, 0)
|
||||
row['Total (bytes)'] = item.total_bytes
|
||||
row['Total (GB)'] = round(item.total_bytes / (1024**3), 2) if item.total_bytes else 0
|
||||
@@ -317,7 +445,7 @@ async def export_traffic_csv(
|
||||
csv_bytes = output.getvalue().encode('utf-8-sig')
|
||||
|
||||
timestamp = datetime.now(UTC).strftime('%Y%m%d_%H%M%S')
|
||||
filename = f'traffic_usage_{request.period}d_{timestamp}.csv'
|
||||
filename = f'traffic_usage_{period_label}_{timestamp}.csv'
|
||||
|
||||
try:
|
||||
bot = Bot(
|
||||
@@ -328,7 +456,7 @@ async def export_traffic_csv(
|
||||
await bot.send_document(
|
||||
chat_id=admin.telegram_id,
|
||||
document=BufferedInputFile(csv_bytes, filename=filename),
|
||||
caption=f'Traffic usage report ({request.period}d)\nUsers: {len(rows)}',
|
||||
caption=f'Traffic usage report ({period_label})\nUsers: {len(rows)}',
|
||||
)
|
||||
except Exception:
|
||||
logger.error('Failed to send CSV to admin %s', admin.telegram_id, exc_info=True)
|
||||
|
||||
@@ -30,10 +30,16 @@ class TrafficUsageResponse(BaseModel):
|
||||
limit: int
|
||||
period_days: int
|
||||
available_tariffs: list[str]
|
||||
available_statuses: list[str]
|
||||
|
||||
|
||||
class ExportCsvRequest(BaseModel):
|
||||
period: int = Field(30, ge=1, le=30)
|
||||
start_date: str | None = None
|
||||
end_date: str | None = None
|
||||
tariffs: str | None = None
|
||||
statuses: str | None = None
|
||||
nodes: str | None = None
|
||||
|
||||
|
||||
class ExportCsvResponse(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user