Files
remnawave-bedolaga-telegram…/database.py
2025-08-04 16:48:42 +03:00

510 lines
21 KiB
Python

from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import BigInteger, String, Float, DateTime, Boolean, Text, Integer
from datetime import datetime
from typing import Optional, List
import logging
logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = 'users'
id: Mapped[int] = mapped_column(primary_key=True)
telegram_id: Mapped[int] = mapped_column(BigInteger, unique=True, index=True)
username: Mapped[Optional[str]] = mapped_column(String(255))
first_name: Mapped[Optional[str]] = mapped_column(String(255))
last_name: Mapped[Optional[str]] = mapped_column(String(255))
language: Mapped[str] = mapped_column(String(10), default='ru')
balance: Mapped[float] = mapped_column(Float, default=0.0)
is_admin: Mapped[bool] = mapped_column(Boolean, default=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
remnawave_uuid: Mapped[Optional[str]] = mapped_column(String(255))
is_trial_used: Mapped[bool] = mapped_column(Boolean, default=False)
class Subscription(Base):
__tablename__ = 'subscriptions'
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(255))
description: Mapped[Optional[str]] = mapped_column(Text)
price: Mapped[float] = mapped_column(Float)
duration_days: Mapped[int] = mapped_column(Integer)
traffic_limit_gb: Mapped[int] = mapped_column(Integer, default=0) # 0 = unlimited
squad_uuid: Mapped[str] = mapped_column(String(255))
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
is_trial: Mapped[bool] = mapped_column(Boolean, default=False)
class UserSubscription(Base):
__tablename__ = 'user_subscriptions'
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, index=True)
subscription_id: Mapped[int] = mapped_column(Integer, index=True)
short_uuid: Mapped[str] = mapped_column(String(255)) # УБРАНО unique=True
expires_at: Mapped[datetime] = mapped_column(DateTime)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
class Payment(Base):
__tablename__ = 'payments'
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, index=True)
amount: Mapped[float] = mapped_column(Float)
payment_type: Mapped[str] = mapped_column(String(50)) # 'topup', 'subscription'
description: Mapped[str] = mapped_column(Text)
status: Mapped[str] = mapped_column(String(50), default='pending') # 'pending', 'completed', 'cancelled'
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
class Promocode(Base):
__tablename__ = 'promocodes'
id: Mapped[int] = mapped_column(primary_key=True)
code: Mapped[str] = mapped_column(String(255), unique=True, index=True)
discount_amount: Mapped[float] = mapped_column(Float)
discount_percent: Mapped[Optional[int]] = mapped_column(Integer)
usage_limit: Mapped[int] = mapped_column(Integer, default=1)
used_count: Mapped[int] = mapped_column(Integer, default=0)
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
class PromocodeUsage(Base):
__tablename__ = 'promocode_usage'
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, index=True)
promocode_id: Mapped[int] = mapped_column(Integer, index=True)
used_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
class Database:
def __init__(self, database_url: str):
self.engine = create_async_engine(
database_url,
echo=False,
pool_pre_ping=True,
pool_recycle=300
)
self.session_factory = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False
)
async def init_db(self):
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def close(self):
await self.engine.dispose()
# User methods
async def get_user_by_telegram_id(self, telegram_id: int) -> Optional[User]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(User).where(User.telegram_id == telegram_id)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting user by telegram_id {telegram_id}: {e}")
return None
async def create_user(self, telegram_id: int, username: str = None,
first_name: str = None, last_name: str = None,
language: str = 'ru', is_admin: bool = False) -> User:
async with self.session_factory() as session:
try:
user = User(
telegram_id=telegram_id,
username=username,
first_name=first_name,
last_name=last_name,
language=language,
is_admin=is_admin
)
session.add(user)
await session.commit()
await session.refresh(user)
return user
except Exception as e:
logger.error(f"Error creating user {telegram_id}: {e}")
await session.rollback()
raise
async def update_user(self, user: User) -> User:
async with self.session_factory() as session:
try:
await session.merge(user)
await session.commit()
return user
except Exception as e:
logger.error(f"Error updating user {user.telegram_id}: {e}")
await session.rollback()
raise
async def add_balance(self, user_id: int, amount: float) -> bool:
async with self.session_factory() as session:
try:
from sqlalchemy import select, update
result = await session.execute(
update(User)
.where(User.telegram_id == user_id)
.values(balance=User.balance + amount)
)
await session.commit()
return result.rowcount > 0
except Exception as e:
logger.error(f"Error adding balance to user {user_id}: {e}")
await session.rollback()
return False
# Subscription methods
async def get_all_subscriptions(self, include_inactive: bool = False) -> List[Subscription]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
query = select(Subscription)
if not include_inactive:
query = query.where(Subscription.is_active == True)
result = await session.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting subscriptions: {e}")
return []
async def get_subscription_by_id(self, subscription_id: int) -> Optional[Subscription]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(Subscription).where(Subscription.id == subscription_id)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting subscription {subscription_id}: {e}")
return None
async def create_subscription(self, name: str, description: str, price: float,
duration_days: int, traffic_limit_gb: int,
squad_uuid: str) -> Subscription:
async with self.session_factory() as session:
try:
subscription = Subscription(
name=name,
description=description,
price=price,
duration_days=duration_days,
traffic_limit_gb=traffic_limit_gb,
squad_uuid=squad_uuid
)
session.add(subscription)
await session.commit()
await session.refresh(subscription)
return subscription
except Exception as e:
logger.error(f"Error creating subscription: {e}")
await session.rollback()
raise
async def update_subscription(self, subscription: Subscription) -> Subscription:
async with self.session_factory() as session:
try:
await session.merge(subscription)
await session.commit()
return subscription
except Exception as e:
logger.error(f"Error updating subscription {subscription.id}: {e}")
await session.rollback()
raise
async def delete_subscription(self, subscription_id: int) -> bool:
async with self.session_factory() as session:
try:
from sqlalchemy import delete
result = await session.execute(
delete(Subscription).where(Subscription.id == subscription_id)
)
await session.commit()
return result.rowcount > 0
except Exception as e:
logger.error(f"Error deleting subscription {subscription_id}: {e}")
await session.rollback()
return False
async def get_user_subscriptions(self, user_id: int) -> List[UserSubscription]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(UserSubscription).where(UserSubscription.user_id == user_id)
)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user subscriptions for {user_id}: {e}")
return []
async def create_user_subscription(self, user_id: int, subscription_id: int,
short_uuid: str, expires_at: datetime) -> UserSubscription:
async with self.session_factory() as session:
try:
user_sub = UserSubscription(
user_id=user_id,
subscription_id=subscription_id,
short_uuid=short_uuid,
expires_at=expires_at
)
session.add(user_sub)
await session.commit()
await session.refresh(user_sub)
return user_sub
except Exception as e:
logger.error(f"Error creating user subscription: {e}")
await session.rollback()
raise
# Payment methods
async def create_payment(self, user_id: int, amount: float, payment_type: str,
description: str, status: str = 'pending') -> Payment:
async with self.session_factory() as session:
try:
payment = Payment(
user_id=user_id,
amount=amount,
payment_type=payment_type,
description=description,
status=status
)
session.add(payment)
await session.commit()
await session.refresh(payment)
return payment
except Exception as e:
logger.error(f"Error creating payment: {e}")
await session.rollback()
raise
async def get_payment_by_id(self, payment_id: int) -> Optional[Payment]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(Payment).where(Payment.id == payment_id)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting payment {payment_id}: {e}")
return None
async def update_payment(self, payment: Payment) -> Payment:
async with self.session_factory() as session:
try:
await session.merge(payment)
await session.commit()
return payment
except Exception as e:
logger.error(f"Error updating payment {payment.id}: {e}")
await session.rollback()
raise
async def get_user_payments(self, user_id: int) -> List[Payment]:
async with self.session_factory() as session:
try:
from sqlalchemy import select, desc
result = await session.execute(
select(Payment)
.where(Payment.user_id == user_id)
.order_by(desc(Payment.created_at))
)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user payments for {user_id}: {e}")
return []
# Promocode methods
async def get_promocode_by_code(self, code: str) -> Optional[Promocode]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(Promocode).where(Promocode.code == code)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting promocode {code}: {e}")
return None
async def create_promocode(self, code: str, discount_amount: float = 0,
discount_percent: int = None, usage_limit: int = 1,
expires_at: datetime = None) -> Promocode:
async with self.session_factory() as session:
try:
promocode = Promocode(
code=code,
discount_amount=discount_amount,
discount_percent=discount_percent,
usage_limit=usage_limit,
expires_at=expires_at
)
session.add(promocode)
await session.commit()
await session.refresh(promocode)
return promocode
except Exception as e:
logger.error(f"Error creating promocode: {e}")
await session.rollback()
raise
async def use_promocode(self, user_id: int, promocode: Promocode) -> bool:
async with self.session_factory() as session:
try:
# Check if already used
from sqlalchemy import select
existing = await session.execute(
select(PromocodeUsage).where(
PromocodeUsage.user_id == user_id,
PromocodeUsage.promocode_id == promocode.id
)
)
if existing.scalar_one_or_none():
return False
# Create usage record
usage = PromocodeUsage(user_id=user_id, promocode_id=promocode.id)
session.add(usage)
# Update promocode used count
promocode.used_count += 1
await session.merge(promocode)
await session.commit()
return True
except Exception as e:
logger.error(f"Error using promocode: {e}")
await session.rollback()
return False
async def get_all_promocodes(self) -> List[Promocode]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(select(Promocode))
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting promocodes: {e}")
return []
# Admin methods
async def get_all_users(self) -> List[User]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(select(User))
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting all users: {e}")
return []
async def get_stats(self) -> dict:
async with self.session_factory() as session:
try:
from sqlalchemy import select, func
# Total users
total_users = await session.execute(
select(func.count(User.id))
)
total_users = total_users.scalar()
# Total subscriptions
total_subs = await session.execute(
select(func.count(UserSubscription.id))
)
total_subs = total_subs.scalar()
# Total payments
total_payments = await session.execute(
select(func.sum(Payment.amount)).where(Payment.status == 'completed')
)
total_payments = total_payments.scalar() or 0
return {
'total_users': total_users,
'total_subscriptions': total_subs,
'total_revenue': total_payments
}
except Exception as e:
logger.error(f"Error getting stats: {e}")
return {
'total_users': 0,
'total_subscriptions': 0,
'total_revenue': 0
}
async def update_user_subscription(self, user_sub: UserSubscription) -> UserSubscription:
async with self.session_factory() as session:
try:
await session.merge(user_sub)
await session.commit()
return user_sub
except Exception as e:
logger.error(f"Error updating user subscription {user_sub.id}: {e}")
await session.rollback()
raise
async def get_expiring_subscriptions(self, user_id: int, days_threshold: int = 3) -> List[UserSubscription]:
async with self.session_factory() as session:
try:
from sqlalchemy import select
from datetime import datetime, timedelta
threshold_date = datetime.utcnow() + timedelta(days=days_threshold)
result = await session.execute(
select(UserSubscription).where(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True,
UserSubscription.expires_at <= threshold_date,
UserSubscription.expires_at > datetime.utcnow()
)
)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting expiring subscriptions for {user_id}: {e}")
return []
async def has_used_trial(self, user_id: int) -> bool:
async with self.session_factory() as session:
try:
from sqlalchemy import select
result = await session.execute(
select(User.is_trial_used).where(User.telegram_id == user_id)
)
is_trial_used = result.scalar_one_or_none()
return is_trial_used or False
except Exception as e:
logger.error(f"Error checking trial usage for user {user_id}: {e}")
return False
async def mark_trial_used(self, user_id: int) -> bool:
async with self.session_factory() as session:
try:
from sqlalchemy import update
result = await session.execute(
update(User)
.where(User.telegram_id == user_id)
.values(is_trial_used=True)
)
await session.commit()
return result.rowcount > 0
except Exception as e:
logger.error(f"Error marking trial used for user {user_id}: {e}")
await session.rollback()
return False