Files
remnawave-bedolaga-telegram…/app/database/crud/discount_offer.py
2025-10-09 05:19:44 +03:00

276 lines
7.9 KiB
Python

from __future__ import annotations
import logging
from datetime import datetime, timedelta
from typing import List, Optional
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.database.crud.promo_offer_log import log_promo_offer_action
from app.database.models import DiscountOffer
logger = logging.getLogger(__name__)
async def upsert_discount_offer(
db: AsyncSession,
*,
user_id: int,
subscription_id: Optional[int],
notification_type: str,
discount_percent: int,
bonus_amount_kopeks: int,
valid_hours: int,
effect_type: str = "percent_discount",
extra_data: Optional[dict] = None,
) -> DiscountOffer:
"""Create or refresh a discount offer for a user."""
expires_at = datetime.utcnow() + timedelta(hours=valid_hours)
result = await db.execute(
select(DiscountOffer)
.where(
DiscountOffer.user_id == user_id,
DiscountOffer.notification_type == notification_type,
DiscountOffer.is_active == True, # noqa: E712
)
.order_by(DiscountOffer.created_at.desc())
)
offer = result.scalars().first()
if offer and offer.claimed_at is None:
offer.discount_percent = discount_percent
offer.bonus_amount_kopeks = bonus_amount_kopeks
offer.expires_at = expires_at
offer.subscription_id = subscription_id
offer.effect_type = effect_type
offer.extra_data = extra_data
else:
offer = DiscountOffer(
user_id=user_id,
subscription_id=subscription_id,
notification_type=notification_type,
discount_percent=discount_percent,
bonus_amount_kopeks=bonus_amount_kopeks,
expires_at=expires_at,
is_active=True,
effect_type=effect_type,
extra_data=extra_data,
)
db.add(offer)
await db.commit()
await db.refresh(offer)
return offer
async def get_offer_by_id(db: AsyncSession, offer_id: int) -> Optional[DiscountOffer]:
result = await db.execute(
select(DiscountOffer)
.options(
selectinload(DiscountOffer.user),
selectinload(DiscountOffer.subscription),
)
.where(DiscountOffer.id == offer_id)
)
return result.scalar_one_or_none()
async def list_discount_offers(
db: AsyncSession,
*,
offset: int = 0,
limit: int = 50,
user_id: Optional[int] = None,
notification_type: Optional[str] = None,
is_active: Optional[bool] = None,
) -> List[DiscountOffer]:
stmt = (
select(DiscountOffer)
.options(
selectinload(DiscountOffer.user),
selectinload(DiscountOffer.subscription),
)
.order_by(DiscountOffer.created_at.desc())
.offset(offset)
.limit(limit)
)
if user_id is not None:
stmt = stmt.where(DiscountOffer.user_id == user_id)
if notification_type:
stmt = stmt.where(DiscountOffer.notification_type == notification_type)
if is_active is not None:
stmt = stmt.where(DiscountOffer.is_active == is_active)
result = await db.execute(stmt)
return result.scalars().all()
async def list_active_discount_offers_for_user(
db: AsyncSession,
user_id: int,
) -> List[DiscountOffer]:
"""Return active (not yet claimed) offers for a user."""
now = datetime.utcnow()
stmt = (
select(DiscountOffer)
.options(
selectinload(DiscountOffer.user),
selectinload(DiscountOffer.subscription),
)
.where(
DiscountOffer.user_id == user_id,
DiscountOffer.is_active == True, # noqa: E712
DiscountOffer.expires_at > now,
)
.order_by(DiscountOffer.expires_at.asc())
)
result = await db.execute(stmt)
return result.scalars().all()
async def count_discount_offers(
db: AsyncSession,
*,
user_id: Optional[int] = None,
notification_type: Optional[str] = None,
is_active: Optional[bool] = None,
) -> int:
stmt = select(func.count(DiscountOffer.id))
if user_id is not None:
stmt = stmt.where(DiscountOffer.user_id == user_id)
if notification_type:
stmt = stmt.where(DiscountOffer.notification_type == notification_type)
if is_active is not None:
stmt = stmt.where(DiscountOffer.is_active == is_active)
result = await db.execute(stmt)
return int(result.scalar() or 0)
async def mark_offer_claimed(
db: AsyncSession,
offer: DiscountOffer,
*,
details: Optional[dict] = None,
) -> DiscountOffer:
offer.claimed_at = datetime.utcnow()
offer.is_active = False
await db.commit()
await db.refresh(offer)
try:
await log_promo_offer_action(
db,
user_id=offer.user_id,
offer_id=offer.id,
action="claimed",
source=offer.notification_type,
percent=offer.discount_percent,
effect_type=offer.effect_type,
details=details,
)
except Exception as exc: # pragma: no cover - defensive logging
logger.warning(
"Failed to record promo offer claim log for offer %s: %s",
offer.id,
exc,
)
try:
await db.rollback()
except Exception as rollback_error: # pragma: no cover - defensive logging
logger.warning(
"Failed to rollback session after promo offer claim log failure: %s",
rollback_error,
)
return offer
async def deactivate_expired_offers(db: AsyncSession) -> int:
now = datetime.utcnow()
result = await db.execute(
select(DiscountOffer).where(
DiscountOffer.is_active == True, # noqa: E712
DiscountOffer.expires_at < now,
)
)
offers = result.scalars().all()
if not offers:
return 0
count = 0
log_payloads = []
for offer in offers:
offer.is_active = False
count += 1
log_payloads.append(
{
"user_id": offer.user_id,
"offer_id": offer.id,
"source": offer.notification_type,
"percent": offer.discount_percent,
"effect_type": offer.effect_type,
}
)
await db.commit()
for payload in log_payloads:
if not payload.get("user_id"):
continue
try:
await log_promo_offer_action(
db,
user_id=payload["user_id"],
offer_id=payload["offer_id"],
action="disabled",
source=payload.get("source"),
percent=payload.get("percent"),
effect_type=payload.get("effect_type"),
details={"reason": "offer_expired"},
)
except Exception as exc: # pragma: no cover - defensive logging
logger.warning(
"Failed to record promo offer disable log for offer %s: %s",
payload.get("offer_id"),
exc,
)
try:
await db.rollback()
except Exception as rollback_error: # pragma: no cover - defensive logging
logger.warning(
"Failed to rollback session after promo offer disable log failure: %s",
rollback_error,
)
return count
async def get_latest_claimed_offer_for_user(
db: AsyncSession,
user_id: int,
source: Optional[str] = None,
) -> Optional[DiscountOffer]:
stmt = (
select(DiscountOffer)
.where(
DiscountOffer.user_id == user_id,
DiscountOffer.claimed_at.isnot(None),
)
.order_by(DiscountOffer.claimed_at.desc())
)
if source:
stmt = stmt.where(DiscountOffer.notification_type == source)
result = await db.execute(stmt)
return result.scalars().first()