From a043fc0e4675552b2deca63b1920762dcb305abe Mon Sep 17 00:00:00 2001 From: Egor Date: Sun, 9 Nov 2025 04:36:33 +0300 Subject: [PATCH] feat: add backend support for server groups --- app/database/crud/server_group.py | 249 ++++++++++++++++++ app/database/models.py | 50 ++++ app/database/universal_migration.py | 164 ++++++++++++ app/services/remnawave_service.py | 85 +++++- app/services/server_group_service.py | 175 ++++++++++++ app/services/subscription_purchase_service.py | 135 +++++++++- app/states.py | 7 + 7 files changed, 854 insertions(+), 11 deletions(-) create mode 100644 app/database/crud/server_group.py create mode 100644 app/services/server_group_service.py diff --git a/app/database/crud/server_group.py b/app/database/crud/server_group.py new file mode 100644 index 00000000..18c698ab --- /dev/null +++ b/app/database/crud/server_group.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import logging +from typing import Dict, List, Optional, Sequence + +from sqlalchemy import and_, delete, func, select, update, cast, String +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.database.models import ( + ServerGroup, + ServerGroupServer, + ServerSquad, + Subscription, + SubscriptionStatus, +) + +logger = logging.getLogger(__name__) + + +async def get_server_groups( + db: AsyncSession, + *, + include_servers: bool = True, +) -> List[ServerGroup]: + """Возвращает список всех групп серверов.""" + + query = ( + select(ServerGroup) + .order_by(ServerGroup.sort_order, ServerGroup.name) + ) + + if include_servers: + query = query.options( + selectinload(ServerGroup.servers).selectinload(ServerGroupServer.server) + ) + + result = await db.execute(query) + return list(result.scalars().unique().all()) + + +async def get_server_group_by_id( + db: AsyncSession, + group_id: int, + *, + include_servers: bool = True, +) -> Optional[ServerGroup]: + query = select(ServerGroup).where(ServerGroup.id == group_id) + if include_servers: + query = query.options( + selectinload(ServerGroup.servers).selectinload(ServerGroupServer.server) + ) + + result = await db.execute(query) + return result.scalars().unique().one_or_none() + + +async def get_server_group_by_name(db: AsyncSession, name: str) -> Optional[ServerGroup]: + result = await db.execute( + select(ServerGroup).where(func.lower(ServerGroup.name) == func.lower(name)) + ) + return result.scalars().one_or_none() + + +async def create_server_group( + db: AsyncSession, + *, + name: str, + server_ids: Optional[Sequence[int]] = None, + sort_order: int = 0, + is_active: bool = True, +) -> ServerGroup: + existing = await get_server_group_by_name(db, name) + if existing: + raise ValueError("Группа с таким названием уже существует") + + group = ServerGroup(name=name.strip(), sort_order=sort_order, is_active=is_active) + db.add(group) + await db.flush() + + await _sync_group_servers(db, group, server_ids or []) + await db.commit() + await db.refresh(group) + return group + + +async def update_server_group( + db: AsyncSession, + group_id: int, + *, + name: Optional[str] = None, + server_ids: Optional[Sequence[int]] = None, + sort_order: Optional[int] = None, + is_active: Optional[bool] = None, +) -> Optional[ServerGroup]: + group = await get_server_group_by_id(db, group_id) + if not group: + return None + + updates: Dict = {} + if name is not None: + trimmed = name.strip() + if trimmed and trimmed.lower() != group.name.lower(): + duplicate = await get_server_group_by_name(db, trimmed) + if duplicate and duplicate.id != group.id: + raise ValueError("Группа с таким названием уже существует") + updates["name"] = trimmed + if sort_order is not None: + updates["sort_order"] = int(sort_order) + if is_active is not None: + updates["is_active"] = bool(is_active) + + if updates: + await db.execute( + update(ServerGroup) + .where(ServerGroup.id == group.id) + .values(**updates) + ) + + if server_ids is not None: + await _sync_group_servers(db, group, server_ids) + + await db.commit() + return await get_server_group_by_id(db, group_id) + + +async def delete_server_group(db: AsyncSession, group_id: int) -> bool: + group = await get_server_group_by_id(db, group_id) + if not group: + return False + + if await is_group_in_use(db, group): + raise ValueError("Нельзя удалить группу, пока её серверы используются активными подписками") + + await db.execute(delete(ServerGroup).where(ServerGroup.id == group.id)) + await db.commit() + return True + + +async def toggle_group_server( + db: AsyncSession, + group_id: int, + server_id: int, + *, + is_enabled: bool, +) -> bool: + result = await db.execute( + update(ServerGroupServer) + .where( + and_( + ServerGroupServer.group_id == group_id, + ServerGroupServer.server_squad_id == server_id, + ) + ) + .values(is_enabled=is_enabled) + ) + if getattr(result, "rowcount", 0): + await db.commit() + return True + return False + + +async def _sync_group_servers( + db: AsyncSession, + group: ServerGroup, + server_ids: Sequence[int], +) -> None: + unique_ids = {int(server_id) for server_id in server_ids if server_id} + + existing_ids = {member.server_squad_id for member in group.servers} + + to_remove = existing_ids - unique_ids + to_add = unique_ids - existing_ids + + if to_remove: + await db.execute( + delete(ServerGroupServer).where( + and_( + ServerGroupServer.group_id == group.id, + ServerGroupServer.server_squad_id.in_(to_remove), + ) + ) + ) + + if to_add: + new_relations = [ + ServerGroupServer(group_id=group.id, server_squad_id=server_id) + for server_id in to_add + ] + db.add_all(new_relations) + + await db.flush() + + +async def is_group_in_use( + db: AsyncSession, + group: ServerGroup | int, +) -> bool: + if isinstance(group, int): + group_obj = await get_server_group_by_id(db, group) + else: + group_obj = group + + if not group_obj: + return False + + squad_uuids = [ + member.server.squad_uuid + for member in group_obj.servers + if member.server and member.server.squad_uuid + ] + + if not squad_uuids: + return False + + like_filters = [ + cast(Subscription.connected_squads, String).like(f'%"{uuid}"%') + for uuid in squad_uuids + ] + + condition = like_filters[0] + for clause in like_filters[1:]: + condition = condition | clause + + result = await db.execute( + select(func.count(Subscription.id)).where( + Subscription.status.in_( + [ + SubscriptionStatus.ACTIVE.value, + SubscriptionStatus.TRIAL.value, + ] + ), + condition, + ) + ) + + return (result.scalar() or 0) > 0 + + +async def get_group_server_ids(group: ServerGroup) -> List[int]: + return [member.server_squad_id for member in group.servers if member.server_squad_id] + + +async def get_group_server_uuids(group: ServerGroup) -> List[str]: + return [ + member.server.squad_uuid + for member in group.servers + if member.server and member.server.squad_uuid + ] diff --git a/app/database/models.py b/app/database/models.py index 09695965..5986022b 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1315,6 +1315,12 @@ class ServerSquad(Base): back_populates="server_squads", lazy="selectin", ) + + groups = relationship( + "ServerGroupServer", + back_populates="server", + cascade="all, delete-orphan", + ) @property def price_rubles(self) -> float: @@ -1336,6 +1342,50 @@ class ServerSquad(Base): return "Доступен" +class ServerGroup(Base): + __tablename__ = "server_groups" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String(255), nullable=False, unique=True) + is_active = Column(Boolean, default=True, nullable=False) + sort_order = Column(Integer, default=0, nullable=False) + created_at = Column(DateTime, default=func.now(), nullable=False) + updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False) + + servers = relationship( + "ServerGroupServer", + back_populates="group", + cascade="all, delete-orphan", + order_by="ServerGroupServer.id", + ) + + def __repr__(self) -> str: + return f"" + + +class ServerGroupServer(Base): + __tablename__ = "server_group_servers" + __table_args__ = ( + UniqueConstraint("group_id", "server_squad_id", name="uq_group_server"), + ) + + id = Column(Integer, primary_key=True, index=True) + group_id = Column(Integer, ForeignKey("server_groups.id", ondelete="CASCADE"), nullable=False, index=True) + server_squad_id = Column(Integer, ForeignKey("server_squads.id", ondelete="CASCADE"), nullable=False, index=True) + is_enabled = Column(Boolean, default=True, nullable=False) + created_at = Column(DateTime, default=func.now(), nullable=False) + updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False) + + group = relationship("ServerGroup", back_populates="servers") + server = relationship("ServerSquad", back_populates="groups") + + def __repr__(self) -> str: + return ( + f"" + ) + + class SubscriptionServer(Base): __tablename__ = "subscription_servers" diff --git a/app/database/universal_migration.py b/app/database/universal_migration.py index 50f05d5b..f8895ec2 100644 --- a/app/database/universal_migration.py +++ b/app/database/universal_migration.py @@ -3533,6 +3533,154 @@ async def add_promo_group_priority_column() -> bool: return False +async def create_server_groups_table() -> bool: + table_exists = await check_table_exists("server_groups") + if table_exists: + logger.info("ℹ️ Таблица server_groups уже существует") + return True + + try: + async with engine.begin() as conn: + db_type = await get_database_type() + + if db_type == "sqlite": + await conn.execute(text( + """ + CREATE TABLE server_groups ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + is_active INTEGER NOT NULL DEFAULT 1, + sort_order INTEGER NOT NULL DEFAULT 0, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + """ + )) + await conn.execute(text( + "CREATE INDEX IF NOT EXISTS idx_server_groups_sort ON server_groups(sort_order, name)" + )) + elif db_type == "postgresql": + await conn.execute(text( + """ + CREATE TABLE server_groups ( + id SERIAL PRIMARY KEY, + name VARCHAR(255) NOT NULL UNIQUE, + is_active BOOLEAN NOT NULL DEFAULT TRUE, + sort_order INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW() + ); + """ + )) + await conn.execute(text( + "CREATE INDEX IF NOT EXISTS idx_server_groups_sort ON server_groups(sort_order, name)" + )) + else: # MySQL + await conn.execute(text( + """ + CREATE TABLE server_groups ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) NOT NULL UNIQUE, + is_active TINYINT(1) NOT NULL DEFAULT 1, + sort_order INT NOT NULL DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + ) ENGINE=InnoDB; + """ + )) + await conn.execute(text( + "CREATE INDEX idx_server_groups_sort ON server_groups(sort_order, name)" + )) + + logger.info("✅ Таблица server_groups создана") + return True + except Exception as error: + logger.error(f"❌ Ошибка создания таблицы server_groups: {error}") + return False + + +async def create_server_group_servers_table() -> bool: + table_exists = await check_table_exists("server_group_servers") + if table_exists: + logger.info("ℹ️ Таблица server_group_servers уже существует") + return True + + try: + async with engine.begin() as conn: + db_type = await get_database_type() + + if db_type == "sqlite": + await conn.execute(text( + """ + CREATE TABLE server_group_servers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + group_id INTEGER NOT NULL, + server_squad_id INTEGER NOT NULL, + is_enabled INTEGER NOT NULL DEFAULT 1, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE(group_id, server_squad_id), + FOREIGN KEY(group_id) REFERENCES server_groups(id) ON DELETE CASCADE, + FOREIGN KEY(server_squad_id) REFERENCES server_squads(id) ON DELETE CASCADE + ); + """ + )) + await conn.execute(text( + "CREATE INDEX IF NOT EXISTS idx_server_group_servers_group ON server_group_servers(group_id)" + )) + await conn.execute(text( + "CREATE INDEX IF NOT EXISTS idx_server_group_servers_server ON server_group_servers(server_squad_id)" + )) + elif db_type == "postgresql": + await conn.execute(text( + """ + CREATE TABLE server_group_servers ( + id SERIAL PRIMARY KEY, + group_id INTEGER NOT NULL REFERENCES server_groups(id) ON DELETE CASCADE, + server_squad_id INTEGER NOT NULL REFERENCES server_squads(id) ON DELETE CASCADE, + is_enabled BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW(), + UNIQUE(group_id, server_squad_id) + ); + """ + )) + await conn.execute(text( + "CREATE INDEX IF NOT EXISTS idx_server_group_servers_group ON server_group_servers(group_id)" + )) + await conn.execute(text( + "CREATE INDEX IF NOT EXISTS idx_server_group_servers_server ON server_group_servers(server_squad_id)" + )) + else: # MySQL + await conn.execute(text( + """ + CREATE TABLE server_group_servers ( + id INT AUTO_INCREMENT PRIMARY KEY, + group_id INT NOT NULL, + server_squad_id INT NOT NULL, + is_enabled TINYINT(1) NOT NULL DEFAULT 1, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + UNIQUE KEY uq_group_server (group_id, server_squad_id), + CONSTRAINT fk_group FOREIGN KEY (group_id) REFERENCES server_groups(id) ON DELETE CASCADE, + CONSTRAINT fk_group_server FOREIGN KEY (server_squad_id) REFERENCES server_squads(id) ON DELETE CASCADE + ) ENGINE=InnoDB; + """ + )) + await conn.execute(text( + "CREATE INDEX idx_server_group_servers_group ON server_group_servers(group_id)" + )) + await conn.execute(text( + "CREATE INDEX idx_server_group_servers_server ON server_group_servers(server_squad_id)" + )) + + logger.info("✅ Таблица server_group_servers создана") + return True + except Exception as error: + logger.error(f"❌ Ошибка создания таблицы server_group_servers: {error}") + return False + + async def create_user_promo_groups_table() -> bool: """Создает таблицу user_promo_groups для связи Many-to-Many между users и promo_groups.""" table_exists = await check_table_exists("user_promo_groups") @@ -3742,6 +3890,16 @@ async def run_universal_migration(): else: logger.warning("⚠️ Проблемы с колонкой is_trial_eligible") + logger.info("=== СОЗДАНИЕ ТАБЛИЦЫ SERVER_GROUPS ===") + server_groups_ready = await create_server_groups_table() + if not server_groups_ready: + logger.warning("⚠️ Не удалось создать таблицу server_groups") + + logger.info("=== СОЗДАНИЕ ТАБЛИЦЫ SERVER_GROUP_SERVERS ===") + server_group_servers_ready = await create_server_group_servers_table() + if not server_group_servers_ready: + logger.warning("⚠️ Не удалось создать таблицу server_group_servers") + logger.info("=== СОЗДАНИЕ ТАБЛИЦЫ PRIVACY_POLICIES ===") privacy_policies_ready = await create_privacy_policies_table() if privacy_policies_ready: @@ -4134,6 +4292,8 @@ async def check_migration_status(): "promo_offer_templates_active_discount_column": False, "promo_offer_logs_table": False, "subscription_temporary_access_table": False, + "server_groups_table": False, + "server_group_servers_table": False, } status["has_made_first_topup_column"] = await check_column_exists('users', 'has_made_first_topup') @@ -4156,6 +4316,8 @@ async def check_migration_status(): status["promo_offer_templates_active_discount_column"] = await check_column_exists('promo_offer_templates', 'active_discount_hours') status["promo_offer_logs_table"] = await check_table_exists('promo_offer_logs') status["subscription_temporary_access_table"] = await check_table_exists('subscription_temporary_access') + status["server_groups_table"] = await check_table_exists('server_groups') + status["server_group_servers_table"] = await check_table_exists('server_group_servers') status["welcome_texts_is_enabled_column"] = await check_column_exists('welcome_texts', 'is_enabled') status["users_promo_group_column"] = await check_column_exists('users', 'promo_group_id') @@ -4221,6 +4383,8 @@ async def check_migration_status(): "promo_offer_templates_active_discount_column": "Колонка active_discount_hours в promo_offer_templates", "promo_offer_logs_table": "Таблица promo_offer_logs", "subscription_temporary_access_table": "Таблица subscription_temporary_access", + "server_groups_table": "Таблица server_groups", + "server_group_servers_table": "Таблица server_group_servers", } for check_key, check_status in status.items(): diff --git a/app/services/remnawave_service.py b/app/services/remnawave_service.py index 4c066533..a422c73d 100644 --- a/app/services/remnawave_service.py +++ b/app/services/remnawave_service.py @@ -38,6 +38,7 @@ from app.database.models import ( SubscriptionStatus, ServerSquad, ) +from app.utils.cache import cache, cache_key from app.utils.subscription_utils import ( resolve_hwid_device_limit_for_payload, ) @@ -536,8 +537,8 @@ class RemnaWaveService: 'bytes': daily_bytes }) - nodes_weekly_data = list(nodes_by_name.values()) - nodes_weekly_data.sort(key=lambda x: x['total_bytes'], reverse=True) + nodes_weekly_data = list(nodes_by_name.values()) + nodes_weekly_data.sort(key=lambda x: x['total_bytes'], reverse=True) result = { "system": { @@ -619,7 +620,7 @@ class RemnaWaveService: logger.info(f"Статистика сформирована: пользователи={result['system']['total_users']}, общий трафик={total_user_traffic}") return result - + except RemnaWaveAPIError as e: logger.error(f"Ошибка Remnawave API при получении статистики: {e}") return {"error": str(e)} @@ -627,7 +628,83 @@ class RemnaWaveService: logger.error(f"Общая ошибка получения системной статистики: {e}") return {"error": f"Внутренняя ошибка сервера: {str(e)}"} - + + async def get_internal_squad_usage_map( + self, + *, + force_refresh: bool = False, + cache_ttl: int = 60, + ) -> Dict[str, Dict[str, int]]: + cache_name = cache_key("remnawave", "internal_squad_usage") + + if not force_refresh: + cached = await cache.get(cache_name) + if isinstance(cached, dict): + return cached + + usage_map: Dict[str, Dict[str, int]] = {} + + internal_squads: List[RemnaWaveInternalSquad] = [] + realtime_usage: List[Dict[str, Any]] = [] + + try: + async with self.get_api_client() as api: + try: + internal_squads = await api.get_internal_squads() + except Exception as error: + logger.warning("Не удалось получить список сквадов RemnaWave: %s", error) + + try: + realtime_usage = await api.get_nodes_realtime_usage() + except Exception as error: + logger.warning("Не удалось получить статистику нагрузки сквадов: %s", error) + except RemnaWaveConfigurationError as error: + logger.debug("RemnaWave API не настроен: %s", error) + return {} + + for squad in internal_squads: + usage_map[squad.uuid] = { + "members_count": int(squad.members_count or 0), + "inbounds_count": int(squad.inbounds_count or 0), + "download_bytes": 0, + "upload_bytes": 0, + "bandwidth_bytes": 0, + } + + for node in realtime_usage: + squad_uuid = ( + node.get("internalSquadUuid") + or node.get("squadUuid") + or node.get("squadUUID") + or node.get("internal_squad_uuid") + or node.get("internalSquadUUID") + ) + if not squad_uuid: + continue + + stats = usage_map.setdefault( + squad_uuid, + { + "members_count": 0, + "inbounds_count": 0, + "download_bytes": 0, + "upload_bytes": 0, + "bandwidth_bytes": 0, + }, + ) + + download = int(node.get("downloadBytes", 0) or 0) + upload = int(node.get("uploadBytes", 0) or 0) + stats["download_bytes"] += download + stats["upload_bytes"] += upload + stats["bandwidth_bytes"] = stats["download_bytes"] + stats["upload_bytes"] + + if usage_map: + await cache.set(cache_name, usage_map, expire=cache_ttl) + + return usage_map + + def _parse_bandwidth_string(self, bandwidth_str: str) -> int: try: if not bandwidth_str or bandwidth_str == '0 B' or bandwidth_str == '0': diff --git a/app/services/server_group_service.py b/app/services/server_group_service.py new file mode 100644 index 00000000..63231746 --- /dev/null +++ b/app/services/server_group_service.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Callable, List, Optional + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database.crud.server_group import get_server_group_by_id +from app.database.crud.server_squad import count_active_users_for_squad +from app.database.models import ServerGroup, ServerGroupServer, ServerSquad +from app.services.remnawave_service import RemnaWaveService + +logger = logging.getLogger(__name__) + + +@dataclass +class ServerLoadSnapshot: + server: ServerSquad + membership: ServerGroupServer + active_users: int + members_count: int + bandwidth_bytes: int + download_bytes: int + upload_bytes: int + is_available: bool + is_enabled: bool + is_overloaded: bool + load_ratio: float + + @property + def maintenance(self) -> bool: + return not self.is_enabled + + +@dataclass +class ServerGroupSnapshot: + group: ServerGroup + servers: List[ServerLoadSnapshot] + total_active_users: int + total_bandwidth_bytes: int + available_servers: List[ServerLoadSnapshot] + + @property + def is_empty(self) -> bool: + return not self.servers + + @property + def is_overloaded(self) -> bool: + return all(server.is_overloaded or not server.is_available for server in self.available_servers) + + +async def build_group_snapshot( + db: AsyncSession, + remnawave: RemnaWaveService, + group: ServerGroup | int, + *, + refresh: bool = False, +) -> ServerGroupSnapshot: + if isinstance(group, int): + group_obj = await get_server_group_by_id(db, group) + else: + group_obj = group + + if not group_obj: + raise ValueError("Server group not found") + + usage_map = await remnawave.get_internal_squad_usage_map(force_refresh=refresh) + + servers: List[ServerLoadSnapshot] = [] + total_active_users = 0 + total_bandwidth_bytes = 0 + + for membership in group_obj.servers: + server = membership.server + if not server: + continue + + usage = usage_map.get(server.squad_uuid, {}) or {} + members_count = int(usage.get("members_count", 0) or 0) + download_bytes = int(usage.get("download_bytes", 0) or 0) + upload_bytes = int(usage.get("upload_bytes", 0) or 0) + bandwidth_bytes = int(usage.get("bandwidth_bytes", download_bytes + upload_bytes) or 0) + + db_active = await count_active_users_for_squad(db, server.squad_uuid) + active_users = max(db_active, members_count, int(server.current_users or 0)) + + capacity = server.max_users or 0 + load_ratio = (active_users / capacity) if capacity else float(active_users) + is_overloaded = capacity > 0 and active_users >= capacity + + snapshot = ServerLoadSnapshot( + server=server, + membership=membership, + active_users=active_users, + members_count=members_count, + bandwidth_bytes=bandwidth_bytes, + download_bytes=download_bytes, + upload_bytes=upload_bytes, + is_available=bool(server.is_available), + is_enabled=bool(membership.is_enabled), + is_overloaded=is_overloaded, + load_ratio=load_ratio, + ) + servers.append(snapshot) + total_active_users += active_users + total_bandwidth_bytes += bandwidth_bytes + + available_servers = [ + server + for server in servers + if server.is_available and server.is_enabled + ] + + return ServerGroupSnapshot( + group=group_obj, + servers=servers, + total_active_users=total_active_users, + total_bandwidth_bytes=total_bandwidth_bytes, + available_servers=available_servers, + ) + + +async def choose_optimal_server( + db: AsyncSession, + remnawave: RemnaWaveService, + group: ServerGroup | int, + *, + refresh_stats: bool = False, + notify_overload: Optional[Callable[[ServerGroupSnapshot], None]] = None, +) -> Optional[tuple[ServerGroupSnapshot, ServerLoadSnapshot]]: + snapshot = await build_group_snapshot(db, remnawave, group, refresh=refresh_stats) + + if not snapshot.servers: + logger.warning("Server group %s не содержит серверов", getattr(snapshot.group, "name", snapshot.group.id)) + return None + + candidates = [ + server for server in snapshot.servers if server.is_available and server.is_enabled + ] + + if not candidates: + candidates = [server for server in snapshot.servers if server.is_enabled] + + if not candidates: + logger.error("Нет доступных серверов в группе %s", snapshot.group.name) + if notify_overload: + notify_overload(snapshot) + return None + + candidates.sort( + key=lambda item: ( + item.active_users, + item.bandwidth_bytes, + item.server.current_users or 0, + item.server.sort_order, + item.server.display_name, + ) + ) + + best = candidates[0] + + logger.info( + "Выбран сервер %s (%s) для группы %s: активных=%s, трафик=%s", + best.server.display_name, + best.server.squad_uuid, + snapshot.group.name, + best.active_users, + best.bandwidth_bytes, + ) + + if notify_overload and snapshot.is_overloaded: + notify_overload(snapshot) + + return snapshot, best diff --git a/app/services/subscription_purchase_service.py b/app/services/subscription_purchase_service.py index 911c4d4c..113a0fa9 100644 --- a/app/services/subscription_purchase_service.py +++ b/app/services/subscription_purchase_service.py @@ -7,6 +7,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config import PERIOD_PRICES, settings +from app.database.crud.server_group import get_server_groups from app.database.crud.server_squad import ( add_user_to_servers, get_available_server_squads, @@ -24,6 +25,8 @@ from app.database.crud.transaction import create_transaction from app.database.crud.user import subtract_user_balance from app.database.models import ServerSquad, Subscription, SubscriptionStatus, TransactionType, User from app.localization.texts import get_texts +from app.services.remnawave_service import RemnaWaveService +from app.services.server_group_service import choose_optimal_server from app.services.subscription_service import SubscriptionService from app.utils.pricing_utils import ( calculate_months_from_days, @@ -103,6 +106,9 @@ class PurchaseServerOption: original_price_label: Optional[str] = None discount_percent: int = 0 is_available: bool = True + selection_key: Optional[str] = None + option_type: str = "server" + metadata: Optional[Dict[str, Any]] = None def to_payload(self) -> Dict[str, Any]: payload: Dict[str, Any] = { @@ -112,6 +118,10 @@ class PurchaseServerOption: "price_label": self.price_label, "is_available": self.is_available, } + payload["selection_key"] = self.selection_key or self.uuid + payload["type"] = self.option_type + if self.metadata: + payload["meta"] = self.metadata if self.original_price_per_month is not None and ( self.original_price_label and self.original_price_per_month != self.price_per_month ): @@ -255,6 +265,7 @@ class PurchaseOptionsContext: default_period: PurchasePeriodConfig period_map: Dict[str, PurchasePeriodConfig] server_uuid_to_id: Dict[str, int] + server_selection_map: Dict[str, str] payload: Dict[str, Any] @@ -322,6 +333,8 @@ def _build_server_option( original_price_label=texts.format_price(base_per_month) if base_per_month != discounted_per_month else None, discount_percent=max(0, discount_percent), is_available=bool(getattr(server, "is_available", True) and not getattr(server, "is_full", False)), + selection_key=server.squad_uuid, + option_type="server", ) @@ -352,6 +365,74 @@ class MiniAppSubscriptionPurchaseService: if existing: server_catalog[uuid] = existing + remnawave_service = RemnaWaveService() + server_groups = await get_server_groups(db) + groups_context: List[Dict[str, Any]] = [] + grouped_server_uuids: set[str] = set() + selection_map: Dict[str, str] = {} + preferred_selection_for_server: Dict[str, str] = {} + + for group in server_groups: + if not group.is_active or not group.servers: + continue + try: + choice = await choose_optimal_server(db, remnawave_service, group) + except Exception as error: # pragma: no cover - defensive logging + logger.warning("Failed to evaluate server group %s: %s", getattr(group, "name", group.id), error) + continue + + if not choice: + continue + + snapshot, best = choice + selection_key = f"group:{group.id}" + selection_map[selection_key] = best.server.squad_uuid + + metadata = { + "groupId": group.id, + "groupName": group.name, + "totalActiveUsers": snapshot.total_active_users, + "totalBandwidthBytes": snapshot.total_bandwidth_bytes, + "isOverloaded": snapshot.is_overloaded, + "selectedServerUuid": best.server.squad_uuid, + "selectedServerName": best.server.display_name, + "servers": [ + { + "uuid": member.server.squad_uuid if member.server else None, + "name": member.server.display_name if member.server else None, + "activeUsers": member.active_users, + "membersCount": member.members_count, + "bandwidthBytes": member.bandwidth_bytes, + "isAvailable": member.is_available, + "isEnabled": member.is_enabled, + "isOverloaded": member.is_overloaded, + } + for member in snapshot.servers + if member.server + ], + } + + for member in snapshot.servers: + if member.server and member.server.squad_uuid: + grouped_server_uuids.add(member.server.squad_uuid) + preferred_selection_for_server.setdefault(member.server.squad_uuid, selection_key) + + groups_context.append( + { + "group": group, + "selection_key": selection_key, + "snapshot": snapshot, + "choice": best, + "metadata": metadata, + } + ) + + server_catalog.setdefault(best.server.squad_uuid, best.server) + + for uuid, server in list(server_catalog.items()): + selection_map.setdefault(uuid, uuid) + preferred_selection_for_server.setdefault(uuid, uuid) + server_uuid_to_id: Dict[str, int] = {} for server in server_catalog.values(): try: @@ -366,6 +447,12 @@ class MiniAppSubscriptionPurchaseService: default_connected = [server.squad_uuid] break + default_selection_keys: List[str] = [] + for uuid in default_connected: + key = preferred_selection_for_server.get(uuid) + if key and key not in default_selection_keys: + default_selection_keys.append(key) + available_periods: Sequence[int] = settings.get_available_subscription_periods() periods: List[PurchasePeriodConfig] = [] period_map: Dict[str, PurchasePeriodConfig] = {} @@ -414,7 +501,9 @@ class MiniAppSubscriptionPurchaseService: texts, period_days, server_catalog, - default_connected, + default_selection_keys, + groups_context=groups_context, + grouped_server_uuids=grouped_server_uuids, ) devices_config = self._build_devices_config( user, @@ -448,6 +537,9 @@ class MiniAppSubscriptionPurchaseService: default_period = period_map.get(f"days:{default_period_days}") or periods[0] + default_selection_keys_payload = list(default_period.servers.default_selection) + default_selection_uuids_payload = [selection_map.get(key, key) for key in default_selection_keys_payload] + default_selection = { "period_id": default_period.id, "periodId": default_period.id, @@ -459,15 +551,18 @@ class MiniAppSubscriptionPurchaseService: "trafficValue": default_period.traffic.current_value if default_period.traffic.current_value is not None else default_period.traffic.default_value, - "servers": list(default_period.servers.default_selection), - "countries": list(default_period.servers.default_selection), - "server_uuids": list(default_period.servers.default_selection), - "serverUuids": list(default_period.servers.default_selection), + "servers": default_selection_keys_payload, + "countries": default_selection_keys_payload, + "server_uuids": default_selection_uuids_payload, + "serverUuids": default_selection_uuids_payload, "devices": default_period.devices.current, "device_limit": default_period.devices.current, "deviceLimit": default_period.devices.current, } + servers_payload = default_period.servers.to_payload() + servers_payload["selection_map"] = selection_map + payload = { "currency": currency, "balance_kopeks": balance_kopeks, @@ -478,7 +573,7 @@ class MiniAppSubscriptionPurchaseService: "subscriptionId": getattr(subscription, "id", None), "periods": [period.to_payload() for period in periods], "traffic": default_period.traffic.to_payload(), - "servers": default_period.servers.to_payload(), + "servers": servers_payload, "devices": default_period.devices.to_payload(), "selection": default_selection, "summary": None, @@ -493,6 +588,7 @@ class MiniAppSubscriptionPurchaseService: default_period=default_period, period_map=period_map, server_uuid_to_id=server_uuid_to_id, + server_selection_map=selection_map, payload=payload, ) @@ -568,22 +664,39 @@ class MiniAppSubscriptionPurchaseService: period_days: int, server_catalog: Dict[str, ServerSquad], default_selection: List[str], + *, + groups_context: List[Dict[str, Any]], + grouped_server_uuids: set, ) -> PurchaseServersConfig: discount_percent = user.get_promo_discount("servers", period_days) options: List[PurchaseServerOption] = [] + for group_ctx in groups_context: + best_snapshot = group_ctx["choice"] + best_server = best_snapshot.server + option = _build_server_option(best_server, discount_percent, texts) + option.selection_key = group_ctx["selection_key"] + option.option_type = "group" + option.is_available = bool(best_snapshot.is_available and best_snapshot.is_enabled) + option.metadata = group_ctx["metadata"] + options.append(option) + for uuid, server in server_catalog.items(): + if uuid in grouped_server_uuids: + continue option = _build_server_option(server, discount_percent, texts) options.append(option) if not options: default_selection = [] + elif not default_selection: + default_selection = [options[0].selection_key or options[0].uuid] return PurchaseServersConfig( options=options, min_selectable=1 if options else 0, max_selectable=len(options), - default_selection=default_selection if default_selection else [opt.uuid for opt in options[:1]], + default_selection=default_selection, hint=None, ) @@ -683,6 +796,14 @@ class MiniAppSubscriptionPurchaseService: if not servers: servers = list(period.servers.default_selection) + resolved_servers: List[str] = [] + for key in servers: + actual_uuid = context.server_selection_map.get(key, key) + if actual_uuid not in resolved_servers: + resolved_servers.append(actual_uuid) + + servers = resolved_servers + if period.servers.min_selectable and len(servers) < period.servers.min_selectable: raise PurchaseValidationError("Select at least one server", code="invalid_servers") diff --git a/app/states.py b/app/states.py index 39b3744d..38768103 100644 --- a/app/states.py +++ b/app/states.py @@ -183,6 +183,13 @@ class SquadMigrationStates(StatesGroup): confirming = State() +class ServerGroupStates(StatesGroup): + waiting_for_name = State() + selecting_servers = State() + editing_name = State() + editing_servers = State() + + class RemnaWaveSyncStates(StatesGroup): waiting_for_schedule = State()