diff --git a/app/webapi/routes/miniapp.py b/app/webapi/routes/miniapp.py index 5ec70d63..8ac94690 100644 --- a/app/webapi/routes/miniapp.py +++ b/app/webapi/routes/miniapp.py @@ -6,7 +6,7 @@ import math from decimal import Decimal, InvalidOperation, ROUND_HALF_UP, ROUND_FLOOR from datetime import datetime, timedelta, timezone from uuid import uuid4 -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from aiogram import Bot from fastapi import APIRouter, Depends, HTTPException, status @@ -24,7 +24,17 @@ from app.database.crud.discount_offer import ( from app.database.crud.promo_group import get_auto_assign_promo_groups from app.database.crud.rules import get_rules_by_language from app.database.crud.promo_offer_template import get_promo_offer_template_by_id -from app.database.crud.server_squad import get_server_squad_by_uuid +from app.database.crud.server_squad import ( + add_user_to_servers, + get_available_server_squads, + get_server_squad_by_uuid, + remove_user_from_servers, +) +from app.database.crud.subscription import ( + add_subscription_servers, + get_subscription_servers, + remove_subscription_servers, +) from app.database.crud.transaction import get_user_total_spent_kopeks from app.database.crud.user import get_user_by_telegram_id from app.database.models import ( @@ -93,7 +103,13 @@ from ..schemas.miniapp import ( MiniAppReferralStats, MiniAppReferralTerms, MiniAppRichTextDocument, + MiniAppSubscriptionActionResponse, + MiniAppSubscriptionDevicesUpdateRequest, MiniAppSubscriptionRequest, + MiniAppSubscriptionServersUpdateRequest, + MiniAppSubscriptionSettingsRequest, + MiniAppSubscriptionSettingsResponse, + MiniAppSubscriptionTrafficUpdateRequest, MiniAppSubscriptionResponse, MiniAppSubscriptionUser, MiniAppTransaction, @@ -1958,6 +1974,206 @@ async def _build_referral_info( ) +def _normalize_currency(currency: Optional[str]) -> str: + if isinstance(currency, str) and currency.strip(): + return currency.strip().upper() + return "RUB" + + +def _apply_discount(amount: Optional[int], percent: Optional[int]) -> int: + if not amount or amount <= 0: + return 0 + normalized_percent = max(0, min(100, percent or 0)) + if normalized_percent == 0: + return amount + discounted = amount - (amount * normalized_percent) // 100 + return max(0, discounted) + + +async def _resolve_subscription_with_access( + db: AsyncSession, + init_data: str, + *, + subscription_id: Optional[int] = None, +) -> Tuple[User, Subscription]: + user, _ = await _resolve_user_from_init_data(db, init_data) + + subscription = getattr(user, "subscription", None) + if not subscription: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={"code": "subscription_not_found", "message": "Subscription not found"}, + ) + + if subscription_id and subscription.id != subscription_id: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={"code": "subscription_mismatch", "message": "Subscription mismatch"}, + ) + + if subscription.is_trial or subscription.actual_status != "active": + raise HTTPException( + status.HTTP_409_CONFLICT, + detail={"code": "subscription_inactive", "message": "Active paid subscription required"}, + ) + + return user, subscription + + +async def _build_subscription_settings( + db: AsyncSession, + user: User, + subscription: Subscription, +) -> Dict[str, Any]: + currency = _normalize_currency(getattr(user, "balance_currency", None)) + + connected_squads: List[str] = list(subscription.connected_squads or []) + current_servers = await _resolve_connected_servers(db, connected_squads) + current_server_set = {server.uuid for server in current_servers} + + available_servers = await get_available_server_squads( + db, promo_group_id=getattr(user, "promo_group_id", None) + ) + subscription_servers_info = await get_subscription_servers(db, subscription.id) + + servers_by_uuid: Dict[str, Dict[str, Any]] = {} + try: + servers_discount = user.get_promo_discount("servers") + except AttributeError: + servers_discount = 0 + + for server in available_servers: + base_price = int(getattr(server, "price_kopeks", 0) or 0) + discounted_price = _apply_discount(base_price, servers_discount) + servers_by_uuid[server.squad_uuid] = { + "uuid": server.squad_uuid, + "name": server.display_name or server.squad_uuid, + "price_kopeks": discounted_price, + "price_label": settings.format_price(discounted_price) if discounted_price else None, + "discount_percent": servers_discount or None, + "is_available": bool(getattr(server, "is_available", True)), + } + + for info in subscription_servers_info: + uuid = info.get("squad_uuid") + if not uuid or uuid in servers_by_uuid: + continue + base_price = int(info.get("paid_price_kopeks") or 0) + servers_by_uuid[uuid] = { + "uuid": uuid, + "name": info.get("display_name") or uuid, + "price_kopeks": base_price if base_price > 0 else None, + "price_label": settings.format_price(base_price) if base_price else None, + "discount_percent": None, + "is_available": bool(info.get("is_available", False)), + } + + server_options: List[Dict[str, Any]] = [] + for uuid, data in servers_by_uuid.items(): + option = dict(data) + option["is_connected"] = uuid in current_server_set + server_options.append(option) + + traffic_packages = [ + pkg for pkg in settings.get_traffic_packages() if pkg.get("enabled", True) + ] + try: + traffic_discount = user.get_promo_discount("traffic") + except AttributeError: + traffic_discount = 0 + + traffic_options: List[Dict[str, Any]] = [] + for package in traffic_packages: + value = int(package.get("gb") or 0) + base_price = int(package.get("price") or 0) + discounted_price = _apply_discount(base_price, traffic_discount) + option = { + "value": value, + "label": _format_limit_label(value if value > 0 else None), + "price_kopeks": discounted_price, + "price_label": settings.format_price(discounted_price) if discounted_price else None, + "is_available": True, + } + if traffic_discount: + option["discount_percent"] = traffic_discount + option["is_current"] = value == (subscription.traffic_limit_gb or 0) + traffic_options.append(option) + + devices_min = max(1, settings.DEFAULT_DEVICE_LIMIT) + devices_max = max(devices_min, settings.MAX_DEVICES_LIMIT) + devices_current = max(devices_min, int(subscription.device_limit or devices_min)) + + try: + devices_discount = user.get_promo_discount("devices") + except AttributeError: + devices_discount = 0 + + device_options: List[Dict[str, Any]] = [] + for count in range(devices_min, devices_max + 1): + additional = max(0, count - settings.DEFAULT_DEVICE_LIMIT) + base_price = additional * settings.PRICE_PER_DEVICE + discounted_price = _apply_discount(base_price, devices_discount) + option = { + "value": count, + "label": str(count), + "price_kopeks": discounted_price, + "price_label": settings.format_price(discounted_price) if discounted_price else None, + } + if count == devices_current: + option["is_current"] = True + device_options.append(option) + + settings_payload: Dict[str, Any] = { + "subscription_id": subscription.id, + "currency": currency, + "current": { + "servers": [server.dict() for server in current_servers], + "traffic_limit_gb": subscription.traffic_limit_gb, + "traffic_limit_label": _format_limit_label(subscription.traffic_limit_gb), + "device_limit": subscription.device_limit, + }, + "servers": { + "available": server_options, + "min": 1, + "max": len(server_options) if server_options else 0, + "can_update": True, + }, + "traffic": { + "options": traffic_options, + "can_update": True, + }, + "devices": { + "options": device_options, + "min": devices_min, + "max": devices_max, + "step": 1, + "current": devices_current, + "can_update": True, + }, + } + + return settings_payload + + +@router.post( + "/subscription/settings", + response_model=MiniAppSubscriptionSettingsResponse, +) +async def get_subscription_settings( + payload: MiniAppSubscriptionSettingsRequest, + db: AsyncSession = Depends(get_db_session), +) -> MiniAppSubscriptionSettingsResponse: + user, subscription = await _resolve_subscription_with_access( + db, + payload.init_data, + subscription_id=payload.subscription_id, + ) + + settings_payload = await _build_subscription_settings(db, user, subscription) + + return MiniAppSubscriptionSettingsResponse(settings=settings_payload) + + @router.post("/subscription", response_model=MiniAppSubscriptionResponse) async def get_subscription_details( payload: MiniAppSubscriptionRequest, @@ -2292,6 +2508,218 @@ async def get_subscription_details( ) +@router.post( + "/subscription/servers", + response_model=MiniAppSubscriptionActionResponse, +) +async def update_subscription_servers( + payload: MiniAppSubscriptionServersUpdateRequest, + db: AsyncSession = Depends(get_db_session), +) -> MiniAppSubscriptionActionResponse: + user, subscription = await _resolve_subscription_with_access( + db, + payload.init_data, + subscription_id=payload.subscription_id, + ) + + raw_servers = payload.servers or [] + selected_servers: List[str] = [] + seen_servers: Set[str] = set() + for entry in raw_servers: + if not entry: + continue + uuid = str(entry).strip() + if not uuid or uuid in seen_servers: + continue + selected_servers.append(uuid) + seen_servers.add(uuid) + + if not selected_servers: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={"code": "invalid_servers", "message": "At least one server must be selected"}, + ) + + current_servers_info = await get_subscription_servers(db, subscription.id) + current_map: Dict[str, Dict[str, Any]] = { + str(info.get("squad_uuid")): info for info in current_servers_info if info.get("squad_uuid") + } + current_set = set(current_map.keys()) + + available_servers = await get_available_server_squads( + db, promo_group_id=getattr(user, "promo_group_id", None) + ) + allowed_servers = {server.squad_uuid for server in available_servers} + + server_records: Dict[str, Any] = {} + for uuid in selected_servers: + if uuid not in allowed_servers and uuid not in current_set: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={"code": "server_unavailable", "message": "Selected server is not available"}, + ) + server = await get_server_squad_by_uuid(db, uuid) + if not server: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={"code": "server_not_found", "message": "Server not found"}, + ) + server_records[uuid] = server + + servers_to_add: List[int] = [] + for uuid in selected_servers: + if uuid not in current_set: + servers_to_add.append(int(server_records[uuid].id)) + + servers_to_remove: List[int] = [] + for uuid, info in current_map.items(): + if uuid not in selected_servers: + try: + server_id = int(info.get("server_id") or 0) + except (TypeError, ValueError): + server_id = 0 + if server_id: + servers_to_remove.append(server_id) + + if not servers_to_add and not servers_to_remove: + return MiniAppSubscriptionActionResponse(success=True) + + if servers_to_remove: + await remove_subscription_servers(db, subscription.id, servers_to_remove) + await remove_user_from_servers(db, servers_to_remove) + + if servers_to_add: + await add_subscription_servers(db, subscription, servers_to_add) + await add_user_to_servers(db, servers_to_add) + + subscription.connected_squads = selected_servers + subscription.updated_at = datetime.utcnow() + await db.commit() + await db.refresh(subscription) + + service = SubscriptionService() + try: + await service.update_remnawave_user(db, subscription) + except Exception as error: # pragma: no cover - best effort sync + logger.warning("Failed to sync subscription servers to RemnaWave: %s", error) + + return MiniAppSubscriptionActionResponse(success=True, message="Servers updated") + + +@router.post( + "/subscription/traffic", + response_model=MiniAppSubscriptionActionResponse, +) +async def update_subscription_traffic( + payload: MiniAppSubscriptionTrafficUpdateRequest, + db: AsyncSession = Depends(get_db_session), +) -> MiniAppSubscriptionActionResponse: + user, subscription = await _resolve_subscription_with_access( + db, + payload.init_data, + subscription_id=payload.subscription_id, + ) + + selected = payload.traffic if payload.traffic is not None else payload.traffic_gb + if selected is None: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={"code": "invalid_traffic", "message": "Traffic limit is required"}, + ) + + try: + selected_value = int(selected) + except (TypeError, ValueError) as error: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={"code": "invalid_traffic", "message": "Invalid traffic value"}, + ) from error + + available_values = { + int(pkg.get("gb") or 0) + for pkg in settings.get_traffic_packages() + if pkg.get("enabled", True) + } + if selected_value not in available_values: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={"code": "traffic_unavailable", "message": "Selected traffic option is not available"}, + ) + + current_value = int(subscription.traffic_limit_gb or 0) + if selected_value == current_value: + return MiniAppSubscriptionActionResponse(success=True) + + subscription.traffic_limit_gb = selected_value + subscription.updated_at = datetime.utcnow() + await db.commit() + await db.refresh(subscription) + + service = SubscriptionService() + try: + await service.update_remnawave_user(db, subscription) + except Exception as error: # pragma: no cover - best effort sync + logger.warning("Failed to sync subscription traffic to RemnaWave: %s", error) + + return MiniAppSubscriptionActionResponse(success=True, message="Traffic limit updated") + + +@router.post( + "/subscription/devices", + response_model=MiniAppSubscriptionActionResponse, +) +async def update_subscription_devices( + payload: MiniAppSubscriptionDevicesUpdateRequest, + db: AsyncSession = Depends(get_db_session), +) -> MiniAppSubscriptionActionResponse: + _, subscription = await _resolve_subscription_with_access( + db, + payload.init_data, + subscription_id=payload.subscription_id, + ) + + selected = payload.devices if payload.devices is not None else payload.device_limit + if selected is None: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={"code": "invalid_devices", "message": "Device limit is required"}, + ) + + try: + selected_value = int(selected) + except (TypeError, ValueError) as error: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={"code": "invalid_devices", "message": "Invalid device limit"}, + ) from error + + devices_min = max(1, settings.DEFAULT_DEVICE_LIMIT) + devices_max = max(devices_min, settings.MAX_DEVICES_LIMIT) + + if selected_value < devices_min or selected_value > devices_max: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={"code": "devices_out_of_range", "message": "Device limit is out of range"}, + ) + + current_value = int(subscription.device_limit or devices_min) + if selected_value == current_value: + return MiniAppSubscriptionActionResponse(success=True) + + subscription.device_limit = selected_value + subscription.updated_at = datetime.utcnow() + await db.commit() + await db.refresh(subscription) + + service = SubscriptionService() + try: + await service.update_remnawave_user(db, subscription) + except Exception as error: # pragma: no cover - best effort sync + logger.warning("Failed to sync subscription devices to RemnaWave: %s", error) + + return MiniAppSubscriptionActionResponse(success=True, message="Device limit updated") + + @router.post( "/promo-codes/activate", response_model=MiniAppPromoCodeActivationResponse, diff --git a/app/webapi/schemas/miniapp.py b/app/webapi/schemas/miniapp.py index 3f668548..6baac644 100644 --- a/app/webapi/schemas/miniapp.py +++ b/app/webapi/schemas/miniapp.py @@ -15,6 +15,41 @@ class MiniAppSubscriptionRequest(BaseModel): init_data: str = Field(..., alias="initData") +class MiniAppSubscriptionSettingsRequest(BaseModel): + init_data: str = Field(..., alias="initData") + subscription_id: Optional[int] = Field(default=None, alias="subscriptionId") + + +class MiniAppSubscriptionServersUpdateRequest(BaseModel): + init_data: str = Field(..., alias="initData") + servers: List[str] = Field(default_factory=list) + subscription_id: Optional[int] = Field(default=None, alias="subscriptionId") + + +class MiniAppSubscriptionTrafficUpdateRequest(BaseModel): + init_data: str = Field(..., alias="initData") + traffic: Optional[int] = None + traffic_gb: Optional[int] = Field(default=None, alias="trafficGb") + subscription_id: Optional[int] = Field(default=None, alias="subscriptionId") + + +class MiniAppSubscriptionDevicesUpdateRequest(BaseModel): + init_data: str = Field(..., alias="initData") + devices: Optional[int] = None + device_limit: Optional[int] = Field(default=None, alias="deviceLimit") + subscription_id: Optional[int] = Field(default=None, alias="subscriptionId") + + +class MiniAppSubscriptionActionResponse(BaseModel): + success: bool = True + message: Optional[str] = None + + +class MiniAppSubscriptionSettingsResponse(BaseModel): + success: bool = True + settings: Dict[str, Any] = Field(default_factory=dict) + + class MiniAppSubscriptionUser(BaseModel): telegram_id: int username: Optional[str] = None