From 9582758d1c85735c8ead8cbfeb56bbdae45288af Mon Sep 17 00:00:00 2001 From: Fringg Date: Wed, 4 Mar 2026 15:29:50 +0300 Subject: [PATCH] fix: restore merge token on DB failure, fix partner_status priority - Add restore_merge_token() to re-store consumed token if execute_merge or db.commit fails, allowing the user to retry instead of being stuck - Fix partner_status priority: PENDING (2) now beats REJECTED (1), so an active application is not lost during merge - Add tests for pending-vs-rejected edge cases (47 tests total) --- app/cabinet/auth/merge_service.py | 34 ++++++++++++++++++++ app/cabinet/routes/account_linking.py | 3 ++ app/services/account_merge_service.py | 4 +-- tests/services/test_account_merge_service.py | 30 +++++++++++++++++ 4 files changed, 69 insertions(+), 2 deletions(-) diff --git a/app/cabinet/auth/merge_service.py b/app/cabinet/auth/merge_service.py index 63f24d4a..92e30721 100644 --- a/app/cabinet/auth/merge_service.py +++ b/app/cabinet/auth/merge_service.py @@ -99,3 +99,37 @@ async def consume_merge_token(token: str) -> dict[str, Any] | None: provider=data.get('provider'), ) return data + + +async def restore_merge_token(token: str, data: dict[str, Any]) -> bool: + """Re-store a consumed merge token so the user can retry after a DB failure. + + Uses the remaining TTL based on the original ``created_at``. + Returns ``True`` if restored, ``False`` if Redis write failed. + """ + created_at_str: str = data.get('created_at', '') + try: + created_at = datetime.fromisoformat(created_at_str) + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=UTC) + elapsed = (datetime.now(UTC) - created_at).total_seconds() + remaining_ttl = max(1, int(MERGE_TOKEN_TTL_SECONDS - elapsed)) + except (ValueError, TypeError): + remaining_ttl = MERGE_TOKEN_TTL_SECONDS + + key = cache_key(MERGE_TOKEN_PREFIX, token) + stored = await cache.set(key, data, expire=remaining_ttl) + if stored: + logger.info( + 'Merge token restored after failed merge', + primary_user_id=data.get('primary_user_id'), + secondary_user_id=data.get('secondary_user_id'), + remaining_ttl=remaining_ttl, + ) + else: + logger.error( + 'Failed to restore merge token to Redis', + primary_user_id=data.get('primary_user_id'), + secondary_user_id=data.get('secondary_user_id'), + ) + return bool(stored) diff --git a/app/cabinet/routes/account_linking.py b/app/cabinet/routes/account_linking.py index a456a706..a59eb4bb 100644 --- a/app/cabinet/routes/account_linking.py +++ b/app/cabinet/routes/account_linking.py @@ -25,6 +25,7 @@ from ..auth.merge_service import ( consume_merge_token, create_merge_token, get_merge_token_data, + restore_merge_token, ) from ..auth.oauth_providers import ( generate_oauth_state, @@ -456,6 +457,7 @@ async def execute_merge_endpoint( await db.commit() except ValueError as exc: await db.rollback() + await restore_merge_token(merge_token, consumed) logger.error('Merge execution failed (ValueError)', error=str(exc)) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -463,6 +465,7 @@ async def execute_merge_endpoint( ) from exc except Exception as exc: await db.rollback() + await restore_merge_token(merge_token, consumed) logger.error('Merge execution failed', exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/app/services/account_merge_service.py b/app/services/account_merge_service.py index da232f30..63ef05b1 100644 --- a/app/services/account_merge_service.py +++ b/app/services/account_merge_service.py @@ -54,8 +54,8 @@ _PAYMENT_MODELS: tuple[type, ...] = ( # Приоритет партнёрских статусов (чем выше число — тем приоритетнее) _PARTNER_STATUS_PRIORITY: dict[str, int] = { PartnerStatus.NONE.value: 0, - PartnerStatus.PENDING.value: 1, - PartnerStatus.REJECTED.value: 2, + PartnerStatus.REJECTED.value: 1, + PartnerStatus.PENDING.value: 2, PartnerStatus.APPROVED.value: 3, } diff --git a/tests/services/test_account_merge_service.py b/tests/services/test_account_merge_service.py index b52bae45..6ba1fafe 100644 --- a/tests/services/test_account_merge_service.py +++ b/tests/services/test_account_merge_service.py @@ -480,6 +480,36 @@ class TestExecuteMergePartnerStatus: assert result.partner_status == 'approved' + async def test_pending_beats_rejected(self, monkeypatch): + """Pending application should not be overwritten by rejected status.""" + db = _make_db() + primary = _make_user(id=1, partner_status='pending') + secondary = _make_user(id=2, partner_status='rejected') + monkeypatch.setattr( + account_merge_service, + 'get_user_by_id', + AsyncMock(side_effect=[primary, secondary]), + ) + with _patch_remnawave_delete(): + result = await execute_merge(db, 1, 2) + + assert result.partner_status == 'pending' + + async def test_rejected_does_not_beat_pending(self, monkeypatch): + """Rejected on secondary should not overwrite pending on primary.""" + db = _make_db() + primary = _make_user(id=1, partner_status='rejected') + secondary = _make_user(id=2, partner_status='pending') + monkeypatch.setattr( + account_merge_service, + 'get_user_by_id', + AsyncMock(side_effect=[primary, secondary]), + ) + with _patch_remnawave_delete(): + result = await execute_merge(db, 1, 2) + + assert result.partner_status == 'pending' + class TestExecuteMergeReferralCommission: async def test_transfers_if_primary_has_none(self, monkeypatch):