Merge pull request #1089 from Fr1ngg/64afzg-bedolaga/fix-subscription-renewal-section-in-miniapp

Implement mini app subscription renewal flow
This commit is contained in:
Egor
2025-10-10 11:57:18 +03:00
committed by GitHub
2 changed files with 709 additions and 1 deletions

View File

@@ -14,7 +14,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.config import settings
from app.config import PERIOD_PRICES, settings
from app.database.crud.discount_offer import (
get_latest_claimed_offer_for_user,
get_offer_by_id,
@@ -26,6 +26,7 @@ from app.database.crud.rules import get_rules_by_language
from app.database.crud.promo_offer_template import get_promo_offer_template_by_id
from app.database.crud.server_squad import (
get_available_server_squads,
get_server_ids_by_uuids,
get_server_squad_by_uuid,
add_user_to_servers,
remove_user_from_servers,
@@ -40,12 +41,14 @@ from app.database.models import (
PromoGroup,
PromoOfferTemplate,
Subscription,
SubscriptionStatus,
SubscriptionTemporaryAccess,
Transaction,
TransactionType,
PaymentMethod,
User,
)
from app.services.admin_notification_service import AdminNotificationService
from app.services.faq_service import FaqService
from app.services.privacy_policy_service import PrivacyPolicyService
from app.services.public_offer_service import PublicOfferService
@@ -75,9 +78,12 @@ from app.utils.user_utils import (
)
from app.utils.pricing_utils import (
apply_percentage_discount,
calculate_months_from_days,
calculate_prorated_price,
get_remaining_months,
validate_pricing_calculation,
)
from app.utils.promo_offer import get_user_active_promo_discount_percent
from ..dependencies import get_db_session
from ..schemas.miniapp import (
@@ -112,6 +118,12 @@ from ..schemas.miniapp import (
MiniAppReferralStats,
MiniAppReferralTerms,
MiniAppRichTextDocument,
MiniAppSubscriptionRenewal,
MiniAppSubscriptionRenewalPromoOffer,
MiniAppSubscriptionRenewalOptionsRequest,
MiniAppSubscriptionRenewalOptionsResponse,
MiniAppSubscriptionRenewalRequest,
MiniAppSubscriptionRenewalResponse,
MiniAppSubscriptionRequest,
MiniAppSubscriptionResponse,
MiniAppSubscriptionUser,
@@ -2325,6 +2337,7 @@ async def get_subscription_details(
promo_offer_discount_source=promo_offer_source,
)
renewal_payload = await _build_subscription_renewal_payload(db, user, subscription)
referral_info = await _build_referral_info(db, user)
return MiniAppSubscriptionResponse(
@@ -2368,6 +2381,8 @@ async def get_subscription_details(
faq=faq_payload,
legal_documents=legal_documents_payload,
referral=referral_info,
subscription_renewal=renewal_payload,
renewal=renewal_payload,
)
@@ -2820,6 +2835,371 @@ def _validate_subscription_id(
)
def _get_discount_percent_for_user(
user: Optional[User],
category: str,
period_days: Optional[int],
) -> int:
if not user:
return 0
try:
percent = user.get_promo_discount(category, period_days)
except AttributeError:
return 0
try:
return int(percent)
except (TypeError, ValueError):
return 0
def _resolve_currency_code(user: Optional[User]) -> str:
currency = getattr(user, "balance_currency", None)
if isinstance(currency, str) and currency.strip():
return currency.strip().upper()
return "RUB"
async def _calculate_subscription_renewal_pricing(
db: AsyncSession,
user: User,
subscription: Subscription,
period_days: int,
*,
subscription_service: Optional[SubscriptionService] = None,
) -> Optional[Dict[str, Any]]:
if period_days is None or period_days <= 0:
return None
service = subscription_service or SubscriptionService()
try:
months_in_period = calculate_months_from_days(period_days)
if months_in_period <= 0:
months_in_period = 1
base_price_original = PERIOD_PRICES.get(period_days, 0)
period_discount_percent = _get_discount_percent_for_user(
user,
"period",
period_days,
)
base_price, base_discount_total = apply_percentage_discount(
base_price_original,
period_discount_percent,
)
servers_price_per_month, per_server_monthly_prices = (
await service.get_countries_price_by_uuids(
subscription.connected_squads,
db,
promo_group_id=getattr(user, "promo_group_id", None),
)
)
servers_discount_percent = _get_discount_percent_for_user(
user,
"servers",
period_days,
)
servers_discount_per_month = servers_price_per_month * servers_discount_percent // 100
discounted_servers_per_month = servers_price_per_month - servers_discount_per_month
total_servers_price = discounted_servers_per_month * months_in_period
server_uuid_prices: Dict[str, int] = {}
if subscription.connected_squads and per_server_monthly_prices:
for squad_uuid, monthly_price in zip(
subscription.connected_squads,
per_server_monthly_prices,
):
discount_per_month = monthly_price * servers_discount_percent // 100
discounted_per_month = monthly_price - discount_per_month
server_uuid_prices[squad_uuid] = discounted_per_month * months_in_period
additional_devices = max(0, subscription.device_limit - settings.DEFAULT_DEVICE_LIMIT)
devices_price_per_month = additional_devices * settings.PRICE_PER_DEVICE
devices_discount_percent = _get_discount_percent_for_user(
user,
"devices",
period_days,
)
devices_discount_per_month = devices_price_per_month * devices_discount_percent // 100
discounted_devices_per_month = devices_price_per_month - devices_discount_per_month
total_devices_price = discounted_devices_per_month * months_in_period
traffic_price_per_month = settings.get_traffic_price(subscription.traffic_limit_gb)
traffic_discount_percent = _get_discount_percent_for_user(
user,
"traffic",
period_days,
)
traffic_discount_per_month = traffic_price_per_month * traffic_discount_percent // 100
discounted_traffic_per_month = traffic_price_per_month - traffic_discount_per_month
total_traffic_price = discounted_traffic_per_month * months_in_period
subtotal_price = (
base_price
+ total_servers_price
+ total_devices_price
+ total_traffic_price
)
monthly_additions = (
discounted_servers_per_month
+ discounted_devices_per_month
+ discounted_traffic_per_month
)
if not validate_pricing_calculation(
base_price,
monthly_additions,
months_in_period,
subtotal_price,
):
logger.error(
"Invalid renewal pricing calculation for subscription %s and period %s",
getattr(subscription, "id", None),
period_days,
)
return None
original_total_price = (
base_price_original
+ servers_price_per_month * months_in_period
+ devices_price_per_month * months_in_period
+ traffic_price_per_month * months_in_period
)
promo_percent = get_user_active_promo_discount_percent(user)
final_price = subtotal_price
promo_discount_value = 0
if subtotal_price > 0 and promo_percent > 0:
final_price, promo_discount_value = apply_percentage_discount(
subtotal_price,
promo_percent,
)
discount_percent = 0
if original_total_price > 0 and final_price < original_total_price:
discount_percent = int(
round(
(original_total_price - final_price)
* 100
/ original_total_price
)
)
discount_percent = max(0, min(100, discount_percent))
if months_in_period > 0:
price_per_month = int(
Decimal(final_price)
/ Decimal(months_in_period)
.quantize(Decimal("1"), rounding=ROUND_HALF_UP)
)
else:
price_per_month = final_price
return {
"period_days": period_days,
"months": months_in_period,
"base_price": base_price,
"base_price_original": base_price_original,
"base_discount_percent": period_discount_percent,
"base_discount_value": base_discount_total,
"servers_price_per_month": servers_price_per_month,
"servers_discount_per_month": servers_discount_per_month,
"servers_discount_percent": servers_discount_percent,
"server_uuid_prices": server_uuid_prices,
"total_servers_price": total_servers_price,
"devices_price_per_month": devices_price_per_month,
"devices_discount_per_month": devices_discount_per_month,
"devices_discount_percent": devices_discount_percent,
"total_devices_price": total_devices_price,
"traffic_price_per_month": traffic_price_per_month,
"traffic_discount_per_month": traffic_discount_per_month,
"traffic_discount_percent": traffic_discount_percent,
"total_traffic_price": total_traffic_price,
"subtotal_price": subtotal_price,
"original_total_price": original_total_price,
"final_price": final_price,
"discount_percent": discount_percent,
"price_per_month": price_per_month,
"promo_discount_percent": promo_percent if promo_discount_value else 0,
"promo_discount_value": promo_discount_value,
}
except Exception as error: # pragma: no cover - defensive logging
logger.error(
"Failed to calculate renewal pricing for subscription %s: %s",
getattr(subscription, "id", None),
error,
)
return None
async def _build_subscription_renewal_payload(
db: AsyncSession,
user: User,
subscription: Subscription,
) -> MiniAppSubscriptionRenewal:
available_periods = [
period for period in settings.get_available_renewal_periods() if period and period > 0
]
currency = _resolve_currency_code(user)
balance_kopeks = getattr(user, "balance_kopeks", None)
balance_label = (
settings.format_price(balance_kopeks)
if balance_kopeks is not None
else None
)
service = SubscriptionService()
periods: List[MiniAppSubscriptionRenewalPeriod] = []
recommended_id: Optional[str] = None
best_metric: Optional[Tuple[int, int, int]] = None
minimal_price: Optional[int] = None
for period_days in available_periods:
pricing = await _calculate_subscription_renewal_pricing(
db,
user,
subscription,
period_days,
subscription_service=service,
)
if not pricing:
continue
final_price = pricing["final_price"]
original_price = pricing["original_total_price"]
discount_percent = pricing["discount_percent"] or 0
months_in_period = pricing["months"]
price_per_month = pricing["price_per_month"] if months_in_period > 0 else None
if minimal_price is None or (
final_price is not None and final_price < minimal_price
):
minimal_price = final_price
price_label = settings.format_price(final_price) if final_price is not None else None
original_label = (
settings.format_price(original_price)
if original_price and original_price > final_price
else None
)
price_per_month_label = (
settings.format_price(price_per_month)
if price_per_month is not None and months_in_period > 0
else None
)
period_id = str(period_days)
periods.append(
MiniAppSubscriptionRenewalPeriod(
id=period_id,
period_id=period_id,
days=period_days,
period_days=period_days,
months=months_in_period,
period_months=months_in_period,
price_kopeks=final_price,
price_label=price_label,
original_price_kopeks=original_price,
original_price_label=original_label,
discount_percent=discount_percent if discount_percent > 0 else None,
price_per_month_kopeks=price_per_month,
price_per_month_label=price_per_month_label,
)
)
metric = (
price_per_month if price_per_month is not None else final_price or 0,
-discount_percent,
-period_days,
)
if recommended_id is None or (best_metric and metric < best_metric) or best_metric is None:
recommended_id = period_id
best_metric = metric
if recommended_id and periods:
for period in periods:
if period.id == recommended_id:
period.is_recommended = True
break
promo_group = getattr(user, "promo_group", None)
promo_group_payload: Optional[MiniAppPromoGroup]
if promo_group:
promo_group_payload = MiniAppPromoGroup(
id=promo_group.id,
name=promo_group.name,
**_extract_promo_discounts(promo_group),
)
else:
promo_group_payload = None
promo_percent = get_user_active_promo_discount_percent(user)
promo_offer_payload = None
if promo_percent and promo_percent > 0:
promo_offer_payload = MiniAppSubscriptionRenewalPromoOffer(
percent=promo_percent,
expires_at=getattr(user, "promo_offer_discount_expires_at", None),
)
missing_amount = None
if minimal_price is not None and balance_kopeks is not None:
if balance_kopeks < minimal_price:
missing_amount = minimal_price - balance_kopeks
missing_label = (
settings.format_price(missing_amount)
if missing_amount is not None and missing_amount > 0
else None
)
default_period_id = recommended_id or (periods[0].id if periods else None)
return MiniAppSubscriptionRenewal(
subscription_id=getattr(subscription, "id", None),
currency=currency,
balance_kopeks=balance_kopeks,
balance_label=balance_label,
periods=periods,
default_period_id=default_period_id,
promo_group=promo_group_payload,
promo_offer=promo_offer_payload,
missing_amount_kopeks=missing_amount,
missing_amount_label=missing_label,
)
def _resolve_period_days_from_payload(
payload: MiniAppSubscriptionRenewalRequest,
) -> Optional[int]:
if payload.period_days is not None:
try:
period_days = int(payload.period_days)
if period_days > 0:
return period_days
except (TypeError, ValueError):
return None
period_id = payload.period_id
if isinstance(period_id, str) and period_id.strip():
candidate = period_id.strip()
if candidate.isdigit():
return int(candidate)
match = re.search(r"(\d+)", candidate)
if match:
try:
return int(match.group(1))
except (TypeError, ValueError):
return None
return None
async def _authorize_miniapp_user(
init_data: str,
db: AsyncSession,
@@ -3251,6 +3631,211 @@ async def subscription_purchase_endpoint(
)
@router.post(
"/subscription/renewal/options",
response_model=MiniAppSubscriptionRenewalOptionsResponse,
)
async def get_subscription_renewal_options_endpoint(
payload: MiniAppSubscriptionRenewalOptionsRequest,
db: AsyncSession = Depends(get_db_session),
) -> MiniAppSubscriptionRenewalOptionsResponse:
user = await _authorize_miniapp_user(payload.init_data, db)
subscription = _ensure_paid_subscription(user)
_validate_subscription_id(payload.subscription_id, subscription)
renewal_payload = await _build_subscription_renewal_payload(db, user, subscription)
balance_kopeks = getattr(user, "balance_kopeks", None)
balance_label = renewal_payload.balance_label or (
settings.format_price(balance_kopeks)
if balance_kopeks is not None
else None
)
currency = renewal_payload.currency or _resolve_currency_code(user)
return MiniAppSubscriptionRenewalOptionsResponse(
currency=currency,
balance_kopeks=balance_kopeks,
balance_label=balance_label,
subscription_id=subscription.id,
renewal=renewal_payload,
)
@router.post(
"/subscription/renewal",
response_model=MiniAppSubscriptionRenewalResponse,
)
async def submit_subscription_renewal_endpoint(
payload: MiniAppSubscriptionRenewalRequest,
db: AsyncSession = Depends(get_db_session),
) -> MiniAppSubscriptionRenewalResponse:
user = await _authorize_miniapp_user(payload.init_data, db)
subscription = _ensure_paid_subscription(user)
_validate_subscription_id(payload.subscription_id, subscription)
period_days = _resolve_period_days_from_payload(payload)
if not period_days:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail={
"code": "validation_error",
"message": "Renewal period must be specified",
},
)
available_periods = {period for period in settings.get_available_renewal_periods() if period > 0}
if period_days not in available_periods:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail={
"code": "invalid_period",
"message": "Selected renewal period is not available",
},
)
pricing = await _calculate_subscription_renewal_pricing(db, user, subscription, period_days)
if not pricing:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "pricing_error",
"message": "Unable to calculate renewal pricing",
},
)
final_price = max(0, int(pricing.get("final_price") or 0))
balance_before = getattr(user, "balance_kopeks", 0)
if final_price > balance_before:
missing = final_price - balance_before
raise HTTPException(
status.HTTP_402_PAYMENT_REQUIRED,
detail={
"code": "insufficient_funds",
"message": (
"Недостаточно средств на балансе. "
f"Не хватает {settings.format_price(missing)}"
),
},
)
months_in_period = pricing.get("months") or calculate_months_from_days(period_days)
description = (
f"Продление подписки на {period_days} дней ({months_in_period} мес)"
if months_in_period
else f"Продление подписки на {period_days} дней"
)
success = await subtract_user_balance(
db,
user,
final_price,
description,
consume_promo_offer=pricing.get("promo_discount_value", 0) > 0,
)
if not success:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "balance_error",
"message": "Failed to charge balance for renewal",
},
)
old_end_date = subscription.end_date
now = datetime.utcnow()
if subscription.end_date and subscription.end_date > now:
new_end_date = subscription.end_date + timedelta(days=period_days)
else:
new_end_date = now + timedelta(days=period_days)
subscription.end_date = new_end_date
subscription.status = SubscriptionStatus.ACTIVE.value
subscription.updated_at = now
await db.commit()
await db.refresh(subscription)
await db.refresh(user)
server_ids = await get_server_ids_by_uuids(db, subscription.connected_squads or [])
if server_ids:
from sqlalchemy import select as sa_select
from app.database.models import ServerSquad
result = await db.execute(
sa_select(ServerSquad.id, ServerSquad.squad_uuid).where(ServerSquad.id.in_(server_ids))
)
id_to_uuid = {row.id: row.squad_uuid for row in result}
total_servers_price = int(pricing.get("total_servers_price") or 0)
default_price = (
total_servers_price // len(server_ids)
if server_ids and total_servers_price > 0
else 0
)
server_prices = [
int(pricing.get("server_uuid_prices", {}).get(id_to_uuid.get(server_id, ""), default_price))
for server_id in server_ids
]
await add_subscription_servers(db, subscription, server_ids, server_prices)
subscription_service = SubscriptionService()
try:
await subscription_service.update_remnawave_user(
db,
subscription,
reset_traffic=settings.RESET_TRAFFIC_ON_PAYMENT,
reset_reason="продление подписки",
)
except Exception as error: # pragma: no cover - defensive logging
logger.warning(
"Failed to update RemnaWave after renewal for user %s: %s",
getattr(user, "telegram_id", None),
error,
)
transaction = await create_transaction(
db=db,
user_id=user.id,
type=TransactionType.SUBSCRIPTION_PAYMENT,
amount_kopeks=final_price,
description=description,
)
if settings.BOT_TOKEN:
bot = Bot(token=settings.BOT_TOKEN)
try:
notification_service = AdminNotificationService(bot)
await notification_service.send_subscription_extension_notification(
db,
user,
subscription,
transaction,
period_days,
old_end_date,
new_end_date=new_end_date,
balance_after=user.balance_kopeks,
)
except Exception as error: # pragma: no cover - defensive logging
logger.warning(
"Failed to send renewal notification for user %s: %s",
getattr(user, "telegram_id", None),
error,
)
finally:
await bot.session.close()
renewal_payload = await _build_subscription_renewal_payload(db, user, subscription)
balance_after = getattr(user, "balance_kopeks", 0)
return MiniAppSubscriptionRenewalResponse(
balance_kopeks=balance_after,
balance_label=settings.format_price(balance_after),
subscription_id=subscription.id,
renewal=renewal_payload,
)
@router.post(
"/subscription/settings",
response_model=MiniAppSubscriptionSettingsResponse,

View File

@@ -319,6 +319,69 @@ class MiniAppPaymentStatusResponse(BaseModel):
results: List[MiniAppPaymentStatusResult] = Field(default_factory=list)
class MiniAppSubscriptionRenewalPromoOffer(BaseModel):
percent: int
expires_at: Optional[datetime] = Field(default=None, alias="expiresAt")
title: Optional[str] = None
message: Optional[str] = None
model_config = ConfigDict(populate_by_name=True)
class MiniAppSubscriptionRenewalPeriod(BaseModel):
id: str
period_id: Optional[str] = Field(default=None, alias="periodId")
days: Optional[int] = None
period_days: Optional[int] = Field(default=None, alias="periodDays")
months: Optional[int] = None
period_months: Optional[int] = Field(default=None, alias="periodMonths")
price_kopeks: Optional[int] = Field(default=None, alias="priceKopeks")
price_label: Optional[str] = Field(default=None, alias="priceLabel")
original_price_kopeks: Optional[int] = Field(default=None, alias="originalPriceKopeks")
original_price_label: Optional[str] = Field(default=None, alias="originalPriceLabel")
discount_percent: Optional[int] = Field(default=None, alias="discountPercent")
price_per_month_kopeks: Optional[int] = Field(default=None, alias="pricePerMonthKopeks")
price_per_month_label: Optional[str] = Field(default=None, alias="pricePerMonthLabel")
is_recommended: bool = Field(default=False, alias="isRecommended")
badge: Optional[str] = None
title: Optional[str] = None
description: Optional[str] = None
model_config = ConfigDict(populate_by_name=True)
@model_validator(mode="before")
@classmethod
def _ensure_aliases(cls, values: Any) -> Any:
if isinstance(values, dict):
if "id" in values and "periodId" not in values:
values.setdefault("periodId", values["id"])
if values.get("days") is None and values.get("periodDays") is not None:
values["days"] = values.get("periodDays")
if values.get("periodDays") is None and values.get("days") is not None:
values["periodDays"] = values.get("days")
if values.get("months") is None and values.get("periodMonths") is not None:
values["months"] = values.get("periodMonths")
if values.get("periodMonths") is None and values.get("months") is not None:
values["periodMonths"] = values.get("months")
return values
class MiniAppSubscriptionRenewal(BaseModel):
subscription_id: Optional[int] = Field(default=None, alias="subscriptionId")
currency: str = "RUB"
balance_kopeks: Optional[int] = Field(default=None, alias="balanceKopeks")
balance_label: Optional[str] = Field(default=None, alias="balanceLabel")
periods: List[MiniAppSubscriptionRenewalPeriod] = Field(default_factory=list)
default_period_id: Optional[str] = Field(default=None, alias="defaultPeriodId")
promo_group: Optional[MiniAppPromoGroup] = Field(default=None, alias="promoGroup")
promo_offer: Optional[MiniAppSubscriptionRenewalPromoOffer] = Field(default=None, alias="promoOffer")
missing_amount_kopeks: Optional[int] = Field(default=None, alias="missingAmountKopeks")
missing_amount_label: Optional[str] = Field(default=None, alias="missingAmountLabel")
status_message: Optional[str] = Field(default=None, alias="statusMessage")
model_config = ConfigDict(populate_by_name=True)
class MiniAppSubscriptionResponse(BaseModel):
success: bool = True
subscription_id: int
@@ -353,6 +416,11 @@ class MiniAppSubscriptionResponse(BaseModel):
faq: Optional[MiniAppFaq] = None
legal_documents: Optional[MiniAppLegalDocuments] = None
referral: Optional[MiniAppReferralInfo] = None
subscription_renewal: Optional[MiniAppSubscriptionRenewal] = Field(
default=None,
alias="subscriptionRenewal",
)
renewal: Optional[MiniAppSubscriptionRenewal] = None
class MiniAppSubscriptionServerOption(BaseModel):
@@ -525,6 +593,61 @@ class MiniAppSubscriptionUpdateResponse(BaseModel):
message: Optional[str] = None
class MiniAppSubscriptionRenewalOptionsRequest(BaseModel):
init_data: str = Field(..., alias="initData")
subscription_id: Optional[int] = Field(default=None, alias="subscriptionId")
model_config = ConfigDict(populate_by_name=True)
class MiniAppSubscriptionRenewalOptionsResponse(BaseModel):
success: bool = True
currency: str
balance_kopeks: Optional[int] = Field(default=None, alias="balanceKopeks")
balance_label: Optional[str] = Field(default=None, alias="balanceLabel")
subscription_id: Optional[int] = Field(default=None, alias="subscriptionId")
renewal: MiniAppSubscriptionRenewal
model_config = ConfigDict(populate_by_name=True)
class MiniAppSubscriptionRenewalRequest(BaseModel):
init_data: str = Field(..., alias="initData")
subscription_id: Optional[int] = Field(default=None, alias="subscriptionId")
period_id: Optional[str] = Field(default=None, alias="periodId")
period_days: Optional[int] = Field(default=None, alias="periodDays")
model_config = ConfigDict(populate_by_name=True)
@model_validator(mode="before")
@classmethod
def _populate_period_fields(cls, values: Any) -> Any:
if isinstance(values, dict):
alias_map = {
"period_id": ["periodId", "period"],
"period_days": ["periodDays", "days", "durationDays", "duration_days"],
}
for target, sources in alias_map.items():
if values.get(target) is not None:
continue
for source in sources:
if source in values and values[source] is not None:
values[target] = values[source]
break
return values
class MiniAppSubscriptionRenewalResponse(BaseModel):
success: bool = True
message: Optional[str] = None
balance_kopeks: Optional[int] = Field(default=None, alias="balanceKopeks")
balance_label: Optional[str] = Field(default=None, alias="balanceLabel")
subscription_id: Optional[int] = Field(default=None, alias="subscriptionId")
renewal: Optional[MiniAppSubscriptionRenewal] = None
model_config = ConfigDict(populate_by_name=True)
class MiniAppSubscriptionPurchaseOptionsRequest(BaseModel):
init_data: str = Field(..., alias="initData")