diff --git a/app/database/models.py b/app/database/models.py index bdf8e544..1b7a7123 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -541,7 +541,11 @@ class PromoGroup(Base): "traffic": self.traffic_discount_percent, "devices": self.device_discount_percent, } - percent = mapping.get(category, 0) + percent = mapping.get(category) or 0 + + if percent == 0 and self.is_default: + base_period_discount = self._get_period_discount(period_days) + percent = max(percent, base_period_discount) return max(0, min(100, percent)) diff --git a/app/utils/price_display.py b/app/utils/price_display.py index 5d0b3add..06747be5 100644 --- a/app/utils/price_display.py +++ b/app/utils/price_display.py @@ -72,11 +72,8 @@ def calculate_user_price( # Get user's promo group discount for this category discount_percent = user.get_promo_discount(category, period_days) else: - # For None user, use base settings discount (only for period category) - if category == "period": - discount_percent = settings.get_base_promo_group_period_discount(period_days) - else: - discount_percent = 0 + # For None user, use base settings discount + discount_percent = settings.get_base_promo_group_period_discount(period_days) logger.debug( f"calculate_user_price: user={user.telegram_id if user else 'None'}, " diff --git a/tests/test_promo_group_base_discounts.py b/tests/test_promo_group_base_discounts.py new file mode 100644 index 00000000..1243b404 --- /dev/null +++ b/tests/test_promo_group_base_discounts.py @@ -0,0 +1,29 @@ +import pytest + +from app.config import settings +from app.database.models import PromoGroup + + +@pytest.fixture +def base_discount_settings(monkeypatch): + monkeypatch.setattr(settings, "BASE_PROMO_GROUP_PERIOD_DISCOUNTS_ENABLED", True) + monkeypatch.setattr(settings, "BASE_PROMO_GROUP_PERIOD_DISCOUNTS", "60:15") + yield + + +def test_base_promo_discount_applies_to_all_categories(base_discount_settings): + promo_group = PromoGroup(name="Default", is_default=True) + + assert promo_group.get_discount_percent("period", 60) == 15 + assert promo_group.get_discount_percent("servers", 60) == 15 + assert promo_group.get_discount_percent("traffic", 60) == 15 + assert promo_group.get_discount_percent("devices", 60) == 15 + + +def test_specific_category_discount_overrides_base(base_discount_settings): + promo_group = PromoGroup( + name="Default", is_default=True, server_discount_percent=5 + ) + + assert promo_group.get_discount_percent("servers", 60) == 5 + assert promo_group.get_discount_percent("devices", 60) == 15