mirror of
https://github.com/BEDOLAGA-DEV/remnawave-bedolaga-telegram-bot.git
synced 2026-02-16 00:50:31 +00:00
Preserve zero device limit when replacing trials
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Security, status
|
||||
@@ -8,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import settings
|
||||
from app.database.crud.server_squad import get_random_trial_squad_uuid
|
||||
from app.database.crud.subscription import (
|
||||
add_subscription_devices,
|
||||
add_subscription_squad,
|
||||
@@ -16,6 +18,7 @@ from app.database.crud.subscription import (
|
||||
create_trial_subscription,
|
||||
extend_subscription,
|
||||
get_subscription_by_user_id,
|
||||
replace_subscription,
|
||||
remove_subscription_squad,
|
||||
)
|
||||
from app.database.models import Subscription, SubscriptionStatus
|
||||
@@ -30,6 +33,8 @@ from ..schemas.subscriptions import (
|
||||
SubscriptionTrafficRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -55,6 +60,28 @@ def _serialize_subscription(subscription: Subscription) -> SubscriptionResponse:
|
||||
)
|
||||
|
||||
|
||||
async def _choose_trial_squads(
|
||||
db: AsyncSession, requested_squad_uuid: Optional[str], fallback_squads: list[str]
|
||||
) -> list[str]:
|
||||
if requested_squad_uuid:
|
||||
return [requested_squad_uuid]
|
||||
|
||||
if fallback_squads:
|
||||
return fallback_squads
|
||||
|
||||
try:
|
||||
squad_uuid = await get_random_trial_squad_uuid(db)
|
||||
except Exception as error:
|
||||
logger.error("Failed to select trial squad: %s", error)
|
||||
squad_uuid = None
|
||||
|
||||
if not squad_uuid:
|
||||
return []
|
||||
|
||||
logger.debug("Selected trial squad %s for subscription replacement", squad_uuid)
|
||||
return [squad_uuid]
|
||||
|
||||
|
||||
async def _get_subscription(db: AsyncSession, subscription_id: int) -> Subscription:
|
||||
result = await db.execute(
|
||||
select(Subscription)
|
||||
@@ -109,7 +136,7 @@ async def create_subscription(
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
) -> SubscriptionResponse:
|
||||
existing = await get_subscription_by_user_id(db, payload.user_id)
|
||||
if existing:
|
||||
if existing and not payload.replace_existing:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "User already has a subscription")
|
||||
|
||||
forced_devices = None
|
||||
@@ -120,15 +147,36 @@ async def create_subscription(
|
||||
trial_device_limit = payload.device_limit
|
||||
if trial_device_limit is None:
|
||||
trial_device_limit = forced_devices
|
||||
duration_days = payload.duration_days or settings.TRIAL_DURATION_DAYS
|
||||
traffic_limit_gb = payload.traffic_limit_gb or settings.TRIAL_TRAFFIC_LIMIT_GB
|
||||
|
||||
subscription = await create_trial_subscription(
|
||||
db,
|
||||
user_id=payload.user_id,
|
||||
duration_days=payload.duration_days,
|
||||
traffic_limit_gb=payload.traffic_limit_gb,
|
||||
device_limit=trial_device_limit,
|
||||
squad_uuid=payload.squad_uuid,
|
||||
)
|
||||
if existing:
|
||||
connected_squads = await _choose_trial_squads(
|
||||
db, payload.squad_uuid, list(existing.connected_squads or [])
|
||||
)
|
||||
subscription = await replace_subscription(
|
||||
db,
|
||||
existing,
|
||||
duration_days=duration_days,
|
||||
traffic_limit_gb=traffic_limit_gb,
|
||||
device_limit=(
|
||||
trial_device_limit
|
||||
if trial_device_limit is not None
|
||||
else settings.TRIAL_DEVICE_LIMIT
|
||||
),
|
||||
connected_squads=connected_squads,
|
||||
is_trial=True,
|
||||
update_server_counters=True,
|
||||
)
|
||||
else:
|
||||
subscription = await create_trial_subscription(
|
||||
db,
|
||||
user_id=payload.user_id,
|
||||
duration_days=duration_days,
|
||||
traffic_limit_gb=traffic_limit_gb,
|
||||
device_limit=trial_device_limit,
|
||||
squad_uuid=payload.squad_uuid,
|
||||
)
|
||||
else:
|
||||
if payload.duration_days is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "duration_days is required for paid subscriptions")
|
||||
@@ -138,15 +186,27 @@ async def create_subscription(
|
||||
device_limit = forced_devices
|
||||
else:
|
||||
device_limit = settings.DEFAULT_DEVICE_LIMIT
|
||||
subscription = await create_paid_subscription(
|
||||
db,
|
||||
user_id=payload.user_id,
|
||||
duration_days=payload.duration_days,
|
||||
traffic_limit_gb=payload.traffic_limit_gb or settings.DEFAULT_TRAFFIC_LIMIT_GB,
|
||||
device_limit=device_limit,
|
||||
connected_squads=payload.connected_squads or [],
|
||||
update_server_counters=True,
|
||||
)
|
||||
if existing:
|
||||
subscription = await replace_subscription(
|
||||
db,
|
||||
existing,
|
||||
duration_days=payload.duration_days,
|
||||
traffic_limit_gb=payload.traffic_limit_gb or settings.DEFAULT_TRAFFIC_LIMIT_GB,
|
||||
device_limit=device_limit,
|
||||
connected_squads=payload.connected_squads or [],
|
||||
is_trial=False,
|
||||
update_server_counters=True,
|
||||
)
|
||||
else:
|
||||
subscription = await create_paid_subscription(
|
||||
db,
|
||||
user_id=payload.user_id,
|
||||
duration_days=payload.duration_days,
|
||||
traffic_limit_gb=payload.traffic_limit_gb or settings.DEFAULT_TRAFFIC_LIMIT_GB,
|
||||
device_limit=device_limit,
|
||||
connected_squads=payload.connected_squads or [],
|
||||
update_server_counters=True,
|
||||
)
|
||||
|
||||
subscription = await _get_subscription(db, subscription.id)
|
||||
return _serialize_subscription(subscription)
|
||||
|
||||
Reference in New Issue
Block a user