mirror of
https://github.com/BEDOLAGA-DEV/remnawave-bedolaga-telegram-bot.git
synced 2026-01-20 03:40:26 +00:00
Preserve zero device limit when replacing trials
This commit is contained in:
@@ -186,6 +186,91 @@ async def create_paid_subscription(
|
||||
return subscription
|
||||
|
||||
|
||||
async def replace_subscription(
|
||||
db: AsyncSession,
|
||||
subscription: Subscription,
|
||||
*,
|
||||
duration_days: int,
|
||||
traffic_limit_gb: int,
|
||||
device_limit: int,
|
||||
connected_squads: List[str],
|
||||
is_trial: bool,
|
||||
autopay_enabled: Optional[bool] = None,
|
||||
autopay_days_before: Optional[int] = None,
|
||||
update_server_counters: bool = False,
|
||||
) -> Subscription:
|
||||
"""Перезаписывает параметры существующей подписки пользователя."""
|
||||
|
||||
current_time = datetime.utcnow()
|
||||
old_squads = set(subscription.connected_squads or [])
|
||||
new_squads = set(connected_squads or [])
|
||||
|
||||
new_autopay_enabled = (
|
||||
subscription.autopay_enabled
|
||||
if autopay_enabled is None
|
||||
else autopay_enabled
|
||||
)
|
||||
new_autopay_days_before = (
|
||||
subscription.autopay_days_before
|
||||
if autopay_days_before is None
|
||||
else autopay_days_before
|
||||
)
|
||||
|
||||
subscription.status = SubscriptionStatus.ACTIVE.value
|
||||
subscription.is_trial = is_trial
|
||||
subscription.start_date = current_time
|
||||
subscription.end_date = current_time + timedelta(days=duration_days)
|
||||
subscription.traffic_limit_gb = traffic_limit_gb
|
||||
subscription.traffic_used_gb = 0.0
|
||||
subscription.device_limit = device_limit
|
||||
subscription.connected_squads = list(new_squads)
|
||||
subscription.subscription_url = None
|
||||
subscription.subscription_crypto_link = None
|
||||
subscription.remnawave_short_uuid = None
|
||||
subscription.autopay_enabled = new_autopay_enabled
|
||||
subscription.autopay_days_before = new_autopay_days_before
|
||||
subscription.updated_at = current_time
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(subscription)
|
||||
|
||||
if update_server_counters:
|
||||
try:
|
||||
from app.database.crud.server_squad import (
|
||||
add_user_to_servers,
|
||||
get_server_ids_by_uuids,
|
||||
remove_user_from_servers,
|
||||
)
|
||||
|
||||
squads_to_remove = old_squads - new_squads
|
||||
squads_to_add = new_squads - old_squads
|
||||
|
||||
if squads_to_remove:
|
||||
server_ids = await get_server_ids_by_uuids(db, list(squads_to_remove))
|
||||
if server_ids:
|
||||
await remove_user_from_servers(db, sorted(server_ids))
|
||||
|
||||
if squads_to_add:
|
||||
server_ids = await get_server_ids_by_uuids(db, list(squads_to_add))
|
||||
if server_ids:
|
||||
await add_user_to_servers(db, sorted(server_ids))
|
||||
|
||||
logger.info(
|
||||
"♻️ Обновлены параметры подписки %s: удалено сквадов %s, добавлено %s",
|
||||
subscription.id,
|
||||
len(squads_to_remove),
|
||||
len(squads_to_add),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"⚠️ Ошибка обновления счетчиков серверов при замене подписки %s: %s",
|
||||
subscription.id,
|
||||
error,
|
||||
)
|
||||
|
||||
return subscription
|
||||
|
||||
|
||||
async def extend_subscription(
|
||||
db: AsyncSession,
|
||||
subscription: Subscription,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Security, status
|
||||
@@ -8,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import settings
|
||||
from app.database.crud.server_squad import get_random_trial_squad_uuid
|
||||
from app.database.crud.subscription import (
|
||||
add_subscription_devices,
|
||||
add_subscription_squad,
|
||||
@@ -16,6 +18,7 @@ from app.database.crud.subscription import (
|
||||
create_trial_subscription,
|
||||
extend_subscription,
|
||||
get_subscription_by_user_id,
|
||||
replace_subscription,
|
||||
remove_subscription_squad,
|
||||
)
|
||||
from app.database.models import Subscription, SubscriptionStatus
|
||||
@@ -30,6 +33,8 @@ from ..schemas.subscriptions import (
|
||||
SubscriptionTrafficRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -55,6 +60,28 @@ def _serialize_subscription(subscription: Subscription) -> SubscriptionResponse:
|
||||
)
|
||||
|
||||
|
||||
async def _choose_trial_squads(
|
||||
db: AsyncSession, requested_squad_uuid: Optional[str], fallback_squads: list[str]
|
||||
) -> list[str]:
|
||||
if requested_squad_uuid:
|
||||
return [requested_squad_uuid]
|
||||
|
||||
if fallback_squads:
|
||||
return fallback_squads
|
||||
|
||||
try:
|
||||
squad_uuid = await get_random_trial_squad_uuid(db)
|
||||
except Exception as error:
|
||||
logger.error("Failed to select trial squad: %s", error)
|
||||
squad_uuid = None
|
||||
|
||||
if not squad_uuid:
|
||||
return []
|
||||
|
||||
logger.debug("Selected trial squad %s for subscription replacement", squad_uuid)
|
||||
return [squad_uuid]
|
||||
|
||||
|
||||
async def _get_subscription(db: AsyncSession, subscription_id: int) -> Subscription:
|
||||
result = await db.execute(
|
||||
select(Subscription)
|
||||
@@ -109,7 +136,7 @@ async def create_subscription(
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
) -> SubscriptionResponse:
|
||||
existing = await get_subscription_by_user_id(db, payload.user_id)
|
||||
if existing:
|
||||
if existing and not payload.replace_existing:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "User already has a subscription")
|
||||
|
||||
forced_devices = None
|
||||
@@ -120,15 +147,36 @@ async def create_subscription(
|
||||
trial_device_limit = payload.device_limit
|
||||
if trial_device_limit is None:
|
||||
trial_device_limit = forced_devices
|
||||
duration_days = payload.duration_days or settings.TRIAL_DURATION_DAYS
|
||||
traffic_limit_gb = payload.traffic_limit_gb or settings.TRIAL_TRAFFIC_LIMIT_GB
|
||||
|
||||
subscription = await create_trial_subscription(
|
||||
db,
|
||||
user_id=payload.user_id,
|
||||
duration_days=payload.duration_days,
|
||||
traffic_limit_gb=payload.traffic_limit_gb,
|
||||
device_limit=trial_device_limit,
|
||||
squad_uuid=payload.squad_uuid,
|
||||
)
|
||||
if existing:
|
||||
connected_squads = await _choose_trial_squads(
|
||||
db, payload.squad_uuid, list(existing.connected_squads or [])
|
||||
)
|
||||
subscription = await replace_subscription(
|
||||
db,
|
||||
existing,
|
||||
duration_days=duration_days,
|
||||
traffic_limit_gb=traffic_limit_gb,
|
||||
device_limit=(
|
||||
trial_device_limit
|
||||
if trial_device_limit is not None
|
||||
else settings.TRIAL_DEVICE_LIMIT
|
||||
),
|
||||
connected_squads=connected_squads,
|
||||
is_trial=True,
|
||||
update_server_counters=True,
|
||||
)
|
||||
else:
|
||||
subscription = await create_trial_subscription(
|
||||
db,
|
||||
user_id=payload.user_id,
|
||||
duration_days=duration_days,
|
||||
traffic_limit_gb=traffic_limit_gb,
|
||||
device_limit=trial_device_limit,
|
||||
squad_uuid=payload.squad_uuid,
|
||||
)
|
||||
else:
|
||||
if payload.duration_days is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "duration_days is required for paid subscriptions")
|
||||
@@ -138,15 +186,27 @@ async def create_subscription(
|
||||
device_limit = forced_devices
|
||||
else:
|
||||
device_limit = settings.DEFAULT_DEVICE_LIMIT
|
||||
subscription = await create_paid_subscription(
|
||||
db,
|
||||
user_id=payload.user_id,
|
||||
duration_days=payload.duration_days,
|
||||
traffic_limit_gb=payload.traffic_limit_gb or settings.DEFAULT_TRAFFIC_LIMIT_GB,
|
||||
device_limit=device_limit,
|
||||
connected_squads=payload.connected_squads or [],
|
||||
update_server_counters=True,
|
||||
)
|
||||
if existing:
|
||||
subscription = await replace_subscription(
|
||||
db,
|
||||
existing,
|
||||
duration_days=payload.duration_days,
|
||||
traffic_limit_gb=payload.traffic_limit_gb or settings.DEFAULT_TRAFFIC_LIMIT_GB,
|
||||
device_limit=device_limit,
|
||||
connected_squads=payload.connected_squads or [],
|
||||
is_trial=False,
|
||||
update_server_counters=True,
|
||||
)
|
||||
else:
|
||||
subscription = await create_paid_subscription(
|
||||
db,
|
||||
user_id=payload.user_id,
|
||||
duration_days=payload.duration_days,
|
||||
traffic_limit_gb=payload.traffic_limit_gb or settings.DEFAULT_TRAFFIC_LIMIT_GB,
|
||||
device_limit=device_limit,
|
||||
connected_squads=payload.connected_squads or [],
|
||||
update_server_counters=True,
|
||||
)
|
||||
|
||||
subscription = await _get_subscription(db, subscription.id)
|
||||
return _serialize_subscription(subscription)
|
||||
|
||||
@@ -34,6 +34,7 @@ class SubscriptionCreateRequest(BaseModel):
|
||||
device_limit: Optional[int] = None
|
||||
squad_uuid: Optional[str] = None
|
||||
connected_squads: Optional[List[str]] = None
|
||||
replace_existing: bool = False
|
||||
|
||||
|
||||
class SubscriptionExtendRequest(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user