mirror of
https://github.com/BEDOLAGA-DEV/remnawave-bedolaga-telegram-bot.git
synced 2026-02-28 07:11:37 +00:00
609 lines
23 KiB
Python
609 lines
23 KiB
Python
import asyncio
|
||
import time
|
||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||
from contextlib import asynccontextmanager
|
||
from functools import wraps
|
||
from typing import ParamSpec, TypeVar
|
||
|
||
import structlog
|
||
from sqlalchemy import bindparam, event, text
|
||
from sqlalchemy.engine import Engine
|
||
from sqlalchemy.exc import InterfaceError, OperationalError
|
||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
|
||
|
||
from app.config import settings
|
||
|
||
|
||
logger = structlog.get_logger(__name__)
|
||
|
||
T = TypeVar('T')
|
||
P = ParamSpec('P')
|
||
R = TypeVar('R')
|
||
|
||
# ============================================================================
|
||
# PRODUCTION-GRADE CONNECTION POOLING
|
||
# ============================================================================
|
||
|
||
|
||
def _is_sqlite_url(url: str) -> bool:
|
||
"""Проверка на SQLite URL (поддерживает sqlite:// и sqlite+aiosqlite://)"""
|
||
return url.startswith('sqlite') or ':memory:' in url
|
||
|
||
|
||
DATABASE_URL = settings.get_database_url()
|
||
IS_SQLITE = _is_sqlite_url(DATABASE_URL)
|
||
|
||
if IS_SQLITE:
|
||
poolclass = NullPool
|
||
pool_kwargs = {}
|
||
else:
|
||
poolclass = AsyncAdaptedQueuePool
|
||
pool_kwargs = {
|
||
'pool_size': 20, # Уменьшен с 30, чтобы не превышать max_connections PostgreSQL
|
||
'max_overflow': 20, # Уменьшен с 50, макс 40 соединений вместо 80
|
||
'pool_timeout': 30, # Уменьшен с 60, быстрее отдавать 503 при перегрузке
|
||
'pool_recycle': 1800, # 30 мин для более быстрого recycling
|
||
'pool_pre_ping': True,
|
||
# Агрессивная очистка мертвых соединений
|
||
'pool_reset_on_return': 'rollback',
|
||
}
|
||
|
||
# ============================================================================
|
||
# ENGINE WITH ADVANCED OPTIMIZATIONS
|
||
# ============================================================================
|
||
|
||
# PostgreSQL-специфичные connect_args
|
||
_pg_connect_args = {
|
||
'server_settings': {
|
||
'application_name': 'remnawave_bot',
|
||
'jit': 'on',
|
||
'statement_timeout': '60000', # 60 секунд
|
||
'idle_in_transaction_session_timeout': '300000', # 5 минут
|
||
},
|
||
'command_timeout': 30, # Уменьшен с 60, быстрее обнаруживать зависшие запросы
|
||
'timeout': 10, # Уменьшен с 60, быстрый провал при недоступности PostgreSQL
|
||
}
|
||
|
||
engine = create_async_engine(
|
||
DATABASE_URL,
|
||
poolclass=poolclass,
|
||
echo='debug' if settings.DEBUG else False,
|
||
future=True,
|
||
# Кеш скомпилированных запросов (правильное размещение)
|
||
query_cache_size=500,
|
||
connect_args=_pg_connect_args if not IS_SQLITE else {},
|
||
execution_options={
|
||
'isolation_level': 'READ COMMITTED',
|
||
},
|
||
**pool_kwargs,
|
||
)
|
||
|
||
# ============================================================================
|
||
# SESSION FACTORY WITH OPTIMIZATIONS
|
||
# ============================================================================
|
||
|
||
AsyncSessionLocal = async_sessionmaker(
|
||
bind=engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
autoflush=False, # Критично для производительности
|
||
autocommit=False,
|
||
)
|
||
|
||
# ============================================================================
|
||
# RETRY LOGIC FOR DATABASE OPERATIONS
|
||
# ============================================================================
|
||
|
||
RETRYABLE_EXCEPTIONS = (OperationalError, InterfaceError, ConnectionRefusedError, OSError, TimeoutError)
|
||
DEFAULT_RETRY_ATTEMPTS = 3
|
||
DEFAULT_RETRY_DELAY = 0.5 # секунды
|
||
|
||
|
||
def with_db_retry(
|
||
attempts: int = DEFAULT_RETRY_ATTEMPTS,
|
||
delay: float = DEFAULT_RETRY_DELAY,
|
||
backoff: float = 2.0,
|
||
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
|
||
"""
|
||
Декоратор для автоматического retry при сбоях подключения к БД.
|
||
|
||
Args:
|
||
attempts: Количество попыток
|
||
delay: Начальная задержка между попытками (секунды)
|
||
backoff: Множитель задержки для каждой следующей попытки
|
||
"""
|
||
|
||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||
@wraps(func)
|
||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||
last_exception: Exception | None = None
|
||
current_delay = delay
|
||
|
||
for attempt in range(1, attempts + 1):
|
||
try:
|
||
return await func(*args, **kwargs)
|
||
except RETRYABLE_EXCEPTIONS as e:
|
||
last_exception = e
|
||
if attempt < attempts:
|
||
logger.warning(
|
||
'Ошибка БД (попытка /): . Повтор через сек...',
|
||
attempt=attempt,
|
||
attempts=attempts,
|
||
e=str(e)[:100],
|
||
current_delay=current_delay,
|
||
)
|
||
await asyncio.sleep(current_delay)
|
||
current_delay *= backoff
|
||
else:
|
||
logger.error('Ошибка БД: все попыток исчерпаны. Последняя ошибка', attempts=attempts, e=str(e))
|
||
|
||
raise last_exception # type: ignore[misc]
|
||
|
||
return wrapper # type: ignore[return-value]
|
||
|
||
return decorator
|
||
|
||
|
||
async def execute_with_retry(
|
||
session: AsyncSession,
|
||
statement,
|
||
attempts: int = DEFAULT_RETRY_ATTEMPTS,
|
||
):
|
||
"""Выполнение SQL с retry логикой."""
|
||
if attempts < 1:
|
||
raise ValueError(f'attempts must be >= 1, got {attempts}')
|
||
|
||
last_exception: Exception | None = None
|
||
delay = DEFAULT_RETRY_DELAY
|
||
|
||
for attempt in range(1, attempts + 1):
|
||
try:
|
||
return await session.execute(statement)
|
||
except RETRYABLE_EXCEPTIONS as e:
|
||
last_exception = e
|
||
if attempt < attempts:
|
||
logger.warning('SQL retry (попытка /)', attempt=attempt, attempts=attempts, e=str(e)[:100])
|
||
await asyncio.sleep(delay)
|
||
delay *= 2
|
||
|
||
raise last_exception # type: ignore[misc]
|
||
|
||
|
||
# ============================================================================
|
||
# QUERY PERFORMANCE MONITORING
|
||
# ============================================================================
|
||
|
||
if settings.DEBUG:
|
||
|
||
@event.listens_for(Engine, 'before_cursor_execute')
|
||
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||
conn.info.setdefault('query_start_time', []).append(time.time())
|
||
logger.debug('🔍 Executing query: ...', statement=statement[:100])
|
||
|
||
@event.listens_for(Engine, 'after_cursor_execute')
|
||
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||
total = time.time() - conn.info['query_start_time'].pop(-1)
|
||
if total > 0.1: # Логируем медленные запросы > 100ms
|
||
logger.warning('🐌 Slow query (s): ...', total=round(total, 3), statement=statement[:100])
|
||
else:
|
||
logger.debug('⚡ Query executed in', total=round(total, 3))
|
||
|
||
# ============================================================================
|
||
# ADVANCED SESSION MANAGER WITH READ REPLICAS
|
||
# ============================================================================
|
||
|
||
HEALTH_CHECK_TIMEOUT = 5.0 # секунды
|
||
|
||
|
||
def _validate_database_url(url: str | None) -> str | None:
|
||
"""Валидация URL базы данных."""
|
||
if not url:
|
||
return None
|
||
url = url.strip()
|
||
if not url or url.isspace():
|
||
return None
|
||
# Простая проверка на валидный формат
|
||
if not ('://' in url or url.startswith('sqlite')):
|
||
logger.warning('Невалидный DATABASE_URL (не содержит ://)')
|
||
return None
|
||
return url
|
||
|
||
|
||
class DatabaseManager:
|
||
"""Продвинутый менеджер БД с поддержкой реплик и кеширования"""
|
||
|
||
def __init__(self):
|
||
self.engine = engine
|
||
self.read_replica_engine: AsyncEngine | None = None
|
||
self._read_replica_session_factory: async_sessionmaker | None = None
|
||
|
||
# Валидация и создание read replica engine
|
||
replica_url = _validate_database_url(getattr(settings, 'DATABASE_READ_REPLICA_URL', None))
|
||
if replica_url:
|
||
try:
|
||
self.read_replica_engine = create_async_engine(
|
||
replica_url,
|
||
poolclass=poolclass,
|
||
pool_size=30, # Больше для read операций
|
||
max_overflow=50,
|
||
pool_pre_ping=True,
|
||
pool_recycle=3600,
|
||
echo=False,
|
||
)
|
||
# Создаём sessionmaker один раз (не при каждом вызове)
|
||
self._read_replica_session_factory = async_sessionmaker(
|
||
bind=self.read_replica_engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
autoflush=False,
|
||
)
|
||
from sqlalchemy.engine import make_url
|
||
|
||
safe_url = make_url(replica_url).render_as_string(hide_password=True)
|
||
logger.info('Read replica настроена', replica_url=safe_url)
|
||
except Exception as e:
|
||
logger.error('Не удалось настроить read replica', e=e)
|
||
self.read_replica_engine = None
|
||
|
||
@asynccontextmanager
|
||
async def session(self, read_only: bool = False):
|
||
"""Контекстный менеджер для работы с сессией БД."""
|
||
# Используем предсозданный sessionmaker вместо создания нового
|
||
if read_only and self._read_replica_session_factory:
|
||
session_factory = self._read_replica_session_factory
|
||
else:
|
||
session_factory = AsyncSessionLocal
|
||
|
||
async with session_factory() as session:
|
||
try:
|
||
yield session
|
||
if not read_only:
|
||
await session.commit()
|
||
except Exception:
|
||
await session.rollback()
|
||
raise
|
||
|
||
async def health_check(self, timeout: float = HEALTH_CHECK_TIMEOUT) -> dict:
|
||
"""
|
||
Проверка здоровья БД с таймаутом.
|
||
|
||
Args:
|
||
timeout: Максимальное время ожидания (секунды)
|
||
"""
|
||
pool = self.engine.pool
|
||
status = 'unhealthy'
|
||
latency = None
|
||
|
||
try:
|
||
async with asyncio.timeout(timeout):
|
||
async with AsyncSessionLocal() as session:
|
||
start = time.time()
|
||
await session.execute(text('SELECT 1'))
|
||
latency = (time.time() - start) * 1000
|
||
status = 'healthy'
|
||
except TimeoutError:
|
||
logger.error('Health check таймаут (сек)', timeout=timeout)
|
||
status = 'timeout'
|
||
except Exception as e:
|
||
logger.error('Database health check failed', e=e)
|
||
status = 'unhealthy'
|
||
|
||
return {
|
||
'status': status,
|
||
'latency_ms': round(latency, 2) if latency else None,
|
||
'pool': _collect_health_pool_metrics(pool),
|
||
}
|
||
|
||
async def health_check_replica(self, timeout: float = HEALTH_CHECK_TIMEOUT) -> dict | None:
|
||
"""Проверка здоровья read replica."""
|
||
if not self.read_replica_engine:
|
||
return None
|
||
|
||
pool = self.read_replica_engine.pool
|
||
status = 'unhealthy'
|
||
latency = None
|
||
|
||
try:
|
||
async with asyncio.timeout(timeout):
|
||
async with self._read_replica_session_factory() as session:
|
||
start = time.time()
|
||
await session.execute(text('SELECT 1'))
|
||
latency = (time.time() - start) * 1000
|
||
status = 'healthy'
|
||
except TimeoutError:
|
||
status = 'timeout'
|
||
except Exception as e:
|
||
logger.error('Read replica health check failed', e=e)
|
||
|
||
return {
|
||
'status': status,
|
||
'latency_ms': round(latency, 2) if latency else None,
|
||
'pool': _collect_health_pool_metrics(pool),
|
||
}
|
||
|
||
|
||
db_manager = DatabaseManager()
|
||
|
||
# ============================================================================
|
||
# SESSION DEPENDENCY FOR FASTAPI/AIOGRAM
|
||
# ============================================================================
|
||
|
||
|
||
async def get_db() -> AsyncGenerator[AsyncSession]:
|
||
"""Стандартная dependency для FastAPI"""
|
||
async with AsyncSessionLocal() as session:
|
||
try:
|
||
yield session
|
||
await session.commit()
|
||
except Exception:
|
||
await session.rollback()
|
||
raise
|
||
|
||
|
||
async def get_db_read_only() -> AsyncGenerator[AsyncSession]:
|
||
"""Read-only dependency для тяжелых SELECT запросов"""
|
||
async with db_manager.session(read_only=True) as session:
|
||
yield session
|
||
|
||
|
||
# ============================================================================
|
||
# BATCH OPERATIONS FOR PERFORMANCE
|
||
# ============================================================================
|
||
|
||
|
||
class BatchOperations:
|
||
"""Утилиты для массовых операций"""
|
||
|
||
@staticmethod
|
||
async def bulk_insert(session: AsyncSession, model, data: list[dict], chunk_size: int = 1000):
|
||
"""Массовая вставка с чанками"""
|
||
for i in range(0, len(data), chunk_size):
|
||
chunk = data[i : i + chunk_size]
|
||
session.add_all([model(**item) for item in chunk])
|
||
await session.flush()
|
||
await session.commit()
|
||
|
||
@staticmethod
|
||
async def bulk_update(session: AsyncSession, model, data: list[dict], chunk_size: int = 1000):
|
||
"""Массовое обновление с чанками"""
|
||
if not data:
|
||
return
|
||
|
||
primary_keys = [column.name for column in model.__table__.primary_key.columns]
|
||
if not primary_keys:
|
||
raise ValueError('Model must have a primary key for bulk_update')
|
||
|
||
updatable_columns = [column.name for column in model.__table__.columns if column.name not in primary_keys]
|
||
|
||
if not updatable_columns:
|
||
raise ValueError('No columns available for update in bulk_update')
|
||
|
||
stmt = (
|
||
model.__table__.update()
|
||
.where(*[getattr(model.__table__.c, pk) == bindparam(pk) for pk in primary_keys])
|
||
.values(**{column: bindparam(column, required=False) for column in updatable_columns})
|
||
)
|
||
|
||
for i in range(0, len(data), chunk_size):
|
||
chunk = data[i : i + chunk_size]
|
||
filtered_chunk = []
|
||
for item in chunk:
|
||
missing_keys = [pk for pk in primary_keys if pk not in item]
|
||
if missing_keys:
|
||
raise ValueError(f'Missing primary key values {missing_keys} for bulk_update')
|
||
|
||
filtered_item = {
|
||
key: value for key, value in item.items() if key in primary_keys or key in updatable_columns
|
||
}
|
||
filtered_chunk.append(filtered_item)
|
||
|
||
await session.execute(stmt, filtered_chunk)
|
||
await session.commit()
|
||
|
||
|
||
batch_ops = BatchOperations()
|
||
|
||
# ============================================================================
|
||
# INITIALIZATION AND CLEANUP
|
||
# ============================================================================
|
||
|
||
|
||
async def close_db() -> None:
|
||
"""Корректное закрытие всех соединений"""
|
||
logger.info('Закрытие соединений с БД...')
|
||
|
||
await engine.dispose()
|
||
|
||
if db_manager.read_replica_engine:
|
||
await db_manager.read_replica_engine.dispose()
|
||
|
||
logger.info('Все подключения к базе данных закрыты')
|
||
|
||
|
||
# ============================================================================
|
||
# SEQUENCE SYNCHRONIZATION (after DB restores)
|
||
# ============================================================================
|
||
|
||
|
||
def _quote_ident(name: str) -> str:
|
||
"""Quote a PostgreSQL identifier to prevent SQL injection."""
|
||
return '"' + name.replace('"', '""') + '"'
|
||
|
||
|
||
async def sync_postgres_sequences() -> bool:
|
||
"""Ensure PostgreSQL sequences match the current max values after restores."""
|
||
if IS_SQLITE:
|
||
logger.debug('Пропускаем синхронизацию последовательностей: SQLite')
|
||
return True
|
||
|
||
try:
|
||
async with engine.begin() as conn:
|
||
result = await conn.execute(
|
||
text(
|
||
"""
|
||
SELECT
|
||
cols.table_schema,
|
||
cols.table_name,
|
||
cols.column_name,
|
||
pg_get_serial_sequence(
|
||
format('%I.%I', cols.table_schema, cols.table_name),
|
||
cols.column_name
|
||
) AS sequence_path
|
||
FROM information_schema.columns AS cols
|
||
WHERE cols.column_default LIKE 'nextval(%'
|
||
AND cols.table_schema NOT IN ('pg_catalog', 'information_schema')
|
||
"""
|
||
)
|
||
)
|
||
|
||
sequences = result.fetchall()
|
||
|
||
if not sequences:
|
||
logger.info('Не найдено последовательностей PostgreSQL для синхронизации')
|
||
return True
|
||
|
||
for table_schema, table_name, column_name, sequence_path in sequences:
|
||
if not sequence_path:
|
||
continue
|
||
|
||
q_col = _quote_ident(column_name)
|
||
q_schema = _quote_ident(table_schema)
|
||
q_table = _quote_ident(table_name)
|
||
|
||
max_result = await conn.execute(text(f'SELECT COALESCE(MAX({q_col}), 0) FROM {q_schema}.{q_table}'))
|
||
max_value = max_result.scalar() or 0
|
||
|
||
# pg_get_serial_sequence returns e.g. '"public"."users_id_seq"'.
|
||
# Split on '"."' to handle quoted identifiers that may contain dots.
|
||
if '"."' in sequence_path:
|
||
seq_schema, seq_name = sequence_path.split('"."', 1)
|
||
seq_schema = seq_schema.strip('"')
|
||
seq_name = seq_name.strip('"')
|
||
else:
|
||
parts = sequence_path.split('.')
|
||
if len(parts) == 2:
|
||
seq_schema, seq_name = parts
|
||
else:
|
||
seq_schema, seq_name = 'public', parts[-1]
|
||
q_seq_schema = _quote_ident(seq_schema)
|
||
q_seq_name = _quote_ident(seq_name)
|
||
current_result = await conn.execute(
|
||
text(f'SELECT last_value, is_called FROM {q_seq_schema}.{q_seq_name}')
|
||
)
|
||
current_row = current_result.fetchone()
|
||
|
||
if current_row:
|
||
current_last, is_called = current_row
|
||
current_next = current_last + 1 if is_called else current_last
|
||
if current_next > max_value:
|
||
continue
|
||
|
||
await conn.execute(
|
||
text(
|
||
"""
|
||
SELECT setval(:sequence_name, :new_value, TRUE)
|
||
"""
|
||
),
|
||
{'sequence_name': sequence_path, 'new_value': max_value},
|
||
)
|
||
logger.info(
|
||
'Последовательность синхронизирована',
|
||
sequence_path=sequence_path,
|
||
max_value=max_value,
|
||
next_id=max_value + 1,
|
||
)
|
||
|
||
return True
|
||
|
||
except Exception as error:
|
||
logger.error('Ошибка синхронизации последовательностей PostgreSQL', error=error)
|
||
return False
|
||
|
||
|
||
# ============================================================================
|
||
# CONNECTION POOL METRICS (для мониторинга)
|
||
# ============================================================================
|
||
|
||
|
||
def _pool_counters(pool):
|
||
"""Return basic pool counters or ``None`` when unsupported."""
|
||
|
||
required_methods = ('size', 'checkedin', 'checkedout', 'overflow')
|
||
|
||
for method_name in required_methods:
|
||
method = getattr(pool, method_name, None)
|
||
if method is None or not callable(method):
|
||
return None
|
||
|
||
size = pool.size()
|
||
checked_in = pool.checkedin()
|
||
checked_out = pool.checkedout()
|
||
overflow = pool.overflow()
|
||
|
||
total_connections = size + overflow
|
||
|
||
return {
|
||
'size': size,
|
||
'checked_in': checked_in,
|
||
'checked_out': checked_out,
|
||
'overflow': overflow,
|
||
'total_connections': total_connections,
|
||
'utilization_percent': (checked_out / total_connections * 100) if total_connections else 0.0,
|
||
}
|
||
|
||
|
||
def _collect_health_pool_metrics(pool) -> dict:
|
||
counters = _pool_counters(pool)
|
||
|
||
if counters is None:
|
||
return {
|
||
'metrics_available': False,
|
||
'size': 0,
|
||
'checked_in': 0,
|
||
'checked_out': 0,
|
||
'overflow': 0,
|
||
'total_connections': 0,
|
||
'utilization': '0.0%',
|
||
}
|
||
|
||
return {
|
||
'metrics_available': True,
|
||
'size': counters['size'],
|
||
'checked_in': counters['checked_in'],
|
||
'checked_out': counters['checked_out'],
|
||
'overflow': counters['overflow'],
|
||
'total_connections': counters['total_connections'],
|
||
'utilization': f'{counters["utilization_percent"]:.1f}%',
|
||
}
|
||
|
||
|
||
async def get_pool_metrics() -> dict:
|
||
"""Детальные метрики пула для Prometheus/Grafana"""
|
||
pool = engine.pool
|
||
|
||
counters = _pool_counters(pool)
|
||
|
||
if counters is None:
|
||
return {
|
||
'metrics_available': False,
|
||
'pool_size': 0,
|
||
'checked_in_connections': 0,
|
||
'checked_out_connections': 0,
|
||
'overflow_connections': 0,
|
||
'total_connections': 0,
|
||
'max_possible_connections': 0,
|
||
'pool_utilization_percent': 0.0,
|
||
}
|
||
|
||
return {
|
||
'metrics_available': True,
|
||
'pool_size': counters['size'],
|
||
'checked_in_connections': counters['checked_in'],
|
||
'checked_out_connections': counters['checked_out'],
|
||
'overflow_connections': counters['overflow'],
|
||
'total_connections': counters['total_connections'],
|
||
'max_possible_connections': counters['total_connections'] + (getattr(pool, '_max_overflow', 0) or 0),
|
||
'pool_utilization_percent': round(counters['utilization_percent'], 2),
|
||
}
|