From 991e5a3112435a9bfcf90794de3dc4fde46064c6 Mon Sep 17 00:00:00 2001 From: Egor Date: Tue, 18 Nov 2025 01:14:43 +0300 Subject: [PATCH] Preserve zero device limit when replacing trials --- app/database/crud/subscription.py | 87 +++++++++++++++++++++++++- app/webapi/routes/subscriptions.py | 96 +++++++++++++++++++++++------ app/webapi/schemas/subscriptions.py | 1 + 3 files changed, 165 insertions(+), 19 deletions(-) diff --git a/app/database/crud/subscription.py b/app/database/crud/subscription.py index 9475bd30..7e1e92d5 100644 --- a/app/database/crud/subscription.py +++ b/app/database/crud/subscription.py @@ -129,7 +129,7 @@ async def create_paid_subscription( connected_squads: List[str] = None, update_server_counters: bool = False, ) -> Subscription: - + end_date = datetime.utcnow() + timedelta(days=duration_days) if device_limit is None: @@ -186,6 +186,91 @@ async def create_paid_subscription( return subscription +async def replace_subscription( + db: AsyncSession, + subscription: Subscription, + *, + duration_days: int, + traffic_limit_gb: int, + device_limit: int, + connected_squads: List[str], + is_trial: bool, + autopay_enabled: Optional[bool] = None, + autopay_days_before: Optional[int] = None, + update_server_counters: bool = False, +) -> Subscription: + """Перезаписывает параметры существующей подписки пользователя.""" + + current_time = datetime.utcnow() + old_squads = set(subscription.connected_squads or []) + new_squads = set(connected_squads or []) + + new_autopay_enabled = ( + subscription.autopay_enabled + if autopay_enabled is None + else autopay_enabled + ) + new_autopay_days_before = ( + subscription.autopay_days_before + if autopay_days_before is None + else autopay_days_before + ) + + subscription.status = SubscriptionStatus.ACTIVE.value + subscription.is_trial = is_trial + subscription.start_date = current_time + subscription.end_date = current_time + timedelta(days=duration_days) + subscription.traffic_limit_gb = traffic_limit_gb + subscription.traffic_used_gb = 0.0 + subscription.device_limit = device_limit + subscription.connected_squads = list(new_squads) + subscription.subscription_url = None + subscription.subscription_crypto_link = None + subscription.remnawave_short_uuid = None + subscription.autopay_enabled = new_autopay_enabled + subscription.autopay_days_before = new_autopay_days_before + subscription.updated_at = current_time + + await db.commit() + await db.refresh(subscription) + + if update_server_counters: + try: + from app.database.crud.server_squad import ( + add_user_to_servers, + get_server_ids_by_uuids, + remove_user_from_servers, + ) + + squads_to_remove = old_squads - new_squads + squads_to_add = new_squads - old_squads + + if squads_to_remove: + server_ids = await get_server_ids_by_uuids(db, list(squads_to_remove)) + if server_ids: + await remove_user_from_servers(db, sorted(server_ids)) + + if squads_to_add: + server_ids = await get_server_ids_by_uuids(db, list(squads_to_add)) + if server_ids: + await add_user_to_servers(db, sorted(server_ids)) + + logger.info( + "♻️ Обновлены параметры подписки %s: удалено сквадов %s, добавлено %s", + subscription.id, + len(squads_to_remove), + len(squads_to_add), + ) + except Exception as error: + logger.error( + "⚠️ Ошибка обновления счетчиков серверов при замене подписки %s: %s", + subscription.id, + error, + ) + + return subscription + + async def extend_subscription( db: AsyncSession, subscription: Subscription, diff --git a/app/webapi/routes/subscriptions.py b/app/webapi/routes/subscriptions.py index 3545f474..4cbd8346 100644 --- a/app/webapi/routes/subscriptions.py +++ b/app/webapi/routes/subscriptions.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from typing import Any, Optional from fastapi import APIRouter, Depends, HTTPException, Query, Security, status @@ -8,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.config import settings +from app.database.crud.server_squad import get_random_trial_squad_uuid from app.database.crud.subscription import ( add_subscription_devices, add_subscription_squad, @@ -16,6 +18,7 @@ from app.database.crud.subscription import ( create_trial_subscription, extend_subscription, get_subscription_by_user_id, + replace_subscription, remove_subscription_squad, ) from app.database.models import Subscription, SubscriptionStatus @@ -30,6 +33,8 @@ from ..schemas.subscriptions import ( SubscriptionTrafficRequest, ) +logger = logging.getLogger(__name__) + router = APIRouter() @@ -55,6 +60,28 @@ def _serialize_subscription(subscription: Subscription) -> SubscriptionResponse: ) +async def _choose_trial_squads( + db: AsyncSession, requested_squad_uuid: Optional[str], fallback_squads: list[str] +) -> list[str]: + if requested_squad_uuid: + return [requested_squad_uuid] + + if fallback_squads: + return fallback_squads + + try: + squad_uuid = await get_random_trial_squad_uuid(db) + except Exception as error: + logger.error("Failed to select trial squad: %s", error) + squad_uuid = None + + if not squad_uuid: + return [] + + logger.debug("Selected trial squad %s for subscription replacement", squad_uuid) + return [squad_uuid] + + async def _get_subscription(db: AsyncSession, subscription_id: int) -> Subscription: result = await db.execute( select(Subscription) @@ -109,7 +136,7 @@ async def create_subscription( db: AsyncSession = Depends(get_db_session), ) -> SubscriptionResponse: existing = await get_subscription_by_user_id(db, payload.user_id) - if existing: + if existing and not payload.replace_existing: raise HTTPException(status.HTTP_400_BAD_REQUEST, "User already has a subscription") forced_devices = None @@ -120,15 +147,36 @@ async def create_subscription( trial_device_limit = payload.device_limit if trial_device_limit is None: trial_device_limit = forced_devices + duration_days = payload.duration_days or settings.TRIAL_DURATION_DAYS + traffic_limit_gb = payload.traffic_limit_gb or settings.TRIAL_TRAFFIC_LIMIT_GB - subscription = await create_trial_subscription( - db, - user_id=payload.user_id, - duration_days=payload.duration_days, - traffic_limit_gb=payload.traffic_limit_gb, - device_limit=trial_device_limit, - squad_uuid=payload.squad_uuid, - ) + if existing: + connected_squads = await _choose_trial_squads( + db, payload.squad_uuid, list(existing.connected_squads or []) + ) + subscription = await replace_subscription( + db, + existing, + duration_days=duration_days, + traffic_limit_gb=traffic_limit_gb, + device_limit=( + trial_device_limit + if trial_device_limit is not None + else settings.TRIAL_DEVICE_LIMIT + ), + connected_squads=connected_squads, + is_trial=True, + update_server_counters=True, + ) + else: + subscription = await create_trial_subscription( + db, + user_id=payload.user_id, + duration_days=duration_days, + traffic_limit_gb=traffic_limit_gb, + device_limit=trial_device_limit, + squad_uuid=payload.squad_uuid, + ) else: if payload.duration_days is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "duration_days is required for paid subscriptions") @@ -138,15 +186,27 @@ async def create_subscription( device_limit = forced_devices else: device_limit = settings.DEFAULT_DEVICE_LIMIT - subscription = await create_paid_subscription( - db, - user_id=payload.user_id, - duration_days=payload.duration_days, - traffic_limit_gb=payload.traffic_limit_gb or settings.DEFAULT_TRAFFIC_LIMIT_GB, - device_limit=device_limit, - connected_squads=payload.connected_squads or [], - update_server_counters=True, - ) + if existing: + subscription = await replace_subscription( + db, + existing, + duration_days=payload.duration_days, + traffic_limit_gb=payload.traffic_limit_gb or settings.DEFAULT_TRAFFIC_LIMIT_GB, + device_limit=device_limit, + connected_squads=payload.connected_squads or [], + is_trial=False, + update_server_counters=True, + ) + else: + subscription = await create_paid_subscription( + db, + user_id=payload.user_id, + duration_days=payload.duration_days, + traffic_limit_gb=payload.traffic_limit_gb or settings.DEFAULT_TRAFFIC_LIMIT_GB, + device_limit=device_limit, + connected_squads=payload.connected_squads or [], + update_server_counters=True, + ) subscription = await _get_subscription(db, subscription.id) return _serialize_subscription(subscription) diff --git a/app/webapi/schemas/subscriptions.py b/app/webapi/schemas/subscriptions.py index f09b5405..12dd8708 100644 --- a/app/webapi/schemas/subscriptions.py +++ b/app/webapi/schemas/subscriptions.py @@ -34,6 +34,7 @@ class SubscriptionCreateRequest(BaseModel): device_limit: Optional[int] = None squad_uuid: Optional[str] = None connected_squads: Optional[List[str]] = None + replace_existing: bool = False class SubscriptionExtendRequest(BaseModel):