From a3532e5878027f5acd701268e85f896097a1d1f6 Mon Sep 17 00:00:00 2001 From: Egor Date: Tue, 11 Nov 2025 13:06:10 +0300 Subject: [PATCH] Handle CryptoBot renewal payload fallbacks --- app/services/payment/cryptobot.py | 197 +++++++ app/services/subscription_renewal_service.py | 564 ++++++++++++++++++ app/webapi/routes/miniapp.py | 559 +++++++++--------- app/webapi/schemas/miniapp.py | 9 + tests/test_miniapp_payments.py | 575 +++++++++++++++++++ 5 files changed, 1645 insertions(+), 259 deletions(-) create mode 100644 app/services/subscription_renewal_service.py diff --git a/app/services/payment/cryptobot.py b/app/services/payment/cryptobot.py index 79a4101b..6c4b26c6 100644 --- a/app/services/payment/cryptobot.py +++ b/app/services/payment/cryptobot.py @@ -16,12 +16,24 @@ from app.database.models import PaymentMethod, TransactionType from app.services.subscription_auto_purchase_service import ( auto_purchase_saved_cart_after_topup, ) +from app.services.subscription_renewal_service import ( + SubscriptionRenewalChargeError, + SubscriptionRenewalPricing, + SubscriptionRenewalService, + RenewalPaymentDescriptor, + build_renewal_period_id, + decode_payment_payload, + parse_payment_metadata, +) from app.utils.currency_converter import currency_converter from app.utils.user_utils import format_referrer_info logger = logging.getLogger(__name__) +renewal_service = SubscriptionRenewalService() + + @dataclass(slots=True) class _AdminNotificationContext: user_id: int @@ -173,6 +185,36 @@ class CryptoBotPaymentMixin: db, invoice_id, status, paid_at ) + descriptor = decode_payment_payload( + getattr(updated_payment, "payload", "") or "", + expected_user_id=updated_payment.user_id, + ) + + if descriptor is None: + inline_payload = payload.get("payload") + if isinstance(inline_payload, str) and inline_payload: + descriptor = decode_payment_payload( + inline_payload, + expected_user_id=updated_payment.user_id, + ) + + if descriptor is None: + metadata = payload.get("metadata") + if isinstance(metadata, dict) and metadata: + descriptor = parse_payment_metadata( + metadata, + expected_user_id=updated_payment.user_id, + ) + if descriptor: + renewal_handled = await self._process_subscription_renewal_payment( + db, + updated_payment, + descriptor, + cryptobot_crud, + ) + if renewal_handled: + return True + if not updated_payment.transaction_id: amount_usd = updated_payment.amount_float @@ -394,6 +436,161 @@ class CryptoBotPaymentMixin: ) return False + async def _process_subscription_renewal_payment( + self, + db: AsyncSession, + payment: Any, + descriptor: RenewalPaymentDescriptor, + cryptobot_crud: Any, + ) -> bool: + try: + payment_service_module = import_module("app.services.payment_service") + user = await payment_service_module.get_user_by_id(db, payment.user_id) + except Exception as error: + logger.error( + "Не удалось загрузить пользователя %s для продления через CryptoBot: %s", + getattr(payment, "user_id", None), + error, + ) + return False + + if not user: + logger.error( + "Пользователь %s не найден при обработке продления через CryptoBot", + getattr(payment, "user_id", None), + ) + return False + + subscription = getattr(user, "subscription", None) + if not subscription or subscription.id != descriptor.subscription_id: + logger.warning( + "Продление через CryptoBot отклонено: подписка %s не совпадает с ожидаемой %s", + getattr(subscription, "id", None), + descriptor.subscription_id, + ) + return False + + pricing_model: Optional[SubscriptionRenewalPricing] = None + if descriptor.pricing_snapshot: + try: + pricing_model = SubscriptionRenewalPricing.from_payload( + descriptor.pricing_snapshot + ) + except Exception as error: + logger.warning( + "Не удалось восстановить сохраненную стоимость продления из payload %s: %s", + payment.invoice_id, + error, + ) + + if pricing_model is None: + try: + pricing_model = await renewal_service.calculate_pricing( + db, + user, + subscription, + descriptor.period_days, + ) + except Exception as error: + logger.error( + "Не удалось пересчитать стоимость продления для CryptoBot %s: %s", + payment.invoice_id, + error, + ) + return False + + if pricing_model.final_total != descriptor.total_amount_kopeks: + logger.warning( + "Сумма продления через CryptoBot %s изменилась (ожидалось %s, получено %s)", + payment.invoice_id, + descriptor.total_amount_kopeks, + pricing_model.final_total, + ) + pricing_model.final_total = descriptor.total_amount_kopeks + pricing_model.per_month = ( + descriptor.total_amount_kopeks // pricing_model.months + if pricing_model.months + else descriptor.total_amount_kopeks + ) + + pricing_model.period_days = descriptor.period_days + pricing_model.period_id = build_renewal_period_id(descriptor.period_days) + + required_balance = max( + 0, + min( + pricing_model.final_total, + descriptor.balance_component_kopeks, + ), + ) + + current_balance = getattr(user, "balance_kopeks", 0) + if current_balance < required_balance: + logger.warning( + "Недостаточно средств на балансе пользователя %s для завершения продления: нужно %s, доступно %s", + user.id, + required_balance, + current_balance, + ) + return False + + description = f"Продление подписки на {descriptor.period_days} дней" + + try: + result = await renewal_service.finalize( + db, + user, + subscription, + pricing_model, + charge_balance_amount=required_balance, + description=description, + payment_method=PaymentMethod.CRYPTOBOT, + ) + except SubscriptionRenewalChargeError as error: + logger.error( + "Списание баланса не выполнено при продлении через CryptoBot %s: %s", + payment.invoice_id, + error, + ) + return False + except Exception as error: + logger.error( + "Ошибка завершения продления через CryptoBot %s: %s", + payment.invoice_id, + error, + exc_info=True, + ) + return False + + transaction = result.transaction + if transaction: + try: + await cryptobot_crud.link_cryptobot_payment_to_transaction( + db, + payment.invoice_id, + transaction.id, + ) + except Exception as error: + logger.warning( + "Не удалось связать платеж CryptoBot %s с транзакцией %s: %s", + payment.invoice_id, + transaction.id, + error, + ) + + external_amount_label = settings.format_price(descriptor.missing_amount_kopeks) + balance_amount_label = settings.format_price(required_balance) + + logger.info( + "Подписка %s продлена через CryptoBot invoice %s (внешний платеж %s, списано с баланса %s)", + subscription.id, + payment.invoice_id, + external_amount_label, + balance_amount_label, + ) + + return True + async def _deliver_admin_topup_notification( self, context: _AdminNotificationContext ) -> None: diff --git a/app/services/subscription_renewal_service.py b/app/services/subscription_renewal_service.py new file mode 100644 index 00000000..6912af23 --- /dev/null +++ b/app/services/subscription_renewal_service.py @@ -0,0 +1,564 @@ +from __future__ import annotations + +import base64 +import json +import logging +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Awaitable, Callable, Dict, List, Optional +from uuid import uuid4 + +from aiogram import Bot +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import settings +from app.database.crud.server_squad import get_server_ids_by_uuids +from app.database.crud.subscription import ( + add_subscription_servers, + calculate_subscription_total_cost, + extend_subscription, +) +from app.database.crud.transaction import create_transaction +from app.database.crud.user import subtract_user_balance +from app.database.models import PaymentMethod, Subscription, Transaction, TransactionType, User +from app.services.admin_notification_service import AdminNotificationService +from app.services.remnawave_service import RemnaWaveConfigurationError +from app.services.subscription_service import SubscriptionService +from app.utils.pricing_utils import ( + apply_percentage_discount, + calculate_months_from_days, + format_period_description, + validate_pricing_calculation, +) + +logger = logging.getLogger(__name__) + + +class SubscriptionRenewalError(Exception): + """Base class for subscription renewal related errors.""" + + +class SubscriptionRenewalChargeError(SubscriptionRenewalError): + """Raised when the balance charge step fails.""" + + +@dataclass(slots=True) +class SubscriptionRenewalPricing: + period_days: int + period_id: str + months: int + base_original_total: int + discounted_total: int + final_total: int + promo_discount_value: int + promo_discount_percent: int + overall_discount_percent: int + per_month: int + server_ids: List[int] + details: Dict[str, Any] + + def to_payload(self) -> Dict[str, Any]: + return { + "period_id": self.period_id, + "period_days": self.period_days, + "months": self.months, + "base_original_total": self.base_original_total, + "discounted_total": self.discounted_total, + "final_total": self.final_total, + "promo_discount_value": self.promo_discount_value, + "promo_discount_percent": self.promo_discount_percent, + "overall_discount_percent": self.overall_discount_percent, + "per_month": self.per_month, + "server_ids": list(self.server_ids), + "details": dict(self.details), + } + + @classmethod + def from_payload(cls, payload: Dict[str, Any]) -> "SubscriptionRenewalPricing": + return cls( + period_days=int(payload.get("period_days", 0) or 0), + period_id=str(payload.get("period_id") or build_renewal_period_id(int(payload.get("period_days", 0) or 0))), + months=int(payload.get("months", 0) or 0), + base_original_total=int(payload.get("base_original_total", 0) or 0), + discounted_total=int(payload.get("discounted_total", 0) or 0), + final_total=int(payload.get("final_total", 0) or 0), + promo_discount_value=int(payload.get("promo_discount_value", 0) or 0), + promo_discount_percent=int(payload.get("promo_discount_percent", 0) or 0), + overall_discount_percent=int(payload.get("overall_discount_percent", 0) or 0), + per_month=int(payload.get("per_month", 0) or 0), + server_ids=list(payload.get("server_ids", []) or []), + details=dict(payload.get("details", {}) or {}), + ) + + +@dataclass(slots=True) +class SubscriptionRenewalResult: + subscription: Subscription + transaction: Optional[Transaction] + total_amount_kopeks: int + charged_from_balance_kopeks: int + old_end_date: Optional[datetime] + + +@dataclass(slots=True) +class RenewalPaymentDescriptor: + user_id: int + subscription_id: int + period_days: int + total_amount_kopeks: int + missing_amount_kopeks: int + payload_id: str + pricing_snapshot: Optional[Dict[str, Any]] = None + + @property + def balance_component_kopeks(self) -> int: + remaining = self.total_amount_kopeks - self.missing_amount_kopeks + return max(0, remaining) + + +_PAYLOAD_PREFIX = "subscription_renewal" + + +def build_renewal_period_id(period_days: int) -> str: + return f"days:{period_days}" + + +def build_payment_descriptor( + user_id: int, + subscription_id: int, + period_days: int, + total_amount_kopeks: int, + missing_amount_kopeks: int, + *, + pricing_snapshot: Optional[Dict[str, Any]] = None, +) -> RenewalPaymentDescriptor: + return RenewalPaymentDescriptor( + user_id=user_id, + subscription_id=subscription_id, + period_days=period_days, + total_amount_kopeks=max(0, int(total_amount_kopeks)), + missing_amount_kopeks=max(0, int(missing_amount_kopeks)), + payload_id=uuid4().hex[:8], + pricing_snapshot=pricing_snapshot or None, + ) + + +def encode_payment_payload(descriptor: RenewalPaymentDescriptor) -> str: + snapshot_segment = "" + if descriptor.pricing_snapshot: + try: + raw_snapshot = json.dumps( + descriptor.pricing_snapshot, + separators=(",", ":"), + ensure_ascii=False, + ).encode("utf-8") + snapshot_segment = base64.urlsafe_b64encode(raw_snapshot).decode("ascii").rstrip("=") + except (TypeError, ValueError): + snapshot_segment = "" + + payload = ( + f"{_PAYLOAD_PREFIX}|{descriptor.user_id}|{descriptor.subscription_id}|" + f"{descriptor.period_days}|{descriptor.total_amount_kopeks}|" + f"{descriptor.missing_amount_kopeks}|{descriptor.payload_id}" + ) + + if snapshot_segment: + payload = f"{payload}|{snapshot_segment}" + + return payload + + +def decode_payment_payload(payload: str, expected_user_id: Optional[int] = None) -> Optional[RenewalPaymentDescriptor]: + if not payload or not payload.startswith(f"{_PAYLOAD_PREFIX}|"): + return None + + parts = payload.split("|") + if len(parts) < 7: + return None + + try: + ( + _, + user_id_raw, + subscription_raw, + period_raw, + total_raw, + missing_raw, + payload_id, + *snapshot_parts, + ) = parts + user_id = int(user_id_raw) + subscription_id = int(subscription_raw) + period_days = int(period_raw) + total_amount = int(total_raw) + missing_amount = int(missing_raw) + except (TypeError, ValueError): + return None + + pricing_snapshot: Optional[Dict[str, Any]] = None + if snapshot_parts: + encoded_snapshot = snapshot_parts[0] + if encoded_snapshot: + padding = "=" * (-len(encoded_snapshot) % 4) + try: + decoded = base64.urlsafe_b64decode((encoded_snapshot + padding).encode("ascii")) + snapshot_data = json.loads(decoded.decode("utf-8")) + if isinstance(snapshot_data, dict): + pricing_snapshot = snapshot_data + except (ValueError, json.JSONDecodeError, UnicodeDecodeError): + logger.warning("Failed to decode renewal pricing snapshot from payload") + + if expected_user_id is not None and user_id != expected_user_id: + return None + + return RenewalPaymentDescriptor( + user_id=user_id, + subscription_id=subscription_id, + period_days=period_days, + total_amount_kopeks=max(0, total_amount), + missing_amount_kopeks=max(0, missing_amount), + payload_id=payload_id, + pricing_snapshot=pricing_snapshot, + ) + + +def build_payment_metadata(descriptor: RenewalPaymentDescriptor) -> Dict[str, Any]: + return { + "payment_purpose": _PAYLOAD_PREFIX, + "subscription_id": str(descriptor.subscription_id), + "period_days": str(descriptor.period_days), + "total_amount_kopeks": str(descriptor.total_amount_kopeks), + "missing_amount_kopeks": str(descriptor.missing_amount_kopeks), + "payload_id": descriptor.payload_id, + "pricing_snapshot": descriptor.pricing_snapshot or {}, + } + + +def parse_payment_metadata( + metadata: Optional[Dict[str, Any]], + *, + expected_user_id: Optional[int] = None, +) -> Optional[RenewalPaymentDescriptor]: + if not metadata: + return None + + if metadata.get("payment_purpose") != _PAYLOAD_PREFIX: + return None + + try: + subscription_id = int(metadata.get("subscription_id")) + period_days = int(metadata.get("period_days")) + total_amount = int(metadata.get("total_amount_kopeks")) + missing_amount = int(metadata.get("missing_amount_kopeks")) + except (TypeError, ValueError): + return None + + payload_id = str(metadata.get("payload_id") or "") + user_id = metadata.get("user_id") + if user_id is not None: + try: + user_id_int = int(user_id) + except (TypeError, ValueError): + user_id_int = None + else: + user_id_int = None + + if expected_user_id is not None and user_id_int is not None and user_id_int != expected_user_id: + return None + + pricing_snapshot = metadata.get("pricing_snapshot") + if isinstance(pricing_snapshot, dict): + snapshot_dict = pricing_snapshot + else: + snapshot_dict = None + + return RenewalPaymentDescriptor( + user_id=user_id_int or expected_user_id or 0, + subscription_id=subscription_id, + period_days=period_days, + total_amount_kopeks=max(0, total_amount), + missing_amount_kopeks=max(0, missing_amount), + payload_id=payload_id, + pricing_snapshot=snapshot_dict, + ) + + +async def with_admin_notification_service( + handler: Callable[[AdminNotificationService], Awaitable[Any]], +) -> None: + if not getattr(settings, "ADMIN_NOTIFICATIONS_ENABLED", False): + return + if not settings.BOT_TOKEN: + logger.debug("Skipping admin notification: bot token is not configured") + return + + bot: Bot | None = None + try: + bot = Bot(token=settings.BOT_TOKEN) + service = AdminNotificationService(bot) + await handler(service) + except Exception as error: # pragma: no cover - defensive logging + logger.error("Failed to send admin notification from renewal service: %s", error) + finally: + if bot is not None: + await bot.session.close() + + +class SubscriptionRenewalService: + """Shared helpers for subscription renewal pricing and processing.""" + + async def calculate_pricing( + self, + db: AsyncSession, + user: User, + subscription: Subscription, + period_days: int, + ) -> SubscriptionRenewalPricing: + connected_uuids = [str(uuid) for uuid in list(subscription.connected_squads or [])] + server_ids: List[int] = [] + if connected_uuids: + server_ids = await get_server_ids_by_uuids(db, connected_uuids) + + traffic_limit = subscription.traffic_limit_gb + if traffic_limit is None: + traffic_limit = settings.DEFAULT_TRAFFIC_LIMIT_GB + + devices_limit = subscription.device_limit + if devices_limit is None: + devices_limit = settings.DEFAULT_DEVICE_LIMIT + + total_cost, details = await calculate_subscription_total_cost( + db, + period_days, + int(traffic_limit or 0), + server_ids, + int(devices_limit or 0), + user=user, + ) + + months = details.get("months_in_period") or calculate_months_from_days(period_days) + + base_original_total = ( + details.get("base_price_original", 0) + + details.get("traffic_price_per_month", 0) * months + + details.get("servers_price_per_month", 0) * months + + details.get("devices_price_per_month", 0) * months + ) + + discounted_total = total_cost + + monthly_additions = 0 + if months > 0: + monthly_additions = ( + details.get("total_servers_price", 0) // months + + details.get("total_devices_price", 0) // months + + details.get("total_traffic_price", 0) // months + ) + + if not validate_pricing_calculation( + details.get("base_price", 0), + monthly_additions, + months, + discounted_total, + ): + logger.warning( + "Renewal pricing validation failed for subscription %s (period %s)", + subscription.id, + period_days, + ) + + from app.utils.promo_offer import get_user_active_promo_discount_percent + + promo_percent = get_user_active_promo_discount_percent(user) + + final_total = discounted_total + promo_discount_value = 0 + if promo_percent > 0 and discounted_total > 0: + final_total, promo_discount_value = apply_percentage_discount( + discounted_total, + promo_percent, + ) + + overall_discount_value = max(0, base_original_total - final_total) + overall_discount_percent = 0 + if base_original_total > 0 and overall_discount_value > 0: + overall_discount_percent = int( + round(overall_discount_value * 100 / base_original_total) + ) + + per_month = final_total // months if months else final_total + + return SubscriptionRenewalPricing( + period_days=period_days, + period_id=build_renewal_period_id(period_days), + months=months, + base_original_total=base_original_total, + discounted_total=discounted_total, + final_total=final_total, + promo_discount_value=promo_discount_value, + promo_discount_percent=promo_percent if promo_discount_value else 0, + overall_discount_percent=overall_discount_percent, + per_month=per_month, + server_ids=list(server_ids), + details=details, + ) + + async def finalize( + self, + db: AsyncSession, + user: User, + subscription: Subscription, + pricing: SubscriptionRenewalPricing, + *, + charge_balance_amount: Optional[int] = None, + description: Optional[str] = None, + payment_method: Optional[PaymentMethod] = None, + ) -> SubscriptionRenewalResult: + final_total = int(pricing.final_total) + if final_total < 0: + final_total = 0 + + period_days = int(pricing.period_days) + charge_from_balance = charge_balance_amount + if charge_from_balance is None: + charge_from_balance = final_total + charge_from_balance = max(0, min(charge_from_balance, final_total)) + + consume_promo_offer = bool(pricing.promo_discount_value) + + description_text = description or f"Продление подписки на {period_days} дней" + + if charge_from_balance > 0 or consume_promo_offer: + success = await subtract_user_balance( + db, + user, + charge_from_balance, + description_text, + consume_promo_offer=consume_promo_offer, + ) + if not success: + raise SubscriptionRenewalChargeError("Failed to charge balance") + await db.refresh(user) + + subscription_before = subscription + old_end_date = subscription_before.end_date + + subscription_after = await extend_subscription(db, subscription_before, period_days) + + server_ids = pricing.server_ids or [] + server_prices_for_period = pricing.details.get("servers_individual_prices", []) + if server_ids: + try: + await add_subscription_servers( + db, + subscription_after, + server_ids, + server_prices_for_period, + ) + except Exception as error: # pragma: no cover - defensive logging + logger.warning( + "Failed to record renewal server prices for subscription %s: %s", + subscription_after.id, + error, + ) + + subscription_service = SubscriptionService() + try: + await subscription_service.update_remnawave_user( + db, + subscription_after, + reset_traffic=settings.RESET_TRAFFIC_ON_PAYMENT, + reset_reason="subscription renewal", + ) + except RemnaWaveConfigurationError as error: # pragma: no cover - configuration issues + logger.warning("RemnaWave update skipped: %s", error) + except Exception as error: # pragma: no cover - defensive logging + logger.error( + "Failed to update RemnaWave user for subscription %s: %s", + subscription_after.id, + error, + ) + + transaction: Optional[Transaction] = None + try: + transaction = await create_transaction( + db=db, + user_id=user.id, + type=TransactionType.SUBSCRIPTION_PAYMENT, + amount_kopeks=final_total, + description=description_text, + payment_method=payment_method, + ) + except Exception as error: # pragma: no cover - defensive logging + logger.warning( + "Failed to create renewal transaction for subscription %s: %s", + subscription_after.id, + error, + ) + + await db.refresh(user) + await db.refresh(subscription_after) + + if transaction and old_end_date and subscription_after.end_date: + await with_admin_notification_service( + lambda service: service.send_subscription_extension_notification( + db, + user, + subscription_after, + transaction, + period_days, + old_end_date, + new_end_date=subscription_after.end_date, + balance_after=user.balance_kopeks, + ) + ) + + return SubscriptionRenewalResult( + subscription=subscription_after, + transaction=transaction, + total_amount_kopeks=final_total, + charged_from_balance_kopeks=charge_from_balance, + old_end_date=old_end_date, + ) + + def build_option_payload( + self, + pricing: SubscriptionRenewalPricing, + *, + language: str, + ) -> Dict[str, Any]: + label = format_period_description(pricing.period_days, language) + price_label = settings.format_price(pricing.final_total) + original_label = None + if ( + pricing.base_original_total + and pricing.base_original_total != pricing.final_total + ): + original_label = settings.format_price(pricing.base_original_total) + + per_month_label = settings.format_price(pricing.per_month) + + payload = { + "id": pricing.period_id, + "days": pricing.period_days, + "months": pricing.months, + "price_kopeks": pricing.final_total, + "price_label": price_label, + "original_price_kopeks": pricing.base_original_total, + "original_price_label": original_label, + "discount_percent": pricing.overall_discount_percent, + "price_per_month_kopeks": pricing.per_month, + "price_per_month_label": per_month_label, + "title": label, + } + + return payload + + +def calculate_missing_amount(balance_kopeks: int, total_kopeks: int) -> int: + if total_kopeks <= 0: + return 0 + if balance_kopeks <= 0: + return total_kopeks + return max(0, total_kopeks - min(balance_kopeks, total_kopeks)) + diff --git a/app/webapi/routes/miniapp.py b/app/webapi/routes/miniapp.py index f4937461..203e6e42 100644 --- a/app/webapi/routes/miniapp.py +++ b/app/webapi/routes/miniapp.py @@ -3,10 +3,10 @@ from __future__ import annotations import logging import re import math -from decimal import Decimal, InvalidOperation, ROUND_HALF_UP, ROUND_FLOOR +from decimal import Decimal, InvalidOperation, ROUND_HALF_UP, ROUND_FLOOR, ROUND_UP from datetime import datetime, timedelta, timezone from uuid import uuid4 -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Union from aiogram import Bot from fastapi import APIRouter, Depends, HTTPException, status @@ -28,13 +28,11 @@ from app.database.crud.promo_offer_template import get_promo_offer_template_by_i from app.database.crud.server_squad import ( add_user_to_servers, get_available_server_squads, - get_server_ids_by_uuids, get_server_squad_by_uuid, remove_user_from_servers, ) from app.database.crud.subscription import ( add_subscription_servers, - calculate_subscription_total_cost, create_trial_subscription, extend_subscription, remove_subscription_servers, @@ -55,7 +53,6 @@ from app.database.models import ( PaymentMethod, User, ) -from app.services.admin_notification_service import AdminNotificationService from app.services.faq_service import FaqService from app.services.privacy_policy_service import PrivacyPolicyService from app.services.public_offer_service import PublicOfferService @@ -68,6 +65,16 @@ from app.services.payment_service import PaymentService, get_wata_payment_by_lin from app.services.promo_offer_service import promo_offer_service from app.services.promocode_service import PromoCodeService from app.services.subscription_service import SubscriptionService +from app.services.subscription_renewal_service import ( + SubscriptionRenewalChargeError, + SubscriptionRenewalService, + build_payment_descriptor, + build_renewal_period_id, + decode_payment_payload, + calculate_missing_amount, + encode_payment_payload, + with_admin_notification_service, +) from app.services.trial_activation_service import ( TrialPaymentChargeFailed, TrialPaymentInsufficientFunds, @@ -94,11 +101,9 @@ from app.utils.user_utils import ( ) from app.utils.pricing_utils import ( apply_percentage_discount, - calculate_months_from_days, calculate_prorated_price, format_period_description, get_remaining_months, - validate_pricing_calculation, ) from app.utils.promo_offer import get_user_active_promo_discount_percent @@ -180,27 +185,7 @@ logger = logging.getLogger(__name__) router = APIRouter() promo_code_service = PromoCodeService() - - -async def _with_admin_notification_service( - handler: Callable[[AdminNotificationService], Awaitable[Any]], -) -> None: - if not getattr(settings, "ADMIN_NOTIFICATIONS_ENABLED", False): - return - if not settings.BOT_TOKEN: - logger.debug("Skipping admin notification: bot token is not configured") - return - - bot: Bot | None = None - try: - bot = Bot(token=settings.BOT_TOKEN) - service = AdminNotificationService(bot) - await handler(service) - except Exception as error: # pragma: no cover - defensive logging - logger.error("Failed to send admin notification from miniapp: %s", error) - finally: - if bot: - await bot.session.close() +renewal_service = SubscriptionRenewalService() _CRYPTOBOT_MIN_USD = 1.0 @@ -1644,6 +1629,9 @@ async def _resolve_cryptobot_payment_status( except (InvalidOperation, TypeError): amount_kopeks = None + descriptor = decode_payment_payload(getattr(payment, "payload", "") or "", expected_user_id=user.id) + purpose = "subscription_renewal" if descriptor else "balance_topup" + return MiniAppPaymentStatusResult( method="cryptobot", status=status, @@ -1660,6 +1648,9 @@ async def _resolve_cryptobot_payment_status( "invoice_id": payment.invoice_id, "payload": query.payload, "started_at": query.started_at, + "purpose": purpose, + "subscription_id": descriptor.subscription_id if descriptor else None, + "period_days": descriptor.period_days if descriptor else None, }, ) @@ -3368,7 +3359,7 @@ async def activate_subscription_trial_endpoint( else: message = f"{message}\n\n💳 {charged_amount_label} has been deducted from your balance." - await _with_admin_notification_service( + await with_admin_notification_service( lambda service: service.send_trial_activation_notification( db, user, @@ -3803,10 +3794,93 @@ def _build_promo_offer_payload(user: Optional[User]) -> Optional[Dict[str, Any]] return payload -def _build_renewal_period_id(period_days: int) -> str: - return f"days:{period_days}" +def _format_payment_method_title(method: str) -> str: + mapping = { + "cryptobot": "CryptoBot", + "yookassa": "YooKassa", + "yookassa_sbp": "YooKassa СБП", + "mulenpay": "MulenPay", + "pal24": "Pal24", + "wata": "WataPay", + "heleket": "Heleket", + "tribute": "Tribute", + "stars": "Telegram Stars", + } + key = (method or "").lower() + return mapping.get(key, method.title() if method else "") +def _build_renewal_success_message( + user: User, + subscription: Subscription, + charged_amount: int, + promo_discount_value: int = 0, +) -> str: + language_code = _normalize_language_code(user) + amount_label = settings.format_price(max(0, charged_amount)) + date_label = ( + format_local_datetime(subscription.end_date, "%d.%m.%Y %H:%M") + if subscription.end_date + else "" + ) + + if language_code == "ru": + if charged_amount > 0: + message = ( + f"Подписка продлена до {date_label}. " if date_label else "Подписка продлена. " + ) + f"Списано {amount_label}." + else: + message = ( + f"Подписка продлена до {date_label}." + if date_label + else "Подписка успешно продлена." + ) + else: + if charged_amount > 0: + message = ( + f"Subscription renewed until {date_label}. " if date_label else "Subscription renewed. " + ) + f"Charged {amount_label}." + else: + message = ( + f"Subscription renewed until {date_label}." + if date_label + else "Subscription renewed successfully." + ) + + if promo_discount_value > 0: + discount_label = settings.format_price(promo_discount_value) + if language_code == "ru": + message += f" Применена дополнительная скидка {discount_label}." + else: + message += f" Promo discount applied: {discount_label}." + + return message + + +def _build_renewal_pending_message( + user: User, + missing_amount: int, + method: str, +) -> str: + language_code = _normalize_language_code(user) + amount_label = settings.format_price(max(0, missing_amount)) + method_title = _format_payment_method_title(method) + + if language_code == "ru": + if method_title: + return ( + f"Недостаточно средств на балансе. Доплатите {amount_label} через {method_title}, " + "чтобы завершить продление." + ) + return ( + f"Недостаточно средств на балансе. Доплатите {amount_label}, чтобы завершить продление." + ) + + if method_title: + return ( + f"Not enough balance. Pay the remaining {amount_label} via {method_title} to finish the renewal." + ) + return f"Not enough balance. Pay the remaining {amount_label} to finish the renewal." def _parse_period_identifier(identifier: Optional[str]) -> Optional[int]: if not identifier: return None @@ -3826,95 +3900,14 @@ async def _calculate_subscription_renewal_pricing( user: User, subscription: Subscription, period_days: int, -) -> Dict[str, Any]: - connected_uuids = [str(uuid) for uuid in list(subscription.connected_squads or [])] - server_ids: List[int] = [] - if connected_uuids: - server_ids = await get_server_ids_by_uuids(db, connected_uuids) - - traffic_limit = subscription.traffic_limit_gb - if traffic_limit is None: - traffic_limit = settings.DEFAULT_TRAFFIC_LIMIT_GB - - devices_limit = subscription.device_limit - if devices_limit is None: - devices_limit = settings.DEFAULT_DEVICE_LIMIT - - total_cost, details = await calculate_subscription_total_cost( +): + return await renewal_service.calculate_pricing( db, + user, + subscription, period_days, - int(traffic_limit or 0), - server_ids, - int(devices_limit or 0), - user=user, ) - months = details.get("months_in_period") or calculate_months_from_days(period_days) - - base_original_total = ( - details.get("base_price_original", 0) - + details.get("traffic_price_per_month", 0) * months - + details.get("servers_price_per_month", 0) * months - + details.get("devices_price_per_month", 0) * months - ) - - discounted_total = total_cost - - monthly_additions = 0 - if months > 0: - monthly_additions = ( - details.get("total_servers_price", 0) // months - + details.get("total_devices_price", 0) // months - + details.get("total_traffic_price", 0) // months - ) - - if not validate_pricing_calculation( - details.get("base_price", 0), - monthly_additions, - months, - discounted_total, - ): - logger.warning( - "Renewal pricing validation failed for subscription %s (period %s)", - subscription.id, - period_days, - ) - - promo_percent = get_user_active_promo_discount_percent(user) - final_total = discounted_total - promo_discount_value = 0 - if promo_percent > 0 and discounted_total > 0: - final_total, promo_discount_value = apply_percentage_discount( - discounted_total, - promo_percent, - ) - - overall_discount_value = max(0, base_original_total - final_total) - overall_discount_percent = 0 - if base_original_total > 0 and overall_discount_value > 0: - overall_discount_percent = int( - round(overall_discount_value * 100 / base_original_total) - ) - - per_month = final_total // months if months else final_total - - pricing_payload: Dict[str, Any] = { - "period_id": _build_renewal_period_id(period_days), - "period_days": period_days, - "months": months, - "base_original_total": base_original_total, - "discounted_total": discounted_total, - "final_total": final_total, - "promo_discount_value": promo_discount_value, - "promo_discount_percent": promo_percent if promo_discount_value else 0, - "overall_discount_percent": overall_discount_percent, - "per_month": per_month, - "server_ids": list(server_ids), - "details": details, - } - - return pricing_payload - async def _prepare_subscription_renewal_options( db: AsyncSession, @@ -3929,12 +3922,13 @@ async def _prepare_subscription_renewal_options( for period_days in available_periods: try: - pricing = await _calculate_subscription_renewal_pricing( + pricing_model = await _calculate_subscription_renewal_pricing( db, user, subscription, period_days, ) + pricing = pricing_model.to_payload() except Exception as error: # pragma: no cover - defensive logging logger.warning( "Failed to calculate renewal pricing for subscription %s (period %s): %s", @@ -4108,7 +4102,11 @@ async def _authorize_miniapp_user( return user -def _ensure_paid_subscription(user: User) -> Subscription: +def _ensure_paid_subscription( + user: User, + *, + allowed_statuses: Optional[Collection[str]] = None, +) -> Subscription: subscription = getattr(user, "subscription", None) if not subscription: raise HTTPException( @@ -4116,7 +4114,9 @@ def _ensure_paid_subscription(user: User) -> Subscription: detail={"code": "subscription_not_found", "message": "Subscription not found"}, ) - if getattr(subscription, "is_trial", False): + normalized_allowed_statuses = set(allowed_statuses or {"active"}) + + if getattr(subscription, "is_trial", False) and "trial" not in normalized_allowed_statuses: raise HTTPException( status.HTTP_403_FORBIDDEN, detail={ @@ -4125,7 +4125,28 @@ def _ensure_paid_subscription(user: User) -> Subscription: }, ) - if not getattr(subscription, "is_active", False): + actual_status = getattr(subscription, "actual_status", None) or "" + + if actual_status not in normalized_allowed_statuses: + if actual_status == "trial": + detail = { + "code": "paid_subscription_required", + "message": "This action is available only for paid subscriptions", + } + elif actual_status == "disabled": + detail = { + "code": "subscription_disabled", + "message": "Subscription is disabled", + } + else: + detail = { + "code": "subscription_inactive", + "message": "Subscription must be active to manage settings", + } + + raise HTTPException(status.HTTP_403_FORBIDDEN, detail=detail) + + if not getattr(subscription, "is_active", False) and "expired" not in normalized_allowed_statuses: raise HTTPException( status.HTTP_403_FORBIDDEN, detail={ @@ -4398,7 +4419,10 @@ async def get_subscription_renewal_options_endpoint( db: AsyncSession = Depends(get_db_session), ) -> MiniAppSubscriptionRenewalOptionsResponse: user = await _authorize_miniapp_user(payload.init_data, db) - subscription = _ensure_paid_subscription(user) + subscription = _ensure_paid_subscription( + user, + allowed_statuses={"active", "trial", "expired"}, + ) _validate_subscription_id(payload.subscription_id, subscription) periods, pricing_map, default_period_id = await _prepare_subscription_renewal_options( @@ -4477,7 +4501,10 @@ async def submit_subscription_renewal_endpoint( db: AsyncSession = Depends(get_db_session), ) -> MiniAppSubscriptionRenewalResponse: user = await _authorize_miniapp_user(payload.init_data, db) - subscription = _ensure_paid_subscription(user) + subscription = _ensure_paid_subscription( + user, + allowed_statuses={"active", "trial", "expired"}, + ) _validate_subscription_id(payload.subscription_id, subscription) period_days: Optional[int] = None @@ -4508,8 +4535,10 @@ async def submit_subscription_renewal_endpoint( detail={"code": "period_unavailable", "message": "Selected renewal period is not available"}, ) + method = (payload.method or "").strip().lower() + try: - pricing = await _calculate_subscription_renewal_pricing( + pricing_model = await _calculate_subscription_renewal_pricing( db, user, subscription, @@ -4529,156 +4558,168 @@ async def submit_subscription_renewal_endpoint( detail={"code": "pricing_failed", "message": "Failed to calculate renewal pricing"}, ) from error - final_total = int(pricing.get("final_total") or 0) + pricing = pricing_model.to_payload() + final_total = int(pricing_model.final_total) balance_kopeks = getattr(user, "balance_kopeks", 0) - - if final_total > 0 and balance_kopeks < final_total: - missing = final_total - balance_kopeks - raise HTTPException( - status.HTTP_402_PAYMENT_REQUIRED, - detail={ - "code": "insufficient_funds", - "message": "Not enough funds to renew the subscription", - "missing_amount_kopeks": missing, - }, - ) - - consume_promo_offer = bool(pricing.get("promo_discount_value")) + missing_amount = calculate_missing_amount(balance_kopeks, final_total) description = f"Продление подписки на {period_days} дней" - old_end_date = subscription.end_date - if final_total > 0 or consume_promo_offer: - success = await subtract_user_balance( - db, - user, - final_total, - description, - consume_promo_offer=consume_promo_offer, - ) - if not success: + if not method or missing_amount <= 0: + if final_total > 0 and balance_kopeks < final_total: + missing = final_total - balance_kopeks raise HTTPException( - status.HTTP_500_INTERNAL_SERVER_ERROR, - detail={"code": "charge_failed", "message": "Failed to charge balance"}, + status.HTTP_402_PAYMENT_REQUIRED, + detail={ + "code": "insufficient_funds", + "message": "Not enough funds to renew the subscription", + "missing_amount_kopeks": missing, + }, ) - await db.refresh(user) - subscription = await extend_subscription(db, subscription, period_days) - - server_ids = pricing.get("server_ids") or [] - server_prices_for_period = pricing.get("details", {}).get( - "servers_individual_prices", - [], - ) - if server_ids: try: - await add_subscription_servers( - db, - subscription, - server_ids, - server_prices_for_period, - ) - except Exception as error: # pragma: no cover - defensive logging - logger.warning( - "Failed to record renewal server prices for subscription %s: %s", - subscription.id, - error, - ) - - subscription_service = SubscriptionService() - try: - await subscription_service.update_remnawave_user( - db, - subscription, - reset_traffic=settings.RESET_TRAFFIC_ON_PAYMENT, - reset_reason="subscription renewal", - ) - except RemnaWaveConfigurationError as error: # pragma: no cover - configuration issues - logger.warning("RemnaWave update skipped: %s", error) - except Exception as error: # pragma: no cover - defensive logging - logger.error( - "Failed to update RemnaWave user for subscription %s: %s", - subscription.id, - error, - ) - - transaction: Optional[Transaction] = None - try: - transaction = await create_transaction( - db=db, - user_id=user.id, - type=TransactionType.SUBSCRIPTION_PAYMENT, - amount_kopeks=final_total, - description=description, - ) - except Exception as error: # pragma: no cover - defensive logging - logger.warning( - "Failed to create renewal transaction for subscription %s: %s", - subscription.id, - error, - ) - - await db.refresh(user) - await db.refresh(subscription) - - if transaction and old_end_date and subscription.end_date: - await _with_admin_notification_service( - lambda service: service.send_subscription_extension_notification( + result = await renewal_service.finalize( db, user, subscription, - transaction, - period_days, - old_end_date, - new_end_date=subscription.end_date, - balance_after=user.balance_kopeks, + pricing_model, + description=description, ) + except SubscriptionRenewalChargeError as error: + logger.error( + "Failed to charge balance for subscription renewal %s: %s", + subscription.id, + error, + ) + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"code": "charge_failed", "message": "Failed to charge balance"}, + ) from error + + updated_subscription = result.subscription + message = _build_renewal_success_message( + user, + updated_subscription, + result.total_amount_kopeks, + pricing_model.promo_discount_value, ) - language_code = _normalize_language_code(user) - amount_label = settings.format_price(final_total) - date_label = ( - format_local_datetime(subscription.end_date, "%d.%m.%Y %H:%M") - if subscription.end_date - else "" - ) + return MiniAppSubscriptionRenewalResponse( + message=message, + balance_kopeks=user.balance_kopeks, + balance_label=settings.format_price(user.balance_kopeks), + subscription_id=updated_subscription.id, + renewed_until=updated_subscription.end_date, + ) - if language_code == "ru": - if final_total > 0: - message = ( - f"Подписка продлена до {date_label}. " if date_label else "Подписка продлена. " - ) + f"Списано {amount_label}." - else: - message = ( - f"Подписка продлена до {date_label}." - if date_label - else "Подписка успешно продлена." + supported_methods = {"cryptobot"} + if method not in supported_methods: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={"code": "unsupported_method", "message": "Payment method is not supported for renewal"}, + ) + + if method == "cryptobot": + if not settings.is_cryptobot_enabled(): + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="Payment method is unavailable") + + rate = await _get_usd_to_rub_rate() + min_amount_kopeks, max_amount_kopeks = _compute_cryptobot_limits(rate) + if missing_amount < min_amount_kopeks: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={ + "code": "amount_below_minimum", + "message": f"Amount is below minimum ({min_amount_kopeks / 100:.2f} RUB)", + }, ) - else: - if final_total > 0: - message = ( - f"Subscription renewed until {date_label}. " if date_label else "Subscription renewed. " - ) + f"Charged {amount_label}." - else: - message = ( - f"Subscription renewed until {date_label}." - if date_label - else "Subscription renewed successfully." + if missing_amount > max_amount_kopeks: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={ + "code": "amount_above_maximum", + "message": f"Amount exceeds maximum ({max_amount_kopeks / 100:.2f} RUB)", + }, ) - promo_discount_value = pricing.get("promo_discount_value") or 0 - if consume_promo_offer and promo_discount_value > 0: - discount_label = settings.format_price(promo_discount_value) - if language_code == "ru": - message += f" Применена дополнительная скидка {discount_label}." - else: - message += f" Promo discount applied: {discount_label}." + try: + decimal_amount = (Decimal(missing_amount) / Decimal(100) / Decimal(str(rate))) + amount_usd = float( + decimal_amount.quantize(Decimal("0.01"), rounding=ROUND_UP) + ) + except (InvalidOperation, ValueError) as error: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={"code": "conversion_failed", "message": "Unable to convert amount to USD"}, + ) from error - return MiniAppSubscriptionRenewalResponse( - message=message, - balance_kopeks=user.balance_kopeks, - balance_label=settings.format_price(user.balance_kopeks), - subscription_id=subscription.id, - renewed_until=subscription.end_date, + if amount_usd <= 0: + amount_usd = float( + decimal_amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) + ) + + descriptor = build_payment_descriptor( + user.id, + subscription.id, + period_days, + final_total, + missing_amount, + pricing_snapshot=pricing, + ) + payload_value = encode_payment_payload(descriptor) + + payment_service = PaymentService() + result = await payment_service.create_cryptobot_payment( + db=db, + user_id=user.id, + amount_usd=amount_usd, + asset=settings.CRYPTOBOT_DEFAULT_ASSET, + description=description, + payload=payload_value, + ) + if not result: + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail={"code": "payment_creation_failed", "message": "Failed to create payment"}, + ) + + payment_url = ( + result.get("mini_app_invoice_url") + or result.get("bot_invoice_url") + or result.get("web_app_invoice_url") + ) + if not payment_url: + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail={"code": "payment_url_missing", "message": "Failed to obtain payment url"}, + ) + + extra_payload = { + "bot_invoice_url": result.get("bot_invoice_url"), + "mini_app_invoice_url": result.get("mini_app_invoice_url"), + "web_app_invoice_url": result.get("web_app_invoice_url"), + } + + message = _build_renewal_pending_message(user, missing_amount, method) + + return MiniAppSubscriptionRenewalResponse( + success=False, + message=message, + balance_kopeks=user.balance_kopeks, + balance_label=settings.format_price(user.balance_kopeks), + subscription_id=subscription.id, + requires_payment=True, + payment_method=method, + payment_url=payment_url, + payment_amount_kopeks=missing_amount, + payment_id=result.get("local_payment_id"), + invoice_id=result.get("invoice_id"), + payment_payload=payload_value, + payment_extra={key: value for key, value in extra_payload.items() if value}, + ) + + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail={"code": "unsupported_method", "message": "Payment method is not supported for renewal"}, ) @@ -4791,7 +4832,7 @@ async def subscription_purchase_endpoint( pass if subscription and transaction and period_days: - await _with_admin_notification_service( + await with_admin_notification_service( lambda service: service.send_subscription_purchase_notification( db, user, @@ -5020,7 +5061,7 @@ async def update_subscription_servers_endpoint( service = SubscriptionService() await service.update_remnawave_user(db, subscription) - await _with_admin_notification_service( + await with_admin_notification_service( lambda service: service.send_subscription_update_notification( db, user, @@ -5184,7 +5225,7 @@ async def update_subscription_traffic_endpoint( service = SubscriptionService() await service.update_remnawave_user(db, subscription) - await _with_admin_notification_service( + await with_admin_notification_service( lambda service: service.send_subscription_update_notification( db, user, @@ -5335,7 +5376,7 @@ async def update_subscription_devices_endpoint( service = SubscriptionService() await service.update_remnawave_user(db, subscription) - await _with_admin_notification_service( + await with_admin_notification_service( lambda service: service.send_subscription_update_notification( db, user, diff --git a/app/webapi/schemas/miniapp.py b/app/webapi/schemas/miniapp.py index ac31f350..5a41c124 100644 --- a/app/webapi/schemas/miniapp.py +++ b/app/webapi/schemas/miniapp.py @@ -202,6 +202,7 @@ class MiniAppSubscriptionRenewalRequest(BaseModel): subscription_id: Optional[int] = Field(default=None, alias="subscriptionId") period_id: Optional[str] = Field(default=None, alias="periodId") period_days: Optional[int] = Field(default=None, alias="periodDays") + method: Optional[str] = None model_config = ConfigDict(populate_by_name=True) @@ -213,6 +214,14 @@ class MiniAppSubscriptionRenewalResponse(BaseModel): balance_label: Optional[str] = Field(default=None, alias="balanceLabel") subscription_id: Optional[int] = Field(default=None, alias="subscriptionId") renewed_until: Optional[datetime] = Field(default=None, alias="renewedUntil") + requires_payment: bool = Field(default=False, alias="requiresPayment") + payment_method: Optional[str] = Field(default=None, alias="paymentMethod") + payment_url: Optional[str] = Field(default=None, alias="paymentUrl") + payment_amount_kopeks: Optional[int] = Field(default=None, alias="paymentAmountKopeks") + payment_id: Optional[int] = Field(default=None, alias="paymentId") + invoice_id: Optional[str] = Field(default=None, alias="invoiceId") + payment_payload: Optional[str] = Field(default=None, alias="paymentPayload") + payment_extra: Optional[Dict[str, Any]] = Field(default=None, alias="paymentExtra") model_config = ConfigDict(populate_by_name=True) diff --git a/tests/test_miniapp_payments.py b/tests/test_miniapp_payments.py index 9bbbb3a7..ffbffdc4 100644 --- a/tests/test_miniapp_payments.py +++ b/tests/test_miniapp_payments.py @@ -19,11 +19,20 @@ os.environ.setdefault('BOT_TOKEN', 'test-token') from app.config import settings from app.webapi.routes import miniapp from app.database.models import PaymentMethod +from app.services.subscription_renewal_service import ( + SubscriptionRenewalPricing, + SubscriptionRenewalResult, + build_payment_descriptor, + decode_payment_payload, + encode_payment_payload, +) +from app.services.payment.cryptobot import CryptoBotPaymentMixin from app.webapi.schemas.miniapp import ( MiniAppPaymentCreateRequest, MiniAppPaymentIntegrationType, MiniAppPaymentMethodsRequest, MiniAppPaymentStatusQuery, + MiniAppSubscriptionRenewalRequest, ) @@ -44,6 +53,572 @@ def test_compute_cryptobot_limits_scale_with_rate(): assert high_rate_max > low_rate_max +def test_encode_decode_renewal_payload_preserves_snapshot(): + pricing_model = SubscriptionRenewalPricing( + period_days=30, + period_id='days:30', + months=1, + base_original_total=12000, + discounted_total=10000, + final_total=9000, + promo_discount_value=1000, + promo_discount_percent=10, + overall_discount_percent=25, + per_month=9000, + server_ids=[1, 2], + details={'servers_price_per_month': 1000}, + ) + + descriptor = build_payment_descriptor( + user_id=1, + subscription_id=42, + period_days=30, + total_amount_kopeks=pricing_model.final_total, + missing_amount_kopeks=1000, + pricing_snapshot=pricing_model.to_payload(), + ) + + payload_value = encode_payment_payload(descriptor) + decoded = decode_payment_payload(payload_value, expected_user_id=1) + + assert decoded is not None + assert decoded.total_amount_kopeks == 9000 + assert decoded.missing_amount_kopeks == 1000 + assert decoded.pricing_snapshot is not None + assert decoded.pricing_snapshot.get('server_ids') == [1, 2] + + +@pytest.mark.anyio("asyncio") +async def test_submit_subscription_renewal_uses_balance_when_sufficient(monkeypatch): + monkeypatch.setattr(settings, 'ADMIN_NOTIFICATIONS_ENABLED', False, raising=False) + monkeypatch.setattr(settings, 'BOT_TOKEN', 'token', raising=False) + monkeypatch.setattr(settings, 'RESET_TRAFFIC_ON_PAYMENT', False, raising=False) + monkeypatch.setattr(settings, 'DEFAULT_LANGUAGE', 'ru', raising=False) + monkeypatch.setattr(settings, 'CRYPTOBOT_ENABLED', False, raising=False) + monkeypatch.setattr(settings, 'CRYPTOBOT_API_TOKEN', None, raising=False) + monkeypatch.setattr(type(settings), 'get_available_renewal_periods', lambda self: [30], raising=False) + + user = types.SimpleNamespace(id=10, balance_kopeks=10000, language='ru') + subscription = types.SimpleNamespace( + id=77, + connected_squads=[], + traffic_limit_gb=100, + device_limit=5, + end_date=datetime.utcnow(), + ) + + pricing_model = SubscriptionRenewalPricing( + period_days=30, + period_id='days:30', + months=1, + base_original_total=10000, + discounted_total=10000, + final_total=10000, + promo_discount_value=0, + promo_discount_percent=0, + overall_discount_percent=0, + per_month=10000, + server_ids=[], + details={}, + ) + + async def fake_authorize(init_data, db): # noqa: ARG001 + return user + + def fake_ensure(subscription_user, allowed_statuses=None): # noqa: ARG001 + return subscription + + async def fake_calculate(db, u, sub, period): # noqa: ARG001 + return pricing_model + + captured: dict[str, Any] = {} + + async def fake_finalize(db, u, sub, pricing, *, charge_balance_amount=None, description=None, payment_method=None): # noqa: ARG001 + charge = charge_balance_amount if charge_balance_amount is not None else pricing.final_total + captured['charge'] = charge + captured['description'] = description + return SubscriptionRenewalResult( + subscription=types.SimpleNamespace(id=sub.id, end_date=datetime.utcnow()), + transaction=types.SimpleNamespace(id=501), + total_amount_kopeks=pricing.final_total, + charged_from_balance_kopeks=charge, + old_end_date=sub.end_date, + ) + + monkeypatch.setattr(miniapp, '_authorize_miniapp_user', fake_authorize) + monkeypatch.setattr(miniapp, '_ensure_paid_subscription', fake_ensure) + monkeypatch.setattr(miniapp, '_validate_subscription_id', lambda *args, **kwargs: None) + monkeypatch.setattr(miniapp, '_calculate_subscription_renewal_pricing', fake_calculate) + monkeypatch.setattr(miniapp.renewal_service, 'finalize', fake_finalize) + + payload = MiniAppSubscriptionRenewalRequest( + initData='init', + subscriptionId=77, + periodId='days:30', + ) + + response = await miniapp.submit_subscription_renewal_endpoint(payload, db=types.SimpleNamespace()) + + assert response.success is True + assert response.requires_payment is False + assert response.subscription_id == 77 + assert response.renewed_until is not None + assert 'Подписка' in (response.message or '') + assert captured['charge'] == 10000 + + +@pytest.mark.anyio("asyncio") +async def test_submit_subscription_renewal_returns_cryptobot_invoice(monkeypatch): + monkeypatch.setattr(settings, 'ADMIN_NOTIFICATIONS_ENABLED', False, raising=False) + monkeypatch.setattr(settings, 'BOT_TOKEN', 'token', raising=False) + monkeypatch.setattr(settings, 'RESET_TRAFFIC_ON_PAYMENT', False, raising=False) + monkeypatch.setattr(settings, 'DEFAULT_LANGUAGE', 'ru', raising=False) + monkeypatch.setattr(settings, 'CRYPTOBOT_ENABLED', True, raising=False) + monkeypatch.setattr(settings, 'CRYPTOBOT_API_TOKEN', 'token', raising=False) + monkeypatch.setattr(settings, 'CRYPTOBOT_DEFAULT_ASSET', 'USDT', raising=False) + monkeypatch.setattr(type(settings), 'get_available_renewal_periods', lambda self: [30], raising=False) + + user = types.SimpleNamespace(id=15, balance_kopeks=5000, language='ru') + subscription = types.SimpleNamespace( + id=88, + connected_squads=[], + traffic_limit_gb=100, + device_limit=5, + end_date=datetime.utcnow(), + ) + + pricing_model = SubscriptionRenewalPricing( + period_days=30, + period_id='days:30', + months=1, + base_original_total=20000, + discounted_total=20000, + final_total=20000, + promo_discount_value=0, + promo_discount_percent=0, + overall_discount_percent=0, + per_month=20000, + server_ids=[], + details={}, + ) + + async def fake_authorize(init_data, db): # noqa: ARG001 + return user + + def fake_ensure(subscription_user, allowed_statuses=None): # noqa: ARG001 + return subscription + + async def fake_calculate(db, u, sub, period): # noqa: ARG001 + return pricing_model + + created_calls: dict[str, Any] = {} + + class DummyPaymentService: + def __init__(self, *args, **kwargs): + pass + + async def create_cryptobot_payment(self, db, **kwargs): + created_calls.update(kwargs) + return { + 'local_payment_id': 321, + 'invoice_id': 'inv_123', + 'bot_invoice_url': 'https://t.me/invoice', + 'mini_app_invoice_url': 'https://mini.app/pay', + 'web_app_invoice_url': None, + } + + async def fake_rate(): + return 100.0 + + monkeypatch.setattr(miniapp, '_authorize_miniapp_user', fake_authorize) + monkeypatch.setattr(miniapp, '_ensure_paid_subscription', fake_ensure) + monkeypatch.setattr(miniapp, '_validate_subscription_id', lambda *args, **kwargs: None) + monkeypatch.setattr(miniapp, '_calculate_subscription_renewal_pricing', fake_calculate) + monkeypatch.setattr(miniapp, 'PaymentService', lambda *args, **kwargs: DummyPaymentService()) + monkeypatch.setattr(miniapp, '_get_usd_to_rub_rate', fake_rate) + + payload = MiniAppSubscriptionRenewalRequest( + initData='init', + subscriptionId=88, + periodId='days:30', + method='cryptobot', + ) + + response = await miniapp.submit_subscription_renewal_endpoint(payload, db=types.SimpleNamespace()) + + assert response.success is False + assert response.requires_payment is True + assert response.payment_method == 'cryptobot' + assert response.payment_amount_kopeks == 15000 + assert response.payment_url == 'https://mini.app/pay' + assert response.invoice_id == 'inv_123' + assert response.payment_id == 321 + assert response.payment_payload and response.payment_payload.startswith('subscription_renewal') + assert created_calls.get('amount_usd') == pytest.approx(1.5) + assert created_calls.get('description') == 'Продление подписки на 30 дней' + + +@pytest.mark.anyio("asyncio") +async def test_submit_subscription_renewal_rounds_up_cryptobot_amount(monkeypatch): + monkeypatch.setattr(settings, 'ADMIN_NOTIFICATIONS_ENABLED', False, raising=False) + monkeypatch.setattr(settings, 'BOT_TOKEN', 'token', raising=False) + monkeypatch.setattr(settings, 'RESET_TRAFFIC_ON_PAYMENT', False, raising=False) + monkeypatch.setattr(settings, 'DEFAULT_LANGUAGE', 'ru', raising=False) + monkeypatch.setattr(settings, 'CRYPTOBOT_ENABLED', True, raising=False) + monkeypatch.setattr(settings, 'CRYPTOBOT_API_TOKEN', 'token', raising=False) + monkeypatch.setattr(settings, 'CRYPTOBOT_DEFAULT_ASSET', 'USDT', raising=False) + monkeypatch.setattr(type(settings), 'get_available_renewal_periods', lambda self: [30], raising=False) + + user = types.SimpleNamespace(id=42, balance_kopeks=0, language='ru') + subscription = types.SimpleNamespace( + id=99, + connected_squads=[], + traffic_limit_gb=100, + device_limit=5, + end_date=datetime.utcnow(), + ) + + pricing_model = SubscriptionRenewalPricing( + period_days=30, + period_id='days:30', + months=1, + base_original_total=9512, + discounted_total=9512, + final_total=9512, + promo_discount_value=0, + promo_discount_percent=0, + overall_discount_percent=0, + per_month=9512, + server_ids=[], + details={}, + ) + + async def fake_authorize(init_data, db): # noqa: ARG001 + return user + + def fake_ensure(subscription_user, allowed_statuses=None): # noqa: ARG001 + return subscription + + async def fake_calculate(db, u, sub, period): # noqa: ARG001 + return pricing_model + + captured: dict[str, Any] = {} + + class DummyPaymentService: + def __init__(self, *args, **kwargs): + pass + + async def create_cryptobot_payment(self, db, **kwargs): + captured.update(kwargs) + return { + 'local_payment_id': 654, + 'invoice_id': 'inv_round', + 'bot_invoice_url': 'https://t.me/pay', + 'mini_app_invoice_url': 'https://mini.app/pay-round', + 'web_app_invoice_url': None, + } + + async def fake_rate(): + return 95.0 + + monkeypatch.setattr(miniapp, '_authorize_miniapp_user', fake_authorize) + monkeypatch.setattr(miniapp, '_ensure_paid_subscription', fake_ensure) + monkeypatch.setattr(miniapp, '_validate_subscription_id', lambda *args, **kwargs: None) + monkeypatch.setattr(miniapp, '_calculate_subscription_renewal_pricing', fake_calculate) + monkeypatch.setattr(miniapp, 'PaymentService', lambda *args, **kwargs: DummyPaymentService()) + monkeypatch.setattr(miniapp, '_get_usd_to_rub_rate', fake_rate) + + payload = MiniAppSubscriptionRenewalRequest( + initData='init', + subscriptionId=99, + periodId='days:30', + method='cryptobot', + ) + + response = await miniapp.submit_subscription_renewal_endpoint(payload, db=types.SimpleNamespace()) + + assert response.requires_payment is True + assert captured.get('amount_usd') == pytest.approx(1.01) + assert response.payment_amount_kopeks == 9512 + + +@pytest.mark.anyio("asyncio") +async def test_cryptobot_renewal_uses_pricing_snapshot(monkeypatch): + module = sys.modules['app.services.payment.cryptobot'] + mixin = CryptoBotPaymentMixin() + + subscription = types.SimpleNamespace(id=77, connected_squads=[], traffic_limit_gb=100, device_limit=5) + user = types.SimpleNamespace(id=5, balance_kopeks=7000, subscription=subscription) + + pricing_model = SubscriptionRenewalPricing( + period_days=30, + period_id='days:30', + months=1, + base_original_total=12000, + discounted_total=10000, + final_total=10000, + promo_discount_value=0, + promo_discount_percent=0, + overall_discount_percent=0, + per_month=10000, + server_ids=[11, 22], + details={'servers_individual_prices': [500, 500]}, + ) + + descriptor = build_payment_descriptor( + user_id=5, + subscription_id=77, + period_days=30, + total_amount_kopeks=10000, + missing_amount_kopeks=3000, + pricing_snapshot=pricing_model.to_payload(), + ) + + payment = types.SimpleNamespace(invoice_id='INV-1', user_id=5) + + async def fake_get_user_by_id(db, user_id): # noqa: ARG001 + return user if user_id == 5 else None + + monkeypatch.setitem(sys.modules, 'app.services.payment_service', types.SimpleNamespace(get_user_by_id=fake_get_user_by_id)) + + async def fail_calculate(*args, **kwargs): # noqa: ARG001 + raise AssertionError('calculate_pricing should not be called when snapshot is present') + + monkeypatch.setattr(module.renewal_service, 'calculate_pricing', fail_calculate) + + captured: dict[str, Any] = {} + + async def fake_finalize(db, u, sub, pricing, *, charge_balance_amount=None, description=None, payment_method=None): # noqa: ARG001 + captured['pricing'] = pricing + captured['charge'] = charge_balance_amount + captured['description'] = description + captured['payment_method'] = payment_method + return SubscriptionRenewalResult( + subscription=types.SimpleNamespace(id=sub.id, end_date=datetime.utcnow()), + transaction=types.SimpleNamespace(id=999), + total_amount_kopeks=pricing.final_total, + charged_from_balance_kopeks=charge_balance_amount or pricing.final_total, + old_end_date=None, + ) + + monkeypatch.setattr(module.renewal_service, 'finalize', fake_finalize) + + async def fake_link(db, invoice_id, transaction_id): # noqa: ARG001 + captured['linked'] = (invoice_id, transaction_id) + + cryptobot_crud = types.SimpleNamespace(link_cryptobot_payment_to_transaction=fake_link) + + result = await mixin._process_subscription_renewal_payment( + db=types.SimpleNamespace(), + payment=payment, + descriptor=descriptor, + cryptobot_crud=cryptobot_crud, + ) + + assert result is True + assert captured['pricing'].server_ids == [11, 22] + assert captured['pricing'].final_total == 10000 + assert captured['charge'] == 7000 + assert captured['payment_method'] == PaymentMethod.CRYPTOBOT + assert captured['linked'] == ('INV-1', 999) + + +@pytest.mark.anyio("asyncio") +async def test_cryptobot_renewal_accepts_changed_pricing_without_snapshot(monkeypatch): + module = sys.modules['app.services.payment.cryptobot'] + mixin = CryptoBotPaymentMixin() + + subscription = types.SimpleNamespace(id=55, connected_squads=[], traffic_limit_gb=50, device_limit=3) + user = types.SimpleNamespace(id=8, balance_kopeks=4000, subscription=subscription) + + descriptor = build_payment_descriptor( + user_id=8, + subscription_id=55, + period_days=30, + total_amount_kopeks=5000, + missing_amount_kopeks=1000, + ) + + payment = types.SimpleNamespace(invoice_id='INV-2', user_id=8) + + async def fake_get_user_by_id(db, user_id): # noqa: ARG001 + return user if user_id == 8 else None + + monkeypatch.setitem(sys.modules, 'app.services.payment_service', types.SimpleNamespace(get_user_by_id=fake_get_user_by_id)) + + recalculated_pricing = SubscriptionRenewalPricing( + period_days=30, + period_id='days:30', + months=1, + base_original_total=5200, + discounted_total=5200, + final_total=5200, + promo_discount_value=0, + promo_discount_percent=0, + overall_discount_percent=0, + per_month=5200, + server_ids=[], + details={}, + ) + + async def fake_calculate(db, u, sub, period): # noqa: ARG001 + return recalculated_pricing + + monkeypatch.setattr(module.renewal_service, 'calculate_pricing', fake_calculate) + + captured: dict[str, Any] = {} + + async def fake_finalize(db, u, sub, pricing, *, charge_balance_amount=None, description=None, payment_method=None): # noqa: ARG001 + captured['pricing'] = pricing + captured['charge'] = charge_balance_amount + return SubscriptionRenewalResult( + subscription=types.SimpleNamespace(id=sub.id, end_date=datetime.utcnow()), + transaction=None, + total_amount_kopeks=pricing.final_total, + charged_from_balance_kopeks=charge_balance_amount or pricing.final_total, + old_end_date=None, + ) + + monkeypatch.setattr(module.renewal_service, 'finalize', fake_finalize) + + async def noop_link(*args, **kwargs): + return None + + cryptobot_crud = types.SimpleNamespace(link_cryptobot_payment_to_transaction=noop_link) + + result = await mixin._process_subscription_renewal_payment( + db=types.SimpleNamespace(), + payment=payment, + descriptor=descriptor, + cryptobot_crud=cryptobot_crud, + ) + + assert result is True + assert captured['pricing'].final_total == 5000 + assert captured['charge'] == 4000 + + +@pytest.mark.anyio("asyncio") +async def test_cryptobot_webhook_uses_inline_payload_when_db_missing(monkeypatch): + module = sys.modules['app.services.payment.cryptobot'] + mixin = CryptoBotPaymentMixin() + + subscription = types.SimpleNamespace(id=91, connected_squads=[], traffic_limit_gb=80, device_limit=4) + user = types.SimpleNamespace(id=21, balance_kopeks=6000, subscription=subscription) + + pricing_model = SubscriptionRenewalPricing( + period_days=30, + period_id='days:30', + months=1, + base_original_total=9000, + discounted_total=9000, + final_total=9000, + promo_discount_value=0, + promo_discount_percent=0, + overall_discount_percent=0, + per_month=9000, + server_ids=[5, 6], + details={'servers_individual_prices': [300, 300]}, + ) + + descriptor = build_payment_descriptor( + user_id=21, + subscription_id=91, + period_days=30, + total_amount_kopeks=9000, + missing_amount_kopeks=3000, + pricing_snapshot=pricing_model.to_payload(), + ) + + encoded_payload = encode_payment_payload(descriptor) + + payment = types.SimpleNamespace( + invoice_id='INV-webhook', + user_id=21, + status='active', + payload=None, + amount='90.00', + asset='USDT', + bot_invoice_url=None, + mini_app_invoice_url=None, + web_app_invoice_url=None, + description='Продление подписки', + ) + + async def fake_get_user_by_id(db, user_id): # noqa: ARG001 + return user if user_id == 21 else None + + monkeypatch.setitem( + sys.modules, + 'app.services.payment_service', + types.SimpleNamespace(get_user_by_id=fake_get_user_by_id), + ) + + async def fail_calculate(*args, **kwargs): # noqa: ARG001 + raise AssertionError('calculate_pricing should not be called') + + monkeypatch.setattr(module.renewal_service, 'calculate_pricing', fail_calculate) + + captured: dict[str, Any] = {} + + async def fake_finalize(db, u, sub, pricing, *, charge_balance_amount=None, description=None, payment_method=None): # noqa: ARG001 + captured['pricing'] = pricing + captured['charge'] = charge_balance_amount + captured['description'] = description + captured['payment_method'] = payment_method + return SubscriptionRenewalResult( + subscription=types.SimpleNamespace(id=sub.id, end_date=datetime.utcnow()), + transaction=types.SimpleNamespace(id=1234), + total_amount_kopeks=pricing.final_total, + charged_from_balance_kopeks=charge_balance_amount or pricing.final_total, + old_end_date=None, + ) + + monkeypatch.setattr(module.renewal_service, 'finalize', fake_finalize) + + linked: dict[str, Any] = {} + + async def fake_get(db, invoice_id): # noqa: ARG001 + return payment if invoice_id == payment.invoice_id else None + + async def fake_update(db, invoice_id, status, paid_at): # noqa: ARG001 + if invoice_id == payment.invoice_id: + payment.status = status + payment.paid_at = paid_at + return payment + + async def fake_link(db, invoice_id, transaction_id): # noqa: ARG001 + linked['value'] = (invoice_id, transaction_id) + return payment + + monkeypatch.setitem( + sys.modules, + 'app.database.crud.cryptobot', + types.SimpleNamespace( + get_cryptobot_payment_by_invoice_id=fake_get, + update_cryptobot_payment_status=fake_update, + link_cryptobot_payment_to_transaction=fake_link, + ), + ) + + webhook_payload = { + 'update_type': 'invoice_paid', + 'payload': { + 'invoice_id': payment.invoice_id, + 'paid_at': '2024-05-01T12:00:00Z', + 'payload': encoded_payload, + }, + } + + result = await mixin.process_cryptobot_webhook(types.SimpleNamespace(), webhook_payload) + + assert result is True + assert captured['pricing'].final_total == 9000 + assert captured['charge'] == 6000 + assert captured['payment_method'] == PaymentMethod.CRYPTOBOT + assert linked['value'] == (payment.invoice_id, 1234) + + @pytest.mark.anyio("asyncio") async def test_create_payment_link_pal24_uses_selected_option(monkeypatch): monkeypatch.setattr(settings, 'PAL24_ENABLED', True, raising=False)