Files
remnawave-bedolaga-telegram…/app/database/crud/tariff.py
2026-01-16 08:35:26 +03:00

626 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from typing import Dict, List, Optional
from sqlalchemy import func, select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.database.models import Tariff, Subscription, PromoGroup, tariff_promo_groups
logger = logging.getLogger(__name__)
def _normalize_period_prices(period_prices: Optional[Dict[int, int]]) -> Dict[str, int]:
"""Нормализует цены периодов в формат {str: int}."""
if not period_prices:
return {}
normalized: Dict[str, int] = {}
for key, value in period_prices.items():
try:
period = int(key)
price = int(value)
except (TypeError, ValueError):
continue
if period > 0 and price >= 0:
normalized[str(period)] = price
return normalized
async def get_all_tariffs(
db: AsyncSession,
*,
include_inactive: bool = False,
offset: int = 0,
limit: Optional[int] = None,
) -> List[Tariff]:
"""Получает все тарифы с опциональной фильтрацией по активности."""
query = select(Tariff).options(selectinload(Tariff.allowed_promo_groups))
if not include_inactive:
query = query.where(Tariff.is_active.is_(True))
query = query.order_by(Tariff.display_order, Tariff.id)
if offset:
query = query.offset(offset)
if limit is not None:
query = query.limit(limit)
result = await db.execute(query)
return result.scalars().all()
async def get_tariff_by_id(
db: AsyncSession,
tariff_id: int,
*,
with_promo_groups: bool = True,
) -> Optional[Tariff]:
"""Получает тариф по ID."""
query = select(Tariff).where(Tariff.id == tariff_id)
if with_promo_groups:
query = query.options(selectinload(Tariff.allowed_promo_groups))
result = await db.execute(query)
return result.scalars().first()
async def count_tariffs(db: AsyncSession, *, include_inactive: bool = False) -> int:
"""Подсчитывает количество тарифов."""
query = select(func.count(Tariff.id))
if not include_inactive:
query = query.where(Tariff.is_active.is_(True))
result = await db.execute(query)
return int(result.scalar_one())
async def get_trial_tariff(db: AsyncSession) -> Optional[Tariff]:
"""Получает тариф, доступный для триала (is_trial_available=True)."""
query = (
select(Tariff)
.where(Tariff.is_trial_available.is_(True))
.where(Tariff.is_active.is_(True))
.options(selectinload(Tariff.allowed_promo_groups))
.limit(1)
)
result = await db.execute(query)
return result.scalars().first()
async def set_trial_tariff(db: AsyncSession, tariff_id: int) -> Optional[Tariff]:
"""Устанавливает тариф как триальный (снимает флаг с других тарифов)."""
# Снимаем флаг с всех тарифов
await db.execute(
Tariff.__table__.update().values(is_trial_available=False)
)
# Устанавливаем флаг на выбранный тариф
tariff = await get_tariff_by_id(db, tariff_id)
if tariff:
tariff.is_trial_available = True
await db.commit()
await db.refresh(tariff)
return tariff
async def clear_trial_tariff(db: AsyncSession) -> None:
"""Снимает флаг триала со всех тарифов."""
await db.execute(
Tariff.__table__.update().values(is_trial_available=False)
)
await db.commit()
async def get_tariffs_for_user(
db: AsyncSession,
promo_group_id: Optional[int] = None,
) -> List[Tariff]:
"""
Получает тарифы, доступные для пользователя с учетом его промогруппы.
Если у тарифа нет ограничений по промогруппам - он доступен всем.
"""
query = (
select(Tariff)
.options(selectinload(Tariff.allowed_promo_groups))
.where(Tariff.is_active.is_(True))
.order_by(Tariff.display_order, Tariff.id)
)
result = await db.execute(query)
tariffs = result.scalars().all()
# Фильтруем по промогруппе
available_tariffs = []
for tariff in tariffs:
if not tariff.allowed_promo_groups:
# Нет ограничений - доступен всем
available_tariffs.append(tariff)
elif promo_group_id is not None:
# Проверяем, есть ли промогруппа пользователя в списке разрешенных
if any(pg.id == promo_group_id for pg in tariff.allowed_promo_groups):
available_tariffs.append(tariff)
# else: пользователь без промогруппы, а у тарифа есть ограничения - пропускаем
return available_tariffs
async def create_tariff(
db: AsyncSession,
name: str,
*,
description: Optional[str] = None,
display_order: int = 0,
is_active: bool = True,
traffic_limit_gb: int = 100,
device_limit: int = 1,
device_price_kopeks: Optional[int] = None,
max_device_limit: Optional[int] = None,
allowed_squads: Optional[List[str]] = None,
server_traffic_limits: Optional[Dict[str, dict]] = None,
period_prices: Optional[Dict[int, int]] = None,
tier_level: int = 1,
is_trial_available: bool = False,
allow_traffic_topup: bool = True,
promo_group_ids: Optional[List[int]] = None,
traffic_topup_enabled: bool = False,
traffic_topup_packages: Optional[Dict[str, int]] = None,
max_topup_traffic_gb: int = 0,
is_daily: bool = False,
daily_price_kopeks: int = 0,
# Произвольное количество дней
custom_days_enabled: bool = False,
price_per_day_kopeks: int = 0,
min_days: int = 1,
max_days: int = 365,
# Произвольный трафик при покупке
custom_traffic_enabled: bool = False,
traffic_price_per_gb_kopeks: int = 0,
min_traffic_gb: int = 1,
max_traffic_gb: int = 1000,
# Режим сброса трафика
traffic_reset_mode: Optional[str] = None, # DAY, WEEK, MONTH, NO_RESET, None = глобальная настройка
) -> Tariff:
"""Создает новый тариф."""
normalized_prices = _normalize_period_prices(period_prices)
tariff = Tariff(
name=name.strip(),
description=description.strip() if description else None,
display_order=max(0, display_order),
is_active=is_active,
traffic_limit_gb=max(0, traffic_limit_gb),
device_limit=max(1, device_limit),
device_price_kopeks=device_price_kopeks,
max_device_limit=max_device_limit,
allowed_squads=allowed_squads or [],
server_traffic_limits=server_traffic_limits or {},
period_prices=normalized_prices,
tier_level=max(1, tier_level),
is_trial_available=is_trial_available,
allow_traffic_topup=allow_traffic_topup,
traffic_topup_enabled=traffic_topup_enabled,
traffic_topup_packages=traffic_topup_packages or {},
max_topup_traffic_gb=max(0, max_topup_traffic_gb),
is_daily=is_daily,
daily_price_kopeks=max(0, daily_price_kopeks),
# Произвольное количество дней
custom_days_enabled=custom_days_enabled,
price_per_day_kopeks=max(0, price_per_day_kopeks),
min_days=max(1, min_days),
max_days=max(1, max_days),
# Произвольный трафик при покупке
custom_traffic_enabled=custom_traffic_enabled,
traffic_price_per_gb_kopeks=max(0, traffic_price_per_gb_kopeks),
min_traffic_gb=max(1, min_traffic_gb),
max_traffic_gb=max(1, max_traffic_gb),
# Режим сброса трафика
traffic_reset_mode=traffic_reset_mode,
)
db.add(tariff)
await db.flush()
# Добавляем промогруппы если указаны
if promo_group_ids:
promo_groups_result = await db.execute(
select(PromoGroup).where(PromoGroup.id.in_(promo_group_ids))
)
promo_groups = promo_groups_result.scalars().all()
tariff.allowed_promo_groups = list(promo_groups)
await db.commit()
await db.refresh(tariff)
logger.info(
"Создан тариф '%s' (id=%s, tier=%s, traffic=%sGB, devices=%s, prices=%s)",
tariff.name,
tariff.id,
tariff.tier_level,
tariff.traffic_limit_gb,
tariff.device_limit,
normalized_prices,
)
return tariff
async def update_tariff(
db: AsyncSession,
tariff: Tariff,
*,
name: Optional[str] = None,
description: Optional[str] = None,
display_order: Optional[int] = None,
is_active: Optional[bool] = None,
traffic_limit_gb: Optional[int] = None,
device_limit: Optional[int] = None,
device_price_kopeks: Optional[int] = ..., # ... = не передан, None = сбросить
max_device_limit: Optional[int] = ..., # ... = не передан, None = сбросить (без лимита)
allowed_squads: Optional[List[str]] = None,
server_traffic_limits: Optional[Dict[str, dict]] = None,
period_prices: Optional[Dict[int, int]] = None,
tier_level: Optional[int] = None,
is_trial_available: Optional[bool] = None,
allow_traffic_topup: Optional[bool] = None,
promo_group_ids: Optional[List[int]] = None,
traffic_topup_enabled: Optional[bool] = None,
traffic_topup_packages: Optional[Dict[str, int]] = None,
max_topup_traffic_gb: Optional[int] = None,
is_daily: Optional[bool] = None,
daily_price_kopeks: Optional[int] = None,
# Произвольное количество дней
custom_days_enabled: Optional[bool] = None,
price_per_day_kopeks: Optional[int] = None,
min_days: Optional[int] = None,
max_days: Optional[int] = None,
# Произвольный трафик при покупке
custom_traffic_enabled: Optional[bool] = None,
traffic_price_per_gb_kopeks: Optional[int] = None,
min_traffic_gb: Optional[int] = None,
max_traffic_gb: Optional[int] = None,
# Режим сброса трафика
traffic_reset_mode: Optional[str] = ..., # ... = не передан, None = сбросить к глобальной настройке
) -> Tariff:
"""Обновляет существующий тариф."""
if name is not None:
tariff.name = name.strip()
if description is not None:
tariff.description = description.strip() if description else None
if display_order is not None:
tariff.display_order = max(0, display_order)
if is_active is not None:
tariff.is_active = is_active
if traffic_limit_gb is not None:
tariff.traffic_limit_gb = max(0, traffic_limit_gb)
if device_limit is not None:
tariff.device_limit = max(1, device_limit)
if device_price_kopeks is not ...:
# Если передан device_price_kopeks (включая None) - обновляем
tariff.device_price_kopeks = device_price_kopeks
if max_device_limit is not ...:
# Если передан max_device_limit (включая None) - обновляем
tariff.max_device_limit = max_device_limit
if allowed_squads is not None:
tariff.allowed_squads = allowed_squads
if server_traffic_limits is not None:
tariff.server_traffic_limits = server_traffic_limits
if allow_traffic_topup is not None:
tariff.allow_traffic_topup = allow_traffic_topup
if period_prices is not None:
tariff.period_prices = _normalize_period_prices(period_prices)
if tier_level is not None:
tariff.tier_level = max(1, tier_level)
if is_trial_available is not None:
tariff.is_trial_available = is_trial_available
if traffic_topup_enabled is not None:
tariff.traffic_topup_enabled = traffic_topup_enabled
if traffic_topup_packages is not None:
tariff.traffic_topup_packages = traffic_topup_packages
if max_topup_traffic_gb is not None:
tariff.max_topup_traffic_gb = max(0, max_topup_traffic_gb)
if is_daily is not None:
tariff.is_daily = is_daily
if daily_price_kopeks is not None:
tariff.daily_price_kopeks = max(0, daily_price_kopeks)
# Произвольное количество дней
if custom_days_enabled is not None:
tariff.custom_days_enabled = custom_days_enabled
if price_per_day_kopeks is not None:
tariff.price_per_day_kopeks = max(0, price_per_day_kopeks)
if min_days is not None:
tariff.min_days = max(1, min_days)
if max_days is not None:
tariff.max_days = max(1, max_days)
# Произвольный трафик при покупке
if custom_traffic_enabled is not None:
tariff.custom_traffic_enabled = custom_traffic_enabled
if traffic_price_per_gb_kopeks is not None:
tariff.traffic_price_per_gb_kopeks = max(0, traffic_price_per_gb_kopeks)
if min_traffic_gb is not None:
tariff.min_traffic_gb = max(1, min_traffic_gb)
if max_traffic_gb is not None:
tariff.max_traffic_gb = max(1, max_traffic_gb)
# Режим сброса трафика
if traffic_reset_mode is not ...:
tariff.traffic_reset_mode = traffic_reset_mode
# Обновляем промогруппы если указаны
if promo_group_ids is not None:
if promo_group_ids:
promo_groups_result = await db.execute(
select(PromoGroup).where(PromoGroup.id.in_(promo_group_ids))
)
promo_groups = promo_groups_result.scalars().all()
tariff.allowed_promo_groups = list(promo_groups)
else:
tariff.allowed_promo_groups = []
await db.commit()
await db.refresh(tariff)
logger.info(
"Обновлен тариф '%s' (id=%s)",
tariff.name,
tariff.id,
)
return tariff
async def delete_tariff(db: AsyncSession, tariff: Tariff) -> bool:
"""
Удаляет тариф.
Подписки с этим тарифом получат tariff_id = NULL.
"""
tariff_id = tariff.id
tariff_name = tariff.name
# Подсчитываем подписки с этим тарифом
subscriptions_count = await db.execute(
select(func.count(Subscription.id)).where(Subscription.tariff_id == tariff_id)
)
affected_subscriptions = subscriptions_count.scalar_one()
# Удаляем тариф (FK с ondelete=SET NULL автоматически обнулит tariff_id в подписках)
await db.delete(tariff)
await db.commit()
logger.info(
"Удален тариф '%s' (id=%s), затронуто подписок: %s",
tariff_name,
tariff_id,
affected_subscriptions,
)
return True
async def get_tariff_subscriptions_count(db: AsyncSession, tariff_id: int) -> int:
"""Подсчитывает количество подписок на тарифе."""
result = await db.execute(
select(func.count(Subscription.id)).where(Subscription.tariff_id == tariff_id)
)
return int(result.scalar_one())
async def set_tariff_promo_groups(
db: AsyncSession,
tariff: Tariff,
promo_group_ids: List[int],
) -> Tariff:
"""Устанавливает промогруппы для тарифа."""
if promo_group_ids:
promo_groups_result = await db.execute(
select(PromoGroup).where(PromoGroup.id.in_(promo_group_ids))
)
promo_groups = promo_groups_result.scalars().all()
tariff.allowed_promo_groups = list(promo_groups)
else:
tariff.allowed_promo_groups = []
await db.commit()
await db.refresh(tariff)
return tariff
async def add_promo_group_to_tariff(
db: AsyncSession,
tariff: Tariff,
promo_group_id: int,
) -> bool:
"""Добавляет промогруппу к тарифу."""
promo_group = await db.get(PromoGroup, promo_group_id)
if not promo_group:
return False
if promo_group not in tariff.allowed_promo_groups:
tariff.allowed_promo_groups.append(promo_group)
await db.commit()
return True
async def remove_promo_group_from_tariff(
db: AsyncSession,
tariff: Tariff,
promo_group_id: int,
) -> bool:
"""Удаляет промогруппу из тарифа."""
for pg in tariff.allowed_promo_groups:
if pg.id == promo_group_id:
tariff.allowed_promo_groups.remove(pg)
await db.commit()
return True
return False
async def get_tariffs_with_subscriptions_count(
db: AsyncSession,
*,
include_inactive: bool = False,
) -> List[tuple]:
"""Получает тарифы с количеством подписок."""
query = (
select(Tariff, func.count(Subscription.id))
.outerjoin(Subscription, Subscription.tariff_id == Tariff.id)
.group_by(Tariff.id)
.order_by(Tariff.display_order, Tariff.id)
)
if not include_inactive:
query = query.where(Tariff.is_active.is_(True))
result = await db.execute(query)
return result.all()
async def reorder_tariffs(
db: AsyncSession,
tariff_order: List[int],
) -> None:
"""Изменяет порядок отображения тарифов."""
for order, tariff_id in enumerate(tariff_order):
await db.execute(
update(Tariff)
.where(Tariff.id == tariff_id)
.values(display_order=order)
)
await db.commit()
logger.info("Изменен порядок тарифов: %s", tariff_order)
async def sync_default_tariff_from_config(db: AsyncSession) -> Optional[Tariff]:
"""
Синхронизирует дефолтный тариф из конфига (.env) в БД.
Создаёт тариф "Стандартный" если в БД нет тарифов.
Обновляет цены существующего тарифа если он есть.
Returns:
Tariff или None если не требуется синхронизация
"""
from app.config import settings, PERIOD_PRICES
# Проверяем есть ли тарифы в БД
result = await db.execute(select(func.count(Tariff.id)))
tariff_count = result.scalar() or 0
# Собираем цены из конфига
period_prices = {}
for period, price in PERIOD_PRICES.items():
if price > 0:
period_prices[str(period)] = price
if not period_prices:
logger.warning("Нет цен в конфиге для создания дефолтного тарифа")
return None
# Ищем тариф с именем "Стандартный" или первый тариф
result = await db.execute(
select(Tariff).where(Tariff.name == "Стандартный").limit(1)
)
existing_tariff = result.scalar_one_or_none()
if existing_tariff:
# Обновляем цены существующего тарифа
existing_tariff.period_prices = period_prices
existing_tariff.traffic_limit_gb = settings.DEFAULT_TRAFFIC_LIMIT_GB
existing_tariff.device_limit = settings.DEFAULT_DEVICE_LIMIT
await db.commit()
await db.refresh(existing_tariff)
logger.info("Обновлён дефолтный тариф 'Стандартный' из конфига")
return existing_tariff
if tariff_count == 0:
# Создаём новый дефолтный тариф
new_tariff = Tariff(
name="Стандартный",
description="Базовый тарифный план",
is_active=True,
is_trial_available=True,
traffic_limit_gb=settings.DEFAULT_TRAFFIC_LIMIT_GB,
device_limit=settings.DEFAULT_DEVICE_LIMIT,
tier_level=1,
display_order=0,
period_prices=period_prices,
allowed_squads=[], # Все серверы по умолчанию
server_traffic_limits={},
)
db.add(new_tariff)
await db.commit()
await db.refresh(new_tariff)
logger.info("Создан дефолтный тариф 'Стандартный' из конфига: %s", period_prices)
return new_tariff
return None
async def load_period_prices_from_db(db: AsyncSession) -> None:
"""
Загружает периоды/цены из тарифа в PERIOD_PRICES.
Работает ТОЛЬКО в режиме tariffs. В режиме classic используются цены из .env.
"""
from app.config import set_period_prices_from_db, settings
# В режиме classic НЕ загружаем цены из тарифов - используем .env
if settings.is_classic_mode():
logger.info("Режим classic: цены периодов берутся из .env, тарифы игнорируются")
return
try:
# Ищем тариф "Стандартный" или первый активный тариф
result = await db.execute(
select(Tariff)
.where(Tariff.is_active.is_(True))
.order_by(Tariff.display_order, Tariff.id)
.limit(1)
)
tariff = result.scalar_one_or_none()
if tariff and tariff.period_prices:
# Преобразуем строковые ключи в int
period_prices = {
int(days): int(price)
for days, price in tariff.period_prices.items()
if int(price) > 0
}
if period_prices:
set_period_prices_from_db(period_prices)
logger.info(
"Загружены периоды из тарифа '%s': %s",
tariff.name,
{f"{d}д": f"{p//100}" for d, p in period_prices.items()}
)
else:
logger.warning("Тариф '%s' не имеет активных периодов", tariff.name)
else:
logger.info("Активные тарифы не найдены, используются цены из .env")
except Exception as e:
logger.error("Ошибка загрузки периодов из БД: %s", e)
async def ensure_tariffs_synced(db: AsyncSession) -> None:
"""
Проверяет и синхронизирует тарифы при запуске.
Вызывается при старте бота.
"""
try:
await sync_default_tariff_from_config(db)
# Загружаем периоды из БД в PERIOD_PRICES
await load_period_prices_from_db(db)
except Exception as e:
logger.error("Ошибка синхронизации тарифов: %s", e)