"""Маршруты управления серверами в административном API.""" from __future__ import annotations from typing import TYPE_CHECKING, Any, Iterable, List, Optional from fastapi import APIRouter, Depends, HTTPException, Query, Security, status from sqlalchemy import func, or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.database.crud.server_squad import ( create_server_squad, delete_server_squad, get_server_connected_users, get_server_squad_by_id, get_server_squad_by_uuid, get_server_statistics, sync_server_user_counts, sync_with_remnawave, update_server_squad, update_server_squad_promo_groups, ) from app.database.models import PromoGroup, ServerSquad, User from app.utils.cache import cache from ..dependencies import get_db_session, require_api_token from ..schemas.servers import ( ServerConnectedUser, ServerConnectedUsersResponse, ServerCountsSyncResponse, ServerCreateRequest, ServerDeleteResponse, ServerListResponse, ServerResponse, ServerStatisticsResponse, ServerSyncResponse, ServerUpdateRequest, ) from ..schemas.users import PromoGroupSummary try: # pragma: no cover - импорт может провалиться без optional-зависимостей from app.services.remnawave_service import RemnaWaveService # type: ignore except Exception: # pragma: no cover - скрываем функционал, если сервис недоступен RemnaWaveService = None # type: ignore[assignment] if TYPE_CHECKING: # pragma: no cover - только для подсказок типов в IDE from app.services.remnawave_service import ( # type: ignore RemnaWaveService as RemnaWaveServiceType, ) else: RemnaWaveServiceType = Any router = APIRouter() def _serialize_promo_group(group: PromoGroup) -> PromoGroupSummary: return PromoGroupSummary( id=group.id, name=group.name, server_discount_percent=group.server_discount_percent, traffic_discount_percent=group.traffic_discount_percent, device_discount_percent=group.device_discount_percent, apply_discounts_to_addons=getattr(group, "apply_discounts_to_addons", True), ) def _serialize_server(server: ServerSquad) -> ServerResponse: promo_groups = [ _serialize_promo_group(group) for group in sorted( getattr(server, "allowed_promo_groups", []) or [], key=lambda pg: pg.name.lower() if getattr(pg, "name", None) else "", ) ] return ServerResponse( id=server.id, squad_uuid=server.squad_uuid, display_name=server.display_name, original_name=server.original_name, country_code=server.country_code, is_available=bool(server.is_available), is_trial_eligible=bool(server.is_trial_eligible), price_kopeks=int(server.price_kopeks or 0), price_rubles=round((server.price_kopeks or 0) / 100, 2), description=server.description, sort_order=int(server.sort_order or 0), max_users=server.max_users, current_users=int(server.current_users or 0), created_at=getattr(server, "created_at", None), updated_at=getattr(server, "updated_at", None), promo_groups=promo_groups, ) def _serialize_connected_user(user: User) -> ServerConnectedUser: subscription = getattr(user, "subscription", None) subscription_status = getattr(subscription, "status", None) if hasattr(subscription_status, "value"): subscription_status = subscription_status.value return ServerConnectedUser( id=user.id, telegram_id=user.telegram_id, username=user.username, first_name=user.first_name, last_name=user.last_name, status=getattr(getattr(user, "status", None), "value", user.status), balance_kopeks=int(user.balance_kopeks or 0), balance_rubles=round((user.balance_kopeks or 0) / 100, 2), subscription_id=getattr(subscription, "id", None), subscription_status=subscription_status, subscription_end_date=getattr(subscription, "end_date", None), ) def _apply_filters( filters: Iterable[Any], query, ): for condition in filters: query = query.where(condition) return query def _get_remnawave_service() -> "RemnaWaveServiceType": if RemnaWaveService is None: # pragma: no cover - зависимость не доступна raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="RemnaWave сервис недоступен", ) return RemnaWaveService() def _ensure_service_configured(service: "RemnaWaveServiceType") -> None: if RemnaWaveService is None: # pragma: no cover - зависимость не доступна raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="RemnaWave сервис недоступен", ) if not service.is_configured: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=service.configuration_error or "RemnaWave API не настроен", ) async def _validate_promo_group_ids( db: AsyncSession, promo_group_ids: Iterable[int] ) -> List[int]: unique_ids = [int(pg_id) for pg_id in set(promo_group_ids)] if not unique_ids: raise HTTPException( status.HTTP_400_BAD_REQUEST, "Нужно выбрать хотя бы одну промогруппу", ) result = await db.execute( select(PromoGroup.id).where(PromoGroup.id.in_(unique_ids)) ) found_ids = result.scalars().all() if not found_ids: raise HTTPException( status.HTTP_400_BAD_REQUEST, "Не найдены промогруппы для обновления сервера", ) return unique_ids @router.get("", response_model=ServerListResponse) async def list_servers( _: Any = Security(require_api_token), db: AsyncSession = Depends(get_db_session), page: int = Query(1, ge=1), limit: int = Query(50, ge=1, le=200), available_only: bool = Query(False, alias="available"), search: Optional[str] = Query(default=None), ) -> ServerListResponse: filters = [] if available_only: filters.append(ServerSquad.is_available.is_(True)) if search: pattern = f"%{search.lower()}%" filters.append( or_( func.lower(ServerSquad.display_name).like(pattern), func.lower(ServerSquad.original_name).like(pattern), func.lower(ServerSquad.squad_uuid).like(pattern), func.lower(ServerSquad.country_code).like(pattern), ) ) base_query = ( select(ServerSquad) .options(selectinload(ServerSquad.allowed_promo_groups)) .order_by(ServerSquad.sort_order, ServerSquad.display_name) ) count_query = select(func.count(ServerSquad.id)) if filters: base_query = _apply_filters(filters, base_query) count_query = _apply_filters(filters, count_query) total = await db.scalar(count_query) or 0 result = await db.execute( base_query.offset((page - 1) * limit).limit(limit) ) servers = result.scalars().unique().all() return ServerListResponse( items=[_serialize_server(server) for server in servers], total=int(total), page=page, limit=limit, ) @router.get("/stats", response_model=ServerStatisticsResponse) async def get_servers_statistics( _: Any = Security(require_api_token), db: AsyncSession = Depends(get_db_session), ) -> ServerStatisticsResponse: stats = await get_server_statistics(db) return ServerStatisticsResponse( total_servers=int(stats.get("total_servers", 0) or 0), available_servers=int(stats.get("available_servers", 0) or 0), unavailable_servers=int(stats.get("unavailable_servers", 0) or 0), servers_with_connections=int(stats.get("servers_with_connections", 0) or 0), total_revenue_kopeks=int(stats.get("total_revenue_kopeks", 0) or 0), total_revenue_rubles=float(stats.get("total_revenue_rubles", 0) or 0), ) @router.post("", response_model=ServerResponse, status_code=status.HTTP_201_CREATED) async def create_server_endpoint( payload: ServerCreateRequest, _: Any = Security(require_api_token), db: AsyncSession = Depends(get_db_session), ) -> ServerResponse: existing = await get_server_squad_by_uuid(db, payload.squad_uuid) if existing: raise HTTPException( status.HTTP_400_BAD_REQUEST, "Server with this UUID already exists", ) try: server = await create_server_squad( db, squad_uuid=payload.squad_uuid, display_name=payload.display_name, original_name=payload.original_name, country_code=payload.country_code, price_kopeks=payload.price_kopeks, description=payload.description, max_users=payload.max_users, is_available=payload.is_available, is_trial_eligible=payload.is_trial_eligible, sort_order=payload.sort_order, promo_group_ids=payload.promo_group_ids, ) except ValueError as error: raise HTTPException(status.HTTP_400_BAD_REQUEST, str(error)) from error await cache.delete_pattern("available_countries*") server = await get_server_squad_by_id(db, server.id) assert server is not None return _serialize_server(server) @router.get("/{server_id}", response_model=ServerResponse) async def get_server_endpoint( server_id: int, _: Any = Security(require_api_token), db: AsyncSession = Depends(get_db_session), ) -> ServerResponse: server = await get_server_squad_by_id(db, server_id) if not server: raise HTTPException(status.HTTP_404_NOT_FOUND, "Server not found") return _serialize_server(server) @router.patch("/{server_id}", response_model=ServerResponse) async def update_server_endpoint( server_id: int, payload: ServerUpdateRequest, _: Any = Security(require_api_token), db: AsyncSession = Depends(get_db_session), ) -> ServerResponse: server = await get_server_squad_by_id(db, server_id) if not server: raise HTTPException(status.HTTP_404_NOT_FOUND, "Server not found") updates = payload.model_dump(exclude_unset=True, by_alias=False) promo_group_ids = updates.pop("promo_group_ids", None) validated_promo_group_ids: Optional[List[int]] = None if promo_group_ids is not None: validated_promo_group_ids = await _validate_promo_group_ids( db, promo_group_ids ) if updates: server = await update_server_squad(db, server_id, **updates) or server if promo_group_ids is not None: try: assert validated_promo_group_ids is not None server = await update_server_squad_promo_groups( db, server_id, validated_promo_group_ids ) or server except ValueError as error: raise HTTPException(status.HTTP_400_BAD_REQUEST, str(error)) from error await cache.delete_pattern("available_countries*") server = await get_server_squad_by_id(db, server_id) assert server is not None return _serialize_server(server) @router.delete("/{server_id}", response_model=ServerDeleteResponse) async def delete_server_endpoint( server_id: int, _: Any = Security(require_api_token), db: AsyncSession = Depends(get_db_session), ) -> ServerDeleteResponse: server = await get_server_squad_by_id(db, server_id) if not server: raise HTTPException(status.HTTP_404_NOT_FOUND, "Server not found") deleted = await delete_server_squad(db, server_id) if not deleted: raise HTTPException( status.HTTP_400_BAD_REQUEST, "Server cannot be deleted because it has active connections", ) await cache.delete_pattern("available_countries*") return ServerDeleteResponse(success=True, message="Server deleted") @router.get( "/{server_id}/users", response_model=ServerConnectedUsersResponse, ) async def get_server_connected_users_endpoint( server_id: int, _: Any = Security(require_api_token), db: AsyncSession = Depends(get_db_session), limit: int = Query(100, ge=1, le=500), offset: int = Query(0, ge=0), ) -> ServerConnectedUsersResponse: server = await get_server_squad_by_id(db, server_id) if not server: raise HTTPException(status.HTTP_404_NOT_FOUND, "Server not found") users = await get_server_connected_users(db, server_id) total = len(users) sliced = users[offset : offset + limit] return ServerConnectedUsersResponse( items=[_serialize_connected_user(user) for user in sliced], total=total, limit=limit, offset=offset, ) @router.post("/sync", response_model=ServerSyncResponse) async def sync_servers_with_remnawave( _: Any = Security(require_api_token), db: AsyncSession = Depends(get_db_session), ) -> ServerSyncResponse: service = _get_remnawave_service() _ensure_service_configured(service) squads = await service.get_all_squads() total = len(squads) created = updated = removed = 0 if squads: created, updated, removed = await sync_with_remnawave(db, squads) await cache.delete_pattern("available_countries*") return ServerSyncResponse( created=created, updated=updated, removed=removed, total=total, ) @router.post("/sync-counts", response_model=ServerCountsSyncResponse) async def sync_server_counts( _: Any = Security(require_api_token), db: AsyncSession = Depends(get_db_session), ) -> ServerCountsSyncResponse: updated = await sync_server_user_counts(db) return ServerCountsSyncResponse(updated=updated)