feat: add backend support for server groups

This commit is contained in:
Egor
2025-11-09 04:36:33 +03:00
parent d6eec8787e
commit a043fc0e46
7 changed files with 854 additions and 11 deletions

View File

@@ -0,0 +1,249 @@
from __future__ import annotations
import logging
from typing import Dict, List, Optional, Sequence
from sqlalchemy import and_, delete, func, select, update, cast, String
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.database.models import (
ServerGroup,
ServerGroupServer,
ServerSquad,
Subscription,
SubscriptionStatus,
)
logger = logging.getLogger(__name__)
async def get_server_groups(
db: AsyncSession,
*,
include_servers: bool = True,
) -> List[ServerGroup]:
"""Возвращает список всех групп серверов."""
query = (
select(ServerGroup)
.order_by(ServerGroup.sort_order, ServerGroup.name)
)
if include_servers:
query = query.options(
selectinload(ServerGroup.servers).selectinload(ServerGroupServer.server)
)
result = await db.execute(query)
return list(result.scalars().unique().all())
async def get_server_group_by_id(
db: AsyncSession,
group_id: int,
*,
include_servers: bool = True,
) -> Optional[ServerGroup]:
query = select(ServerGroup).where(ServerGroup.id == group_id)
if include_servers:
query = query.options(
selectinload(ServerGroup.servers).selectinload(ServerGroupServer.server)
)
result = await db.execute(query)
return result.scalars().unique().one_or_none()
async def get_server_group_by_name(db: AsyncSession, name: str) -> Optional[ServerGroup]:
result = await db.execute(
select(ServerGroup).where(func.lower(ServerGroup.name) == func.lower(name))
)
return result.scalars().one_or_none()
async def create_server_group(
db: AsyncSession,
*,
name: str,
server_ids: Optional[Sequence[int]] = None,
sort_order: int = 0,
is_active: bool = True,
) -> ServerGroup:
existing = await get_server_group_by_name(db, name)
if existing:
raise ValueError("Группа с таким названием уже существует")
group = ServerGroup(name=name.strip(), sort_order=sort_order, is_active=is_active)
db.add(group)
await db.flush()
await _sync_group_servers(db, group, server_ids or [])
await db.commit()
await db.refresh(group)
return group
async def update_server_group(
db: AsyncSession,
group_id: int,
*,
name: Optional[str] = None,
server_ids: Optional[Sequence[int]] = None,
sort_order: Optional[int] = None,
is_active: Optional[bool] = None,
) -> Optional[ServerGroup]:
group = await get_server_group_by_id(db, group_id)
if not group:
return None
updates: Dict = {}
if name is not None:
trimmed = name.strip()
if trimmed and trimmed.lower() != group.name.lower():
duplicate = await get_server_group_by_name(db, trimmed)
if duplicate and duplicate.id != group.id:
raise ValueError("Группа с таким названием уже существует")
updates["name"] = trimmed
if sort_order is not None:
updates["sort_order"] = int(sort_order)
if is_active is not None:
updates["is_active"] = bool(is_active)
if updates:
await db.execute(
update(ServerGroup)
.where(ServerGroup.id == group.id)
.values(**updates)
)
if server_ids is not None:
await _sync_group_servers(db, group, server_ids)
await db.commit()
return await get_server_group_by_id(db, group_id)
async def delete_server_group(db: AsyncSession, group_id: int) -> bool:
group = await get_server_group_by_id(db, group_id)
if not group:
return False
if await is_group_in_use(db, group):
raise ValueError("Нельзя удалить группу, пока её серверы используются активными подписками")
await db.execute(delete(ServerGroup).where(ServerGroup.id == group.id))
await db.commit()
return True
async def toggle_group_server(
db: AsyncSession,
group_id: int,
server_id: int,
*,
is_enabled: bool,
) -> bool:
result = await db.execute(
update(ServerGroupServer)
.where(
and_(
ServerGroupServer.group_id == group_id,
ServerGroupServer.server_squad_id == server_id,
)
)
.values(is_enabled=is_enabled)
)
if getattr(result, "rowcount", 0):
await db.commit()
return True
return False
async def _sync_group_servers(
db: AsyncSession,
group: ServerGroup,
server_ids: Sequence[int],
) -> None:
unique_ids = {int(server_id) for server_id in server_ids if server_id}
existing_ids = {member.server_squad_id for member in group.servers}
to_remove = existing_ids - unique_ids
to_add = unique_ids - existing_ids
if to_remove:
await db.execute(
delete(ServerGroupServer).where(
and_(
ServerGroupServer.group_id == group.id,
ServerGroupServer.server_squad_id.in_(to_remove),
)
)
)
if to_add:
new_relations = [
ServerGroupServer(group_id=group.id, server_squad_id=server_id)
for server_id in to_add
]
db.add_all(new_relations)
await db.flush()
async def is_group_in_use(
db: AsyncSession,
group: ServerGroup | int,
) -> bool:
if isinstance(group, int):
group_obj = await get_server_group_by_id(db, group)
else:
group_obj = group
if not group_obj:
return False
squad_uuids = [
member.server.squad_uuid
for member in group_obj.servers
if member.server and member.server.squad_uuid
]
if not squad_uuids:
return False
like_filters = [
cast(Subscription.connected_squads, String).like(f'%"{uuid}"%')
for uuid in squad_uuids
]
condition = like_filters[0]
for clause in like_filters[1:]:
condition = condition | clause
result = await db.execute(
select(func.count(Subscription.id)).where(
Subscription.status.in_(
[
SubscriptionStatus.ACTIVE.value,
SubscriptionStatus.TRIAL.value,
]
),
condition,
)
)
return (result.scalar() or 0) > 0
async def get_group_server_ids(group: ServerGroup) -> List[int]:
return [member.server_squad_id for member in group.servers if member.server_squad_id]
async def get_group_server_uuids(group: ServerGroup) -> List[str]:
return [
member.server.squad_uuid
for member in group.servers
if member.server and member.server.squad_uuid
]

View File

@@ -1315,6 +1315,12 @@ class ServerSquad(Base):
back_populates="server_squads",
lazy="selectin",
)
groups = relationship(
"ServerGroupServer",
back_populates="server",
cascade="all, delete-orphan",
)
@property
def price_rubles(self) -> float:
@@ -1336,6 +1342,50 @@ class ServerSquad(Base):
return "Доступен"
class ServerGroup(Base):
__tablename__ = "server_groups"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(255), nullable=False, unique=True)
is_active = Column(Boolean, default=True, nullable=False)
sort_order = Column(Integer, default=0, nullable=False)
created_at = Column(DateTime, default=func.now(), nullable=False)
updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False)
servers = relationship(
"ServerGroupServer",
back_populates="group",
cascade="all, delete-orphan",
order_by="ServerGroupServer.id",
)
def __repr__(self) -> str:
return f"<ServerGroup(id={self.id}, name='{self.name}')>"
class ServerGroupServer(Base):
__tablename__ = "server_group_servers"
__table_args__ = (
UniqueConstraint("group_id", "server_squad_id", name="uq_group_server"),
)
id = Column(Integer, primary_key=True, index=True)
group_id = Column(Integer, ForeignKey("server_groups.id", ondelete="CASCADE"), nullable=False, index=True)
server_squad_id = Column(Integer, ForeignKey("server_squads.id", ondelete="CASCADE"), nullable=False, index=True)
is_enabled = Column(Boolean, default=True, nullable=False)
created_at = Column(DateTime, default=func.now(), nullable=False)
updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False)
group = relationship("ServerGroup", back_populates="servers")
server = relationship("ServerSquad", back_populates="groups")
def __repr__(self) -> str:
return (
f"<ServerGroupServer(group_id={self.group_id}, server_squad_id={self.server_squad_id}, "
f"is_enabled={self.is_enabled})>"
)
class SubscriptionServer(Base):
__tablename__ = "subscription_servers"

View File

@@ -3533,6 +3533,154 @@ async def add_promo_group_priority_column() -> bool:
return False
async def create_server_groups_table() -> bool:
table_exists = await check_table_exists("server_groups")
if table_exists:
logger.info(" Таблица server_groups уже существует")
return True
try:
async with engine.begin() as conn:
db_type = await get_database_type()
if db_type == "sqlite":
await conn.execute(text(
"""
CREATE TABLE server_groups (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
is_active INTEGER NOT NULL DEFAULT 1,
sort_order INTEGER NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
"""
))
await conn.execute(text(
"CREATE INDEX IF NOT EXISTS idx_server_groups_sort ON server_groups(sort_order, name)"
))
elif db_type == "postgresql":
await conn.execute(text(
"""
CREATE TABLE server_groups (
id SERIAL PRIMARY KEY,
name VARCHAR(255) NOT NULL UNIQUE,
is_active BOOLEAN NOT NULL DEFAULT TRUE,
sort_order INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW()
);
"""
))
await conn.execute(text(
"CREATE INDEX IF NOT EXISTS idx_server_groups_sort ON server_groups(sort_order, name)"
))
else: # MySQL
await conn.execute(text(
"""
CREATE TABLE server_groups (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(255) NOT NULL UNIQUE,
is_active TINYINT(1) NOT NULL DEFAULT 1,
sort_order INT NOT NULL DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
) ENGINE=InnoDB;
"""
))
await conn.execute(text(
"CREATE INDEX idx_server_groups_sort ON server_groups(sort_order, name)"
))
logger.info("✅ Таблица server_groups создана")
return True
except Exception as error:
logger.error(f"❌ Ошибка создания таблицы server_groups: {error}")
return False
async def create_server_group_servers_table() -> bool:
table_exists = await check_table_exists("server_group_servers")
if table_exists:
logger.info(" Таблица server_group_servers уже существует")
return True
try:
async with engine.begin() as conn:
db_type = await get_database_type()
if db_type == "sqlite":
await conn.execute(text(
"""
CREATE TABLE server_group_servers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
group_id INTEGER NOT NULL,
server_squad_id INTEGER NOT NULL,
is_enabled INTEGER NOT NULL DEFAULT 1,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(group_id, server_squad_id),
FOREIGN KEY(group_id) REFERENCES server_groups(id) ON DELETE CASCADE,
FOREIGN KEY(server_squad_id) REFERENCES server_squads(id) ON DELETE CASCADE
);
"""
))
await conn.execute(text(
"CREATE INDEX IF NOT EXISTS idx_server_group_servers_group ON server_group_servers(group_id)"
))
await conn.execute(text(
"CREATE INDEX IF NOT EXISTS idx_server_group_servers_server ON server_group_servers(server_squad_id)"
))
elif db_type == "postgresql":
await conn.execute(text(
"""
CREATE TABLE server_group_servers (
id SERIAL PRIMARY KEY,
group_id INTEGER NOT NULL REFERENCES server_groups(id) ON DELETE CASCADE,
server_squad_id INTEGER NOT NULL REFERENCES server_squads(id) ON DELETE CASCADE,
is_enabled BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW(),
UNIQUE(group_id, server_squad_id)
);
"""
))
await conn.execute(text(
"CREATE INDEX IF NOT EXISTS idx_server_group_servers_group ON server_group_servers(group_id)"
))
await conn.execute(text(
"CREATE INDEX IF NOT EXISTS idx_server_group_servers_server ON server_group_servers(server_squad_id)"
))
else: # MySQL
await conn.execute(text(
"""
CREATE TABLE server_group_servers (
id INT AUTO_INCREMENT PRIMARY KEY,
group_id INT NOT NULL,
server_squad_id INT NOT NULL,
is_enabled TINYINT(1) NOT NULL DEFAULT 1,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
UNIQUE KEY uq_group_server (group_id, server_squad_id),
CONSTRAINT fk_group FOREIGN KEY (group_id) REFERENCES server_groups(id) ON DELETE CASCADE,
CONSTRAINT fk_group_server FOREIGN KEY (server_squad_id) REFERENCES server_squads(id) ON DELETE CASCADE
) ENGINE=InnoDB;
"""
))
await conn.execute(text(
"CREATE INDEX idx_server_group_servers_group ON server_group_servers(group_id)"
))
await conn.execute(text(
"CREATE INDEX idx_server_group_servers_server ON server_group_servers(server_squad_id)"
))
logger.info("✅ Таблица server_group_servers создана")
return True
except Exception as error:
logger.error(f"❌ Ошибка создания таблицы server_group_servers: {error}")
return False
async def create_user_promo_groups_table() -> bool:
"""Создает таблицу user_promo_groups для связи Many-to-Many между users и promo_groups."""
table_exists = await check_table_exists("user_promo_groups")
@@ -3742,6 +3890,16 @@ async def run_universal_migration():
else:
logger.warning("⚠️ Проблемы с колонкой is_trial_eligible")
logger.info("=== СОЗДАНИЕ ТАБЛИЦЫ SERVER_GROUPS ===")
server_groups_ready = await create_server_groups_table()
if not server_groups_ready:
logger.warning("⚠️ Не удалось создать таблицу server_groups")
logger.info("=== СОЗДАНИЕ ТАБЛИЦЫ SERVER_GROUP_SERVERS ===")
server_group_servers_ready = await create_server_group_servers_table()
if not server_group_servers_ready:
logger.warning("⚠️ Не удалось создать таблицу server_group_servers")
logger.info("=== СОЗДАНИЕ ТАБЛИЦЫ PRIVACY_POLICIES ===")
privacy_policies_ready = await create_privacy_policies_table()
if privacy_policies_ready:
@@ -4134,6 +4292,8 @@ async def check_migration_status():
"promo_offer_templates_active_discount_column": False,
"promo_offer_logs_table": False,
"subscription_temporary_access_table": False,
"server_groups_table": False,
"server_group_servers_table": False,
}
status["has_made_first_topup_column"] = await check_column_exists('users', 'has_made_first_topup')
@@ -4156,6 +4316,8 @@ async def check_migration_status():
status["promo_offer_templates_active_discount_column"] = await check_column_exists('promo_offer_templates', 'active_discount_hours')
status["promo_offer_logs_table"] = await check_table_exists('promo_offer_logs')
status["subscription_temporary_access_table"] = await check_table_exists('subscription_temporary_access')
status["server_groups_table"] = await check_table_exists('server_groups')
status["server_group_servers_table"] = await check_table_exists('server_group_servers')
status["welcome_texts_is_enabled_column"] = await check_column_exists('welcome_texts', 'is_enabled')
status["users_promo_group_column"] = await check_column_exists('users', 'promo_group_id')
@@ -4221,6 +4383,8 @@ async def check_migration_status():
"promo_offer_templates_active_discount_column": "Колонка active_discount_hours в promo_offer_templates",
"promo_offer_logs_table": "Таблица promo_offer_logs",
"subscription_temporary_access_table": "Таблица subscription_temporary_access",
"server_groups_table": "Таблица server_groups",
"server_group_servers_table": "Таблица server_group_servers",
}
for check_key, check_status in status.items():

View File

@@ -38,6 +38,7 @@ from app.database.models import (
SubscriptionStatus,
ServerSquad,
)
from app.utils.cache import cache, cache_key
from app.utils.subscription_utils import (
resolve_hwid_device_limit_for_payload,
)
@@ -536,8 +537,8 @@ class RemnaWaveService:
'bytes': daily_bytes
})
nodes_weekly_data = list(nodes_by_name.values())
nodes_weekly_data.sort(key=lambda x: x['total_bytes'], reverse=True)
nodes_weekly_data = list(nodes_by_name.values())
nodes_weekly_data.sort(key=lambda x: x['total_bytes'], reverse=True)
result = {
"system": {
@@ -619,7 +620,7 @@ class RemnaWaveService:
logger.info(f"Статистика сформирована: пользователи={result['system']['total_users']}, общий трафик={total_user_traffic}")
return result
except RemnaWaveAPIError as e:
logger.error(f"Ошибка Remnawave API при получении статистики: {e}")
return {"error": str(e)}
@@ -627,7 +628,83 @@ class RemnaWaveService:
logger.error(f"Общая ошибка получения системной статистики: {e}")
return {"error": f"Внутренняя ошибка сервера: {str(e)}"}
async def get_internal_squad_usage_map(
self,
*,
force_refresh: bool = False,
cache_ttl: int = 60,
) -> Dict[str, Dict[str, int]]:
cache_name = cache_key("remnawave", "internal_squad_usage")
if not force_refresh:
cached = await cache.get(cache_name)
if isinstance(cached, dict):
return cached
usage_map: Dict[str, Dict[str, int]] = {}
internal_squads: List[RemnaWaveInternalSquad] = []
realtime_usage: List[Dict[str, Any]] = []
try:
async with self.get_api_client() as api:
try:
internal_squads = await api.get_internal_squads()
except Exception as error:
logger.warning("Не удалось получить список сквадов RemnaWave: %s", error)
try:
realtime_usage = await api.get_nodes_realtime_usage()
except Exception as error:
logger.warning("Не удалось получить статистику нагрузки сквадов: %s", error)
except RemnaWaveConfigurationError as error:
logger.debug("RemnaWave API не настроен: %s", error)
return {}
for squad in internal_squads:
usage_map[squad.uuid] = {
"members_count": int(squad.members_count or 0),
"inbounds_count": int(squad.inbounds_count or 0),
"download_bytes": 0,
"upload_bytes": 0,
"bandwidth_bytes": 0,
}
for node in realtime_usage:
squad_uuid = (
node.get("internalSquadUuid")
or node.get("squadUuid")
or node.get("squadUUID")
or node.get("internal_squad_uuid")
or node.get("internalSquadUUID")
)
if not squad_uuid:
continue
stats = usage_map.setdefault(
squad_uuid,
{
"members_count": 0,
"inbounds_count": 0,
"download_bytes": 0,
"upload_bytes": 0,
"bandwidth_bytes": 0,
},
)
download = int(node.get("downloadBytes", 0) or 0)
upload = int(node.get("uploadBytes", 0) or 0)
stats["download_bytes"] += download
stats["upload_bytes"] += upload
stats["bandwidth_bytes"] = stats["download_bytes"] + stats["upload_bytes"]
if usage_map:
await cache.set(cache_name, usage_map, expire=cache_ttl)
return usage_map
def _parse_bandwidth_string(self, bandwidth_str: str) -> int:
try:
if not bandwidth_str or bandwidth_str == '0 B' or bandwidth_str == '0':

View File

@@ -0,0 +1,175 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Callable, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from app.database.crud.server_group import get_server_group_by_id
from app.database.crud.server_squad import count_active_users_for_squad
from app.database.models import ServerGroup, ServerGroupServer, ServerSquad
from app.services.remnawave_service import RemnaWaveService
logger = logging.getLogger(__name__)
@dataclass
class ServerLoadSnapshot:
server: ServerSquad
membership: ServerGroupServer
active_users: int
members_count: int
bandwidth_bytes: int
download_bytes: int
upload_bytes: int
is_available: bool
is_enabled: bool
is_overloaded: bool
load_ratio: float
@property
def maintenance(self) -> bool:
return not self.is_enabled
@dataclass
class ServerGroupSnapshot:
group: ServerGroup
servers: List[ServerLoadSnapshot]
total_active_users: int
total_bandwidth_bytes: int
available_servers: List[ServerLoadSnapshot]
@property
def is_empty(self) -> bool:
return not self.servers
@property
def is_overloaded(self) -> bool:
return all(server.is_overloaded or not server.is_available for server in self.available_servers)
async def build_group_snapshot(
db: AsyncSession,
remnawave: RemnaWaveService,
group: ServerGroup | int,
*,
refresh: bool = False,
) -> ServerGroupSnapshot:
if isinstance(group, int):
group_obj = await get_server_group_by_id(db, group)
else:
group_obj = group
if not group_obj:
raise ValueError("Server group not found")
usage_map = await remnawave.get_internal_squad_usage_map(force_refresh=refresh)
servers: List[ServerLoadSnapshot] = []
total_active_users = 0
total_bandwidth_bytes = 0
for membership in group_obj.servers:
server = membership.server
if not server:
continue
usage = usage_map.get(server.squad_uuid, {}) or {}
members_count = int(usage.get("members_count", 0) or 0)
download_bytes = int(usage.get("download_bytes", 0) or 0)
upload_bytes = int(usage.get("upload_bytes", 0) or 0)
bandwidth_bytes = int(usage.get("bandwidth_bytes", download_bytes + upload_bytes) or 0)
db_active = await count_active_users_for_squad(db, server.squad_uuid)
active_users = max(db_active, members_count, int(server.current_users or 0))
capacity = server.max_users or 0
load_ratio = (active_users / capacity) if capacity else float(active_users)
is_overloaded = capacity > 0 and active_users >= capacity
snapshot = ServerLoadSnapshot(
server=server,
membership=membership,
active_users=active_users,
members_count=members_count,
bandwidth_bytes=bandwidth_bytes,
download_bytes=download_bytes,
upload_bytes=upload_bytes,
is_available=bool(server.is_available),
is_enabled=bool(membership.is_enabled),
is_overloaded=is_overloaded,
load_ratio=load_ratio,
)
servers.append(snapshot)
total_active_users += active_users
total_bandwidth_bytes += bandwidth_bytes
available_servers = [
server
for server in servers
if server.is_available and server.is_enabled
]
return ServerGroupSnapshot(
group=group_obj,
servers=servers,
total_active_users=total_active_users,
total_bandwidth_bytes=total_bandwidth_bytes,
available_servers=available_servers,
)
async def choose_optimal_server(
db: AsyncSession,
remnawave: RemnaWaveService,
group: ServerGroup | int,
*,
refresh_stats: bool = False,
notify_overload: Optional[Callable[[ServerGroupSnapshot], None]] = None,
) -> Optional[tuple[ServerGroupSnapshot, ServerLoadSnapshot]]:
snapshot = await build_group_snapshot(db, remnawave, group, refresh=refresh_stats)
if not snapshot.servers:
logger.warning("Server group %s не содержит серверов", getattr(snapshot.group, "name", snapshot.group.id))
return None
candidates = [
server for server in snapshot.servers if server.is_available and server.is_enabled
]
if not candidates:
candidates = [server for server in snapshot.servers if server.is_enabled]
if not candidates:
logger.error("Нет доступных серверов в группе %s", snapshot.group.name)
if notify_overload:
notify_overload(snapshot)
return None
candidates.sort(
key=lambda item: (
item.active_users,
item.bandwidth_bytes,
item.server.current_users or 0,
item.server.sort_order,
item.server.display_name,
)
)
best = candidates[0]
logger.info(
"Выбран сервер %s (%s) для группы %s: активных=%s, трафик=%s",
best.server.display_name,
best.server.squad_uuid,
snapshot.group.name,
best.active_users,
best.bandwidth_bytes,
)
if notify_overload and snapshot.is_overloaded:
notify_overload(snapshot)
return snapshot, best

View File

@@ -7,6 +7,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import PERIOD_PRICES, settings
from app.database.crud.server_group import get_server_groups
from app.database.crud.server_squad import (
add_user_to_servers,
get_available_server_squads,
@@ -24,6 +25,8 @@ from app.database.crud.transaction import create_transaction
from app.database.crud.user import subtract_user_balance
from app.database.models import ServerSquad, Subscription, SubscriptionStatus, TransactionType, User
from app.localization.texts import get_texts
from app.services.remnawave_service import RemnaWaveService
from app.services.server_group_service import choose_optimal_server
from app.services.subscription_service import SubscriptionService
from app.utils.pricing_utils import (
calculate_months_from_days,
@@ -103,6 +106,9 @@ class PurchaseServerOption:
original_price_label: Optional[str] = None
discount_percent: int = 0
is_available: bool = True
selection_key: Optional[str] = None
option_type: str = "server"
metadata: Optional[Dict[str, Any]] = None
def to_payload(self) -> Dict[str, Any]:
payload: Dict[str, Any] = {
@@ -112,6 +118,10 @@ class PurchaseServerOption:
"price_label": self.price_label,
"is_available": self.is_available,
}
payload["selection_key"] = self.selection_key or self.uuid
payload["type"] = self.option_type
if self.metadata:
payload["meta"] = self.metadata
if self.original_price_per_month is not None and (
self.original_price_label and self.original_price_per_month != self.price_per_month
):
@@ -255,6 +265,7 @@ class PurchaseOptionsContext:
default_period: PurchasePeriodConfig
period_map: Dict[str, PurchasePeriodConfig]
server_uuid_to_id: Dict[str, int]
server_selection_map: Dict[str, str]
payload: Dict[str, Any]
@@ -322,6 +333,8 @@ def _build_server_option(
original_price_label=texts.format_price(base_per_month) if base_per_month != discounted_per_month else None,
discount_percent=max(0, discount_percent),
is_available=bool(getattr(server, "is_available", True) and not getattr(server, "is_full", False)),
selection_key=server.squad_uuid,
option_type="server",
)
@@ -352,6 +365,74 @@ class MiniAppSubscriptionPurchaseService:
if existing:
server_catalog[uuid] = existing
remnawave_service = RemnaWaveService()
server_groups = await get_server_groups(db)
groups_context: List[Dict[str, Any]] = []
grouped_server_uuids: set[str] = set()
selection_map: Dict[str, str] = {}
preferred_selection_for_server: Dict[str, str] = {}
for group in server_groups:
if not group.is_active or not group.servers:
continue
try:
choice = await choose_optimal_server(db, remnawave_service, group)
except Exception as error: # pragma: no cover - defensive logging
logger.warning("Failed to evaluate server group %s: %s", getattr(group, "name", group.id), error)
continue
if not choice:
continue
snapshot, best = choice
selection_key = f"group:{group.id}"
selection_map[selection_key] = best.server.squad_uuid
metadata = {
"groupId": group.id,
"groupName": group.name,
"totalActiveUsers": snapshot.total_active_users,
"totalBandwidthBytes": snapshot.total_bandwidth_bytes,
"isOverloaded": snapshot.is_overloaded,
"selectedServerUuid": best.server.squad_uuid,
"selectedServerName": best.server.display_name,
"servers": [
{
"uuid": member.server.squad_uuid if member.server else None,
"name": member.server.display_name if member.server else None,
"activeUsers": member.active_users,
"membersCount": member.members_count,
"bandwidthBytes": member.bandwidth_bytes,
"isAvailable": member.is_available,
"isEnabled": member.is_enabled,
"isOverloaded": member.is_overloaded,
}
for member in snapshot.servers
if member.server
],
}
for member in snapshot.servers:
if member.server and member.server.squad_uuid:
grouped_server_uuids.add(member.server.squad_uuid)
preferred_selection_for_server.setdefault(member.server.squad_uuid, selection_key)
groups_context.append(
{
"group": group,
"selection_key": selection_key,
"snapshot": snapshot,
"choice": best,
"metadata": metadata,
}
)
server_catalog.setdefault(best.server.squad_uuid, best.server)
for uuid, server in list(server_catalog.items()):
selection_map.setdefault(uuid, uuid)
preferred_selection_for_server.setdefault(uuid, uuid)
server_uuid_to_id: Dict[str, int] = {}
for server in server_catalog.values():
try:
@@ -366,6 +447,12 @@ class MiniAppSubscriptionPurchaseService:
default_connected = [server.squad_uuid]
break
default_selection_keys: List[str] = []
for uuid in default_connected:
key = preferred_selection_for_server.get(uuid)
if key and key not in default_selection_keys:
default_selection_keys.append(key)
available_periods: Sequence[int] = settings.get_available_subscription_periods()
periods: List[PurchasePeriodConfig] = []
period_map: Dict[str, PurchasePeriodConfig] = {}
@@ -414,7 +501,9 @@ class MiniAppSubscriptionPurchaseService:
texts,
period_days,
server_catalog,
default_connected,
default_selection_keys,
groups_context=groups_context,
grouped_server_uuids=grouped_server_uuids,
)
devices_config = self._build_devices_config(
user,
@@ -448,6 +537,9 @@ class MiniAppSubscriptionPurchaseService:
default_period = period_map.get(f"days:{default_period_days}") or periods[0]
default_selection_keys_payload = list(default_period.servers.default_selection)
default_selection_uuids_payload = [selection_map.get(key, key) for key in default_selection_keys_payload]
default_selection = {
"period_id": default_period.id,
"periodId": default_period.id,
@@ -459,15 +551,18 @@ class MiniAppSubscriptionPurchaseService:
"trafficValue": default_period.traffic.current_value
if default_period.traffic.current_value is not None
else default_period.traffic.default_value,
"servers": list(default_period.servers.default_selection),
"countries": list(default_period.servers.default_selection),
"server_uuids": list(default_period.servers.default_selection),
"serverUuids": list(default_period.servers.default_selection),
"servers": default_selection_keys_payload,
"countries": default_selection_keys_payload,
"server_uuids": default_selection_uuids_payload,
"serverUuids": default_selection_uuids_payload,
"devices": default_period.devices.current,
"device_limit": default_period.devices.current,
"deviceLimit": default_period.devices.current,
}
servers_payload = default_period.servers.to_payload()
servers_payload["selection_map"] = selection_map
payload = {
"currency": currency,
"balance_kopeks": balance_kopeks,
@@ -478,7 +573,7 @@ class MiniAppSubscriptionPurchaseService:
"subscriptionId": getattr(subscription, "id", None),
"periods": [period.to_payload() for period in periods],
"traffic": default_period.traffic.to_payload(),
"servers": default_period.servers.to_payload(),
"servers": servers_payload,
"devices": default_period.devices.to_payload(),
"selection": default_selection,
"summary": None,
@@ -493,6 +588,7 @@ class MiniAppSubscriptionPurchaseService:
default_period=default_period,
period_map=period_map,
server_uuid_to_id=server_uuid_to_id,
server_selection_map=selection_map,
payload=payload,
)
@@ -568,22 +664,39 @@ class MiniAppSubscriptionPurchaseService:
period_days: int,
server_catalog: Dict[str, ServerSquad],
default_selection: List[str],
*,
groups_context: List[Dict[str, Any]],
grouped_server_uuids: set,
) -> PurchaseServersConfig:
discount_percent = user.get_promo_discount("servers", period_days)
options: List[PurchaseServerOption] = []
for group_ctx in groups_context:
best_snapshot = group_ctx["choice"]
best_server = best_snapshot.server
option = _build_server_option(best_server, discount_percent, texts)
option.selection_key = group_ctx["selection_key"]
option.option_type = "group"
option.is_available = bool(best_snapshot.is_available and best_snapshot.is_enabled)
option.metadata = group_ctx["metadata"]
options.append(option)
for uuid, server in server_catalog.items():
if uuid in grouped_server_uuids:
continue
option = _build_server_option(server, discount_percent, texts)
options.append(option)
if not options:
default_selection = []
elif not default_selection:
default_selection = [options[0].selection_key or options[0].uuid]
return PurchaseServersConfig(
options=options,
min_selectable=1 if options else 0,
max_selectable=len(options),
default_selection=default_selection if default_selection else [opt.uuid for opt in options[:1]],
default_selection=default_selection,
hint=None,
)
@@ -683,6 +796,14 @@ class MiniAppSubscriptionPurchaseService:
if not servers:
servers = list(period.servers.default_selection)
resolved_servers: List[str] = []
for key in servers:
actual_uuid = context.server_selection_map.get(key, key)
if actual_uuid not in resolved_servers:
resolved_servers.append(actual_uuid)
servers = resolved_servers
if period.servers.min_selectable and len(servers) < period.servers.min_selectable:
raise PurchaseValidationError("Select at least one server", code="invalid_servers")

View File

@@ -183,6 +183,13 @@ class SquadMigrationStates(StatesGroup):
confirming = State()
class ServerGroupStates(StatesGroup):
waiting_for_name = State()
selecting_servers = State()
editing_name = State()
editing_servers = State()
class RemnaWaveSyncStates(StatesGroup):
waiting_for_schedule = State()