Files
remnawave-bedolaga-telegram…/database.py
2025-08-07 07:49:48 +03:00

736 lines
32 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import BigInteger, String, Float, DateTime, Boolean, Text, Integer, text
from datetime import datetime
from typing import Optional, List
import logging
logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = 'users'
id: Mapped[int] = mapped_column(primary_key=True)
telegram_id: Mapped[int] = mapped_column(BigInteger, unique=True, index=True)
username: Mapped[Optional[str]] = mapped_column(String(255))
first_name: Mapped[Optional[str]] = mapped_column(String(255))
last_name: Mapped[Optional[str]] = mapped_column(String(255))
language: Mapped[str] = mapped_column(String(10), default='ru')
balance: Mapped[float] = mapped_column(Float, default=0.0)
is_admin: Mapped[bool] = mapped_column(Boolean, default=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
remnawave_uuid: Mapped[Optional[str]] = mapped_column(String(255))
is_trial_used: Mapped[bool] = mapped_column(Boolean, default=False)
class Subscription(Base):
__tablename__ = 'subscriptions'
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(255))
description: Mapped[Optional[str]] = mapped_column(Text)
price: Mapped[float] = mapped_column(Float)
duration_days: Mapped[int] = mapped_column(Integer)
traffic_limit_gb: Mapped[int] = mapped_column(Integer, default=0) # 0 = unlimited
squad_uuid: Mapped[str] = mapped_column(String(255))
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
is_trial: Mapped[bool] = mapped_column(Boolean, default=False)
is_imported: Mapped[bool] = mapped_column(Boolean, default=False)
class UserSubscription(Base):
__tablename__ = 'user_subscriptions'
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, index=True)
subscription_id: Mapped[int] = mapped_column(Integer, index=True)
short_uuid: Mapped[str] = mapped_column(String(255)) # УБРАНО unique=True
expires_at: Mapped[datetime] = mapped_column(DateTime)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
traffic_limit_gb: Mapped[Optional[int]] = mapped_column(Integer) # Добавлено поле
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, onupdate=datetime.utcnow) # Добавлено поле
class Payment(Base):
__tablename__ = 'payments'
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, index=True)
amount: Mapped[float] = mapped_column(Float)
payment_type: Mapped[str] = mapped_column(String(50)) # 'topup', 'subscription'
description: Mapped[str] = mapped_column(Text)
status: Mapped[str] = mapped_column(String(50), default='pending') # 'pending', 'completed', 'cancelled'
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
class Promocode(Base):
__tablename__ = 'promocodes'
id: Mapped[int] = mapped_column(primary_key=True)
code: Mapped[str] = mapped_column(String(255), unique=True, index=True)
discount_amount: Mapped[float] = mapped_column(Float)
discount_percent: Mapped[Optional[int]] = mapped_column(Integer)
usage_limit: Mapped[int] = mapped_column(Integer, default=1)
used_count: Mapped[int] = mapped_column(Integer, default=0)
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
class PromocodeUsage(Base):
__tablename__ = 'promocode_usage'
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, index=True)
promocode_id: Mapped[int] = mapped_column(Integer, index=True)
used_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
class Database:
def __init__(self, database_url: str):
self.engine = create_async_engine(
database_url,
echo=False,
pool_pre_ping=True,
pool_recycle=300
)
self.session_factory = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False
)
async def init_db(self):
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Выполняем миграции
await self.migrate_user_subscriptions()
await self.migrate_subscription_imported_field()
async def close(self):
await self.engine.dispose()
# User methods
async def get_user_by_telegram_id(self, telegram_id: int) -> Optional[User]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(User).where(User.telegram_id == telegram_id)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting user by telegram_id {telegram_id}: {e}")
return None
async def create_user(self, telegram_id: int, username: str = None,
first_name: str = None, last_name: str = None,
language: str = 'ru', is_admin: bool = False) -> User:
async with self.session_factory() as session:
try:
user = User(
telegram_id=telegram_id,
username=username,
first_name=first_name,
last_name=last_name,
language=language,
is_admin=is_admin
)
session.add(user)
await session.commit()
await session.refresh(user)
return user
except Exception as e:
logger.error(f"Error creating user {telegram_id}: {e}")
await session.rollback()
raise
async def update_user(self, user: User) -> User:
async with self.session_factory() as session:
try:
await session.merge(user)
await session.commit()
return user
except Exception as e:
logger.error(f"Error updating user {user.telegram_id}: {e}")
await session.rollback()
raise
async def add_balance(self, user_id: int, amount: float) -> bool:
async with self.session_factory() as session:
try:
from sqlalchemy import select, update
result = await session.execute(
update(User)
.where(User.telegram_id == user_id)
.values(balance=User.balance + amount)
)
await session.commit()
return result.rowcount > 0
except Exception as e:
logger.error(f"Error adding balance to user {user_id}: {e}")
await session.rollback()
return False
# Subscription methods
async def get_all_subscriptions(self, include_inactive: bool = False, exclude_trial: bool = True, exclude_imported: bool = True) -> List[Subscription]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
query = select(Subscription)
if not include_inactive:
query = query.where(Subscription.is_active == True)
if exclude_trial:
query = query.where(Subscription.is_trial == False)
if exclude_imported:
query = query.where(Subscription.is_imported == False) # Исключаем импортированные
result = await session.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting subscriptions: {e}")
return []
async def get_all_subscriptions_admin(self) -> List[Subscription]:
"""Get all subscriptions including imported ones (for admin purposes)"""
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(select(Subscription))
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting admin subscriptions: {e}")
return []
async def migrate_subscription_imported_field(self):
"""Add is_imported field to subscriptions table"""
try:
async with self.engine.begin() as conn:
try:
await conn.execute(text("""
ALTER TABLE subscriptions
ADD COLUMN IF NOT EXISTS is_imported BOOLEAN DEFAULT FALSE
"""))
logger.info("Successfully added is_imported field to subscriptions table")
except Exception as e:
logger.info(f"Migration may have already been applied: {e}")
except Exception as e:
logger.error(f"Error during subscription migration: {e}")
async def get_subscription_by_id(self, subscription_id: int) -> Optional[Subscription]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(Subscription).where(Subscription.id == subscription_id)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting subscription {subscription_id}: {e}")
return None
async def create_subscription(self, name: str, description: str, price: float,
duration_days: int, traffic_limit_gb: int,
squad_uuid: str, is_imported: bool = False) -> Subscription:
async with self.session_factory() as session:
try:
subscription = Subscription(
name=name,
description=description,
price=price,
duration_days=duration_days,
traffic_limit_gb=traffic_limit_gb,
squad_uuid=squad_uuid,
is_imported=is_imported # Добавляем поддержку is_imported
)
session.add(subscription)
await session.commit()
await session.refresh(subscription)
return subscription
except Exception as e:
logger.error(f"Error creating subscription: {e}")
await session.rollback()
raise
async def update_subscription(self, subscription: Subscription) -> Subscription:
async with self.session_factory() as session:
try:
await session.merge(subscription)
await session.commit()
return subscription
except Exception as e:
logger.error(f"Error updating subscription {subscription.id}: {e}")
await session.rollback()
raise
async def delete_subscription(self, subscription_id: int) -> bool:
async with self.session_factory() as session:
try:
from sqlalchemy import delete
result = await session.execute(
delete(Subscription).where(Subscription.id == subscription_id)
)
await session.commit()
return result.rowcount > 0
except Exception as e:
logger.error(f"Error deleting subscription {subscription_id}: {e}")
await session.rollback()
return False
async def get_user_subscriptions(self, user_id: int) -> List[UserSubscription]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(UserSubscription).where(UserSubscription.user_id == user_id)
)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user subscriptions for {user_id}: {e}")
return []
async def create_user_subscription(self, user_id: int, subscription_id: int,
short_uuid: str, expires_at: datetime,
is_active: bool = True, traffic_limit_gb: int = None) -> Optional[UserSubscription]:
"""Create user subscription with proper error handling"""
async with self.session_factory() as session:
try:
# Проверяем что подписка не существует
from sqlalchemy import select
existing = await session.execute(
select(UserSubscription).where(
UserSubscription.user_id == user_id,
UserSubscription.short_uuid == short_uuid
)
)
existing_sub = existing.scalar_one_or_none()
if existing_sub:
logger.warning(f"Subscription with short_uuid {short_uuid} already exists for user {user_id}")
return existing_sub
# Создаем новую подписку
new_subscription = UserSubscription(
user_id=user_id,
subscription_id=subscription_id,
short_uuid=short_uuid,
expires_at=expires_at,
is_active=is_active
)
session.add(new_subscription)
await session.commit()
await session.refresh(new_subscription)
return new_subscription
except Exception as e:
logger.error(f"Error creating user subscription: {e}")
await session.rollback()
return None
# Payment methods
async def create_payment(self, user_id: int, amount: float, payment_type: str,
description: str, status: str = 'pending') -> Payment:
async with self.session_factory() as session:
try:
payment = Payment(
user_id=user_id,
amount=amount,
payment_type=payment_type,
description=description,
status=status
)
session.add(payment)
await session.commit()
await session.refresh(payment)
return payment
except Exception as e:
logger.error(f"Error creating payment: {e}")
await session.rollback()
raise
async def get_payment_by_id(self, payment_id: int) -> Optional[Payment]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(Payment).where(Payment.id == payment_id)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting payment {payment_id}: {e}")
return None
async def update_payment(self, payment: Payment) -> Payment:
async with self.session_factory() as session:
try:
await session.merge(payment)
await session.commit()
return payment
except Exception as e:
logger.error(f"Error updating payment {payment.id}: {e}")
await session.rollback()
raise
async def get_user_payments(self, user_id: int) -> List[Payment]:
async with self.session_factory() as session:
try:
from sqlalchemy import select, desc
result = await session.execute(
select(Payment)
.where(Payment.user_id == user_id)
.order_by(desc(Payment.created_at))
)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user payments for {user_id}: {e}")
return []
# Promocode methods
async def get_promocode_by_code(self, code: str) -> Optional[Promocode]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(Promocode).where(Promocode.code == code)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting promocode {code}: {e}")
return None
async def create_promocode(self, code: str, discount_amount: float = 0,
discount_percent: int = None, usage_limit: int = 1,
expires_at: datetime = None) -> Promocode:
async with self.session_factory() as session:
try:
promocode = Promocode(
code=code,
discount_amount=discount_amount,
discount_percent=discount_percent,
usage_limit=usage_limit,
expires_at=expires_at
)
session.add(promocode)
await session.commit()
await session.refresh(promocode)
return promocode
except Exception as e:
logger.error(f"Error creating promocode: {e}")
await session.rollback()
raise
async def use_promocode(self, user_id: int, promocode: Promocode) -> bool:
async with self.session_factory() as session:
try:
# Check if already used
from sqlalchemy import select
existing = await session.execute(
select(PromocodeUsage).where(
PromocodeUsage.user_id == user_id,
PromocodeUsage.promocode_id == promocode.id
)
)
if existing.scalar_one_or_none():
return False
# Create usage record
usage = PromocodeUsage(user_id=user_id, promocode_id=promocode.id)
session.add(usage)
# Update promocode used count
promocode.used_count += 1
await session.merge(promocode)
await session.commit()
return True
except Exception as e:
logger.error(f"Error using promocode: {e}")
await session.rollback()
return False
async def get_all_promocodes(self) -> List[Promocode]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(select(Promocode))
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting promocodes: {e}")
return []
# Admin methods
async def get_all_users(self) -> List[User]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(select(User))
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting all users: {e}")
return []
async def get_stats(self) -> dict:
async with self.session_factory() as session:
try:
from sqlalchemy import select, func
# Total users
total_users = await session.execute(
select(func.count(User.id))
)
total_users = total_users.scalar()
# Total subscriptions (excluding trial)
total_subs_non_trial = await session.execute(
select(func.count(UserSubscription.id))
.join(Subscription, UserSubscription.subscription_id == Subscription.id)
.where(Subscription.is_trial == False)
)
total_subs_non_trial = total_subs_non_trial.scalar()
# Total payments (excluding trial payments)
total_payments = await session.execute(
select(func.sum(Payment.amount)).where(
Payment.status == 'completed',
Payment.payment_type != 'trial' # Исключаем тестовые платежи
)
)
total_payments = total_payments.scalar() or 0
return {
'total_users': total_users,
'total_subscriptions_non_trial': total_subs_non_trial,
'total_revenue': total_payments
}
except Exception as e:
logger.error(f"Error getting stats: {e}")
return {
'total_users': 0,
'total_subscriptions_non_trial': 0,
'total_revenue': 0
}
async def get_trial_subscriptions(self) -> List[Subscription]:
"""Get only trial subscriptions"""
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(Subscription).where(Subscription.is_trial == True)
)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting trial subscriptions: {e}")
return []
async def get_user_subscription_by_short_uuid(self, user_id: int, short_uuid: str) -> Optional[UserSubscription]:
"""Get user subscription by short_uuid"""
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(UserSubscription).where(
UserSubscription.user_id == user_id,
UserSubscription.short_uuid == short_uuid
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting user subscription by short_uuid: {e}")
return None
async def update_user_subscription(self, user_subscription: UserSubscription) -> bool:
"""Update user subscription"""
async with self.session_factory() as session:
try:
# Устанавливаем время обновления
user_subscription.updated_at = datetime.utcnow()
# Обновляем подписку
await session.merge(user_subscription)
await session.commit()
return True
except Exception as e:
logger.error(f"Error updating user subscription: {e}")
await session.rollback()
return False
async def migrate_user_subscriptions(self):
"""Migrate user_subscriptions table to add missing columns"""
try:
async with self.engine.begin() as conn:
# Проверяем существование столбцов и добавляем их если нет
try:
await conn.execute(text("""
ALTER TABLE user_subscriptions
ADD COLUMN IF NOT EXISTS traffic_limit_gb INTEGER,
ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP
"""))
logger.info("Successfully migrated user_subscriptions table")
except Exception as e:
logger.info(f"Migration may have already been applied or error occurred: {e}")
except Exception as e:
logger.error(f"Error during migration: {e}")
async def get_expiring_subscriptions(self, user_id: int, days_threshold: int = 3) -> List[UserSubscription]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
from datetime import datetime, timedelta
threshold_date = datetime.utcnow() + timedelta(days=days_threshold)
result = await session.execute(
select(UserSubscription).where(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True,
UserSubscription.expires_at <= threshold_date,
UserSubscription.expires_at > datetime.utcnow()
)
)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting expiring subscriptions for {user_id}: {e}")
return []
async def has_used_trial(self, user_id: int) -> bool:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(User.is_trial_used).where(User.telegram_id == user_id)
)
is_trial_used = result.scalar_one_or_none()
return is_trial_used or False
except Exception as e:
logger.error(f"Error checking trial usage for user {user_id}: {e}")
return False
async def mark_trial_used(self, user_id: int) -> bool:
async with self.session_factory() as session:
try:
from sqlalchemy import update
result = await session.execute(
update(User)
.where(User.telegram_id == user_id)
.values(is_trial_used=True)
)
await session.commit()
return result.rowcount > 0
except Exception as e:
logger.error(f"Error marking trial used for user {user_id}: {e}")
await session.rollback()
return False
async def get_all_payments_paginated(self, offset: int = 0, limit: int = 10) -> tuple[List[Payment], int]:
"""Get all payments with pagination"""
async with self.session_factory() as session:
try:
from sqlalchemy import select, desc, func
# Получаем общее количество записей
count_result = await session.execute(
select(func.count(Payment.id))
)
total_count = count_result.scalar()
# Получаем платежи с пагинацией
result = await session.execute(
select(Payment)
.order_by(desc(Payment.created_at))
.offset(offset)
.limit(limit)
)
payments = list(result.scalars().all())
return payments, total_count
except Exception as e:
logger.error(f"Error getting paginated payments: {e}")
return [], 0
async def get_payments_by_type_paginated(self, payment_type: str, offset: int = 0, limit: int = 10) -> tuple[List[Payment], int]:
"""Get payments by type with pagination"""
async with self.session_factory() as session:
try:
from sqlalchemy import select, desc, func
# Получаем общее количество записей
count_result = await session.execute(
select(func.count(Payment.id)).where(Payment.payment_type == payment_type)
)
total_count = count_result.scalar()
# Получаем платежи с пагинацией
result = await session.execute(
select(Payment)
.where(Payment.payment_type == payment_type)
.order_by(desc(Payment.created_at))
.offset(offset)
.limit(limit)
)
payments = list(result.scalars().all())
return payments, total_count
except Exception as e:
logger.error(f"Error getting paginated payments by type: {e}")
return [], 0
async def get_payments_by_status_paginated(self, status: str, offset: int = 0, limit: int = 10) -> tuple[List[Payment], int]:
"""Get payments by status with pagination"""
async with self.session_factory() as session:
try:
from sqlalchemy import select, desc, func
# Получаем общее количество записей
count_result = await session.execute(
select(func.count(Payment.id)).where(Payment.status == status)
)
total_count = count_result.scalar()
# Получаем платежи с пагинацией
result = await session.execute(
select(Payment)
.where(Payment.status == status)
.order_by(desc(Payment.created_at))
.offset(offset)
.limit(limit)
)
payments = list(result.scalars().all())
return payments, total_count
except Exception as e:
logger.error(f"Error getting paginated payments by status: {e}")
return [], 0
async def get_user_subscriptions_by_plan_id(self, plan_id: int) -> List[UserSubscription]:
"""Get all user subscriptions for a specific plan"""
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(UserSubscription).where(UserSubscription.subscription_id == plan_id)
)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user subscriptions for plan {plan_id}: {e}")
return []
async def delete_user_subscription(self, user_subscription_id: int) -> bool:
"""Delete user subscription by ID"""
async with self.session_factory() as session:
try:
from sqlalchemy import delete
result = await session.execute(
delete(UserSubscription).where(UserSubscription.id == user_subscription_id)
)
await session.commit()
return result.rowcount > 0
except Exception as e:
logger.error(f"Error deleting user subscription {user_subscription_id}: {e}")
await session.rollback()
return False