mirror of
https://github.com/BEDOLAGA-DEV/remnawave-bedolaga-telegram-bot.git
synced 2026-02-10 22:20:24 +00:00
736 lines
32 KiB
Python
736 lines
32 KiB
Python
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||
from sqlalchemy import BigInteger, String, Float, DateTime, Boolean, Text, Integer, text
|
||
from datetime import datetime
|
||
from typing import Optional, List
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class Base(DeclarativeBase):
|
||
pass
|
||
|
||
class User(Base):
|
||
__tablename__ = 'users'
|
||
|
||
id: Mapped[int] = mapped_column(primary_key=True)
|
||
telegram_id: Mapped[int] = mapped_column(BigInteger, unique=True, index=True)
|
||
username: Mapped[Optional[str]] = mapped_column(String(255))
|
||
first_name: Mapped[Optional[str]] = mapped_column(String(255))
|
||
last_name: Mapped[Optional[str]] = mapped_column(String(255))
|
||
language: Mapped[str] = mapped_column(String(10), default='ru')
|
||
balance: Mapped[float] = mapped_column(Float, default=0.0)
|
||
is_admin: Mapped[bool] = mapped_column(Boolean, default=False)
|
||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||
remnawave_uuid: Mapped[Optional[str]] = mapped_column(String(255))
|
||
is_trial_used: Mapped[bool] = mapped_column(Boolean, default=False)
|
||
|
||
class Subscription(Base):
|
||
__tablename__ = 'subscriptions'
|
||
|
||
id: Mapped[int] = mapped_column(primary_key=True)
|
||
name: Mapped[str] = mapped_column(String(255))
|
||
description: Mapped[Optional[str]] = mapped_column(Text)
|
||
price: Mapped[float] = mapped_column(Float)
|
||
duration_days: Mapped[int] = mapped_column(Integer)
|
||
traffic_limit_gb: Mapped[int] = mapped_column(Integer, default=0) # 0 = unlimited
|
||
squad_uuid: Mapped[str] = mapped_column(String(255))
|
||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||
is_trial: Mapped[bool] = mapped_column(Boolean, default=False)
|
||
is_imported: Mapped[bool] = mapped_column(Boolean, default=False)
|
||
|
||
class UserSubscription(Base):
|
||
__tablename__ = 'user_subscriptions'
|
||
|
||
id: Mapped[int] = mapped_column(primary_key=True)
|
||
user_id: Mapped[int] = mapped_column(BigInteger, index=True)
|
||
subscription_id: Mapped[int] = mapped_column(Integer, index=True)
|
||
short_uuid: Mapped[str] = mapped_column(String(255)) # УБРАНО unique=True
|
||
expires_at: Mapped[datetime] = mapped_column(DateTime)
|
||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||
traffic_limit_gb: Mapped[Optional[int]] = mapped_column(Integer) # Добавлено поле
|
||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||
updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, onupdate=datetime.utcnow) # Добавлено поле
|
||
|
||
class Payment(Base):
|
||
__tablename__ = 'payments'
|
||
|
||
id: Mapped[int] = mapped_column(primary_key=True)
|
||
user_id: Mapped[int] = mapped_column(BigInteger, index=True)
|
||
amount: Mapped[float] = mapped_column(Float)
|
||
payment_type: Mapped[str] = mapped_column(String(50)) # 'topup', 'subscription'
|
||
description: Mapped[str] = mapped_column(Text)
|
||
status: Mapped[str] = mapped_column(String(50), default='pending') # 'pending', 'completed', 'cancelled'
|
||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||
|
||
class Promocode(Base):
|
||
__tablename__ = 'promocodes'
|
||
|
||
id: Mapped[int] = mapped_column(primary_key=True)
|
||
code: Mapped[str] = mapped_column(String(255), unique=True, index=True)
|
||
discount_amount: Mapped[float] = mapped_column(Float)
|
||
discount_percent: Mapped[Optional[int]] = mapped_column(Integer)
|
||
usage_limit: Mapped[int] = mapped_column(Integer, default=1)
|
||
used_count: Mapped[int] = mapped_column(Integer, default=0)
|
||
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
|
||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||
|
||
class PromocodeUsage(Base):
|
||
__tablename__ = 'promocode_usage'
|
||
|
||
id: Mapped[int] = mapped_column(primary_key=True)
|
||
user_id: Mapped[int] = mapped_column(BigInteger, index=True)
|
||
promocode_id: Mapped[int] = mapped_column(Integer, index=True)
|
||
used_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||
|
||
class Database:
|
||
def __init__(self, database_url: str):
|
||
self.engine = create_async_engine(
|
||
database_url,
|
||
echo=False,
|
||
pool_pre_ping=True,
|
||
pool_recycle=300
|
||
)
|
||
self.session_factory = async_sessionmaker(
|
||
self.engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False
|
||
)
|
||
|
||
async def init_db(self):
|
||
async with self.engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
|
||
# Выполняем миграции
|
||
await self.migrate_user_subscriptions()
|
||
await self.migrate_subscription_imported_field()
|
||
|
||
async def close(self):
|
||
await self.engine.dispose()
|
||
|
||
# User methods
|
||
async def get_user_by_telegram_id(self, telegram_id: int) -> Optional[User]:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
result = await session.execute(
|
||
select(User).where(User.telegram_id == telegram_id)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
except Exception as e:
|
||
logger.error(f"Error getting user by telegram_id {telegram_id}: {e}")
|
||
return None
|
||
|
||
async def create_user(self, telegram_id: int, username: str = None,
|
||
first_name: str = None, last_name: str = None,
|
||
language: str = 'ru', is_admin: bool = False) -> User:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
user = User(
|
||
telegram_id=telegram_id,
|
||
username=username,
|
||
first_name=first_name,
|
||
last_name=last_name,
|
||
language=language,
|
||
is_admin=is_admin
|
||
)
|
||
session.add(user)
|
||
await session.commit()
|
||
await session.refresh(user)
|
||
return user
|
||
except Exception as e:
|
||
logger.error(f"Error creating user {telegram_id}: {e}")
|
||
await session.rollback()
|
||
raise
|
||
|
||
async def update_user(self, user: User) -> User:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
await session.merge(user)
|
||
await session.commit()
|
||
return user
|
||
except Exception as e:
|
||
logger.error(f"Error updating user {user.telegram_id}: {e}")
|
||
await session.rollback()
|
||
raise
|
||
|
||
async def add_balance(self, user_id: int, amount: float) -> bool:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select, update
|
||
result = await session.execute(
|
||
update(User)
|
||
.where(User.telegram_id == user_id)
|
||
.values(balance=User.balance + amount)
|
||
)
|
||
await session.commit()
|
||
return result.rowcount > 0
|
||
except Exception as e:
|
||
logger.error(f"Error adding balance to user {user_id}: {e}")
|
||
await session.rollback()
|
||
return False
|
||
|
||
# Subscription methods
|
||
async def get_all_subscriptions(self, include_inactive: bool = False, exclude_trial: bool = True, exclude_imported: bool = True) -> List[Subscription]:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
query = select(Subscription)
|
||
if not include_inactive:
|
||
query = query.where(Subscription.is_active == True)
|
||
if exclude_trial:
|
||
query = query.where(Subscription.is_trial == False)
|
||
if exclude_imported:
|
||
query = query.where(Subscription.is_imported == False) # Исключаем импортированные
|
||
result = await session.execute(query)
|
||
return list(result.scalars().all())
|
||
except Exception as e:
|
||
logger.error(f"Error getting subscriptions: {e}")
|
||
return []
|
||
|
||
async def get_all_subscriptions_admin(self) -> List[Subscription]:
|
||
"""Get all subscriptions including imported ones (for admin purposes)"""
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
result = await session.execute(select(Subscription))
|
||
return list(result.scalars().all())
|
||
except Exception as e:
|
||
logger.error(f"Error getting admin subscriptions: {e}")
|
||
return []
|
||
|
||
async def migrate_subscription_imported_field(self):
|
||
"""Add is_imported field to subscriptions table"""
|
||
try:
|
||
async with self.engine.begin() as conn:
|
||
try:
|
||
await conn.execute(text("""
|
||
ALTER TABLE subscriptions
|
||
ADD COLUMN IF NOT EXISTS is_imported BOOLEAN DEFAULT FALSE
|
||
"""))
|
||
logger.info("Successfully added is_imported field to subscriptions table")
|
||
except Exception as e:
|
||
logger.info(f"Migration may have already been applied: {e}")
|
||
except Exception as e:
|
||
logger.error(f"Error during subscription migration: {e}")
|
||
|
||
async def get_subscription_by_id(self, subscription_id: int) -> Optional[Subscription]:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
result = await session.execute(
|
||
select(Subscription).where(Subscription.id == subscription_id)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
except Exception as e:
|
||
logger.error(f"Error getting subscription {subscription_id}: {e}")
|
||
return None
|
||
|
||
async def create_subscription(self, name: str, description: str, price: float,
|
||
duration_days: int, traffic_limit_gb: int,
|
||
squad_uuid: str, is_imported: bool = False) -> Subscription:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
subscription = Subscription(
|
||
name=name,
|
||
description=description,
|
||
price=price,
|
||
duration_days=duration_days,
|
||
traffic_limit_gb=traffic_limit_gb,
|
||
squad_uuid=squad_uuid,
|
||
is_imported=is_imported # Добавляем поддержку is_imported
|
||
)
|
||
session.add(subscription)
|
||
await session.commit()
|
||
await session.refresh(subscription)
|
||
return subscription
|
||
except Exception as e:
|
||
logger.error(f"Error creating subscription: {e}")
|
||
await session.rollback()
|
||
raise
|
||
|
||
async def update_subscription(self, subscription: Subscription) -> Subscription:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
await session.merge(subscription)
|
||
await session.commit()
|
||
return subscription
|
||
except Exception as e:
|
||
logger.error(f"Error updating subscription {subscription.id}: {e}")
|
||
await session.rollback()
|
||
raise
|
||
|
||
async def delete_subscription(self, subscription_id: int) -> bool:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import delete
|
||
result = await session.execute(
|
||
delete(Subscription).where(Subscription.id == subscription_id)
|
||
)
|
||
await session.commit()
|
||
return result.rowcount > 0
|
||
except Exception as e:
|
||
logger.error(f"Error deleting subscription {subscription_id}: {e}")
|
||
await session.rollback()
|
||
return False
|
||
|
||
async def get_user_subscriptions(self, user_id: int) -> List[UserSubscription]:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
result = await session.execute(
|
||
select(UserSubscription).where(UserSubscription.user_id == user_id)
|
||
)
|
||
return list(result.scalars().all())
|
||
except Exception as e:
|
||
logger.error(f"Error getting user subscriptions for {user_id}: {e}")
|
||
return []
|
||
|
||
async def create_user_subscription(self, user_id: int, subscription_id: int,
|
||
short_uuid: str, expires_at: datetime,
|
||
is_active: bool = True, traffic_limit_gb: int = None) -> Optional[UserSubscription]:
|
||
"""Create user subscription with proper error handling"""
|
||
async with self.session_factory() as session:
|
||
try:
|
||
# Проверяем что подписка не существует
|
||
from sqlalchemy import select
|
||
existing = await session.execute(
|
||
select(UserSubscription).where(
|
||
UserSubscription.user_id == user_id,
|
||
UserSubscription.short_uuid == short_uuid
|
||
)
|
||
)
|
||
existing_sub = existing.scalar_one_or_none()
|
||
|
||
if existing_sub:
|
||
logger.warning(f"Subscription with short_uuid {short_uuid} already exists for user {user_id}")
|
||
return existing_sub
|
||
|
||
# Создаем новую подписку
|
||
new_subscription = UserSubscription(
|
||
user_id=user_id,
|
||
subscription_id=subscription_id,
|
||
short_uuid=short_uuid,
|
||
expires_at=expires_at,
|
||
is_active=is_active
|
||
)
|
||
|
||
session.add(new_subscription)
|
||
await session.commit()
|
||
await session.refresh(new_subscription)
|
||
|
||
return new_subscription
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error creating user subscription: {e}")
|
||
await session.rollback()
|
||
return None
|
||
|
||
# Payment methods
|
||
async def create_payment(self, user_id: int, amount: float, payment_type: str,
|
||
description: str, status: str = 'pending') -> Payment:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
payment = Payment(
|
||
user_id=user_id,
|
||
amount=amount,
|
||
payment_type=payment_type,
|
||
description=description,
|
||
status=status
|
||
)
|
||
session.add(payment)
|
||
await session.commit()
|
||
await session.refresh(payment)
|
||
return payment
|
||
except Exception as e:
|
||
logger.error(f"Error creating payment: {e}")
|
||
await session.rollback()
|
||
raise
|
||
|
||
async def get_payment_by_id(self, payment_id: int) -> Optional[Payment]:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
result = await session.execute(
|
||
select(Payment).where(Payment.id == payment_id)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
except Exception as e:
|
||
logger.error(f"Error getting payment {payment_id}: {e}")
|
||
return None
|
||
|
||
async def update_payment(self, payment: Payment) -> Payment:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
await session.merge(payment)
|
||
await session.commit()
|
||
return payment
|
||
except Exception as e:
|
||
logger.error(f"Error updating payment {payment.id}: {e}")
|
||
await session.rollback()
|
||
raise
|
||
|
||
async def get_user_payments(self, user_id: int) -> List[Payment]:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select, desc
|
||
result = await session.execute(
|
||
select(Payment)
|
||
.where(Payment.user_id == user_id)
|
||
.order_by(desc(Payment.created_at))
|
||
)
|
||
return list(result.scalars().all())
|
||
except Exception as e:
|
||
logger.error(f"Error getting user payments for {user_id}: {e}")
|
||
return []
|
||
|
||
# Promocode methods
|
||
async def get_promocode_by_code(self, code: str) -> Optional[Promocode]:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
result = await session.execute(
|
||
select(Promocode).where(Promocode.code == code)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
except Exception as e:
|
||
logger.error(f"Error getting promocode {code}: {e}")
|
||
return None
|
||
|
||
async def create_promocode(self, code: str, discount_amount: float = 0,
|
||
discount_percent: int = None, usage_limit: int = 1,
|
||
expires_at: datetime = None) -> Promocode:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
promocode = Promocode(
|
||
code=code,
|
||
discount_amount=discount_amount,
|
||
discount_percent=discount_percent,
|
||
usage_limit=usage_limit,
|
||
expires_at=expires_at
|
||
)
|
||
session.add(promocode)
|
||
await session.commit()
|
||
await session.refresh(promocode)
|
||
return promocode
|
||
except Exception as e:
|
||
logger.error(f"Error creating promocode: {e}")
|
||
await session.rollback()
|
||
raise
|
||
|
||
async def use_promocode(self, user_id: int, promocode: Promocode) -> bool:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
# Check if already used
|
||
from sqlalchemy import select
|
||
existing = await session.execute(
|
||
select(PromocodeUsage).where(
|
||
PromocodeUsage.user_id == user_id,
|
||
PromocodeUsage.promocode_id == promocode.id
|
||
)
|
||
)
|
||
if existing.scalar_one_or_none():
|
||
return False
|
||
|
||
# Create usage record
|
||
usage = PromocodeUsage(user_id=user_id, promocode_id=promocode.id)
|
||
session.add(usage)
|
||
|
||
# Update promocode used count
|
||
promocode.used_count += 1
|
||
await session.merge(promocode)
|
||
|
||
await session.commit()
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Error using promocode: {e}")
|
||
await session.rollback()
|
||
return False
|
||
|
||
async def get_all_promocodes(self) -> List[Promocode]:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
result = await session.execute(select(Promocode))
|
||
return list(result.scalars().all())
|
||
except Exception as e:
|
||
logger.error(f"Error getting promocodes: {e}")
|
||
return []
|
||
|
||
# Admin methods
|
||
async def get_all_users(self) -> List[User]:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
result = await session.execute(select(User))
|
||
return list(result.scalars().all())
|
||
except Exception as e:
|
||
logger.error(f"Error getting all users: {e}")
|
||
return []
|
||
|
||
async def get_stats(self) -> dict:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select, func
|
||
|
||
# Total users
|
||
total_users = await session.execute(
|
||
select(func.count(User.id))
|
||
)
|
||
total_users = total_users.scalar()
|
||
|
||
# Total subscriptions (excluding trial)
|
||
total_subs_non_trial = await session.execute(
|
||
select(func.count(UserSubscription.id))
|
||
.join(Subscription, UserSubscription.subscription_id == Subscription.id)
|
||
.where(Subscription.is_trial == False)
|
||
)
|
||
total_subs_non_trial = total_subs_non_trial.scalar()
|
||
|
||
# Total payments (excluding trial payments)
|
||
total_payments = await session.execute(
|
||
select(func.sum(Payment.amount)).where(
|
||
Payment.status == 'completed',
|
||
Payment.payment_type != 'trial' # Исключаем тестовые платежи
|
||
)
|
||
)
|
||
total_payments = total_payments.scalar() or 0
|
||
|
||
return {
|
||
'total_users': total_users,
|
||
'total_subscriptions_non_trial': total_subs_non_trial,
|
||
'total_revenue': total_payments
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Error getting stats: {e}")
|
||
return {
|
||
'total_users': 0,
|
||
'total_subscriptions_non_trial': 0,
|
||
'total_revenue': 0
|
||
}
|
||
|
||
async def get_trial_subscriptions(self) -> List[Subscription]:
|
||
"""Get only trial subscriptions"""
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
result = await session.execute(
|
||
select(Subscription).where(Subscription.is_trial == True)
|
||
)
|
||
return list(result.scalars().all())
|
||
except Exception as e:
|
||
logger.error(f"Error getting trial subscriptions: {e}")
|
||
return []
|
||
|
||
async def get_user_subscription_by_short_uuid(self, user_id: int, short_uuid: str) -> Optional[UserSubscription]:
|
||
"""Get user subscription by short_uuid"""
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
result = await session.execute(
|
||
select(UserSubscription).where(
|
||
UserSubscription.user_id == user_id,
|
||
UserSubscription.short_uuid == short_uuid
|
||
)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
except Exception as e:
|
||
logger.error(f"Error getting user subscription by short_uuid: {e}")
|
||
return None
|
||
|
||
async def update_user_subscription(self, user_subscription: UserSubscription) -> bool:
|
||
"""Update user subscription"""
|
||
async with self.session_factory() as session:
|
||
try:
|
||
# Устанавливаем время обновления
|
||
user_subscription.updated_at = datetime.utcnow()
|
||
|
||
# Обновляем подписку
|
||
await session.merge(user_subscription)
|
||
await session.commit()
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Error updating user subscription: {e}")
|
||
await session.rollback()
|
||
return False
|
||
|
||
async def migrate_user_subscriptions(self):
|
||
"""Migrate user_subscriptions table to add missing columns"""
|
||
try:
|
||
async with self.engine.begin() as conn:
|
||
# Проверяем существование столбцов и добавляем их если нет
|
||
try:
|
||
await conn.execute(text("""
|
||
ALTER TABLE user_subscriptions
|
||
ADD COLUMN IF NOT EXISTS traffic_limit_gb INTEGER,
|
||
ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP
|
||
"""))
|
||
logger.info("Successfully migrated user_subscriptions table")
|
||
except Exception as e:
|
||
logger.info(f"Migration may have already been applied or error occurred: {e}")
|
||
except Exception as e:
|
||
logger.error(f"Error during migration: {e}")
|
||
|
||
async def get_expiring_subscriptions(self, user_id: int, days_threshold: int = 3) -> List[UserSubscription]:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
from datetime import datetime, timedelta
|
||
|
||
threshold_date = datetime.utcnow() + timedelta(days=days_threshold)
|
||
|
||
result = await session.execute(
|
||
select(UserSubscription).where(
|
||
UserSubscription.user_id == user_id,
|
||
UserSubscription.is_active == True,
|
||
UserSubscription.expires_at <= threshold_date,
|
||
UserSubscription.expires_at > datetime.utcnow()
|
||
)
|
||
)
|
||
return list(result.scalars().all())
|
||
except Exception as e:
|
||
logger.error(f"Error getting expiring subscriptions for {user_id}: {e}")
|
||
return []
|
||
|
||
async def has_used_trial(self, user_id: int) -> bool:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
result = await session.execute(
|
||
select(User.is_trial_used).where(User.telegram_id == user_id)
|
||
)
|
||
is_trial_used = result.scalar_one_or_none()
|
||
return is_trial_used or False
|
||
except Exception as e:
|
||
logger.error(f"Error checking trial usage for user {user_id}: {e}")
|
||
return False
|
||
|
||
async def mark_trial_used(self, user_id: int) -> bool:
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import update
|
||
result = await session.execute(
|
||
update(User)
|
||
.where(User.telegram_id == user_id)
|
||
.values(is_trial_used=True)
|
||
)
|
||
await session.commit()
|
||
return result.rowcount > 0
|
||
except Exception as e:
|
||
logger.error(f"Error marking trial used for user {user_id}: {e}")
|
||
await session.rollback()
|
||
return False
|
||
|
||
async def get_all_payments_paginated(self, offset: int = 0, limit: int = 10) -> tuple[List[Payment], int]:
|
||
"""Get all payments with pagination"""
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select, desc, func
|
||
|
||
# Получаем общее количество записей
|
||
count_result = await session.execute(
|
||
select(func.count(Payment.id))
|
||
)
|
||
total_count = count_result.scalar()
|
||
|
||
# Получаем платежи с пагинацией
|
||
result = await session.execute(
|
||
select(Payment)
|
||
.order_by(desc(Payment.created_at))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
payments = list(result.scalars().all())
|
||
|
||
return payments, total_count
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting paginated payments: {e}")
|
||
return [], 0
|
||
|
||
async def get_payments_by_type_paginated(self, payment_type: str, offset: int = 0, limit: int = 10) -> tuple[List[Payment], int]:
|
||
"""Get payments by type with pagination"""
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select, desc, func
|
||
|
||
# Получаем общее количество записей
|
||
count_result = await session.execute(
|
||
select(func.count(Payment.id)).where(Payment.payment_type == payment_type)
|
||
)
|
||
total_count = count_result.scalar()
|
||
|
||
# Получаем платежи с пагинацией
|
||
result = await session.execute(
|
||
select(Payment)
|
||
.where(Payment.payment_type == payment_type)
|
||
.order_by(desc(Payment.created_at))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
payments = list(result.scalars().all())
|
||
|
||
return payments, total_count
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting paginated payments by type: {e}")
|
||
return [], 0
|
||
|
||
async def get_payments_by_status_paginated(self, status: str, offset: int = 0, limit: int = 10) -> tuple[List[Payment], int]:
|
||
"""Get payments by status with pagination"""
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select, desc, func
|
||
|
||
# Получаем общее количество записей
|
||
count_result = await session.execute(
|
||
select(func.count(Payment.id)).where(Payment.status == status)
|
||
)
|
||
total_count = count_result.scalar()
|
||
|
||
# Получаем платежи с пагинацией
|
||
result = await session.execute(
|
||
select(Payment)
|
||
.where(Payment.status == status)
|
||
.order_by(desc(Payment.created_at))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
payments = list(result.scalars().all())
|
||
|
||
return payments, total_count
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting paginated payments by status: {e}")
|
||
return [], 0
|
||
|
||
async def get_user_subscriptions_by_plan_id(self, plan_id: int) -> List[UserSubscription]:
|
||
"""Get all user subscriptions for a specific plan"""
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import select
|
||
result = await session.execute(
|
||
select(UserSubscription).where(UserSubscription.subscription_id == plan_id)
|
||
)
|
||
return list(result.scalars().all())
|
||
except Exception as e:
|
||
logger.error(f"Error getting user subscriptions for plan {plan_id}: {e}")
|
||
return []
|
||
|
||
async def delete_user_subscription(self, user_subscription_id: int) -> bool:
|
||
"""Delete user subscription by ID"""
|
||
async with self.session_factory() as session:
|
||
try:
|
||
from sqlalchemy import delete
|
||
result = await session.execute(
|
||
delete(UserSubscription).where(UserSubscription.id == user_subscription_id)
|
||
)
|
||
await session.commit()
|
||
return result.rowcount > 0
|
||
except Exception as e:
|
||
logger.error(f"Error deleting user subscription {user_subscription_id}: {e}")
|
||
await session.rollback()
|
||
return False
|