Files
remnawave-bedolaga-telegram…/app/webapi/routes/servers.py

419 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Маршруты управления серверами в административном API."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Iterable, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Security, status
from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.database.crud.server_squad import (
create_server_squad,
delete_server_squad,
get_server_connected_users,
get_server_squad_by_id,
get_server_squad_by_uuid,
get_server_statistics,
sync_server_user_counts,
sync_with_remnawave,
update_server_squad,
update_server_squad_promo_groups,
)
from app.database.models import PromoGroup, ServerSquad, User
from app.utils.cache import cache
from ..dependencies import get_db_session, require_api_token
from ..schemas.servers import (
ServerConnectedUser,
ServerConnectedUsersResponse,
ServerCountsSyncResponse,
ServerCreateRequest,
ServerDeleteResponse,
ServerListResponse,
ServerResponse,
ServerStatisticsResponse,
ServerSyncResponse,
ServerUpdateRequest,
)
from ..schemas.users import PromoGroupSummary
try: # pragma: no cover - импорт может провалиться без optional-зависимостей
from app.services.remnawave_service import RemnaWaveService # type: ignore
except Exception: # pragma: no cover - скрываем функционал, если сервис недоступен
RemnaWaveService = None # type: ignore[assignment]
if TYPE_CHECKING: # pragma: no cover - только для подсказок типов в IDE
from app.services.remnawave_service import ( # type: ignore
RemnaWaveService as RemnaWaveServiceType,
)
else:
RemnaWaveServiceType = Any
router = APIRouter()
def _serialize_promo_group(group: PromoGroup) -> PromoGroupSummary:
return PromoGroupSummary(
id=group.id,
name=group.name,
server_discount_percent=group.server_discount_percent,
traffic_discount_percent=group.traffic_discount_percent,
device_discount_percent=group.device_discount_percent,
apply_discounts_to_addons=getattr(group, "apply_discounts_to_addons", True),
)
def _serialize_server(server: ServerSquad) -> ServerResponse:
promo_groups = [
_serialize_promo_group(group)
for group in sorted(
getattr(server, "allowed_promo_groups", []) or [],
key=lambda pg: pg.name.lower() if getattr(pg, "name", None) else "",
)
]
return ServerResponse(
id=server.id,
squad_uuid=server.squad_uuid,
display_name=server.display_name,
original_name=server.original_name,
country_code=server.country_code,
is_available=bool(server.is_available),
is_trial_eligible=bool(server.is_trial_eligible),
price_kopeks=int(server.price_kopeks or 0),
price_rubles=round((server.price_kopeks or 0) / 100, 2),
description=server.description,
sort_order=int(server.sort_order or 0),
max_users=server.max_users,
current_users=int(server.current_users or 0),
created_at=getattr(server, "created_at", None),
updated_at=getattr(server, "updated_at", None),
promo_groups=promo_groups,
)
def _serialize_connected_user(user: User) -> ServerConnectedUser:
subscription = getattr(user, "subscription", None)
subscription_status = getattr(subscription, "status", None)
if hasattr(subscription_status, "value"):
subscription_status = subscription_status.value
return ServerConnectedUser(
id=user.id,
telegram_id=user.telegram_id,
username=user.username,
first_name=user.first_name,
last_name=user.last_name,
status=getattr(getattr(user, "status", None), "value", user.status),
balance_kopeks=int(user.balance_kopeks or 0),
balance_rubles=round((user.balance_kopeks or 0) / 100, 2),
subscription_id=getattr(subscription, "id", None),
subscription_status=subscription_status,
subscription_end_date=getattr(subscription, "end_date", None),
)
def _apply_filters(
filters: Iterable[Any],
query,
):
for condition in filters:
query = query.where(condition)
return query
def _get_remnawave_service() -> "RemnaWaveServiceType":
if RemnaWaveService is None: # pragma: no cover - зависимость не доступна
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="RemnaWave сервис недоступен",
)
return RemnaWaveService()
def _ensure_service_configured(service: "RemnaWaveServiceType") -> None:
if RemnaWaveService is None: # pragma: no cover - зависимость не доступна
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="RemnaWave сервис недоступен",
)
if not service.is_configured:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=service.configuration_error or "RemnaWave API не настроен",
)
async def _validate_promo_group_ids(
db: AsyncSession, promo_group_ids: Iterable[int]
) -> List[int]:
unique_ids = [int(pg_id) for pg_id in set(promo_group_ids)]
if not unique_ids:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Нужно выбрать хотя бы одну промогруппу",
)
result = await db.execute(
select(PromoGroup.id).where(PromoGroup.id.in_(unique_ids))
)
found_ids = result.scalars().all()
if not found_ids:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Не найдены промогруппы для обновления сервера",
)
return unique_ids
@router.get("", response_model=ServerListResponse)
async def list_servers(
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=200),
available_only: bool = Query(False, alias="available"),
search: Optional[str] = Query(default=None),
) -> ServerListResponse:
filters = []
if available_only:
filters.append(ServerSquad.is_available.is_(True))
if search:
pattern = f"%{search.lower()}%"
filters.append(
or_(
func.lower(ServerSquad.display_name).like(pattern),
func.lower(ServerSquad.original_name).like(pattern),
func.lower(ServerSquad.squad_uuid).like(pattern),
func.lower(ServerSquad.country_code).like(pattern),
)
)
base_query = (
select(ServerSquad)
.options(selectinload(ServerSquad.allowed_promo_groups))
.order_by(ServerSquad.sort_order, ServerSquad.display_name)
)
count_query = select(func.count(ServerSquad.id))
if filters:
base_query = _apply_filters(filters, base_query)
count_query = _apply_filters(filters, count_query)
total = await db.scalar(count_query) or 0
result = await db.execute(
base_query.offset((page - 1) * limit).limit(limit)
)
servers = result.scalars().unique().all()
return ServerListResponse(
items=[_serialize_server(server) for server in servers],
total=int(total),
page=page,
limit=limit,
)
@router.get("/stats", response_model=ServerStatisticsResponse)
async def get_servers_statistics(
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerStatisticsResponse:
stats = await get_server_statistics(db)
return ServerStatisticsResponse(
total_servers=int(stats.get("total_servers", 0) or 0),
available_servers=int(stats.get("available_servers", 0) or 0),
unavailable_servers=int(stats.get("unavailable_servers", 0) or 0),
servers_with_connections=int(stats.get("servers_with_connections", 0) or 0),
total_revenue_kopeks=int(stats.get("total_revenue_kopeks", 0) or 0),
total_revenue_rubles=float(stats.get("total_revenue_rubles", 0) or 0),
)
@router.post("", response_model=ServerResponse, status_code=status.HTTP_201_CREATED)
async def create_server_endpoint(
payload: ServerCreateRequest,
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerResponse:
existing = await get_server_squad_by_uuid(db, payload.squad_uuid)
if existing:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Server with this UUID already exists",
)
try:
server = await create_server_squad(
db,
squad_uuid=payload.squad_uuid,
display_name=payload.display_name,
original_name=payload.original_name,
country_code=payload.country_code,
price_kopeks=payload.price_kopeks,
description=payload.description,
max_users=payload.max_users,
is_available=payload.is_available,
is_trial_eligible=payload.is_trial_eligible,
sort_order=payload.sort_order,
promo_group_ids=payload.promo_group_ids,
)
except ValueError as error:
raise HTTPException(status.HTTP_400_BAD_REQUEST, str(error)) from error
await cache.delete_pattern("available_countries*")
server = await get_server_squad_by_id(db, server.id)
assert server is not None
return _serialize_server(server)
@router.get("/{server_id}", response_model=ServerResponse)
async def get_server_endpoint(
server_id: int,
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerResponse:
server = await get_server_squad_by_id(db, server_id)
if not server:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Server not found")
return _serialize_server(server)
@router.patch("/{server_id}", response_model=ServerResponse)
async def update_server_endpoint(
server_id: int,
payload: ServerUpdateRequest,
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerResponse:
server = await get_server_squad_by_id(db, server_id)
if not server:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Server not found")
updates = payload.model_dump(exclude_unset=True, by_alias=False)
promo_group_ids = updates.pop("promo_group_ids", None)
validated_promo_group_ids: Optional[List[int]] = None
if promo_group_ids is not None:
validated_promo_group_ids = await _validate_promo_group_ids(
db, promo_group_ids
)
if updates:
server = await update_server_squad(db, server_id, **updates) or server
if promo_group_ids is not None:
try:
assert validated_promo_group_ids is not None
server = await update_server_squad_promo_groups(
db, server_id, validated_promo_group_ids
) or server
except ValueError as error:
raise HTTPException(status.HTTP_400_BAD_REQUEST, str(error)) from error
await cache.delete_pattern("available_countries*")
server = await get_server_squad_by_id(db, server_id)
assert server is not None
return _serialize_server(server)
@router.delete("/{server_id}", response_model=ServerDeleteResponse)
async def delete_server_endpoint(
server_id: int,
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerDeleteResponse:
server = await get_server_squad_by_id(db, server_id)
if not server:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Server not found")
deleted = await delete_server_squad(db, server_id)
if not deleted:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Server cannot be deleted because it has active connections",
)
await cache.delete_pattern("available_countries*")
return ServerDeleteResponse(success=True, message="Server deleted")
@router.get(
"/{server_id}/users",
response_model=ServerConnectedUsersResponse,
)
async def get_server_connected_users_endpoint(
server_id: int,
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
) -> ServerConnectedUsersResponse:
server = await get_server_squad_by_id(db, server_id)
if not server:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Server not found")
users = await get_server_connected_users(db, server_id)
total = len(users)
sliced = users[offset : offset + limit]
return ServerConnectedUsersResponse(
items=[_serialize_connected_user(user) for user in sliced],
total=total,
limit=limit,
offset=offset,
)
@router.post("/sync", response_model=ServerSyncResponse)
async def sync_servers_with_remnawave(
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerSyncResponse:
service = _get_remnawave_service()
_ensure_service_configured(service)
squads = await service.get_all_squads()
total = len(squads)
created = updated = removed = 0
if squads:
created, updated, removed = await sync_with_remnawave(db, squads)
await cache.delete_pattern("available_countries*")
return ServerSyncResponse(
created=created,
updated=updated,
removed=removed,
total=total,
)
@router.post("/sync-counts", response_model=ServerCountsSyncResponse)
async def sync_server_counts(
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerCountsSyncResponse:
updated = await sync_server_user_counts(db)
return ServerCountsSyncResponse(updated=updated)