Preserve zero device limit when replacing trials

This commit is contained in:
Egor
2025-11-18 01:14:43 +03:00
parent 562bb69082
commit 991e5a3112
3 changed files with 165 additions and 19 deletions

View File

@@ -129,7 +129,7 @@ async def create_paid_subscription(
connected_squads: List[str] = None,
update_server_counters: bool = False,
) -> Subscription:
end_date = datetime.utcnow() + timedelta(days=duration_days)
if device_limit is None:
@@ -186,6 +186,91 @@ async def create_paid_subscription(
return subscription
async def replace_subscription(
db: AsyncSession,
subscription: Subscription,
*,
duration_days: int,
traffic_limit_gb: int,
device_limit: int,
connected_squads: List[str],
is_trial: bool,
autopay_enabled: Optional[bool] = None,
autopay_days_before: Optional[int] = None,
update_server_counters: bool = False,
) -> Subscription:
"""Перезаписывает параметры существующей подписки пользователя."""
current_time = datetime.utcnow()
old_squads = set(subscription.connected_squads or [])
new_squads = set(connected_squads or [])
new_autopay_enabled = (
subscription.autopay_enabled
if autopay_enabled is None
else autopay_enabled
)
new_autopay_days_before = (
subscription.autopay_days_before
if autopay_days_before is None
else autopay_days_before
)
subscription.status = SubscriptionStatus.ACTIVE.value
subscription.is_trial = is_trial
subscription.start_date = current_time
subscription.end_date = current_time + timedelta(days=duration_days)
subscription.traffic_limit_gb = traffic_limit_gb
subscription.traffic_used_gb = 0.0
subscription.device_limit = device_limit
subscription.connected_squads = list(new_squads)
subscription.subscription_url = None
subscription.subscription_crypto_link = None
subscription.remnawave_short_uuid = None
subscription.autopay_enabled = new_autopay_enabled
subscription.autopay_days_before = new_autopay_days_before
subscription.updated_at = current_time
await db.commit()
await db.refresh(subscription)
if update_server_counters:
try:
from app.database.crud.server_squad import (
add_user_to_servers,
get_server_ids_by_uuids,
remove_user_from_servers,
)
squads_to_remove = old_squads - new_squads
squads_to_add = new_squads - old_squads
if squads_to_remove:
server_ids = await get_server_ids_by_uuids(db, list(squads_to_remove))
if server_ids:
await remove_user_from_servers(db, sorted(server_ids))
if squads_to_add:
server_ids = await get_server_ids_by_uuids(db, list(squads_to_add))
if server_ids:
await add_user_to_servers(db, sorted(server_ids))
logger.info(
"♻️ Обновлены параметры подписки %s: удалено сквадов %s, добавлено %s",
subscription.id,
len(squads_to_remove),
len(squads_to_add),
)
except Exception as error:
logger.error(
"⚠️ Ошибка обновления счетчиков серверов при замене подписки %s: %s",
subscription.id,
error,
)
return subscription
async def extend_subscription(
db: AsyncSession,
subscription: Subscription,

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import logging
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Security, status
@@ -8,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.config import settings
from app.database.crud.server_squad import get_random_trial_squad_uuid
from app.database.crud.subscription import (
add_subscription_devices,
add_subscription_squad,
@@ -16,6 +18,7 @@ from app.database.crud.subscription import (
create_trial_subscription,
extend_subscription,
get_subscription_by_user_id,
replace_subscription,
remove_subscription_squad,
)
from app.database.models import Subscription, SubscriptionStatus
@@ -30,6 +33,8 @@ from ..schemas.subscriptions import (
SubscriptionTrafficRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter()
@@ -55,6 +60,28 @@ def _serialize_subscription(subscription: Subscription) -> SubscriptionResponse:
)
async def _choose_trial_squads(
db: AsyncSession, requested_squad_uuid: Optional[str], fallback_squads: list[str]
) -> list[str]:
if requested_squad_uuid:
return [requested_squad_uuid]
if fallback_squads:
return fallback_squads
try:
squad_uuid = await get_random_trial_squad_uuid(db)
except Exception as error:
logger.error("Failed to select trial squad: %s", error)
squad_uuid = None
if not squad_uuid:
return []
logger.debug("Selected trial squad %s for subscription replacement", squad_uuid)
return [squad_uuid]
async def _get_subscription(db: AsyncSession, subscription_id: int) -> Subscription:
result = await db.execute(
select(Subscription)
@@ -109,7 +136,7 @@ async def create_subscription(
db: AsyncSession = Depends(get_db_session),
) -> SubscriptionResponse:
existing = await get_subscription_by_user_id(db, payload.user_id)
if existing:
if existing and not payload.replace_existing:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "User already has a subscription")
forced_devices = None
@@ -120,15 +147,36 @@ async def create_subscription(
trial_device_limit = payload.device_limit
if trial_device_limit is None:
trial_device_limit = forced_devices
duration_days = payload.duration_days or settings.TRIAL_DURATION_DAYS
traffic_limit_gb = payload.traffic_limit_gb or settings.TRIAL_TRAFFIC_LIMIT_GB
subscription = await create_trial_subscription(
db,
user_id=payload.user_id,
duration_days=payload.duration_days,
traffic_limit_gb=payload.traffic_limit_gb,
device_limit=trial_device_limit,
squad_uuid=payload.squad_uuid,
)
if existing:
connected_squads = await _choose_trial_squads(
db, payload.squad_uuid, list(existing.connected_squads or [])
)
subscription = await replace_subscription(
db,
existing,
duration_days=duration_days,
traffic_limit_gb=traffic_limit_gb,
device_limit=(
trial_device_limit
if trial_device_limit is not None
else settings.TRIAL_DEVICE_LIMIT
),
connected_squads=connected_squads,
is_trial=True,
update_server_counters=True,
)
else:
subscription = await create_trial_subscription(
db,
user_id=payload.user_id,
duration_days=duration_days,
traffic_limit_gb=traffic_limit_gb,
device_limit=trial_device_limit,
squad_uuid=payload.squad_uuid,
)
else:
if payload.duration_days is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "duration_days is required for paid subscriptions")
@@ -138,15 +186,27 @@ async def create_subscription(
device_limit = forced_devices
else:
device_limit = settings.DEFAULT_DEVICE_LIMIT
subscription = await create_paid_subscription(
db,
user_id=payload.user_id,
duration_days=payload.duration_days,
traffic_limit_gb=payload.traffic_limit_gb or settings.DEFAULT_TRAFFIC_LIMIT_GB,
device_limit=device_limit,
connected_squads=payload.connected_squads or [],
update_server_counters=True,
)
if existing:
subscription = await replace_subscription(
db,
existing,
duration_days=payload.duration_days,
traffic_limit_gb=payload.traffic_limit_gb or settings.DEFAULT_TRAFFIC_LIMIT_GB,
device_limit=device_limit,
connected_squads=payload.connected_squads or [],
is_trial=False,
update_server_counters=True,
)
else:
subscription = await create_paid_subscription(
db,
user_id=payload.user_id,
duration_days=payload.duration_days,
traffic_limit_gb=payload.traffic_limit_gb or settings.DEFAULT_TRAFFIC_LIMIT_GB,
device_limit=device_limit,
connected_squads=payload.connected_squads or [],
update_server_counters=True,
)
subscription = await _get_subscription(db, subscription.id)
return _serialize_subscription(subscription)

View File

@@ -34,6 +34,7 @@ class SubscriptionCreateRequest(BaseModel):
device_limit: Optional[int] = None
squad_uuid: Optional[str] = None
connected_squads: Optional[List[str]] = None
replace_existing: bool = False
class SubscriptionExtendRequest(BaseModel):