mirror of
https://github.com/BEDOLAGA-DEV/remnawave-bedolaga-telegram-bot.git
synced 2026-04-28 16:50:08 +00:00
fix: migrate VK OAuth to VK ID OAuth 2.1 with PKCE
VK deprecated oauth.vk.com on Sep 30, 2025. Migrate to VK ID (id.vk.ru) with mandatory PKCE S256 and device_id support. - Rewrite VKProvider: new endpoints, PKCE code_verifier/challenge, user_info format - Add prepare_auth_state() hook for provider-specific state (PKCE) - Use atomic Redis GETDEL for OAuth state validation (prevent TOCTOU race) - Add CacheService.getdel() method - Check cache.set() result in generate_oauth_state - Filter ephemeral keys (_prefix) from Redis storage - Fix garbled log messages, use exc_info for tracebacks - Add input validation (min_length, max_length on code/state) - Generic error messages (no provider name leakage)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""OAuth 2.0 provider implementations for cabinet authentication."""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TypedDict
|
||||
@@ -33,7 +35,7 @@ class OAuthTokenResponse(TypedDict, total=False):
|
||||
expires_in: int
|
||||
refresh_token: str
|
||||
scope: str
|
||||
# VK-specific: email and user_id come in token response
|
||||
# Provider-specific extra fields (optional)
|
||||
email: str
|
||||
user_id: int
|
||||
|
||||
@@ -67,15 +69,19 @@ class DiscordUserInfoResponse(TypedDict, total=False):
|
||||
avatar: str
|
||||
|
||||
|
||||
class VKUserInfoItem(TypedDict, total=False):
|
||||
id: int
|
||||
class VKIDUserData(TypedDict, total=False):
|
||||
"""VK ID /oauth2/user_info response user object."""
|
||||
|
||||
user_id: str
|
||||
first_name: str
|
||||
last_name: str
|
||||
photo_200: str
|
||||
phone: str
|
||||
avatar: str
|
||||
email: str
|
||||
|
||||
|
||||
class VKUserInfoResponse(TypedDict, total=False):
|
||||
response: list[VKUserInfoItem]
|
||||
class VKIDUserInfoResponse(TypedDict, total=False):
|
||||
user: VKIDUserData
|
||||
|
||||
|
||||
# --- Models ---
|
||||
@@ -97,23 +103,38 @@ class OAuthUserInfo(BaseModel):
|
||||
# --- CSRF state management (Redis) ---
|
||||
|
||||
|
||||
async def generate_oauth_state(provider: str) -> str:
|
||||
"""Generate a CSRF state token for OAuth flow. Stored in Redis with TTL."""
|
||||
async def generate_oauth_state(provider: str, extra_data: dict[str, str] | None = None) -> str:
|
||||
"""Generate a CSRF state token for OAuth flow.
|
||||
|
||||
Stores provider name and optional extra data (e.g., PKCE code_verifier) in Redis with TTL.
|
||||
Keys prefixed with '_' are ephemeral and NOT stored in Redis (e.g., _code_challenge).
|
||||
CacheService handles JSON serialization internally.
|
||||
"""
|
||||
state = secrets.token_urlsafe(32)
|
||||
await cache.set(cache_key('oauth_state', state), provider, expire=STATE_TTL_SECONDS)
|
||||
value: dict[str, Any] = {'provider': provider}
|
||||
if extra_data:
|
||||
# Filter out ephemeral keys (prefixed with '_') — they're only needed for the URL
|
||||
value.update({k: v for k, v in extra_data.items() if not k.startswith('_')})
|
||||
stored = await cache.set(cache_key('oauth_state', state), value, expire=STATE_TTL_SECONDS)
|
||||
if not stored:
|
||||
logger.error('Failed to store OAuth state in Redis')
|
||||
raise RuntimeError('Failed to store OAuth state')
|
||||
return state
|
||||
|
||||
|
||||
async def validate_oauth_state(state: str, provider: str) -> bool:
|
||||
"""Validate and consume a CSRF state token from Redis."""
|
||||
async def validate_oauth_state(state: str, provider: str) -> dict[str, Any] | None:
|
||||
"""Validate and consume a CSRF state token from Redis.
|
||||
|
||||
Uses atomic GETDEL to prevent TOCTOU race conditions.
|
||||
Returns the stored data dict (with 'provider' key + any extra data) or None if invalid.
|
||||
"""
|
||||
key = cache_key('oauth_state', state)
|
||||
stored_provider: str | None = await cache.get(key)
|
||||
if stored_provider is None:
|
||||
return False
|
||||
await cache.delete(key)
|
||||
if stored_provider != provider:
|
||||
return False
|
||||
return True
|
||||
data: Any = await cache.getdel(key)
|
||||
if data is None:
|
||||
return None
|
||||
if not isinstance(data, dict) or data.get('provider') != provider:
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
# --- Provider implementations ---
|
||||
@@ -130,13 +151,28 @@ class OAuthProvider(ABC):
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
@abstractmethod
|
||||
def get_authorization_url(self, state: str) -> str:
|
||||
"""Build the authorization URL for the provider."""
|
||||
def prepare_auth_state(self) -> dict[str, str]:
|
||||
"""Return extra data to store with OAuth state (e.g., PKCE code_verifier).
|
||||
|
||||
Override in providers that need PKCE or other state-stored data.
|
||||
The returned dict is stored in Redis alongside the state token
|
||||
and passed back via validate_oauth_state().
|
||||
"""
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
async def exchange_code(self, code: str) -> OAuthTokenResponse:
|
||||
"""Exchange authorization code for tokens."""
|
||||
def get_authorization_url(self, state: str, **kwargs: Any) -> str:
|
||||
"""Build the authorization URL for the provider.
|
||||
|
||||
kwargs may contain extra data from prepare_auth_state() (e.g., code_challenge).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def exchange_code(self, code: str, **kwargs: Any) -> OAuthTokenResponse:
|
||||
"""Exchange authorization code for tokens.
|
||||
|
||||
kwargs may contain provider-specific params (e.g., device_id, code_verifier for VK).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_info(self, token_data: OAuthTokenResponse) -> OAuthUserInfo:
|
||||
@@ -151,7 +187,7 @@ class GoogleProvider(OAuthProvider):
|
||||
TOKEN_URL = 'https://oauth2.googleapis.com/token'
|
||||
USERINFO_URL = 'https://www.googleapis.com/oauth2/v3/userinfo'
|
||||
|
||||
def get_authorization_url(self, state: str) -> str:
|
||||
def get_authorization_url(self, state: str, **kwargs: Any) -> str:
|
||||
params: dict[str, str] = {
|
||||
'client_id': self.client_id,
|
||||
'redirect_uri': self.redirect_uri,
|
||||
@@ -164,7 +200,7 @@ class GoogleProvider(OAuthProvider):
|
||||
request = httpx.Request('GET', self.AUTHORIZE_URL, params=params)
|
||||
return str(request.url)
|
||||
|
||||
async def exchange_code(self, code: str) -> OAuthTokenResponse:
|
||||
async def exchange_code(self, code: str, **kwargs: Any) -> OAuthTokenResponse:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.post(
|
||||
self.TOKEN_URL,
|
||||
@@ -209,7 +245,7 @@ class YandexProvider(OAuthProvider):
|
||||
TOKEN_URL = 'https://oauth.yandex.com/token'
|
||||
USERINFO_URL = 'https://login.yandex.ru/info'
|
||||
|
||||
def get_authorization_url(self, state: str) -> str:
|
||||
def get_authorization_url(self, state: str, **kwargs: Any) -> str:
|
||||
params: dict[str, str] = {
|
||||
'client_id': self.client_id,
|
||||
'redirect_uri': self.redirect_uri,
|
||||
@@ -221,7 +257,7 @@ class YandexProvider(OAuthProvider):
|
||||
request = httpx.Request('GET', self.AUTHORIZE_URL, params=params)
|
||||
return str(request.url)
|
||||
|
||||
async def exchange_code(self, code: str) -> OAuthTokenResponse:
|
||||
async def exchange_code(self, code: str, **kwargs: Any) -> OAuthTokenResponse:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.post(
|
||||
self.TOKEN_URL,
|
||||
@@ -275,7 +311,7 @@ class DiscordProvider(OAuthProvider):
|
||||
TOKEN_URL = 'https://discord.com/api/oauth2/token'
|
||||
USERINFO_URL = 'https://discord.com/api/v10/users/@me'
|
||||
|
||||
def get_authorization_url(self, state: str) -> str:
|
||||
def get_authorization_url(self, state: str, **kwargs: Any) -> str:
|
||||
params: dict[str, str] = {
|
||||
'client_id': self.client_id,
|
||||
'redirect_uri': self.redirect_uri,
|
||||
@@ -287,7 +323,7 @@ class DiscordProvider(OAuthProvider):
|
||||
request = httpx.Request('GET', self.AUTHORIZE_URL, params=params)
|
||||
return str(request.url)
|
||||
|
||||
async def exchange_code(self, code: str) -> OAuthTokenResponse:
|
||||
async def exchange_code(self, code: str, **kwargs: Any) -> OAuthTokenResponse:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.post(
|
||||
self.TOKEN_URL,
|
||||
@@ -329,35 +365,72 @@ class DiscordProvider(OAuthProvider):
|
||||
|
||||
|
||||
class VKProvider(OAuthProvider):
|
||||
"""VK ID OAuth 2.1 provider (id.vk.ru).
|
||||
|
||||
Uses OAuth 2.1 with mandatory PKCE (S256).
|
||||
Old oauth.vk.com endpoints deprecated since September 30, 2025.
|
||||
"""
|
||||
|
||||
name = 'vk'
|
||||
display_name = 'VK'
|
||||
|
||||
AUTHORIZE_URL = 'https://oauth.vk.com/authorize'
|
||||
TOKEN_URL = 'https://oauth.vk.com/access_token'
|
||||
USERINFO_URL = 'https://api.vk.com/method/users.get'
|
||||
API_VERSION = '5.131'
|
||||
AUTHORIZE_URL = 'https://id.vk.ru/authorize'
|
||||
TOKEN_URL = 'https://id.vk.ru/oauth2/auth'
|
||||
USERINFO_URL = 'https://id.vk.ru/oauth2/user_info'
|
||||
|
||||
def get_authorization_url(self, state: str) -> str:
|
||||
@staticmethod
|
||||
def _generate_pkce() -> tuple[str, str]:
|
||||
"""Generate PKCE code_verifier and code_challenge (S256)."""
|
||||
code_verifier = secrets.token_urlsafe(64)
|
||||
digest = hashlib.sha256(code_verifier.encode('ascii')).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b'=').decode('ascii')
|
||||
return code_verifier, code_challenge
|
||||
|
||||
def prepare_auth_state(self) -> dict[str, str]:
|
||||
"""Generate PKCE pair. code_verifier stored in Redis, code_challenge only goes to URL."""
|
||||
code_verifier, code_challenge = self._generate_pkce()
|
||||
# code_challenge is ephemeral — only needed for the authorization URL,
|
||||
# not stored in Redis (code_verifier is the secret used during token exchange)
|
||||
return {
|
||||
'code_verifier': code_verifier,
|
||||
'_code_challenge': code_challenge,
|
||||
}
|
||||
|
||||
def get_authorization_url(self, state: str, **kwargs: Any) -> str:
|
||||
code_challenge: str = kwargs.get('_code_challenge', '')
|
||||
params: dict[str, str] = {
|
||||
'client_id': self.client_id,
|
||||
'redirect_uri': self.redirect_uri,
|
||||
'response_type': 'code',
|
||||
'scope': 'email',
|
||||
'scope': 'vkid.personal_info email',
|
||||
'state': state,
|
||||
'v': self.API_VERSION,
|
||||
'code_challenge': code_challenge,
|
||||
'code_challenge_method': 'S256',
|
||||
}
|
||||
request = httpx.Request('GET', self.AUTHORIZE_URL, params=params)
|
||||
return str(request.url)
|
||||
|
||||
async def exchange_code(self, code: str) -> OAuthTokenResponse:
|
||||
async def exchange_code(self, code: str, **kwargs: Any) -> OAuthTokenResponse:
|
||||
device_id: str = kwargs.get('device_id', '')
|
||||
code_verifier: str = kwargs.get('code_verifier', '')
|
||||
state: str = kwargs.get('state', '')
|
||||
|
||||
if not device_id:
|
||||
raise ValueError('device_id is required for VK ID token exchange')
|
||||
if not code_verifier:
|
||||
raise ValueError('code_verifier is required for VK ID token exchange')
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(
|
||||
response = await client.post(
|
||||
self.TOKEN_URL,
|
||||
params={
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
data={
|
||||
'grant_type': 'authorization_code',
|
||||
'code': code,
|
||||
'redirect_uri': self.redirect_uri,
|
||||
'client_id': self.client_id,
|
||||
'device_id': device_id,
|
||||
'code_verifier': code_verifier,
|
||||
'state': state,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -366,33 +439,37 @@ class VKProvider(OAuthProvider):
|
||||
|
||||
async def get_user_info(self, token_data: OAuthTokenResponse) -> OAuthUserInfo:
|
||||
access_token = token_data['access_token']
|
||||
user_id: int | None = token_data.get('user_id')
|
||||
# VK returns email in token response, not in userinfo
|
||||
email: str | None = token_data.get('email')
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(
|
||||
response = await client.post(
|
||||
self.USERINFO_URL,
|
||||
params={
|
||||
data={
|
||||
'access_token': access_token,
|
||||
'fields': 'photo_200',
|
||||
'v': self.API_VERSION,
|
||||
'client_id': self.client_id,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data: VKUserInfoResponse = response.json()
|
||||
data: VKIDUserInfoResponse = response.json()
|
||||
|
||||
users: list[Any] = data.get('response', [])
|
||||
user_data: VKUserInfoItem = users[0] if users else {} # type: ignore[assignment]
|
||||
user_data = data.get('user')
|
||||
if not user_data:
|
||||
raise ValueError('VK ID response missing user data')
|
||||
|
||||
user_id = user_data.get('user_id')
|
||||
if not user_id:
|
||||
raise ValueError('VK ID response missing user_id')
|
||||
|
||||
# VK ID returns email only if 'email' scope was granted and user has a verified email
|
||||
email: str | None = user_data.get('email') or None
|
||||
|
||||
return OAuthUserInfo(
|
||||
provider='vk',
|
||||
provider_id=str(user_id or user_data.get('id', '')),
|
||||
provider_id=str(user_id),
|
||||
email=email,
|
||||
email_verified=bool(email),
|
||||
first_name=user_data.get('first_name'),
|
||||
last_name=user_data.get('last_name'),
|
||||
avatar_url=user_data.get('photo_200'),
|
||||
avatar_url=user_data.get('avatar'),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -75,8 +75,9 @@ class OAuthAuthorizeResponse(BaseModel):
|
||||
|
||||
|
||||
class OAuthCallbackRequest(BaseModel):
|
||||
code: str = Field(..., description='Authorization code from provider')
|
||||
state: str = Field(..., description='CSRF state token')
|
||||
code: str = Field(..., min_length=1, max_length=2048, description='Authorization code from provider')
|
||||
state: str = Field(..., max_length=128, description='CSRF state token')
|
||||
device_id: str | None = Field(None, max_length=256, description='Device ID from VK ID callback')
|
||||
campaign_slug: str | None = Field(
|
||||
None, min_length=1, max_length=64, pattern=r'^[a-zA-Z0-9_-]+$', description='Campaign slug from web link'
|
||||
)
|
||||
@@ -105,11 +106,13 @@ async def get_oauth_authorize_url(provider: str):
|
||||
if not oauth_provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f'OAuth provider "{provider}" is not enabled',
|
||||
detail='Requested OAuth provider is not available',
|
||||
)
|
||||
|
||||
state = await generate_oauth_state(provider)
|
||||
authorize_url = oauth_provider.get_authorization_url(state)
|
||||
# Generate extra state data (e.g., PKCE code_verifier for VK)
|
||||
auth_extra = oauth_provider.prepare_auth_state()
|
||||
state = await generate_oauth_state(provider, extra_data=auth_extra or None)
|
||||
authorize_url = oauth_provider.get_authorization_url(state, **auth_extra)
|
||||
|
||||
return OAuthAuthorizeResponse(authorize_url=authorize_url, state=state)
|
||||
|
||||
@@ -121,8 +124,9 @@ async def oauth_callback(
|
||||
db: AsyncSession = Depends(get_cabinet_db),
|
||||
):
|
||||
"""Handle OAuth callback: exchange code, find/create user, return JWT."""
|
||||
# 1. Validate CSRF state
|
||||
if not await validate_oauth_state(request.state, provider):
|
||||
# 1. Validate CSRF state and retrieve stored data (e.g., PKCE code_verifier)
|
||||
state_data = await validate_oauth_state(request.state, provider)
|
||||
if not state_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid or expired OAuth state',
|
||||
@@ -133,14 +137,21 @@ async def oauth_callback(
|
||||
if not oauth_provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f'OAuth provider "{provider}" is not enabled',
|
||||
detail='Requested OAuth provider is not available',
|
||||
)
|
||||
|
||||
# 3. Exchange code for tokens
|
||||
# 3. Exchange code for tokens (pass PKCE code_verifier and device_id if present)
|
||||
exchange_kwargs: dict[str, str] = {'state': request.state}
|
||||
code_verifier = state_data.get('code_verifier')
|
||||
if code_verifier:
|
||||
exchange_kwargs['code_verifier'] = code_verifier
|
||||
if request.device_id:
|
||||
exchange_kwargs['device_id'] = request.device_id
|
||||
|
||||
try:
|
||||
token_data = await oauth_provider.exchange_code(request.code)
|
||||
token_data = await oauth_provider.exchange_code(request.code, **exchange_kwargs)
|
||||
except Exception as exc:
|
||||
logger.error('OAuth code exchange failed for', provider=provider, exc=exc)
|
||||
logger.error('OAuth code exchange failed', provider=provider, exc_info=exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Failed to exchange authorization code',
|
||||
@@ -150,7 +161,7 @@ async def oauth_callback(
|
||||
try:
|
||||
user_info: OAuthUserInfo = await oauth_provider.get_user_info(token_data)
|
||||
except Exception as exc:
|
||||
logger.error('OAuth user info fetch failed for', provider=provider, exc=exc)
|
||||
logger.error('OAuth user info fetch failed', provider=provider, exc_info=exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Failed to fetch user information from provider',
|
||||
@@ -159,7 +170,7 @@ async def oauth_callback(
|
||||
# 5. Find user by provider ID
|
||||
user = await get_user_by_oauth_provider(db, provider, user_info.provider_id)
|
||||
if user:
|
||||
logger.info('OAuth login via for existing user', provider=provider, user_id=user.id)
|
||||
logger.info('OAuth login for existing user', provider=provider, user_id=user.id)
|
||||
return await _finalize_oauth_login(db, user, provider, request.campaign_slug, request.referral_code)
|
||||
|
||||
# 6. Find user by email (if verified) and link provider
|
||||
@@ -167,7 +178,7 @@ async def oauth_callback(
|
||||
user = await get_user_by_email(db, user_info.email)
|
||||
if user:
|
||||
await set_user_oauth_provider_id(db, user, provider, user_info.provider_id)
|
||||
logger.info('OAuth login via linked to existing email user', provider=provider, user_id=user.id)
|
||||
logger.info('OAuth provider linked to existing email user', provider=provider, user_id=user.id)
|
||||
return await _finalize_oauth_login(db, user, provider, request.campaign_slug, request.referral_code)
|
||||
|
||||
# 7. Resolve referral code for new user
|
||||
@@ -205,5 +216,5 @@ async def oauth_callback(
|
||||
username=user_info.username,
|
||||
referred_by_id=referrer_id,
|
||||
)
|
||||
logger.info('OAuth new user created via with id', provider=provider, user_id=user.id)
|
||||
logger.info('New OAuth user created', provider=provider, user_id=user.id)
|
||||
return await _finalize_oauth_login(db, user, provider, request.campaign_slug, request.referral_code)
|
||||
|
||||
@@ -82,6 +82,23 @@ class CacheService:
|
||||
logger.error('Ошибка setnx в кеш', key=key, error=e)
|
||||
return False
|
||||
|
||||
async def getdel(self, key: str) -> Any | None:
|
||||
"""Atomically get and delete a key (Redis GETDEL).
|
||||
|
||||
Returns the deserialized value if it existed, None otherwise.
|
||||
"""
|
||||
if not self._connected:
|
||||
return None
|
||||
|
||||
try:
|
||||
value = await self.redis_client.getdel(key)
|
||||
if value:
|
||||
return json.loads(value)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error('Ошибка атомарного getdel из кеша', key=key, error=e)
|
||||
return None
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
if not self._connected:
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user