Files
remnawave-bedolaga-telegram…/app/database/database.py
2026-02-18 08:11:33 +03:00

609 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from functools import wraps
from typing import ParamSpec, TypeVar
import structlog
from sqlalchemy import bindparam, event, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import InterfaceError, OperationalError
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
from app.config import settings
logger = structlog.get_logger(__name__)
T = TypeVar('T')
P = ParamSpec('P')
R = TypeVar('R')
# ============================================================================
# PRODUCTION-GRADE CONNECTION POOLING
# ============================================================================
def _is_sqlite_url(url: str) -> bool:
"""Проверка на SQLite URL (поддерживает sqlite:// и sqlite+aiosqlite://)"""
return url.startswith('sqlite') or ':memory:' in url
DATABASE_URL = settings.get_database_url()
IS_SQLITE = _is_sqlite_url(DATABASE_URL)
if IS_SQLITE:
poolclass = NullPool
pool_kwargs = {}
else:
poolclass = AsyncAdaptedQueuePool
pool_kwargs = {
'pool_size': 20, # Уменьшен с 30, чтобы не превышать max_connections PostgreSQL
'max_overflow': 20, # Уменьшен с 50, макс 40 соединений вместо 80
'pool_timeout': 30, # Уменьшен с 60, быстрее отдавать 503 при перегрузке
'pool_recycle': 1800, # 30 мин для более быстрого recycling
'pool_pre_ping': True,
# Агрессивная очистка мертвых соединений
'pool_reset_on_return': 'rollback',
}
# ============================================================================
# ENGINE WITH ADVANCED OPTIMIZATIONS
# ============================================================================
# PostgreSQL-специфичные connect_args
_pg_connect_args = {
'server_settings': {
'application_name': 'remnawave_bot',
'jit': 'on',
'statement_timeout': '60000', # 60 секунд
'idle_in_transaction_session_timeout': '300000', # 5 минут
},
'command_timeout': 30, # Уменьшен с 60, быстрее обнаруживать зависшие запросы
'timeout': 10, # Уменьшен с 60, быстрый провал при недоступности PostgreSQL
}
engine = create_async_engine(
DATABASE_URL,
poolclass=poolclass,
echo='debug' if settings.DEBUG else False,
future=True,
# Кеш скомпилированных запросов (правильное размещение)
query_cache_size=500,
connect_args=_pg_connect_args if not IS_SQLITE else {},
execution_options={
'isolation_level': 'READ COMMITTED',
},
**pool_kwargs,
)
# ============================================================================
# SESSION FACTORY WITH OPTIMIZATIONS
# ============================================================================
AsyncSessionLocal = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False, # Критично для производительности
autocommit=False,
)
# ============================================================================
# RETRY LOGIC FOR DATABASE OPERATIONS
# ============================================================================
RETRYABLE_EXCEPTIONS = (OperationalError, InterfaceError, ConnectionRefusedError, OSError, TimeoutError)
DEFAULT_RETRY_ATTEMPTS = 3
DEFAULT_RETRY_DELAY = 0.5 # секунды
def with_db_retry(
attempts: int = DEFAULT_RETRY_ATTEMPTS,
delay: float = DEFAULT_RETRY_DELAY,
backoff: float = 2.0,
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
"""
Декоратор для автоматического retry при сбоях подключения к БД.
Args:
attempts: Количество попыток
delay: Начальная задержка между попытками (секунды)
backoff: Множитель задержки для каждой следующей попытки
"""
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
last_exception: Exception | None = None
current_delay = delay
for attempt in range(1, attempts + 1):
try:
return await func(*args, **kwargs)
except RETRYABLE_EXCEPTIONS as e:
last_exception = e
if attempt < attempts:
logger.warning(
'Ошибка БД (попытка /): . Повтор через сек...',
attempt=attempt,
attempts=attempts,
e=str(e)[:100],
current_delay=current_delay,
)
await asyncio.sleep(current_delay)
current_delay *= backoff
else:
logger.error('Ошибка БД: все попыток исчерпаны. Последняя ошибка', attempts=attempts, e=str(e))
raise last_exception # type: ignore[misc]
return wrapper # type: ignore[return-value]
return decorator
async def execute_with_retry(
session: AsyncSession,
statement,
attempts: int = DEFAULT_RETRY_ATTEMPTS,
):
"""Выполнение SQL с retry логикой."""
if attempts < 1:
raise ValueError(f'attempts must be >= 1, got {attempts}')
last_exception: Exception | None = None
delay = DEFAULT_RETRY_DELAY
for attempt in range(1, attempts + 1):
try:
return await session.execute(statement)
except RETRYABLE_EXCEPTIONS as e:
last_exception = e
if attempt < attempts:
logger.warning('SQL retry (попытка /)', attempt=attempt, attempts=attempts, e=str(e)[:100])
await asyncio.sleep(delay)
delay *= 2
raise last_exception # type: ignore[misc]
# ============================================================================
# QUERY PERFORMANCE MONITORING
# ============================================================================
if settings.DEBUG:
@event.listens_for(Engine, 'before_cursor_execute')
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
conn.info.setdefault('query_start_time', []).append(time.time())
logger.debug('🔍 Executing query: ...', statement=statement[:100])
@event.listens_for(Engine, 'after_cursor_execute')
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
total = time.time() - conn.info['query_start_time'].pop(-1)
if total > 0.1: # Логируем медленные запросы > 100ms
logger.warning('🐌 Slow query (s): ...', total=round(total, 3), statement=statement[:100])
else:
logger.debug('⚡ Query executed in', total=round(total, 3))
# ============================================================================
# ADVANCED SESSION MANAGER WITH READ REPLICAS
# ============================================================================
HEALTH_CHECK_TIMEOUT = 5.0 # секунды
def _validate_database_url(url: str | None) -> str | None:
"""Валидация URL базы данных."""
if not url:
return None
url = url.strip()
if not url or url.isspace():
return None
# Простая проверка на валидный формат
if not ('://' in url or url.startswith('sqlite')):
logger.warning('Невалидный DATABASE_URL (не содержит ://)')
return None
return url
class DatabaseManager:
"""Продвинутый менеджер БД с поддержкой реплик и кеширования"""
def __init__(self):
self.engine = engine
self.read_replica_engine: AsyncEngine | None = None
self._read_replica_session_factory: async_sessionmaker | None = None
# Валидация и создание read replica engine
replica_url = _validate_database_url(getattr(settings, 'DATABASE_READ_REPLICA_URL', None))
if replica_url:
try:
self.read_replica_engine = create_async_engine(
replica_url,
poolclass=poolclass,
pool_size=30, # Больше для read операций
max_overflow=50,
pool_pre_ping=True,
pool_recycle=3600,
echo=False,
)
# Создаём sessionmaker один раз (не при каждом вызове)
self._read_replica_session_factory = async_sessionmaker(
bind=self.read_replica_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
from sqlalchemy.engine import make_url
safe_url = make_url(replica_url).render_as_string(hide_password=True)
logger.info('Read replica настроена', replica_url=safe_url)
except Exception as e:
logger.error('Не удалось настроить read replica', e=e)
self.read_replica_engine = None
@asynccontextmanager
async def session(self, read_only: bool = False):
"""Контекстный менеджер для работы с сессией БД."""
# Используем предсозданный sessionmaker вместо создания нового
if read_only and self._read_replica_session_factory:
session_factory = self._read_replica_session_factory
else:
session_factory = AsyncSessionLocal
async with session_factory() as session:
try:
yield session
if not read_only:
await session.commit()
except Exception:
await session.rollback()
raise
async def health_check(self, timeout: float = HEALTH_CHECK_TIMEOUT) -> dict:
"""
Проверка здоровья БД с таймаутом.
Args:
timeout: Максимальное время ожидания (секунды)
"""
pool = self.engine.pool
status = 'unhealthy'
latency = None
try:
async with asyncio.timeout(timeout):
async with AsyncSessionLocal() as session:
start = time.time()
await session.execute(text('SELECT 1'))
latency = (time.time() - start) * 1000
status = 'healthy'
except TimeoutError:
logger.error('Health check таймаут (сек)', timeout=timeout)
status = 'timeout'
except Exception as e:
logger.error('Database health check failed', e=e)
status = 'unhealthy'
return {
'status': status,
'latency_ms': round(latency, 2) if latency else None,
'pool': _collect_health_pool_metrics(pool),
}
async def health_check_replica(self, timeout: float = HEALTH_CHECK_TIMEOUT) -> dict | None:
"""Проверка здоровья read replica."""
if not self.read_replica_engine:
return None
pool = self.read_replica_engine.pool
status = 'unhealthy'
latency = None
try:
async with asyncio.timeout(timeout):
async with self._read_replica_session_factory() as session:
start = time.time()
await session.execute(text('SELECT 1'))
latency = (time.time() - start) * 1000
status = 'healthy'
except TimeoutError:
status = 'timeout'
except Exception as e:
logger.error('Read replica health check failed', e=e)
return {
'status': status,
'latency_ms': round(latency, 2) if latency else None,
'pool': _collect_health_pool_metrics(pool),
}
db_manager = DatabaseManager()
# ============================================================================
# SESSION DEPENDENCY FOR FASTAPI/AIOGRAM
# ============================================================================
async def get_db() -> AsyncGenerator[AsyncSession]:
"""Стандартная dependency для FastAPI"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
async def get_db_read_only() -> AsyncGenerator[AsyncSession]:
"""Read-only dependency для тяжелых SELECT запросов"""
async with db_manager.session(read_only=True) as session:
yield session
# ============================================================================
# BATCH OPERATIONS FOR PERFORMANCE
# ============================================================================
class BatchOperations:
"""Утилиты для массовых операций"""
@staticmethod
async def bulk_insert(session: AsyncSession, model, data: list[dict], chunk_size: int = 1000):
"""Массовая вставка с чанками"""
for i in range(0, len(data), chunk_size):
chunk = data[i : i + chunk_size]
session.add_all([model(**item) for item in chunk])
await session.flush()
await session.commit()
@staticmethod
async def bulk_update(session: AsyncSession, model, data: list[dict], chunk_size: int = 1000):
"""Массовое обновление с чанками"""
if not data:
return
primary_keys = [column.name for column in model.__table__.primary_key.columns]
if not primary_keys:
raise ValueError('Model must have a primary key for bulk_update')
updatable_columns = [column.name for column in model.__table__.columns if column.name not in primary_keys]
if not updatable_columns:
raise ValueError('No columns available for update in bulk_update')
stmt = (
model.__table__.update()
.where(*[getattr(model.__table__.c, pk) == bindparam(pk) for pk in primary_keys])
.values(**{column: bindparam(column, required=False) for column in updatable_columns})
)
for i in range(0, len(data), chunk_size):
chunk = data[i : i + chunk_size]
filtered_chunk = []
for item in chunk:
missing_keys = [pk for pk in primary_keys if pk not in item]
if missing_keys:
raise ValueError(f'Missing primary key values {missing_keys} for bulk_update')
filtered_item = {
key: value for key, value in item.items() if key in primary_keys or key in updatable_columns
}
filtered_chunk.append(filtered_item)
await session.execute(stmt, filtered_chunk)
await session.commit()
batch_ops = BatchOperations()
# ============================================================================
# INITIALIZATION AND CLEANUP
# ============================================================================
async def close_db() -> None:
"""Корректное закрытие всех соединений"""
logger.info('Закрытие соединений с БД...')
await engine.dispose()
if db_manager.read_replica_engine:
await db_manager.read_replica_engine.dispose()
logger.info('Все подключения к базе данных закрыты')
# ============================================================================
# SEQUENCE SYNCHRONIZATION (after DB restores)
# ============================================================================
def _quote_ident(name: str) -> str:
"""Quote a PostgreSQL identifier to prevent SQL injection."""
return '"' + name.replace('"', '""') + '"'
async def sync_postgres_sequences() -> bool:
"""Ensure PostgreSQL sequences match the current max values after restores."""
if IS_SQLITE:
logger.debug('Пропускаем синхронизацию последовательностей: SQLite')
return True
try:
async with engine.begin() as conn:
result = await conn.execute(
text(
"""
SELECT
cols.table_schema,
cols.table_name,
cols.column_name,
pg_get_serial_sequence(
format('%I.%I', cols.table_schema, cols.table_name),
cols.column_name
) AS sequence_path
FROM information_schema.columns AS cols
WHERE cols.column_default LIKE 'nextval(%'
AND cols.table_schema NOT IN ('pg_catalog', 'information_schema')
"""
)
)
sequences = result.fetchall()
if not sequences:
logger.info('Не найдено последовательностей PostgreSQL для синхронизации')
return True
for table_schema, table_name, column_name, sequence_path in sequences:
if not sequence_path:
continue
q_col = _quote_ident(column_name)
q_schema = _quote_ident(table_schema)
q_table = _quote_ident(table_name)
max_result = await conn.execute(text(f'SELECT COALESCE(MAX({q_col}), 0) FROM {q_schema}.{q_table}'))
max_value = max_result.scalar() or 0
# pg_get_serial_sequence returns e.g. '"public"."users_id_seq"'.
# Split on '"."' to handle quoted identifiers that may contain dots.
if '"."' in sequence_path:
seq_schema, seq_name = sequence_path.split('"."', 1)
seq_schema = seq_schema.strip('"')
seq_name = seq_name.strip('"')
else:
parts = sequence_path.split('.')
if len(parts) == 2:
seq_schema, seq_name = parts
else:
seq_schema, seq_name = 'public', parts[-1]
q_seq_schema = _quote_ident(seq_schema)
q_seq_name = _quote_ident(seq_name)
current_result = await conn.execute(
text(f'SELECT last_value, is_called FROM {q_seq_schema}.{q_seq_name}')
)
current_row = current_result.fetchone()
if current_row:
current_last, is_called = current_row
current_next = current_last + 1 if is_called else current_last
if current_next > max_value:
continue
await conn.execute(
text(
"""
SELECT setval(:sequence_name, :new_value, TRUE)
"""
),
{'sequence_name': sequence_path, 'new_value': max_value},
)
logger.info(
'Последовательность синхронизирована',
sequence_path=sequence_path,
max_value=max_value,
next_id=max_value + 1,
)
return True
except Exception as error:
logger.error('Ошибка синхронизации последовательностей PostgreSQL', error=error)
return False
# ============================================================================
# CONNECTION POOL METRICS (для мониторинга)
# ============================================================================
def _pool_counters(pool):
"""Return basic pool counters or ``None`` when unsupported."""
required_methods = ('size', 'checkedin', 'checkedout', 'overflow')
for method_name in required_methods:
method = getattr(pool, method_name, None)
if method is None or not callable(method):
return None
size = pool.size()
checked_in = pool.checkedin()
checked_out = pool.checkedout()
overflow = pool.overflow()
total_connections = size + overflow
return {
'size': size,
'checked_in': checked_in,
'checked_out': checked_out,
'overflow': overflow,
'total_connections': total_connections,
'utilization_percent': (checked_out / total_connections * 100) if total_connections else 0.0,
}
def _collect_health_pool_metrics(pool) -> dict:
counters = _pool_counters(pool)
if counters is None:
return {
'metrics_available': False,
'size': 0,
'checked_in': 0,
'checked_out': 0,
'overflow': 0,
'total_connections': 0,
'utilization': '0.0%',
}
return {
'metrics_available': True,
'size': counters['size'],
'checked_in': counters['checked_in'],
'checked_out': counters['checked_out'],
'overflow': counters['overflow'],
'total_connections': counters['total_connections'],
'utilization': f'{counters["utilization_percent"]:.1f}%',
}
async def get_pool_metrics() -> dict:
"""Детальные метрики пула для Prometheus/Grafana"""
pool = engine.pool
counters = _pool_counters(pool)
if counters is None:
return {
'metrics_available': False,
'pool_size': 0,
'checked_in_connections': 0,
'checked_out_connections': 0,
'overflow_connections': 0,
'total_connections': 0,
'max_possible_connections': 0,
'pool_utilization_percent': 0.0,
}
return {
'metrics_available': True,
'pool_size': counters['size'],
'checked_in_connections': counters['checked_in'],
'checked_out_connections': counters['checked_out'],
'overflow_connections': counters['overflow'],
'total_connections': counters['total_connections'],
'max_possible_connections': counters['total_connections'] + (getattr(pool, '_max_overflow', 0) or 0),
'pool_utilization_percent': round(counters['utilization_percent'], 2),
}