mirror of
https://github.com/BEDOLAGA-DEV/remnawave-bedolaga-telegram-bot.git
synced 2026-03-02 16:20:49 +00:00
- Multi-channel subscription enforcement via middleware, events, and cabinet API - 3-layer cache architecture: Redis -> PostgreSQL -> rate-limited Telegram API - ChatMemberUpdated event-driven tracking with automatic VPN access control - Admin management via bot FSM handler and REST API with full CRUD - Channel ID normalization: @username resolved to numeric ID at creation time - Fail-closed error handling: API errors deny access (security-first) - Background reconciliation with keyset pagination (100 per batch) - Per-user rate limiting on subscription check button (5s cooldown) - Redis connection pooling via cache singleton (no per-request connections) - Database: channel_id index, multi-row upsert optimization - Localization: en, ru, zh, fa, ua translations for all new strings - Frontend blocking UI with channel list and subscription status - Admin channel management page with toggle, delete, and create
439 lines
15 KiB
Python
439 lines
15 KiB
Python
import json
|
||
from datetime import timedelta
|
||
from typing import Any
|
||
|
||
import redis.asyncio as redis
|
||
import structlog
|
||
|
||
from app.config import settings
|
||
|
||
|
||
logger = structlog.get_logger(__name__)
|
||
|
||
|
||
class CacheService:
|
||
def __init__(self):
|
||
self.redis_client: redis.Redis | None = None
|
||
self._connected = False
|
||
|
||
async def connect(self):
|
||
try:
|
||
self.redis_client = redis.from_url(settings.REDIS_URL)
|
||
await self.redis_client.ping()
|
||
self._connected = True
|
||
logger.info('✅ Подключение к Redis кешу установлено')
|
||
except Exception as e:
|
||
logger.warning('⚠️ Не удалось подключиться к Redis', error=e)
|
||
self._connected = False
|
||
|
||
async def disconnect(self):
|
||
if self.redis_client:
|
||
await self.redis_client.close()
|
||
self._connected = False
|
||
|
||
async def get(self, key: str) -> Any | None:
|
||
if not self._connected:
|
||
return None
|
||
|
||
try:
|
||
value = await self.redis_client.get(key)
|
||
if value:
|
||
return json.loads(value)
|
||
return None
|
||
except Exception as e:
|
||
logger.error('Ошибка получения из кеша', key=key, error=e)
|
||
return None
|
||
|
||
async def set(self, key: str, value: Any, expire: int | timedelta = None) -> bool:
|
||
if not self._connected:
|
||
return False
|
||
|
||
try:
|
||
serialized_value = json.dumps(value, default=str)
|
||
|
||
if isinstance(expire, timedelta):
|
||
expire = int(expire.total_seconds())
|
||
|
||
await self.redis_client.set(key, serialized_value, ex=expire)
|
||
return True
|
||
except Exception as e:
|
||
logger.error('Ошибка записи в кеш', key=key, error=e)
|
||
return False
|
||
|
||
async def setnx(self, key: str, value: Any, expire: int | timedelta = None) -> bool:
|
||
"""Атомарная операция SET IF NOT EXISTS.
|
||
|
||
Устанавливает значение только если ключ не существует.
|
||
Возвращает True если значение было установлено, False если ключ уже существовал.
|
||
"""
|
||
if not self._connected:
|
||
return False
|
||
|
||
try:
|
||
serialized_value = json.dumps(value, default=str)
|
||
|
||
if isinstance(expire, timedelta):
|
||
expire = int(expire.total_seconds())
|
||
|
||
# SET с NX возвращает True если установлено, None если ключ существует
|
||
result = await self.redis_client.set(key, serialized_value, ex=expire, nx=True)
|
||
return result is True
|
||
except Exception as e:
|
||
logger.error('Ошибка setnx в кеш', key=key, error=e)
|
||
return False
|
||
|
||
async def delete(self, key: str) -> bool:
|
||
if not self._connected:
|
||
return False
|
||
|
||
try:
|
||
deleted = await self.redis_client.delete(key)
|
||
return deleted > 0
|
||
except Exception as e:
|
||
logger.error('Ошибка удаления из кеша', key=key, error=e)
|
||
return False
|
||
|
||
async def delete_pattern(self, pattern: str) -> int:
|
||
if not self._connected:
|
||
return 0
|
||
|
||
try:
|
||
keys = await self.redis_client.keys(pattern)
|
||
if not keys:
|
||
return 0
|
||
|
||
deleted = await self.redis_client.delete(*keys)
|
||
return int(deleted)
|
||
except Exception as e:
|
||
logger.error('Ошибка удаления ключей по шаблону', pattern=pattern, error=e)
|
||
return 0
|
||
|
||
async def exists(self, key: str) -> bool:
|
||
if not self._connected:
|
||
return False
|
||
|
||
try:
|
||
return await self.redis_client.exists(key)
|
||
except Exception as e:
|
||
logger.error('Ошибка проверки существования в кеше', key=key, error=e)
|
||
return False
|
||
|
||
async def expire(self, key: str, seconds: int) -> bool:
|
||
if not self._connected:
|
||
return False
|
||
|
||
try:
|
||
return await self.redis_client.expire(key, seconds)
|
||
except Exception as e:
|
||
logger.error('Ошибка установки TTL для', key=key, error=e)
|
||
return False
|
||
|
||
async def get_keys(self, pattern: str = '*') -> list:
|
||
if not self._connected:
|
||
return []
|
||
|
||
try:
|
||
keys = await self.redis_client.keys(pattern)
|
||
return [key.decode() if isinstance(key, bytes) else key for key in keys]
|
||
except Exception as e:
|
||
logger.error('Ошибка получения ключей по паттерну', pattern=pattern, error=e)
|
||
return []
|
||
|
||
async def flush_all(self) -> bool:
|
||
if not self._connected:
|
||
return False
|
||
|
||
try:
|
||
await self.redis_client.flushall()
|
||
logger.info('🗑️ Кеш полностью очищен')
|
||
return True
|
||
except Exception as e:
|
||
logger.error('Ошибка очистки кеша', error=e)
|
||
return False
|
||
|
||
async def increment(self, key: str, amount: int = 1) -> int | None:
|
||
if not self._connected:
|
||
return None
|
||
|
||
try:
|
||
return await self.redis_client.incrby(key, amount)
|
||
except Exception as e:
|
||
logger.error('Ошибка инкремента', key=key, error=e)
|
||
return None
|
||
|
||
async def set_hash(self, name: str, mapping: dict, expire: int = None) -> bool:
|
||
if not self._connected:
|
||
return False
|
||
|
||
try:
|
||
await self.redis_client.hset(name, mapping=mapping)
|
||
if expire:
|
||
await self.redis_client.expire(name, expire)
|
||
return True
|
||
except Exception as e:
|
||
logger.error('Ошибка записи хеша', name=name, error=e)
|
||
return False
|
||
|
||
async def get_hash(self, name: str, key: str = None) -> dict | str | None:
|
||
if not self._connected:
|
||
return None
|
||
|
||
try:
|
||
if key:
|
||
value = await self.redis_client.hget(name, key)
|
||
return value.decode() if value else None
|
||
hash_data = await self.redis_client.hgetall(name)
|
||
return {k.decode(): v.decode() for k, v in hash_data.items()}
|
||
except Exception as e:
|
||
logger.error('Ошибка получения хеша', name=name, error=e)
|
||
return None
|
||
|
||
async def lpush(self, key: str, value: Any) -> bool:
|
||
"""Добавить элемент в начало списка (очереди)."""
|
||
if not self._connected:
|
||
return False
|
||
|
||
try:
|
||
serialized = json.dumps(value, default=str)
|
||
await self.redis_client.lpush(key, serialized)
|
||
return True
|
||
except Exception as e:
|
||
logger.error('Ошибка добавления в очередь', key=key, error=e)
|
||
return False
|
||
|
||
async def rpop(self, key: str) -> Any | None:
|
||
"""Извлечь элемент из конца списка (FIFO очередь)."""
|
||
if not self._connected:
|
||
return None
|
||
|
||
try:
|
||
value = await self.redis_client.rpop(key)
|
||
if value:
|
||
return json.loads(value)
|
||
return None
|
||
except Exception as e:
|
||
logger.error('Ошибка извлечения из очереди', key=key, error=e)
|
||
return None
|
||
|
||
async def llen(self, key: str) -> int:
|
||
"""Получить длину списка (очереди)."""
|
||
if not self._connected:
|
||
return 0
|
||
|
||
try:
|
||
return await self.redis_client.llen(key)
|
||
except Exception as e:
|
||
logger.error('Ошибка получения длины очереди', key=key, error=e)
|
||
return 0
|
||
|
||
async def lrange(self, key: str, start: int = 0, end: int = -1) -> list:
|
||
"""Получить элементы списка без удаления."""
|
||
if not self._connected:
|
||
return []
|
||
|
||
try:
|
||
items = await self.redis_client.lrange(key, start, end)
|
||
return [json.loads(item) for item in items]
|
||
except Exception as e:
|
||
logger.error('Ошибка чтения очереди', key=key, error=e)
|
||
return []
|
||
|
||
|
||
cache = CacheService()
|
||
|
||
|
||
def cache_key(*parts) -> str:
|
||
return ':'.join(str(part) for part in parts)
|
||
|
||
|
||
async def cached_function(key: str, expire: int = 300):
|
||
def decorator(func):
|
||
async def wrapper(*args, **kwargs):
|
||
cache_result = await cache.get(key)
|
||
if cache_result is not None:
|
||
return cache_result
|
||
|
||
result = await func(*args, **kwargs)
|
||
await cache.set(key, result, expire)
|
||
return result
|
||
|
||
return wrapper
|
||
|
||
return decorator
|
||
|
||
|
||
class UserCache:
|
||
@staticmethod
|
||
async def get_user_data(user_id: int) -> dict | None:
|
||
key = cache_key('user', user_id)
|
||
return await cache.get(key)
|
||
|
||
@staticmethod
|
||
async def set_user_data(user_id: int, data: dict, expire: int = 3600) -> bool:
|
||
key = cache_key('user', user_id)
|
||
return await cache.set(key, data, expire)
|
||
|
||
@staticmethod
|
||
async def delete_user_data(user_id: int) -> bool:
|
||
key = cache_key('user', user_id)
|
||
return await cache.delete(key)
|
||
|
||
@staticmethod
|
||
async def get_user_session(user_id: int, session_key: str) -> Any | None:
|
||
key = cache_key('session', user_id, session_key)
|
||
return await cache.get(key)
|
||
|
||
@staticmethod
|
||
async def set_user_session(user_id: int, session_key: str, data: Any, expire: int = 1800) -> bool:
|
||
key = cache_key('session', user_id, session_key)
|
||
return await cache.set(key, data, expire)
|
||
|
||
@staticmethod
|
||
async def delete_user_session(user_id: int, session_key: str) -> bool:
|
||
key = cache_key('session', user_id, session_key)
|
||
return await cache.delete(key)
|
||
|
||
|
||
class SystemCache:
|
||
@staticmethod
|
||
async def get_system_stats() -> dict | None:
|
||
return await cache.get('system:stats')
|
||
|
||
@staticmethod
|
||
async def set_system_stats(stats: dict, expire: int = 300) -> bool:
|
||
return await cache.set('system:stats', stats, expire)
|
||
|
||
@staticmethod
|
||
async def get_nodes_status() -> list | None:
|
||
return await cache.get('remnawave:nodes')
|
||
|
||
@staticmethod
|
||
async def set_nodes_status(nodes: list, expire: int = 60) -> bool:
|
||
return await cache.set('remnawave:nodes', nodes, expire)
|
||
|
||
@staticmethod
|
||
async def get_daily_stats(date: str) -> dict | None:
|
||
key = cache_key('stats', 'daily', date)
|
||
return await cache.get(key)
|
||
|
||
@staticmethod
|
||
async def set_daily_stats(date: str, stats: dict) -> bool:
|
||
key = cache_key('stats', 'daily', date)
|
||
return await cache.set(key, stats, 86400) # 24 часа
|
||
|
||
|
||
class RateLimitCache:
|
||
@staticmethod
|
||
async def is_rate_limited(user_id: int, action: str, limit: int, window: int) -> bool:
|
||
key = cache_key('rate_limit', user_id, action)
|
||
current = await cache.get(key)
|
||
|
||
if current is None:
|
||
await cache.set(key, 1, window)
|
||
return False
|
||
|
||
if current >= limit:
|
||
return True
|
||
|
||
await cache.increment(key)
|
||
return False
|
||
|
||
@staticmethod
|
||
async def reset_rate_limit(user_id: int, action: str) -> bool:
|
||
key = cache_key('rate_limit', user_id, action)
|
||
return await cache.delete(key)
|
||
|
||
|
||
class ChannelSubCache:
|
||
"""Cache for user channel subscription statuses.
|
||
|
||
Redis keys:
|
||
- channel_sub:{telegram_id}:{channel_id} -> "1" or "0" (TTL 600s)
|
||
- required_channels:active -> JSON list of active channels (TTL 60s)
|
||
"""
|
||
|
||
SUB_TTL = 600 # 10 min -- individual user subscription status
|
||
CHANNELS_TTL = 60 # 1 min -- list of required channels
|
||
|
||
@staticmethod
|
||
async def get_sub_status(telegram_id: int, channel_id: str) -> bool | None:
|
||
"""Get subscription status from cache. None = cache miss."""
|
||
key = cache_key('channel_sub', telegram_id, channel_id)
|
||
result = await cache.get(key)
|
||
if result is None:
|
||
return None
|
||
return result == 1
|
||
|
||
@staticmethod
|
||
async def get_sub_statuses(telegram_id: int, channel_ids: list[str]) -> dict[str, bool | None]:
|
||
"""Batch-fetch subscription statuses via Redis MGET (single round-trip).
|
||
|
||
Returns {channel_id: True/False/None} where None = cache miss.
|
||
Falls back to sequential gets if Redis pipeline is unavailable.
|
||
"""
|
||
if not channel_ids:
|
||
return {}
|
||
|
||
if not cache._connected or cache.redis_client is None:
|
||
return dict.fromkeys(channel_ids, None)
|
||
|
||
keys = [cache_key('channel_sub', telegram_id, ch_id) for ch_id in channel_ids]
|
||
try:
|
||
raw_values = await cache.redis_client.mget(keys)
|
||
except Exception as e:
|
||
logger.warning('Redis MGET failed, falling back to sequential', error=str(e))
|
||
result: dict[str, bool | None] = {}
|
||
for ch_id in channel_ids:
|
||
result[ch_id] = await ChannelSubCache.get_sub_status(telegram_id, ch_id)
|
||
return result
|
||
|
||
statuses: dict[str, bool | None] = {}
|
||
for ch_id, raw in zip(channel_ids, raw_values, strict=True):
|
||
if raw is None:
|
||
statuses[ch_id] = None
|
||
else:
|
||
try:
|
||
parsed = json.loads(raw)
|
||
statuses[ch_id] = parsed == 1
|
||
except (ValueError, TypeError):
|
||
statuses[ch_id] = None
|
||
return statuses
|
||
|
||
@staticmethod
|
||
async def set_sub_status(telegram_id: int, channel_id: str, is_member: bool) -> None:
|
||
key = cache_key('channel_sub', telegram_id, channel_id)
|
||
await cache.set(key, 1 if is_member else 0, expire=ChannelSubCache.SUB_TTL)
|
||
|
||
@staticmethod
|
||
async def invalidate_sub(telegram_id: int, channel_id: str) -> None:
|
||
key = cache_key('channel_sub', telegram_id, channel_id)
|
||
await cache.delete(key)
|
||
|
||
@staticmethod
|
||
async def invalidate_user_channels(telegram_id: int, channel_ids: list[str]) -> None:
|
||
"""Invalidate specific channel keys for a user using single Redis DELETE.
|
||
|
||
Uses multi-key DELETE (O(K)) instead of delete_pattern() which uses KEYS (O(N)).
|
||
At 100k users * 5 channels = 500k keys, KEYS would block Redis for seconds.
|
||
"""
|
||
if not channel_ids or not cache._connected or not cache.redis_client:
|
||
return
|
||
keys = [cache_key('channel_sub', telegram_id, ch_id) for ch_id in channel_ids]
|
||
try:
|
||
await cache.redis_client.delete(*keys)
|
||
except Exception as e:
|
||
logger.warning('Failed to invalidate user channel cache', telegram_id=telegram_id, error=e)
|
||
|
||
@staticmethod
|
||
async def get_required_channels() -> list[dict] | None:
|
||
"""Get the list of required channels from cache."""
|
||
return await cache.get('required_channels:active')
|
||
|
||
@staticmethod
|
||
async def set_required_channels(channels: list[dict]) -> None:
|
||
await cache.set('required_channels:active', channels, expire=ChannelSubCache.CHANNELS_TTL)
|
||
|
||
@staticmethod
|
||
async def invalidate_channels() -> None:
|
||
await cache.delete('required_channels:active')
|