Validate promo groups before updating server

This commit is contained in:
Egor
2025-11-03 07:19:03 +03:00
parent 23cb6dd8d0
commit 08e0b3a657
6 changed files with 611 additions and 9 deletions

View File

@@ -47,6 +47,7 @@ async def create_server_squad(
max_users: int = None,
is_available: bool = True,
is_trial_eligible: bool = False,
sort_order: int = 0,
promo_group_ids: Optional[Iterable[int]] = None,
) -> ServerSquad:
@@ -80,6 +81,7 @@ async def create_server_squad(
max_users=max_users,
is_available=is_available,
is_trial_eligible=is_trial_eligible,
sort_order=sort_order,
allowed_promo_groups=promo_groups,
)
@@ -260,8 +262,15 @@ async def update_server_squad(
) -> Optional[ServerSquad]:
valid_fields = {
'display_name', 'country_code', 'price_kopeks', 'description',
'max_users', 'is_available', 'sort_order', 'is_trial_eligible'
"display_name",
"original_name",
"country_code",
"price_kopeks",
"description",
"max_users",
"is_available",
"sort_order",
"is_trial_eligible",
}
filtered_updates = {k: v for k, v in updates.items() if k in valid_fields}

View File

@@ -18,7 +18,6 @@ from sqlalchemy.orm import selectinload
from sqlalchemy.ext.asyncio import AsyncSession
from app.database.crud.user import (
create_user,
create_user_no_commit,
get_users_list,
get_user_by_telegram_id,
update_user,
@@ -303,27 +302,32 @@ class RemnaWaveService:
first_name_from_desc, last_name_from_desc, username_from_desc = self._extract_user_data_from_description(description)
# Используем извлеченное имя или дефолтное значение
fallback_first_name = f"Panel User {telegram_id}"
full_first_name = fallback_first_name
full_last_name = None
if first_name_from_desc and last_name_from_desc:
full_first_name = first_name_from_desc
full_last_name = last_name_from_desc
elif first_name_from_desc:
full_first_name = first_name_from_desc
full_last_name = last_name_from_desc
else:
full_first_name = f"User {telegram_id}"
full_last_name = None
username = username_from_desc or panel_user.get("username")
try:
db_user = await create_user_no_commit(
create_kwargs = dict(
db=db,
telegram_id=telegram_id,
username=username,
first_name=full_first_name,
last_name=full_last_name,
language="ru",
)
if full_last_name:
create_kwargs["last_name"] = full_last_name
db_user = await create_user(**create_kwargs)
return db_user, True
except IntegrityError as create_error:
logger.info(

View File

@@ -20,6 +20,7 @@ from .routes import (
promo_offers,
pages,
remnawave,
servers,
stats,
subscriptions,
tickets,
@@ -67,6 +68,13 @@ OPENAPI_TAGS = [
"name": "promo-groups",
"description": "Создание и управление промо-группами и их участниками.",
},
{
"name": "servers",
"description": (
"Управление серверами RemnaWave, их доступностью, промогруппами и "
"ручная синхронизация данных.",
),
},
{
"name": "promo-offers",
"description": "Управление промо-предложениями, шаблонами и журналом событий.",
@@ -137,6 +145,7 @@ def create_web_api_app() -> FastAPI:
app.include_router(transactions.router, prefix="/transactions", tags=["transactions"])
app.include_router(promo_groups.router, prefix="/promo-groups", tags=["promo-groups"])
app.include_router(promo_offers.router, prefix="/promo-offers", tags=["promo-offers"])
app.include_router(servers.router, prefix="/servers", tags=["servers"])
app.include_router(
main_menu_buttons.router,
prefix="/main-menu/buttons",

View File

@@ -7,6 +7,7 @@ from . import (
promo_offers,
pages,
promo_groups,
servers,
remnawave,
stats,
subscriptions,
@@ -26,6 +27,7 @@ __all__ = [
"promo_offers",
"pages",
"promo_groups",
"servers",
"remnawave",
"stats",
"subscriptions",

View File

@@ -0,0 +1,418 @@
"""Маршруты управления серверами в административном API."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Iterable, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Security, status
from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.database.crud.server_squad import (
create_server_squad,
delete_server_squad,
get_server_connected_users,
get_server_squad_by_id,
get_server_squad_by_uuid,
get_server_statistics,
sync_server_user_counts,
sync_with_remnawave,
update_server_squad,
update_server_squad_promo_groups,
)
from app.database.models import PromoGroup, ServerSquad, User
from app.utils.cache import cache
from ..dependencies import get_db_session, require_api_token
from ..schemas.servers import (
ServerConnectedUser,
ServerConnectedUsersResponse,
ServerCountsSyncResponse,
ServerCreateRequest,
ServerDeleteResponse,
ServerListResponse,
ServerResponse,
ServerStatisticsResponse,
ServerSyncResponse,
ServerUpdateRequest,
)
from ..schemas.users import PromoGroupSummary
try: # pragma: no cover - импорт может провалиться без optional-зависимостей
from app.services.remnawave_service import RemnaWaveService # type: ignore
except Exception: # pragma: no cover - скрываем функционал, если сервис недоступен
RemnaWaveService = None # type: ignore[assignment]
if TYPE_CHECKING: # pragma: no cover - только для подсказок типов в IDE
from app.services.remnawave_service import ( # type: ignore
RemnaWaveService as RemnaWaveServiceType,
)
else:
RemnaWaveServiceType = Any
router = APIRouter()
def _serialize_promo_group(group: PromoGroup) -> PromoGroupSummary:
return PromoGroupSummary(
id=group.id,
name=group.name,
server_discount_percent=group.server_discount_percent,
traffic_discount_percent=group.traffic_discount_percent,
device_discount_percent=group.device_discount_percent,
apply_discounts_to_addons=getattr(group, "apply_discounts_to_addons", True),
)
def _serialize_server(server: ServerSquad) -> ServerResponse:
promo_groups = [
_serialize_promo_group(group)
for group in sorted(
getattr(server, "allowed_promo_groups", []) or [],
key=lambda pg: pg.name.lower() if getattr(pg, "name", None) else "",
)
]
return ServerResponse(
id=server.id,
squad_uuid=server.squad_uuid,
display_name=server.display_name,
original_name=server.original_name,
country_code=server.country_code,
is_available=bool(server.is_available),
is_trial_eligible=bool(server.is_trial_eligible),
price_kopeks=int(server.price_kopeks or 0),
price_rubles=round((server.price_kopeks or 0) / 100, 2),
description=server.description,
sort_order=int(server.sort_order or 0),
max_users=server.max_users,
current_users=int(server.current_users or 0),
created_at=getattr(server, "created_at", None),
updated_at=getattr(server, "updated_at", None),
promo_groups=promo_groups,
)
def _serialize_connected_user(user: User) -> ServerConnectedUser:
subscription = getattr(user, "subscription", None)
subscription_status = getattr(subscription, "status", None)
if hasattr(subscription_status, "value"):
subscription_status = subscription_status.value
return ServerConnectedUser(
id=user.id,
telegram_id=user.telegram_id,
username=user.username,
first_name=user.first_name,
last_name=user.last_name,
status=getattr(getattr(user, "status", None), "value", user.status),
balance_kopeks=int(user.balance_kopeks or 0),
balance_rubles=round((user.balance_kopeks or 0) / 100, 2),
subscription_id=getattr(subscription, "id", None),
subscription_status=subscription_status,
subscription_end_date=getattr(subscription, "end_date", None),
)
def _apply_filters(
filters: Iterable[Any],
query,
):
for condition in filters:
query = query.where(condition)
return query
def _get_remnawave_service() -> "RemnaWaveServiceType":
if RemnaWaveService is None: # pragma: no cover - зависимость не доступна
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="RemnaWave сервис недоступен",
)
return RemnaWaveService()
def _ensure_service_configured(service: "RemnaWaveServiceType") -> None:
if RemnaWaveService is None: # pragma: no cover - зависимость не доступна
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="RemnaWave сервис недоступен",
)
if not service.is_configured:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=service.configuration_error or "RemnaWave API не настроен",
)
async def _validate_promo_group_ids(
db: AsyncSession, promo_group_ids: Iterable[int]
) -> List[int]:
unique_ids = [int(pg_id) for pg_id in set(promo_group_ids)]
if not unique_ids:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Нужно выбрать хотя бы одну промогруппу",
)
result = await db.execute(
select(PromoGroup.id).where(PromoGroup.id.in_(unique_ids))
)
found_ids = result.scalars().all()
if not found_ids:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Не найдены промогруппы для обновления сервера",
)
return unique_ids
@router.get("", response_model=ServerListResponse)
async def list_servers(
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=200),
available_only: bool = Query(False, alias="available"),
search: Optional[str] = Query(default=None),
) -> ServerListResponse:
filters = []
if available_only:
filters.append(ServerSquad.is_available.is_(True))
if search:
pattern = f"%{search.lower()}%"
filters.append(
or_(
func.lower(ServerSquad.display_name).like(pattern),
func.lower(ServerSquad.original_name).like(pattern),
func.lower(ServerSquad.squad_uuid).like(pattern),
func.lower(ServerSquad.country_code).like(pattern),
)
)
base_query = (
select(ServerSquad)
.options(selectinload(ServerSquad.allowed_promo_groups))
.order_by(ServerSquad.sort_order, ServerSquad.display_name)
)
count_query = select(func.count(ServerSquad.id))
if filters:
base_query = _apply_filters(filters, base_query)
count_query = _apply_filters(filters, count_query)
total = await db.scalar(count_query) or 0
result = await db.execute(
base_query.offset((page - 1) * limit).limit(limit)
)
servers = result.scalars().unique().all()
return ServerListResponse(
items=[_serialize_server(server) for server in servers],
total=int(total),
page=page,
limit=limit,
)
@router.get("/stats", response_model=ServerStatisticsResponse)
async def get_servers_statistics(
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerStatisticsResponse:
stats = await get_server_statistics(db)
return ServerStatisticsResponse(
total_servers=int(stats.get("total_servers", 0) or 0),
available_servers=int(stats.get("available_servers", 0) or 0),
unavailable_servers=int(stats.get("unavailable_servers", 0) or 0),
servers_with_connections=int(stats.get("servers_with_connections", 0) or 0),
total_revenue_kopeks=int(stats.get("total_revenue_kopeks", 0) or 0),
total_revenue_rubles=float(stats.get("total_revenue_rubles", 0) or 0),
)
@router.post("", response_model=ServerResponse, status_code=status.HTTP_201_CREATED)
async def create_server_endpoint(
payload: ServerCreateRequest,
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerResponse:
existing = await get_server_squad_by_uuid(db, payload.squad_uuid)
if existing:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Server with this UUID already exists",
)
try:
server = await create_server_squad(
db,
squad_uuid=payload.squad_uuid,
display_name=payload.display_name,
original_name=payload.original_name,
country_code=payload.country_code,
price_kopeks=payload.price_kopeks,
description=payload.description,
max_users=payload.max_users,
is_available=payload.is_available,
is_trial_eligible=payload.is_trial_eligible,
sort_order=payload.sort_order,
promo_group_ids=payload.promo_group_ids,
)
except ValueError as error:
raise HTTPException(status.HTTP_400_BAD_REQUEST, str(error)) from error
await cache.delete_pattern("available_countries*")
server = await get_server_squad_by_id(db, server.id)
assert server is not None
return _serialize_server(server)
@router.get("/{server_id}", response_model=ServerResponse)
async def get_server_endpoint(
server_id: int,
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerResponse:
server = await get_server_squad_by_id(db, server_id)
if not server:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Server not found")
return _serialize_server(server)
@router.patch("/{server_id}", response_model=ServerResponse)
async def update_server_endpoint(
server_id: int,
payload: ServerUpdateRequest,
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerResponse:
server = await get_server_squad_by_id(db, server_id)
if not server:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Server not found")
updates = payload.model_dump(exclude_unset=True, by_alias=False)
promo_group_ids = updates.pop("promo_group_ids", None)
validated_promo_group_ids: Optional[List[int]] = None
if promo_group_ids is not None:
validated_promo_group_ids = await _validate_promo_group_ids(
db, promo_group_ids
)
if updates:
server = await update_server_squad(db, server_id, **updates) or server
if promo_group_ids is not None:
try:
assert validated_promo_group_ids is not None
server = await update_server_squad_promo_groups(
db, server_id, validated_promo_group_ids
) or server
except ValueError as error:
raise HTTPException(status.HTTP_400_BAD_REQUEST, str(error)) from error
await cache.delete_pattern("available_countries*")
server = await get_server_squad_by_id(db, server_id)
assert server is not None
return _serialize_server(server)
@router.delete("/{server_id}", response_model=ServerDeleteResponse)
async def delete_server_endpoint(
server_id: int,
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerDeleteResponse:
server = await get_server_squad_by_id(db, server_id)
if not server:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Server not found")
deleted = await delete_server_squad(db, server_id)
if not deleted:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Server cannot be deleted because it has active connections",
)
await cache.delete_pattern("available_countries*")
return ServerDeleteResponse(success=True, message="Server deleted")
@router.get(
"/{server_id}/users",
response_model=ServerConnectedUsersResponse,
)
async def get_server_connected_users_endpoint(
server_id: int,
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
) -> ServerConnectedUsersResponse:
server = await get_server_squad_by_id(db, server_id)
if not server:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Server not found")
users = await get_server_connected_users(db, server_id)
total = len(users)
sliced = users[offset : offset + limit]
return ServerConnectedUsersResponse(
items=[_serialize_connected_user(user) for user in sliced],
total=total,
limit=limit,
offset=offset,
)
@router.post("/sync", response_model=ServerSyncResponse)
async def sync_servers_with_remnawave(
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerSyncResponse:
service = _get_remnawave_service()
_ensure_service_configured(service)
squads = await service.get_all_squads()
total = len(squads)
created = updated = removed = 0
if squads:
created, updated, removed = await sync_with_remnawave(db, squads)
await cache.delete_pattern("available_countries*")
return ServerSyncResponse(
created=created,
updated=updated,
removed=removed,
total=total,
)
@router.post("/sync-counts", response_model=ServerCountsSyncResponse)
async def sync_server_counts(
_: Any = Security(require_api_token),
db: AsyncSession = Depends(get_db_session),
) -> ServerCountsSyncResponse:
updated = await sync_server_user_counts(db)
return ServerCountsSyncResponse(updated=updated)

View File

@@ -0,0 +1,160 @@
"""Pydantic-схемы для управления серверами через Web API."""
from __future__ import annotations
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
from .users import PromoGroupSummary
class ServerResponse(BaseModel):
"""Полная информация о сервере."""
model_config = ConfigDict(from_attributes=True, populate_by_name=True)
id: int
squad_uuid: str = Field(alias="squadUuid")
display_name: str = Field(alias="displayName")
original_name: Optional[str] = Field(default=None, alias="originalName")
country_code: Optional[str] = Field(default=None, alias="countryCode")
is_available: bool = Field(alias="isAvailable")
is_trial_eligible: bool = Field(default=False, alias="isTrialEligible")
price_kopeks: int = Field(alias="priceKopeks")
price_rubles: float = Field(alias="priceRubles")
description: Optional[str] = None
sort_order: int = Field(default=0, alias="sortOrder")
max_users: Optional[int] = Field(default=None, alias="maxUsers")
current_users: int = Field(default=0, alias="currentUsers")
created_at: Optional[datetime] = Field(default=None, alias="createdAt")
updated_at: Optional[datetime] = Field(default=None, alias="updatedAt")
promo_groups: List[PromoGroupSummary] = Field(
default_factory=list, alias="promoGroups"
)
class ServerListResponse(BaseModel):
"""Список серверов с пагинацией."""
items: List[ServerResponse]
total: int
page: int
limit: int
class ServerCreateRequest(BaseModel):
"""Запрос на создание сервера."""
squad_uuid: str = Field(alias="squadUuid")
display_name: str = Field(alias="displayName")
original_name: Optional[str] = Field(default=None, alias="originalName")
country_code: Optional[str] = Field(default=None, alias="countryCode")
price_kopeks: int = Field(default=0, alias="priceKopeks")
description: Optional[str] = None
max_users: Optional[int] = Field(default=None, alias="maxUsers")
is_available: bool = Field(default=True, alias="isAvailable")
is_trial_eligible: bool = Field(default=False, alias="isTrialEligible")
sort_order: int = Field(default=0, alias="sortOrder")
promo_group_ids: Optional[List[int]] = Field(
default=None,
alias="promoGroupIds",
description="Список идентификаторов промогрупп, доступных на сервере.",
)
class ServerUpdateRequest(BaseModel):
"""Запрос на обновление свойств сервера."""
display_name: Optional[str] = Field(default=None, alias="displayName")
original_name: Optional[str] = Field(default=None, alias="originalName")
country_code: Optional[str] = Field(default=None, alias="countryCode")
price_kopeks: Optional[int] = Field(default=None, alias="priceKopeks")
description: Optional[str] = None
max_users: Optional[int] = Field(default=None, alias="maxUsers")
is_available: Optional[bool] = Field(default=None, alias="isAvailable")
is_trial_eligible: Optional[bool] = Field(
default=None, alias="isTrialEligible"
)
sort_order: Optional[int] = Field(default=None, alias="sortOrder")
promo_group_ids: Optional[List[int]] = Field(
default=None,
alias="promoGroupIds",
description="Если передан список, он заменит текущие промогруппы сервера.",
)
class ServerSyncResponse(BaseModel):
"""Результат синхронизации серверов с RemnaWave."""
model_config = ConfigDict(populate_by_name=True)
created: int
updated: int
removed: int
total: int
class ServerStatisticsResponse(BaseModel):
"""Агрегированная статистика по серверам."""
model_config = ConfigDict(populate_by_name=True)
total_servers: int = Field(alias="totalServers")
available_servers: int = Field(alias="availableServers")
unavailable_servers: int = Field(alias="unavailableServers")
servers_with_connections: int = Field(alias="serversWithConnections")
total_revenue_kopeks: int = Field(alias="totalRevenueKopeks")
total_revenue_rubles: float = Field(alias="totalRevenueRubles")
class ServerCountsSyncResponse(BaseModel):
"""Результат обновления счетчиков пользователей серверов."""
model_config = ConfigDict(populate_by_name=True)
updated: int
class ServerConnectedUser(BaseModel):
"""Краткая информация о пользователе, подключенном к серверу."""
model_config = ConfigDict(populate_by_name=True)
id: int
telegram_id: int = Field(alias="telegramId")
username: Optional[str] = None
first_name: Optional[str] = Field(default=None, alias="firstName")
last_name: Optional[str] = Field(default=None, alias="lastName")
status: str
balance_kopeks: int = Field(alias="balanceKopeks")
balance_rubles: float = Field(alias="balanceRubles")
subscription_id: Optional[int] = Field(default=None, alias="subscriptionId")
subscription_status: Optional[str] = Field(
default=None, alias="subscriptionStatus"
)
subscription_end_date: Optional[datetime] = Field(
default=None, alias="subscriptionEndDate"
)
class ServerConnectedUsersResponse(BaseModel):
"""Список пользователей, подключенных к серверу."""
model_config = ConfigDict(populate_by_name=True)
items: List[ServerConnectedUser]
total: int
limit: int
offset: int
class ServerDeleteResponse(BaseModel):
"""Ответ при удалении сервера."""
model_config = ConfigDict(populate_by_name=True)
success: bool
message: str