Files
remnawave-bedolaga-telegram…/app/database/crud/server_squad.py
2025-11-09 05:48:45 +03:00

795 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import random
from datetime import datetime
from typing import Iterable, List, Optional, Sequence, Tuple
from sqlalchemy import (
select,
and_,
func,
update,
delete,
text,
or_,
cast,
String,
)
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.database.models import (
PromoGroup,
ServerCategory,
ServerSquad,
SubscriptionServer,
Subscription,
SubscriptionStatus,
User,
)
logger = logging.getLogger(__name__)
async def _get_default_promo_group_id(db: AsyncSession) -> Optional[int]:
result = await db.execute(
select(PromoGroup.id).where(PromoGroup.is_default.is_(True)).limit(1)
)
return result.scalar_one_or_none()
async def create_server_squad(
db: AsyncSession,
squad_uuid: str,
display_name: str,
original_name: str = None,
country_code: str = None,
price_kopeks: int = 0,
description: str = None,
max_users: int = None,
is_available: bool = True,
is_trial_eligible: bool = False,
sort_order: int = 0,
category_id: Optional[int] = None,
promo_group_ids: Optional[Iterable[int]] = None,
) -> ServerSquad:
normalized_group_ids: Sequence[int]
if promo_group_ids is None:
default_id = await _get_default_promo_group_id(db)
normalized_group_ids = [default_id] if default_id is not None else []
else:
normalized_group_ids = [int(pg_id) for pg_id in set(promo_group_ids)]
if not normalized_group_ids:
raise ValueError("Server squad must be linked to at least one promo group")
promo_groups_result = await db.execute(
select(PromoGroup).where(PromoGroup.id.in_(normalized_group_ids))
)
promo_groups = promo_groups_result.scalars().all()
if len(promo_groups) != len(normalized_group_ids):
logger.warning(
"Не все промогруппы найдены при создании сервера %s", display_name
)
if category_id is not None:
category_exists = await db.execute(
select(ServerCategory.id).where(ServerCategory.id == category_id)
)
if category_exists.scalar_one_or_none() is None:
raise ValueError("Server category not found")
server_squad = ServerSquad(
squad_uuid=squad_uuid,
display_name=display_name,
original_name=original_name,
country_code=country_code,
price_kopeks=price_kopeks,
description=description,
max_users=max_users,
is_available=is_available,
is_trial_eligible=is_trial_eligible,
sort_order=sort_order,
category_id=category_id,
allowed_promo_groups=promo_groups,
)
db.add(server_squad)
await db.commit()
await db.refresh(server_squad)
logger.info(f"✅ Создан сервер {display_name} (UUID: {squad_uuid})")
return server_squad
async def get_server_squad_by_uuid(
db: AsyncSession,
squad_uuid: str
) -> Optional[ServerSquad]:
result = await db.execute(
select(ServerSquad)
.options(
selectinload(ServerSquad.allowed_promo_groups),
selectinload(ServerSquad.category),
)
.where(ServerSquad.squad_uuid == squad_uuid)
)
return result.scalars().unique().one_or_none()
async def get_server_squad_by_id(
db: AsyncSession,
server_id: int
) -> Optional[ServerSquad]:
result = await db.execute(
select(ServerSquad)
.options(
selectinload(ServerSquad.allowed_promo_groups),
selectinload(ServerSquad.category),
)
.where(ServerSquad.id == server_id)
)
return result.scalars().unique().one_or_none()
async def get_all_server_squads(
db: AsyncSession,
available_only: bool = False,
page: int = 1,
limit: int = 50
) -> Tuple[List[ServerSquad], int]:
query = select(ServerSquad).options(selectinload(ServerSquad.category))
if available_only:
query = query.where(ServerSquad.is_available == True)
count_query = select(func.count(ServerSquad.id))
if available_only:
count_query = count_query.where(ServerSquad.is_available == True)
count_result = await db.execute(count_query)
total_count = count_result.scalar()
offset = (page - 1) * limit
query = query.order_by(ServerSquad.sort_order, ServerSquad.display_name)
query = query.offset(offset).limit(limit)
result = await db.execute(query)
servers = result.scalars().all()
return servers, total_count
async def get_available_server_squads(
db: AsyncSession,
promo_group_id: Optional[int] = None,
category_id: Optional[int] = None,
only_with_capacity: bool = False,
) -> List[ServerSquad]:
query = (
select(ServerSquad)
.options(
selectinload(ServerSquad.allowed_promo_groups),
selectinload(ServerSquad.category),
)
.where(ServerSquad.is_available.is_(True))
.order_by(ServerSquad.sort_order, ServerSquad.display_name)
)
if promo_group_id is not None:
query = query.join(ServerSquad.allowed_promo_groups).where(
PromoGroup.id == promo_group_id
)
if category_id is not None:
query = query.where(ServerSquad.category_id == category_id)
result = await db.execute(query)
squads = result.scalars().unique().all()
if only_with_capacity:
filtered: List[ServerSquad] = []
for squad in squads:
if squad.is_full:
continue
filtered.append(squad)
return filtered
return squads
async def get_active_server_squads(db: AsyncSession) -> List[ServerSquad]:
"""Возвращает список активных серверов, доступных для подключения."""
squads = await get_available_server_squads(db)
if not squads:
return []
eligible: List[ServerSquad] = []
for squad in squads:
max_users = squad.max_users
current_users = squad.current_users or 0
if max_users is not None and current_users >= max_users:
continue
eligible.append(squad)
if eligible:
return eligible
return squads
async def choose_random_active_server_squad(
db: AsyncSession,
) -> Optional[ServerSquad]:
"""Возвращает случайный активный сервер."""
squads = await get_active_server_squads(db)
if not squads:
return None
return random.choice(squads)
async def get_random_active_squad_uuid(
db: AsyncSession,
fallback_uuid: Optional[str] = None,
) -> Optional[str]:
"""Возвращает UUID случайного активного сервера или запасной UUID."""
squad = await choose_random_active_server_squad(db)
if squad:
return squad.squad_uuid
return fallback_uuid
async def choose_least_loaded_server_in_category(
db: AsyncSession,
category_id: int,
promo_group_id: Optional[int] = None,
) -> Optional[ServerSquad]:
"""Выбирает наименее загруженный сервер в категории."""
category_squads = await get_available_server_squads(
db,
promo_group_id=promo_group_id,
category_id=category_id,
only_with_capacity=True,
)
if not category_squads:
return None
def load_key(squad: ServerSquad) -> tuple:
max_users = squad.max_users or 0
current_users = squad.current_users or 0
if max_users:
ratio = current_users / max_users
else:
ratio = 0
return (ratio, current_users, squad.id)
return min(category_squads, key=load_key)
async def update_server_squad_promo_groups(
db: AsyncSession, server_id: int, promo_group_ids: Iterable[int]
) -> Optional[ServerSquad]:
unique_ids = [int(pg_id) for pg_id in set(promo_group_ids)]
if not unique_ids:
raise ValueError("Нужно выбрать хотя бы одну промогруппу")
server = await get_server_squad_by_id(db, server_id)
if not server:
return None
result = await db.execute(
select(PromoGroup).where(PromoGroup.id.in_(unique_ids))
)
promo_groups = result.scalars().all()
if not promo_groups:
raise ValueError("Не найдены промогруппы для обновления сервера")
server.allowed_promo_groups = promo_groups
await db.commit()
await db.refresh(server)
logger.info(
"Обновлены промогруппы сервера %s (ID: %s): %s",
server.display_name,
server.id,
", ".join(pg.name for pg in promo_groups),
)
return server
async def update_server_squad(
db: AsyncSession,
server_id: int,
**updates
) -> Optional[ServerSquad]:
valid_fields = {
"display_name",
"original_name",
"country_code",
"price_kopeks",
"description",
"max_users",
"is_available",
"sort_order",
"is_trial_eligible",
"category_id",
}
filtered_updates = {k: v for k, v in updates.items() if k in valid_fields}
if "category_id" in filtered_updates:
category_id = filtered_updates["category_id"]
if category_id is not None:
category_exists = await db.execute(
select(ServerCategory.id).where(ServerCategory.id == category_id)
)
if category_exists.scalar_one_or_none() is None:
raise ValueError("Server category not found")
if not filtered_updates:
return None
await db.execute(
update(ServerSquad)
.where(ServerSquad.id == server_id)
.values(**filtered_updates)
)
await db.commit()
return await get_server_squad_by_id(db, server_id)
async def delete_server_squad(db: AsyncSession, server_id: int) -> bool:
connections_result = await db.execute(
select(func.count(SubscriptionServer.id))
.where(SubscriptionServer.server_squad_id == server_id)
)
connections_count = connections_result.scalar()
if connections_count > 0:
logger.warning(f"⚠ Нельзя удалить сервер {server_id}: есть активные подключения ({connections_count})")
return False
await db.execute(
delete(ServerSquad).where(ServerSquad.id == server_id)
)
await db.commit()
logger.info(f"🗑️ Удален сервер (ID: {server_id})")
return True
async def sync_with_remnawave(
db: AsyncSession,
remnawave_squads: List[dict]
) -> Tuple[int, int, int]:
created = 0
updated = 0
removed = 0
existing_servers = {}
result = await db.execute(select(ServerSquad))
for server in result.scalars().all():
existing_servers[server.squad_uuid] = server
remnawave_uuids = {squad['uuid'] for squad in remnawave_squads}
for squad in remnawave_squads:
squad_uuid = squad['uuid']
original_name = squad.get('name', f'Squad {squad_uuid[:8]}')
if squad_uuid in existing_servers:
server = existing_servers[squad_uuid]
if server.original_name != original_name:
server.original_name = original_name
updated += 1
else:
await create_server_squad(
db=db,
squad_uuid=squad_uuid,
display_name=_generate_display_name(original_name),
original_name=original_name,
country_code=_extract_country_code(original_name),
price_kopeks=1000,
is_available=False
)
created += 1
removed_servers = [
server for uuid, server in existing_servers.items()
if uuid not in remnawave_uuids
]
if removed_servers:
removed_ids = [server.id for server in removed_servers]
removed_uuids = {server.squad_uuid for server in removed_servers}
subscription_ids_result = await db.execute(
select(SubscriptionServer.subscription_id)
.where(SubscriptionServer.server_squad_id.in_(removed_ids))
)
subscription_ids = {row[0] for row in subscription_ids_result.fetchall()}
for server in removed_servers:
logger.info(
"🗑️ Удаляется сервер %s (UUID: %s)",
server.display_name,
server.squad_uuid,
)
await db.execute(
delete(SubscriptionServer).where(SubscriptionServer.server_squad_id.in_(removed_ids))
)
subscriptions_to_update: dict[int, Subscription] = {}
if subscription_ids:
subscriptions_result = await db.execute(
select(Subscription).where(Subscription.id.in_(subscription_ids))
)
for subscription in subscriptions_result.scalars().unique().all():
subscriptions_to_update[subscription.id] = subscription
for squad_uuid in removed_uuids:
if not squad_uuid:
continue
extra_result = await db.execute(
select(Subscription).where(
text("connected_squads::text LIKE :uuid_pattern")
),
{"uuid_pattern": f'%"{squad_uuid}"%'}
)
for subscription in extra_result.scalars().unique().all():
subscriptions_to_update[subscription.id] = subscription
cleaned_subscriptions = 0
for subscription in subscriptions_to_update.values():
current_squads = list(subscription.connected_squads or [])
if not current_squads:
continue
filtered_squads = [
squad_uuid for squad_uuid in current_squads if squad_uuid not in removed_uuids
]
if len(filtered_squads) != len(current_squads):
subscription.connected_squads = filtered_squads
subscription.updated_at = datetime.utcnow()
cleaned_subscriptions += 1
await db.execute(delete(ServerSquad).where(ServerSquad.id.in_(removed_ids)))
removed = len(removed_servers)
if cleaned_subscriptions:
logger.info(
"🧹 Обновлены подписки после удаления серверов: %s",
cleaned_subscriptions,
)
await db.commit()
logger.info(f"🔄 Синхронизация завершена: +{created} ~{updated} -{removed}")
return created, updated, removed
async def get_server_connected_users(
db: AsyncSession,
server_id: int
) -> List[User]:
server_uuid_result = await db.execute(
select(ServerSquad.squad_uuid).where(ServerSquad.id == server_id)
)
server_uuid = server_uuid_result.scalar_one_or_none()
connection_filters = [SubscriptionServer.id.isnot(None)]
if server_uuid:
connection_filters.append(
cast(Subscription.connected_squads, String).like(
f'%"{server_uuid}"%'
)
)
result = await db.execute(
select(User)
.join(Subscription, Subscription.user_id == User.id)
.outerjoin(
SubscriptionServer,
and_(
SubscriptionServer.subscription_id == Subscription.id,
SubscriptionServer.server_squad_id == server_id,
),
)
.where(or_(*connection_filters))
.options(selectinload(User.subscription))
.order_by(User.id)
)
return result.scalars().unique().all()
async def get_trial_eligible_server_squads(
db: AsyncSession,
include_unavailable: bool = False,
) -> List[ServerSquad]:
query = select(ServerSquad).where(ServerSquad.is_trial_eligible.is_(True))
result = await db.execute(query)
squads = result.scalars().unique().all()
if include_unavailable:
return squads
preferred_squads: List[ServerSquad] = []
fallback_squads: List[ServerSquad] = []
for squad in squads:
current_users = squad.current_users or 0
is_full = squad.max_users is not None and current_users >= squad.max_users
if is_full:
continue
if squad.is_available:
preferred_squads.append(squad)
else:
fallback_squads.append(squad)
if preferred_squads:
return preferred_squads
if fallback_squads:
return fallback_squads
return squads
async def choose_random_trial_server_squad(
db: AsyncSession,
) -> Optional[ServerSquad]:
squads = await get_trial_eligible_server_squads(db)
if not squads:
return None
return random.choice(squads)
async def get_random_trial_squad_uuid(
db: AsyncSession,
) -> Optional[str]:
squad = await choose_random_trial_server_squad(db)
if squad:
return squad.squad_uuid
return None
def _generate_display_name(original_name: str) -> str:
country_names = {
'NL': '🇳🇱 Нидерланды',
'DE': '🇩🇪 Германия',
'US': '🇺🇸 США',
'FR': '🇫🇷 Франция',
'GB': '🇬🇧 Великобритания',
'IT': '🇮🇹 Италия',
'ES': '🇪🇸 Испания',
'CA': '🇨🇦 Канада',
'JP': '🇯🇵 Япония',
'SG': '🇸🇬 Сингапур',
'AU': '🇦🇺 Австралия',
}
name_upper = original_name.upper()
for code, display_name in country_names.items():
if code in name_upper:
return display_name
return f"🌍 {original_name}"
def _extract_country_code(original_name: str) -> Optional[str]:
codes = ['NL', 'DE', 'US', 'FR', 'GB', 'IT', 'ES', 'CA', 'JP', 'SG', 'AU']
name_upper = original_name.upper()
for code in codes:
if code in name_upper:
return code
return None
async def get_server_statistics(db: AsyncSession) -> dict:
total_result = await db.execute(select(func.count(ServerSquad.id)))
total_servers = total_result.scalar()
available_result = await db.execute(
select(func.count(ServerSquad.id))
.where(ServerSquad.is_available == True)
)
available_servers = available_result.scalar()
servers_with_connections = 0
all_servers_result = await db.execute(select(ServerSquad.squad_uuid))
all_server_uuids = [row[0] for row in all_servers_result.fetchall()]
for squad_uuid in all_server_uuids:
count_result = await db.execute(
text("""
SELECT COUNT(s.id)
FROM subscriptions s
WHERE s.status IN ('active', 'trial')
AND s.connected_squads::text LIKE :uuid_pattern
"""),
{"uuid_pattern": f'%"{squad_uuid}"%'}
)
user_count = count_result.scalar() or 0
if user_count > 0:
servers_with_connections += 1
revenue_result = await db.execute(
select(func.coalesce(func.sum(SubscriptionServer.paid_price_kopeks), 0))
)
total_revenue_kopeks = revenue_result.scalar()
return {
'total_servers': total_servers,
'available_servers': available_servers,
'unavailable_servers': total_servers - available_servers,
'servers_with_connections': servers_with_connections,
'total_revenue_kopeks': total_revenue_kopeks,
'total_revenue_rubles': total_revenue_kopeks / 100
}
async def count_active_users_for_squad(db: AsyncSession, squad_uuid: str) -> int:
"""Возвращает количество активных подписок, подключенных к указанному скваду."""
result = await db.execute(
select(func.count(Subscription.id)).where(
Subscription.status.in_(
[
SubscriptionStatus.ACTIVE.value,
SubscriptionStatus.TRIAL.value,
]
),
cast(Subscription.connected_squads, String).like(f'%"{squad_uuid}"%'),
)
)
return result.scalar() or 0
async def add_user_to_servers(
db: AsyncSession,
server_squad_ids: List[int]
) -> bool:
try:
for server_id in server_squad_ids:
await db.execute(
update(ServerSquad)
.where(ServerSquad.id == server_id)
.values(current_users=ServerSquad.current_users + 1)
)
await db.commit()
logger.info(f"✅ Увеличен счетчик пользователей для серверов: {server_squad_ids}")
return True
except Exception as e:
logger.error(f"Ошибка увеличения счетчика пользователей: {e}")
await db.rollback()
return False
async def remove_user_from_servers(
db: AsyncSession,
server_squad_ids: List[int]
) -> bool:
try:
for server_id in server_squad_ids:
await db.execute(
update(ServerSquad)
.where(ServerSquad.id == server_id)
.values(current_users=func.greatest(ServerSquad.current_users - 1, 0))
)
await db.commit()
logger.info(f"✅ Уменьшен счетчик пользователей для серверов: {server_squad_ids}")
return True
except Exception as e:
logger.error(f"Ошибка уменьшения счетчика пользователей: {e}")
await db.rollback()
return False
async def get_server_ids_by_uuids(
db: AsyncSession,
squad_uuids: List[str]
) -> List[int]:
result = await db.execute(
select(ServerSquad.id)
.where(ServerSquad.squad_uuid.in_(squad_uuids))
)
return [row[0] for row in result.fetchall()]
async def sync_server_user_counts(db: AsyncSession) -> int:
try:
all_servers_result = await db.execute(select(ServerSquad.id, ServerSquad.squad_uuid))
all_servers = all_servers_result.fetchall()
logger.info(f"🔍 Найдено серверов для синхронизации: {len(all_servers)}")
updated_count = 0
for server_id, squad_uuid in all_servers:
count_result = await db.execute(
text("""
SELECT COUNT(s.id)
FROM subscriptions s
WHERE s.status IN ('active', 'trial')
AND s.connected_squads::text LIKE :uuid_pattern
"""),
{"uuid_pattern": f'%"{squad_uuid}"%'}
)
actual_users = count_result.scalar() or 0
logger.info(f"📊 Сервер {server_id} ({squad_uuid[:8]}): {actual_users} пользователей")
await db.execute(
update(ServerSquad)
.where(ServerSquad.id == server_id)
.values(current_users=actual_users)
)
updated_count += 1
await db.commit()
logger.info(f"✅ Синхронизированы счетчики для {updated_count} серверов")
return updated_count
except Exception as e:
logger.error(f"Ошибка синхронизации счетчиков пользователей: {e}")
await db.rollback()
return 0