feat(notifications): enhance notification security and ownership checks

- Added ownership verification for user notifications to ensure only the rightful owner can mark them as read.
- Implemented checks to confirm that admin notifications are correctly identified before allowing them to be marked as read.
- Introduced a new method to retrieve notifications by ID in the TicketNotificationCRUD for improved data handling.
This commit is contained in:
PEDZEO
2026-01-19 00:39:36 +03:00
parent 63e45e12de
commit b1206a84c7
3 changed files with 54 additions and 15 deletions

View File

@@ -84,12 +84,22 @@ async def mark_notification_as_read(
db: AsyncSession = Depends(get_cabinet_db),
):
"""Mark a notification as read."""
success = await TicketNotificationCRUD.mark_as_read(db, notification_id)
if not success:
# Security: Verify notification belongs to current user and is not an admin notification
notification = await TicketNotificationCRUD.get_by_id(db, notification_id)
if not notification:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Notification not found",
)
# Check ownership: notification must belong to user and not be an admin notification
if notification.user_id != user.id or notification.is_for_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to mark this notification as read",
)
await TicketNotificationCRUD.mark_as_read(db, notification_id)
return {"success": True}
@@ -154,12 +164,22 @@ async def mark_admin_notification_as_read(
db: AsyncSession = Depends(get_cabinet_db),
):
"""Mark an admin notification as read."""
success = await TicketNotificationCRUD.mark_as_read(db, notification_id)
if not success:
# Security: Verify notification exists and is an admin notification
notification = await TicketNotificationCRUD.get_by_id(db, notification_id)
if not notification:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Notification not found",
)
# Check that this is actually an admin notification
if not notification.is_for_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="This is not an admin notification",
)
await TicketNotificationCRUD.mark_as_read(db, notification_id)
return {"success": True}

View File

@@ -63,7 +63,10 @@ class CabinetConnectionManager:
async def send_to_user(self, user_id: int, message: dict) -> None:
"""Отправить сообщение конкретному пользователю."""
connections = self._user_connections.get(user_id, set())
# Snapshot connections under the lock to avoid mutation during iteration
async with self._lock:
connections = list(self._user_connections.get(user_id, set()))
if not connections:
return
@@ -78,19 +81,27 @@ class CabinetConnectionManager:
disconnected.add(ws)
# Cleanup disconnected
if disconnected:
async with self._lock:
for ws in disconnected:
self._user_connections.get(user_id, set()).discard(ws)
async def send_to_admins(self, message: dict) -> None:
"""Отправить сообщение всем админам."""
# Snapshot connections under the lock to avoid mutation during iteration
async with self._lock:
if not self._admin_connections:
return
# Create a snapshot: list of (user_id, list of websockets)
admin_snapshot = [
(user_id, list(connections))
for user_id, connections in self._admin_connections.items()
]
data = json.dumps(message, default=str, ensure_ascii=False)
disconnected_by_user: Dict[int, Set[WebSocket]] = {}
for user_id, connections in self._admin_connections.items():
for user_id, connections in admin_snapshot:
for ws in connections:
try:
await ws.send_text(data)
@@ -101,6 +112,7 @@ class CabinetConnectionManager:
disconnected_by_user[user_id].add(ws)
# Cleanup disconnected
if disconnected_by_user:
async with self._lock:
for user_id, ws_set in disconnected_by_user.items():
for ws in ws_set:

View File

@@ -17,6 +17,13 @@ logger = logging.getLogger(__name__)
class TicketNotificationCRUD:
"""CRUD operations for ticket notifications in cabinet."""
@staticmethod
async def get_by_id(db: AsyncSession, notification_id: int) -> Optional[TicketNotification]:
"""Get notification by ID."""
query = select(TicketNotification).where(TicketNotification.id == notification_id)
result = await db.execute(query)
return result.scalar_one_or_none()
@staticmethod
async def create(
db: AsyncSession,