mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
(feat:connectors) abstracting auth, base class
This commit is contained in:
@@ -3995,9 +3995,9 @@ class GoogleDriveAuth(Resource):
|
||||
def get(self):
|
||||
"""Get Google Drive OAuth authorization URL"""
|
||||
try:
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
|
||||
auth = GoogleDriveAuth()
|
||||
auth = ConnectorCreator.create_auth("google_drive")
|
||||
|
||||
# Generate state parameter for CSRF protection
|
||||
import uuid
|
||||
@@ -4029,7 +4029,7 @@ class GoogleDriveCallback(Resource):
|
||||
def get(self):
|
||||
"""Handle Google Drive OAuth callback"""
|
||||
try:
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from flask import request
|
||||
import uuid
|
||||
|
||||
@@ -4050,7 +4050,7 @@ class GoogleDriveCallback(Resource):
|
||||
|
||||
# Exchange code for tokens
|
||||
try:
|
||||
auth = GoogleDriveAuth()
|
||||
auth = ConnectorCreator.create_auth("google_drive")
|
||||
token_info = auth.exchange_code_for_tokens(authorization_code)
|
||||
|
||||
# Log detailed information about the token_info we received
|
||||
@@ -4193,7 +4193,7 @@ class GoogleDriveRefresh(Resource):
|
||||
def post(self):
|
||||
"""Refresh Google Drive access token"""
|
||||
try:
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
|
||||
data = request.get_json()
|
||||
refresh_token = data.get('refresh_token')
|
||||
@@ -4203,7 +4203,7 @@ class GoogleDriveRefresh(Resource):
|
||||
jsonify({"success": False, "error": "Refresh token not provided"}), 400
|
||||
)
|
||||
|
||||
auth = GoogleDriveAuth()
|
||||
auth = ConnectorCreator.create_auth("google_drive")
|
||||
token_info = auth.refresh_access_token(refresh_token)
|
||||
|
||||
return make_response(
|
||||
@@ -4241,7 +4241,7 @@ class GoogleDriveFiles(Resource):
|
||||
def post(self):
|
||||
"""Get list of files from Google Drive"""
|
||||
try:
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
|
||||
data = request.get_json()
|
||||
session_token = data.get('session_token')
|
||||
@@ -4254,7 +4254,7 @@ class GoogleDriveFiles(Resource):
|
||||
)
|
||||
|
||||
# Create Google Drive loader with session token only
|
||||
loader = GoogleDriveLoader(session_token)
|
||||
loader = ConnectorCreator.create_connector("google_drive", session_token)
|
||||
|
||||
# Get files from Google Drive (limit to first N files, metadata only)
|
||||
files_config = {
|
||||
@@ -4329,7 +4329,7 @@ class GoogleDriveValidateSession(Resource):
|
||||
"""Validate Google Drive session token and return user info"""
|
||||
try:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
|
||||
data = request.get_json()
|
||||
session_token = data.get('session_token')
|
||||
@@ -4352,8 +4352,8 @@ class GoogleDriveValidateSession(Resource):
|
||||
|
||||
# Get token info and check if it's expired
|
||||
token_info = session["token_info"]
|
||||
auth = GoogleDriveAuth()
|
||||
|
||||
auth = ConnectorCreator.create_auth("google_drive")
|
||||
|
||||
# Check if token is expired using our improved method
|
||||
is_expired = auth.is_token_expired(token_info)
|
||||
|
||||
|
||||
@@ -5,7 +5,14 @@ This module contains connectors for external knowledge bases and document storag
|
||||
that require authentication and specialized handling, separate from simple web scrapers.
|
||||
"""
|
||||
|
||||
from .base import BaseConnectorAuth, BaseConnectorLoader
|
||||
from .connector_creator import ConnectorCreator
|
||||
from .google_drive import GoogleDriveAuth, GoogleDriveLoader
|
||||
|
||||
__all__ = ['ConnectorCreator', 'GoogleDriveAuth', 'GoogleDriveLoader']
|
||||
__all__ = [
|
||||
'BaseConnectorAuth',
|
||||
'BaseConnectorLoader',
|
||||
'ConnectorCreator',
|
||||
'GoogleDriveAuth',
|
||||
'GoogleDriveLoader'
|
||||
]
|
||||
|
||||
129
application/parser/connectors/base.py
Normal file
129
application/parser/connectors/base.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Base classes for external knowledge base connectors.
|
||||
|
||||
This module provides minimal abstract base classes that define the essential
|
||||
interface for external knowledge base connectors.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
class BaseConnectorAuth(ABC):
|
||||
"""
|
||||
Abstract base class for connector authentication.
|
||||
|
||||
Defines the minimal interface that all connector authentication
|
||||
implementations must follow.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate authorization URL for OAuth flows.
|
||||
|
||||
Args:
|
||||
state: Optional state parameter for CSRF protection
|
||||
|
||||
Returns:
|
||||
Authorization URL
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Exchange authorization code for access tokens.
|
||||
|
||||
Args:
|
||||
authorization_code: Authorization code from OAuth callback
|
||||
|
||||
Returns:
|
||||
Dictionary containing token information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Refresh an expired access token.
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token
|
||||
|
||||
Returns:
|
||||
Dictionary containing refreshed token information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if a token is expired.
|
||||
|
||||
Args:
|
||||
token_info: Token information dictionary
|
||||
|
||||
Returns:
|
||||
True if token is expired, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseConnectorLoader(ABC):
|
||||
"""
|
||||
Abstract base class for connector loaders.
|
||||
|
||||
Defines the minimal interface that all connector loader
|
||||
implementations must follow.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session_token: str):
|
||||
"""
|
||||
Initialize the connector loader.
|
||||
|
||||
Args:
|
||||
session_token: Authentication session token
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""
|
||||
Load documents from the external knowledge base.
|
||||
|
||||
Args:
|
||||
inputs: Configuration dictionary containing:
|
||||
- file_ids: Optional list of specific file IDs to load
|
||||
- folder_ids: Optional list of folder IDs to browse/download
|
||||
- limit: Maximum number of items to return
|
||||
- list_only: If True, return metadata without content
|
||||
- recursive: Whether to recursively process folders
|
||||
|
||||
Returns:
|
||||
List of Document objects
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Download files/folders to a local directory.
|
||||
|
||||
Args:
|
||||
local_dir: Local directory path to download files to
|
||||
source_config: Configuration for what to download
|
||||
|
||||
Returns:
|
||||
Dictionary containing download results:
|
||||
- files_downloaded: Number of files downloaded
|
||||
- directory_path: Path where files were downloaded
|
||||
- empty_result: Whether no files were downloaded
|
||||
- source_type: Type of connector
|
||||
- config_used: Configuration that was used
|
||||
- error: Error message if download failed (optional)
|
||||
"""
|
||||
pass
|
||||
@@ -1,30 +1,35 @@
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
|
||||
|
||||
class ConnectorCreator:
|
||||
"""
|
||||
Factory class for creating external knowledge base connectors.
|
||||
|
||||
Factory class for creating external knowledge base connectors and auth providers.
|
||||
|
||||
These are different from remote loaders as they typically require
|
||||
authentication and connect to external document storage systems.
|
||||
"""
|
||||
|
||||
|
||||
connectors = {
|
||||
"google_drive": GoogleDriveLoader,
|
||||
}
|
||||
|
||||
auth_providers = {
|
||||
"google_drive": GoogleDriveAuth,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_connector(cls, connector_type, *args, **kwargs):
|
||||
"""
|
||||
Create a connector instance for the specified type.
|
||||
|
||||
|
||||
Args:
|
||||
connector_type: Type of connector to create (e.g., 'google_drive')
|
||||
*args, **kwargs: Arguments to pass to the connector constructor
|
||||
|
||||
|
||||
Returns:
|
||||
Connector instance
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If connector type is not supported
|
||||
"""
|
||||
@@ -33,11 +38,30 @@ class ConnectorCreator:
|
||||
raise ValueError(f"No connector class found for type {connector_type}")
|
||||
return connector_class(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_auth(cls, connector_type):
|
||||
"""
|
||||
Create an auth provider instance for the specified connector type.
|
||||
|
||||
Args:
|
||||
connector_type: Type of connector auth to create (e.g., 'google_drive')
|
||||
|
||||
Returns:
|
||||
Auth provider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If connector type is not supported for auth
|
||||
"""
|
||||
auth_class = cls.auth_providers.get(connector_type.lower())
|
||||
if not auth_class:
|
||||
raise ValueError(f"No auth class found for type {connector_type}")
|
||||
return auth_class()
|
||||
|
||||
@classmethod
|
||||
def get_supported_connectors(cls):
|
||||
"""
|
||||
Get list of supported connector types.
|
||||
|
||||
|
||||
Returns:
|
||||
List of supported connector type strings
|
||||
"""
|
||||
@@ -47,10 +71,10 @@ class ConnectorCreator:
|
||||
def is_supported(cls, connector_type):
|
||||
"""
|
||||
Check if a connector type is supported.
|
||||
|
||||
|
||||
Args:
|
||||
connector_type: Type of connector to check
|
||||
|
||||
|
||||
Returns:
|
||||
True if supported, False otherwise
|
||||
"""
|
||||
|
||||
@@ -8,9 +8,10 @@ from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.base import BaseConnectorAuth
|
||||
|
||||
|
||||
class GoogleDriveAuth:
|
||||
class GoogleDriveAuth(BaseConnectorAuth):
|
||||
"""
|
||||
Handles Google OAuth 2.0 authentication for Google Drive access.
|
||||
"""
|
||||
@@ -31,15 +32,6 @@ class GoogleDriveAuth:
|
||||
|
||||
|
||||
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate Google OAuth authorization URL.
|
||||
|
||||
Args:
|
||||
state: Optional state parameter for CSRF protection
|
||||
|
||||
Returns:
|
||||
Authorization URL for Google OAuth flow
|
||||
"""
|
||||
try:
|
||||
flow = Flow.from_client_config(
|
||||
{
|
||||
@@ -69,15 +61,6 @@ class GoogleDriveAuth:
|
||||
raise
|
||||
|
||||
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Exchange authorization code for access and refresh tokens.
|
||||
|
||||
Args:
|
||||
authorization_code: Authorization code from OAuth callback
|
||||
|
||||
Returns:
|
||||
Dictionary containing token information
|
||||
"""
|
||||
try:
|
||||
if not authorization_code:
|
||||
raise ValueError("Authorization code is required")
|
||||
|
||||
@@ -11,12 +11,12 @@ from typing import List, Dict, Any, Optional
|
||||
from googleapiclient.http import MediaIoBaseDownload
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from application.parser.connectors.base import BaseConnectorLoader
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
class GoogleDriveLoader(BaseRemote):
|
||||
class GoogleDriveLoader(BaseConnectorLoader):
|
||||
|
||||
SUPPORTED_MIME_TYPES = {
|
||||
'application/pdf': '.pdf',
|
||||
@@ -104,25 +104,6 @@ class GoogleDriveLoader(BaseRemote):
|
||||
return None
|
||||
|
||||
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""
|
||||
Load items from Google Drive according to simple browsing semantics.
|
||||
|
||||
Behavior:
|
||||
- If file_ids are provided: return those files (optionally with content).
|
||||
- If folder_id is provided: return the immediate children (folders and files) of that folder.
|
||||
- If no folder_id: return the immediate children (folders and files) of Drive 'root'.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary containing configuration:
|
||||
- folder_id: Optional Google Drive folder ID whose direct children to list
|
||||
- file_ids: Optional list of specific file IDs to load
|
||||
- limit: Maximum number of items to return
|
||||
- list_only: If True, only return metadata without content
|
||||
- session_token: Optional session token to use for authentication (backward compatibility)
|
||||
|
||||
Returns:
|
||||
List of Document objects (folders are returned as metadata-only documents)
|
||||
"""
|
||||
session_token = inputs.get('session_token')
|
||||
if session_token and session_token != self.session_token:
|
||||
logging.warning("Session token in inputs differs from loader's session token. Using loader's session token.")
|
||||
|
||||
Reference in New Issue
Block a user