diff --git a/app/config.py b/app/config.py index 1bf109b2..196f9817 100644 --- a/app/config.py +++ b/app/config.py @@ -91,6 +91,7 @@ class Settings(BaseSettings): TRIAL_PAYMENT_ENABLED: bool = False TRIAL_ACTIVATION_PRICE: int = 0 TRIAL_USER_TAG: Optional[str] = None + TRIAL_INTERNAL_SQUADS: Optional[str] = None DEFAULT_TRAFFIC_LIMIT_GB: int = 100 DEFAULT_DEVICE_LIMIT: int = 1 DEFAULT_TRAFFIC_RESET_STRATEGY: str = "MONTH" @@ -816,6 +817,28 @@ class Settings(BaseSettings): def get_trial_user_tag(self) -> Optional[str]: return self._normalize_user_tag(self.TRIAL_USER_TAG, "TRIAL_USER_TAG") + def get_trial_internal_squads(self) -> list[str]: + raw_value = self.TRIAL_INTERNAL_SQUADS + if raw_value is None: + return [] + + if isinstance(raw_value, str): + items = [item.strip() for item in re.split(r"[,\n]", raw_value) if item.strip()] + elif isinstance(raw_value, (list, tuple, set)): + items = [str(item).strip() for item in raw_value if str(item).strip()] + else: + return [] + + seen = set() + unique_items: list[str] = [] + for item in items: + lowered = item.lower() + if lowered in seen: + continue + seen.add(lowered) + unique_items.append(item) + return unique_items + def get_paid_subscription_user_tag(self) -> Optional[str]: return self._normalize_user_tag( self.PAID_SUBSCRIPTION_USER_TAG, diff --git a/app/database/crud/subscription.py b/app/database/crud/subscription.py index 0ba9f72e..4693f478 100644 --- a/app/database/crud/subscription.py +++ b/app/database/crud/subscription.py @@ -72,6 +72,20 @@ async def create_trial_subscription( end_date = datetime.utcnow() + timedelta(days=duration_days) + trial_internal_squads = settings.get_trial_internal_squads() + trial_user: Optional[User] = None + if trial_internal_squads: + try: + trial_user = await db.get(User, user_id) + if trial_user is not None: + trial_user.active_internal_squads = trial_internal_squads + except Exception as error: + logger.warning( + "Не удалось применить internal squads для триала пользователя %s: %s", + user_id, + error, + ) + subscription = Subscription( user_id=user_id, status=SubscriptionStatus.ACTIVE.value, @@ -131,10 +145,30 @@ async def create_paid_subscription( ) -> Subscription: end_date = datetime.utcnow() + timedelta(days=duration_days) - + if device_limit is None: device_limit = settings.DEFAULT_DEVICE_LIMIT + trial_internal_squads = settings.get_trial_internal_squads() + if trial_internal_squads: + try: + paid_user = await db.get(User, user_id) + if paid_user and paid_user.active_internal_squads: + current_set = { + str(item).strip().lower() + for item in paid_user.active_internal_squads + if str(item).strip() + } + trial_set = {item.lower() for item in trial_internal_squads} + if current_set == trial_set: + paid_user.active_internal_squads = [] + except Exception as error: + logger.warning( + "Не удалось сбросить trial internal squads при покупке для пользователя %s: %s", + user_id, + error, + ) + subscription = Subscription( user_id=user_id, status=SubscriptionStatus.ACTIVE.value, diff --git a/app/database/crud/user.py b/app/database/crud/user.py index 2f944df2..88ef9798 100644 --- a/app/database/crud/user.py +++ b/app/database/crud/user.py @@ -2,7 +2,7 @@ import logging import secrets import string from datetime import datetime, timedelta -from typing import Optional, List, Dict +from typing import Optional, List, Dict, Iterable from sqlalchemy import select, and_, or_, func, case, nullslast, text from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload, joinedload @@ -34,6 +34,26 @@ def generate_referral_code() -> str: return f"ref{code_suffix}" +def _normalize_internal_squads(value: Optional[Iterable[str]]) -> Optional[list[str]]: + if value is None: + return None + + try: + items = [str(item).strip() for item in value if str(item).strip()] + except TypeError: + return None + + seen = set() + normalized: list[str] = [] + for item in items: + lowered = item.lower() + if lowered in seen: + continue + seen.add(lowered) + normalized.append(item) + return normalized + + async def get_user_by_id(db: AsyncSession, user_id: int) -> Optional[User]: result = await db.execute( select(User) @@ -171,7 +191,8 @@ async def create_user_no_commit( last_name: str = None, language: str = "ru", referred_by_id: int = None, - referral_code: str = None + referral_code: str = None, + active_internal_squads: Optional[Iterable[str]] = None, ) -> User: """ Создает пользователя без немедленного коммита для пакетной обработки @@ -197,6 +218,7 @@ async def create_user_no_commit( has_had_paid_subscription=False, has_made_first_topup=False, promo_group_id=promo_group_id, + active_internal_squads=_normalize_internal_squads(active_internal_squads), ) db.add(user) @@ -222,7 +244,8 @@ async def create_user( last_name: str = None, language: str = "ru", referred_by_id: int = None, - referral_code: str = None + referral_code: str = None, + active_internal_squads: Optional[Iterable[str]] = None, ) -> User: if not referral_code: @@ -248,6 +271,7 @@ async def create_user( has_had_paid_subscription=False, has_made_first_topup=False, promo_group_id=promo_group_id, + active_internal_squads=_normalize_internal_squads(active_internal_squads), ) db.add(user) @@ -295,6 +319,8 @@ async def update_user( for field, value in kwargs.items(): if field in ("first_name", "last_name"): value = sanitize_telegram_name(value) + if field == "active_internal_squads": + value = _normalize_internal_squads(value) if hasattr(user, field): setattr(user, field, value) diff --git a/app/database/models.py b/app/database/models.py index f8e9f719..d3164d86 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -600,6 +600,7 @@ class User(Base): promo_offer_discount_source = Column(String(100), nullable=True) promo_offer_discount_expires_at = Column(DateTime, nullable=True) last_remnawave_sync = Column(DateTime, nullable=True) + active_internal_squads = Column(JSON, nullable=True) trojan_password = Column(String(255), nullable=True) vless_uuid = Column(String(255), nullable=True) ss_password = Column(String(255), nullable=True) diff --git a/app/database/universal_migration.py b/app/database/universal_migration.py index 6bf30aff..1d9d4cb4 100644 --- a/app/database/universal_migration.py +++ b/app/database/universal_migration.py @@ -1489,6 +1489,40 @@ async def ensure_user_promo_offer_discount_columns(): return False +async def ensure_user_internal_squads_column() -> bool: + try: + column_exists = await check_column_exists('users', 'active_internal_squads') + if column_exists: + return True + + async with engine.begin() as conn: + db_type = await get_database_type() + + if db_type == 'sqlite': + column_def = 'JSON NULL' + elif db_type == 'postgresql': + column_def = 'JSONB NULL' + elif db_type == 'mysql': + column_def = 'JSON NULL' + else: + raise ValueError(f"Unsupported database type: {db_type}") + + await conn.execute( + text( + f"ALTER TABLE users ADD COLUMN active_internal_squads {column_def}" + ) + ) + + logger.info("✅ Добавлена колонка active_internal_squads в таблицу users") + return True + except Exception as error: + logger.error( + "Ошибка добавления колонки active_internal_squads в users: %s", + error, + ) + return False + + async def ensure_promo_offer_template_active_duration_column() -> bool: try: column_exists = await check_column_exists('promo_offer_templates', 'active_discount_hours') @@ -3967,6 +4001,12 @@ async def run_universal_migration(): else: logger.warning("⚠️ Не удалось обновить пользовательские промо-скидки") + internal_squads_ready = await ensure_user_internal_squads_column() + if internal_squads_ready: + logger.info("✅ Колонка active_internal_squads для users готова") + else: + logger.warning("⚠️ Не удалось обновить колонку active_internal_squads для users") + effect_types_updated = await migrate_discount_offer_effect_types() if effect_types_updated: logger.info("✅ Типы эффектов промо-предложений обновлены") @@ -4267,6 +4307,7 @@ async def check_migration_status(): "promo_offer_templates_active_discount_column": False, "promo_offer_logs_table": False, "subscription_temporary_access_table": False, + "users_active_internal_squads_column": False, } status["has_made_first_topup_column"] = await check_column_exists('users', 'has_made_first_topup') @@ -4303,6 +4344,7 @@ async def check_migration_status(): status["users_promo_offer_discount_expires_column"] = await check_column_exists('users', 'promo_offer_discount_expires_at') status["users_referral_commission_percent_column"] = await check_column_exists('users', 'referral_commission_percent') status["subscription_crypto_link_column"] = await check_column_exists('subscriptions', 'subscription_crypto_link') + status["users_active_internal_squads_column"] = await check_column_exists('users', 'active_internal_squads') media_fields_exist = ( await check_column_exists('broadcast_history', 'has_media') and diff --git a/app/handlers/admin/users.py b/app/handlers/admin/users.py index d25499e6..9b0c4657 100644 --- a/app/handlers/admin/users.py +++ b/app/handlers/admin/users.py @@ -4606,7 +4606,11 @@ async def admin_buy_subscription_execute( username=target_user.username, telegram_id=target_user.telegram_id ), - active_internal_squads=subscription.connected_squads, + active_internal_squads=( + list(target_user.active_internal_squads or []) + if target_user.active_internal_squads is not None + else list(subscription.connected_squads or []) + ), ) if hwid_limit is not None: @@ -4632,7 +4636,11 @@ async def admin_buy_subscription_execute( username=target_user.username, telegram_id=target_user.telegram_id ), - active_internal_squads=subscription.connected_squads, + active_internal_squads=( + list(target_user.active_internal_squads or []) + if target_user.active_internal_squads is not None + else list(subscription.connected_squads or []) + ), ) if hwid_limit is not None: diff --git a/app/services/monitoring_service.py b/app/services/monitoring_service.py index 9a0f6f95..864bb2be 100644 --- a/app/services/monitoring_service.py +++ b/app/services/monitoring_service.py @@ -312,7 +312,11 @@ class MonitoringService: username=user.username, telegram_id=user.telegram_id ), - active_internal_squads=subscription.connected_squads, + active_internal_squads=( + list(user.active_internal_squads or []) + if user.active_internal_squads is not None + else list(subscription.connected_squads or []) + ), ) if hwid_limit is not None: diff --git a/app/services/remnawave_service.py b/app/services/remnawave_service.py index 001b7b54..468645d3 100644 --- a/app/services/remnawave_service.py +++ b/app/services/remnawave_service.py @@ -846,33 +846,60 @@ class RemnaWaveService: except Exception as e: logger.error(f"Error updating squad inbounds: {e}") return False - + + @staticmethod + def _serialize_internal_squad(squad) -> Dict[str, Any]: + inbounds = [ + asdict(inbound) if is_dataclass(inbound) else inbound + for inbound in getattr(squad, "inbounds", []) or [] + ] + return { + 'uuid': getattr(squad, 'uuid', ''), + 'name': getattr(squad, 'name', ''), + 'members_count': getattr(squad, 'members_count', 0), + 'inbounds_count': getattr(squad, 'inbounds_count', 0), + 'inbounds': inbounds, + } + async def get_all_squads(self) -> List[Dict[str, Any]]: - + try: async with self.get_api_client() as api: squads = await api.get_internal_squads() result = [] for squad in squads: - inbounds = [ - asdict(inbound) if is_dataclass(inbound) else inbound - for inbound in squad.inbounds or [] - ] - result.append({ - 'uuid': squad.uuid, - 'name': squad.name, - 'members_count': squad.members_count, - 'inbounds_count': squad.inbounds_count, - 'inbounds': inbounds, - }) - + result.append(self._serialize_internal_squad(squad)) + logger.info(f"✅ Получено {len(result)} сквадов из Remnawave") return result - + except Exception as e: logger.error(f"Ошибка получения сквадов из Remnawave: {e}") return [] + + async def get_internal_squad(self, uuid: str) -> Optional[Dict[str, Any]]: + try: + async with self.get_api_client() as api: + squad = await api.get_internal_squad_by_uuid(uuid) + if not squad: + return None + return self._serialize_internal_squad(squad) + except Exception as error: + logger.error("Ошибка получения internal squad %s: %s", uuid, error) + return None + + async def get_internal_squad_accessible_nodes(self, uuid: str) -> List[Dict[str, Any]]: + try: + async with self.get_api_client() as api: + nodes = await api.get_internal_squad_accessible_nodes(uuid) + return [ + asdict(node) if is_dataclass(node) else node + for node in nodes or [] + ] + except Exception as error: + logger.error("Ошибка получения нод internal squad %s: %s", uuid, error) + return [] async def create_squad(self, name: str, inbounds: List[str]) -> Optional[str]: try: @@ -1704,6 +1731,12 @@ class RemnaWaveService: telegram_id=user.telegram_id, ) + internal_squads = ( + list(user.active_internal_squads or []) + if user.active_internal_squads is not None + else list(subscription.connected_squads or []) + ) + create_kwargs = dict( username=username, expire_at=expire_at, @@ -1716,7 +1749,7 @@ class RemnaWaveService: username=user.username, telegram_id=user.telegram_id ), - active_internal_squads=subscription.connected_squads, + active_internal_squads=internal_squads, ) if hwid_limit is not None: @@ -1730,7 +1763,7 @@ class RemnaWaveService: traffic_limit_bytes=create_kwargs['traffic_limit_bytes'], traffic_limit_strategy=TrafficLimitStrategy.MONTH, description=create_kwargs['description'], - active_internal_squads=subscription.connected_squads, + active_internal_squads=internal_squads, ) if hwid_limit is not None: diff --git a/app/services/subscription_service.py b/app/services/subscription_service.py index f7a29e2f..6e663c38 100644 --- a/app/services/subscription_service.py +++ b/app/services/subscription_service.py @@ -139,6 +139,63 @@ class SubscriptionService: return settings.get_paid_subscription_user_tag() + @staticmethod + def _normalize_internal_squads(value: Optional[Iterable[str]]) -> Optional[list[str]]: + if value is None: + return None + + try: + items = [str(item).strip() for item in value if str(item).strip()] + except TypeError: + return None + + seen = set() + normalized: list[str] = [] + for item in items: + lowered = item.lower() + if lowered in seen: + continue + seen.add(lowered) + normalized.append(item) + return normalized + + @staticmethod + def _select_internal_squads(user: User, subscription: Subscription) -> Optional[list[str]]: + if user.active_internal_squads is not None: + return SubscriptionService._normalize_internal_squads(user.active_internal_squads) + return SubscriptionService._normalize_internal_squads(subscription.connected_squads) + + async def _resolve_internal_squad_uuids( + self, + api: RemnaWaveAPI, + squads: Optional[Iterable[str]], + ) -> Optional[list[str]]: + normalized = self._normalize_internal_squads(squads) + if normalized is None: + return None + if not normalized: + return [] + + try: + available = await api.get_internal_squads() + uuid_lookup = {squad.uuid.lower(): squad.uuid for squad in available} + name_lookup = {squad.name.lower(): squad.uuid for squad in available} + except Exception as error: + logger.warning("Не удалось получить список internal squads: %s", error) + return normalized + + resolved: list[str] = [] + for item in normalized: + lowered = item.lower() + if lowered in uuid_lookup: + resolved.append(uuid_lookup[lowered]) + continue + if lowered in name_lookup: + resolved.append(name_lookup[lowered]) + continue + logger.warning("Не удалось сопоставить internal squad '%s' с панелью", item) + return resolved + @property def is_configured(self) -> bool: return self._config_error is None @@ -175,16 +232,24 @@ class SubscriptionService: if not user: logger.error(f"Пользователь {subscription.user_id} не найден") return None - + validation_success = await self.validate_and_clean_subscription(db, subscription, user) if not validation_success: logger.error(f"Ошибка валидации подписки для пользователя {user.telegram_id}") return None user_tag = self._resolve_user_tag(subscription) + requested_internal_squads = self._select_internal_squads(user, subscription) async with self.get_api_client() as api: hwid_limit = resolve_hwid_device_limit_for_payload(subscription) + resolved_internal_squads = await self._resolve_internal_squad_uuids( + api, + requested_internal_squads, + ) + internal_squads_payload = ( + resolved_internal_squads if resolved_internal_squads is not None else [] + ) existing_users = await api.get_user_by_telegram_id(user.telegram_id) if existing_users: logger.info(f"🔄 Найден существующий пользователь в панели для {user.telegram_id}") @@ -207,7 +272,7 @@ class SubscriptionService: username=user.username, telegram_id=user.telegram_id ), - active_internal_squads=subscription.connected_squads, + active_internal_squads=internal_squads_payload, ) if user_tag is not None: @@ -245,7 +310,7 @@ class SubscriptionService: username=user.username, telegram_id=user.telegram_id ), - active_internal_squads=subscription.connected_squads, + active_internal_squads=internal_squads_payload, ) if user_tag is not None: @@ -313,9 +378,17 @@ class SubscriptionService: logger.info(f"🔔 Статус подписки {subscription.id} автоматически изменен на 'expired'") user_tag = self._resolve_user_tag(subscription) + requested_internal_squads = self._select_internal_squads(user, subscription) async with self.get_api_client() as api: hwid_limit = resolve_hwid_device_limit_for_payload(subscription) + resolved_internal_squads = await self._resolve_internal_squad_uuids( + api, + requested_internal_squads, + ) + internal_squads_payload = ( + resolved_internal_squads if resolved_internal_squads is not None else [] + ) update_kwargs = dict( uuid=user.remnawave_uuid, @@ -328,7 +401,7 @@ class SubscriptionService: username=user.username, telegram_id=user.telegram_id ), - active_internal_squads=subscription.connected_squads, + active_internal_squads=internal_squads_payload, ) if user_tag is not None: diff --git a/app/services/system_settings_service.py b/app/services/system_settings_service.py index 7b78034a..6ad92f74 100644 --- a/app/services/system_settings_service.py +++ b/app/services/system_settings_service.py @@ -654,6 +654,15 @@ class BotConfigurationService: "warning": "Неверный формат будет проигнорирован при создании пользователя.", "dependencies": "Активация триала и включенная интеграция с RemnaWave", }, + "TRIAL_INTERNAL_SQUADS": { + "description": ( + "Список internal squads, которые нужно назначать пользователям с триальной подпиской." + ), + "format": "Укажите названия сквадов через запятую или с новой строки.", + "example": "Default, Trial Access", + "warning": "При оплате подписки эти сквады будут сброшены.", + "dependencies": "RemnaWave API и активированный триал", + }, "PAID_SUBSCRIPTION_USER_TAG": { "description": ( "Тег, который бот ставит пользователю при покупке платной подписки в панели RemnaWave." diff --git a/app/webapi/routes/remnawave.py b/app/webapi/routes/remnawave.py index b85193d6..4b010edf 100644 --- a/app/webapi/routes/remnawave.py +++ b/app/webapi/routes/remnawave.py @@ -14,6 +14,8 @@ from app.database.crud.server_squad import ( from ..dependencies import get_db_session, require_api_token from ..schemas.remnawave import ( RemnaWaveConnectionStatus, + RemnaWaveAccessibleNode, + RemnaWaveAccessibleNodeListResponse, RemnaWaveGenericSyncResponse, RemnaWaveInboundsResponse, RemnaWaveNode, @@ -150,8 +152,50 @@ async def get_system_statistics( if not stats or "system" not in stats: raise HTTPException(status.HTTP_502_BAD_GATEWAY, "Не удалось получить статистику RemnaWave") - stats["last_updated"] = _parse_last_updated(stats.get("last_updated")) - return RemnaWaveSystemStatsResponse(**stats) + stats["last_updated"] = _parse_last_updated(stats.get("last_updated")) + return RemnaWaveSystemStatsResponse(**stats) + + +@router.get("/internal-squads", response_model=RemnaWaveSquadListResponse) +async def list_internal_squads( + _: Any = Security(require_api_token), +) -> RemnaWaveSquadListResponse: + service = _get_service() + _ensure_service_configured(service) + + squads = await service.get_all_squads() + items = [RemnaWaveSquad(**squad) for squad in squads] + return RemnaWaveSquadListResponse(items=items, total=len(items)) + + +@router.get("/internal-squads/{squad_uuid}", response_model=RemnaWaveSquad) +async def get_internal_squad( + squad_uuid: str, + _: Any = Security(require_api_token), +) -> RemnaWaveSquad: + service = _get_service() + _ensure_service_configured(service) + + squad = await service.get_internal_squad(squad_uuid) + if not squad: + raise HTTPException(status.HTTP_404_NOT_FOUND, "Squad not found") + return RemnaWaveSquad(**squad) + + +@router.get( + "/internal-squads/{squad_uuid}/nodes", + response_model=RemnaWaveAccessibleNodeListResponse, +) +async def get_internal_squad_nodes( + squad_uuid: str, + _: Any = Security(require_api_token), +) -> RemnaWaveAccessibleNodeListResponse: + service = _get_service() + _ensure_service_configured(service) + + nodes = await service.get_internal_squad_accessible_nodes(squad_uuid) + items = [RemnaWaveAccessibleNode(**node) for node in nodes] + return RemnaWaveAccessibleNodeListResponse(items=items, total=len(items)) @router.get("/nodes", response_model=RemnaWaveNodeListResponse) diff --git a/app/webapi/routes/users.py b/app/webapi/routes/users.py index 7e6b444f..8d9154be 100644 --- a/app/webapi/routes/users.py +++ b/app/webapi/routes/users.py @@ -88,6 +88,7 @@ def _serialize_user(user: User) -> UserResponse: created_at=user.created_at, updated_at=user.updated_at, last_activity=user.last_activity, + active_internal_squads=list(user.active_internal_squads or []), promo_group=_serialize_promo_group(promo_group), subscription=_serialize_subscription(subscription), ) @@ -205,6 +206,7 @@ async def create_user_endpoint( last_name=payload.last_name, language=payload.language, referred_by_id=payload.referred_by_id, + active_internal_squads=payload.active_internal_squads, ) if payload.promo_group_id and payload.promo_group_id != user.promo_group_id: @@ -263,6 +265,9 @@ async def update_user_endpoint( raise HTTPException(status.HTTP_400_BAD_REQUEST, "Promo group not found") updates["promo_group_id"] = promo_group.id + if payload.active_internal_squads is not None: + updates["active_internal_squads"] = payload.active_internal_squads + if payload.referral_code is not None and payload.referral_code != found_user.referral_code: existing_code_owner = await get_user_by_referral_code(db, payload.referral_code) if existing_code_owner and existing_code_owner.id != found_user.id: diff --git a/app/webapi/schemas/remnawave.py b/app/webapi/schemas/remnawave.py index d640f5bd..cc00f3cb 100644 --- a/app/webapi/schemas/remnawave.py +++ b/app/webapi/schemas/remnawave.py @@ -137,6 +137,20 @@ class RemnaWaveSquadListResponse(BaseModel): total: int +class RemnaWaveAccessibleNode(BaseModel): + uuid: str + node_name: str + country_code: str + config_profile_uuid: str + config_profile_name: str + active_inbounds: List[str] = Field(default_factory=list) + + +class RemnaWaveAccessibleNodeListResponse(BaseModel): + items: List[RemnaWaveAccessibleNode] + total: int + + class RemnaWaveSquadCreateRequest(BaseModel): name: str inbound_uuids: List[str] = Field(default_factory=list) diff --git a/app/webapi/schemas/users.py b/app/webapi/schemas/users.py index fbf910fc..3073ffa4 100644 --- a/app/webapi/schemas/users.py +++ b/app/webapi/schemas/users.py @@ -49,6 +49,7 @@ class UserResponse(BaseModel): created_at: datetime updated_at: datetime last_activity: Optional[datetime] = None + active_internal_squads: List[str] = Field(default_factory=list) promo_group: Optional[PromoGroupSummary] = None subscription: Optional[SubscriptionSummary] = None @@ -68,6 +69,7 @@ class UserCreateRequest(BaseModel): language: str = "ru" referred_by_id: Optional[int] = None promo_group_id: Optional[int] = None + active_internal_squads: Optional[List[str]] = Field(default=None) class UserUpdateRequest(BaseModel): @@ -80,6 +82,7 @@ class UserUpdateRequest(BaseModel): referral_code: Optional[str] = None has_had_paid_subscription: Optional[bool] = None has_made_first_topup: Optional[bool] = None + active_internal_squads: Optional[List[str]] = Field(default=None) class BalanceUpdateRequest(BaseModel):