mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-03-01 15:51:10 +00:00
(feat:oneDrive) file loading for ingestion
This commit is contained in:
@@ -10,15 +10,14 @@ from application.parser.connectors.base import BaseConnectorAuth
|
||||
|
||||
class SharePointAuth(BaseConnectorAuth):
|
||||
"""
|
||||
Handles Microsoft OAuth 2.0 authentication.
|
||||
Handles Microsoft OAuth 2.0 authentication for SharePoint/OneDrive.
|
||||
|
||||
# Documentation:
|
||||
- https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow
|
||||
- https://learn.microsoft.com/en-gb/entra/msal/python/
|
||||
Note: Files.Read scope allows access to files the user has granted access to,
|
||||
similar to Google Drive's drive.file scope.
|
||||
"""
|
||||
|
||||
# Microsoft Graph scopes for SharePoint access
|
||||
SCOPES = [
|
||||
"Files.Read",
|
||||
"User.Read",
|
||||
]
|
||||
|
||||
@@ -26,17 +25,26 @@ class SharePointAuth(BaseConnectorAuth):
|
||||
self.client_id = settings.MICROSOFT_CLIENT_ID
|
||||
self.client_secret = settings.MICROSOFT_CLIENT_SECRET
|
||||
|
||||
if not self.client_id or not self.client_secret:
|
||||
if not self.client_id:
|
||||
raise ValueError(
|
||||
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID and MICROSOFT_CLIENT_SECRET in settings."
|
||||
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID in settings."
|
||||
)
|
||||
|
||||
if not self.client_secret:
|
||||
raise ValueError(
|
||||
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_SECRET in settings."
|
||||
)
|
||||
|
||||
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
|
||||
self.tenant_id = settings.MICROSOFT_TENANT_ID
|
||||
self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://{self.tenant_id}.ciamlogin.com/{self.tenant_id}")
|
||||
self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://login.microsoftonline.com/{self.tenant_id}")
|
||||
|
||||
logging.info(f"SharePointAuth initialized with: client_id={self.client_id[:8]}, tenant_id={self.tenant_id}, redirect_uri={self.redirect_uri}, authority={self.authority}")
|
||||
|
||||
self.auth_app = ConfidentialClientApplication(
|
||||
client_id=self.client_id, client_credential=self.client_secret, authority=self.authority
|
||||
client_id=self.client_id,
|
||||
client_credential=self.client_secret,
|
||||
authority=self.authority
|
||||
)
|
||||
|
||||
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||
@@ -45,36 +53,94 @@ class SharePointAuth(BaseConnectorAuth):
|
||||
)
|
||||
|
||||
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||
logging.info(f"Exchanging authorization code for token with scopes: {self.SCOPES}")
|
||||
logging.info(f"Redirect URI: {self.redirect_uri}")
|
||||
|
||||
result = self.auth_app.acquire_token_by_authorization_code(
|
||||
code=authorization_code, scopes=self.SCOPES, redirect_uri=self.redirect_uri
|
||||
code=authorization_code,
|
||||
scopes=self.SCOPES,
|
||||
redirect_uri=self.redirect_uri
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
logging.error(f"Error acquiring token: {result.get('error_description')}")
|
||||
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
|
||||
error_msg = f"Error acquiring token: {result.get('error_description')}"
|
||||
logging.error(f"{error_msg} - Full result: {result}")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logging.info(f"Token acquired successfully")
|
||||
return self.map_token_response(result)
|
||||
|
||||
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
logging.info(f"Refreshing access token")
|
||||
result = self.auth_app.acquire_token_by_refresh_token(refresh_token=refresh_token, scopes=self.SCOPES)
|
||||
|
||||
if "error" in result:
|
||||
logging.error(f"Error acquiring token: {result.get('error_description')}")
|
||||
logging.error(f"Error refreshing token: {result.get('error_description')} - Full result: {result}")
|
||||
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
|
||||
|
||||
logging.info(f"Token refreshed successfully")
|
||||
return self.map_token_response(result)
|
||||
|
||||
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||
try:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
sessions_collection = db["connector_sessions"]
|
||||
session = sessions_collection.find_one({"session_token": session_token})
|
||||
|
||||
if not session:
|
||||
raise ValueError(f"Invalid session token: {session_token}")
|
||||
|
||||
if "token_info" not in session:
|
||||
raise ValueError("Session missing token information")
|
||||
|
||||
token_info = session["token_info"]
|
||||
if not token_info:
|
||||
raise ValueError("Invalid token information")
|
||||
|
||||
required_fields = ["access_token", "refresh_token"]
|
||||
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required token fields: {missing_fields}")
|
||||
|
||||
if 'client_id' not in token_info:
|
||||
token_info['client_id'] = settings.MICROSOFT_CLIENT_ID
|
||||
if 'tenant_id' not in token_info:
|
||||
token_info['tenant_id'] = settings.MICROSOFT_TENANT_ID
|
||||
if 'client_secret' not in token_info:
|
||||
token_info['client_secret'] = settings.MICROSOFT_CLIENT_SECRET
|
||||
if 'token_uri' not in token_info:
|
||||
token_info['token_uri'] = f"https://login.microsoftonline.com/{settings.MICROSOFT_TENANT_ID}/oauth2/v2.0/token"
|
||||
|
||||
logging.info(f"Retrieved token from session. Expiry: {token_info.get('expiry')}")
|
||||
return token_info
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to retrieve SharePoint token information: {str(e)}")
|
||||
|
||||
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||
if not token_info or "expiry" not in token_info:
|
||||
# If no expiry info, consider token expired to be safe
|
||||
if not token_info:
|
||||
return True
|
||||
|
||||
# Get expiry timestamp and current time
|
||||
expiry_timestamp = token_info["expiry"]
|
||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||
expiry_timestamp = token_info.get("expiry")
|
||||
|
||||
# Token is expired if current time is greater than or equal to expiry time
|
||||
return current_timestamp >= expiry_timestamp
|
||||
if expiry_timestamp is None:
|
||||
logging.warning("Token expiry is None, treating as expired")
|
||||
return True
|
||||
|
||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||
expires_in = expiry_timestamp - current_timestamp
|
||||
|
||||
if expires_in < 60:
|
||||
logging.info(f"Token expires in {expires_in} seconds, treating as expired")
|
||||
return True
|
||||
|
||||
logging.debug(f"Token not expired. Expires in {expires_in} seconds")
|
||||
return False
|
||||
|
||||
def map_token_response(self, result) -> Dict[str, Any]:
|
||||
return {
|
||||
|
||||
@@ -1,44 +1,457 @@
|
||||
from typing import List, Dict, Any
|
||||
"""
|
||||
SharePoint/OneDrive loader for DocsGPT.
|
||||
Loads documents from SharePoint/OneDrive using Microsoft Graph API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from application.parser.connectors.base import BaseConnectorLoader
|
||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
class SharePointLoader(BaseConnectorLoader):
|
||||
|
||||
SUPPORTED_MIME_TYPES = {
|
||||
'application/pdf': '.pdf',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
|
||||
'application/msword': '.doc',
|
||||
'application/vnd.ms-powerpoint': '.ppt',
|
||||
'application/vnd.ms-excel': '.xls',
|
||||
'text/plain': '.txt',
|
||||
'text/csv': '.csv',
|
||||
'text/html': '.html',
|
||||
'text/markdown': '.md',
|
||||
'text/x-rst': '.rst',
|
||||
'application/json': '.json',
|
||||
'application/epub+zip': '.epub',
|
||||
'application/rtf': '.rtf',
|
||||
'image/jpeg': '.jpg',
|
||||
'image/jpg': '.jpg',
|
||||
'image/png': '.png',
|
||||
}
|
||||
|
||||
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
def __init__(self, session_token: str):
|
||||
pass
|
||||
self.auth = SharePointAuth()
|
||||
self.session_token = session_token
|
||||
|
||||
token_info = self.auth.get_token_info_from_session(session_token)
|
||||
self.access_token = token_info.get('access_token')
|
||||
self.refresh_token = token_info.get('refresh_token')
|
||||
|
||||
if not self.access_token:
|
||||
raise ValueError("No access token found in session")
|
||||
|
||||
self.next_page_token = None
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
'Authorization': f'Bearer {self.access_token}',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
def _ensure_valid_token(self):
|
||||
if not self.access_token:
|
||||
raise ValueError("No access token available")
|
||||
|
||||
token_info = {'access_token': self.access_token, 'expiry': None}
|
||||
if self.auth.is_token_expired(token_info):
|
||||
logging.info("Token expired, attempting refresh")
|
||||
try:
|
||||
new_token_info = self.auth.refresh_access_token(self.refresh_token)
|
||||
self.access_token = new_token_info.get('access_token')
|
||||
logging.info("Token refreshed successfully")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to refresh token: {e}")
|
||||
raise ValueError("Failed to refresh access token")
|
||||
|
||||
def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]:
|
||||
try:
|
||||
drive_item_id = file_metadata.get('id')
|
||||
file_name = file_metadata.get('name', 'Unknown')
|
||||
file_data = file_metadata.get('file', {})
|
||||
mime_type = file_data.get('mimeType', 'application/octet-stream')
|
||||
|
||||
if mime_type not in self.SUPPORTED_MIME_TYPES:
|
||||
logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}")
|
||||
return None
|
||||
|
||||
doc_metadata = {
|
||||
'file_name': file_name,
|
||||
'mime_type': mime_type,
|
||||
'size': file_metadata.get('size'),
|
||||
'created_time': file_metadata.get('createdDateTime'),
|
||||
'modified_time': file_metadata.get('lastModifiedDateTime'),
|
||||
'source': 'share_point'
|
||||
}
|
||||
|
||||
if not load_content:
|
||||
return Document(
|
||||
text="",
|
||||
doc_id=drive_item_id,
|
||||
extra_info=doc_metadata
|
||||
)
|
||||
|
||||
content = self._download_file_content(drive_item_id)
|
||||
if content is None:
|
||||
logging.warning(f"Could not load content for file {file_name} ({drive_item_id})")
|
||||
return None
|
||||
|
||||
return Document(
|
||||
text=content,
|
||||
doc_id=drive_item_id,
|
||||
extra_info=doc_metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file: {e}")
|
||||
return None
|
||||
|
||||
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""
|
||||
Load documents from the external knowledge base.
|
||||
try:
|
||||
documents: List[Document] = []
|
||||
|
||||
Args:
|
||||
inputs: Configuration dictionary containing:
|
||||
- file_ids: Optional list of specific file IDs to load
|
||||
- folder_ids: Optional list of folder IDs to browse/download
|
||||
- limit: Maximum number of items to return
|
||||
- list_only: If True, return metadata without content
|
||||
- recursive: Whether to recursively process folders
|
||||
folder_id = inputs.get('folder_id')
|
||||
file_ids = inputs.get('file_ids', [])
|
||||
limit = inputs.get('limit', 100)
|
||||
list_only = inputs.get('list_only', False)
|
||||
load_content = not list_only
|
||||
page_token = inputs.get('page_token')
|
||||
search_query = inputs.get('search_query')
|
||||
self.next_page_token = None
|
||||
|
||||
Returns:
|
||||
List of Document objects
|
||||
"""
|
||||
pass
|
||||
if file_ids:
|
||||
for file_id in file_ids:
|
||||
try:
|
||||
doc = self._load_file_by_id(file_id, load_content=load_content)
|
||||
if doc:
|
||||
if not search_query or (
|
||||
search_query.lower() in doc.extra_info.get('file_name', '').lower()
|
||||
):
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading file {file_id}: {e}")
|
||||
continue
|
||||
else:
|
||||
parent_id = folder_id if folder_id else 'root'
|
||||
documents = self._list_items_in_parent(
|
||||
parent_id,
|
||||
limit=limit,
|
||||
load_content=load_content,
|
||||
page_token=page_token,
|
||||
search_query=search_query
|
||||
)
|
||||
|
||||
logging.info(f"Loaded {len(documents)} documents from SharePoint/OneDrive")
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading data from SharePoint/OneDrive: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]:
|
||||
self._ensure_valid_token()
|
||||
|
||||
try:
|
||||
url = f"{self.GRAPH_API_BASE}/me/drive/items/{file_id}"
|
||||
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
file_metadata = response.json()
|
||||
return self._process_file(file_metadata, load_content=load_content)
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code in [401, 403]:
|
||||
logging.error(f"Authentication error loading file {file_id}")
|
||||
try:
|
||||
new_token_info = self.auth.refresh_access_token(self.refresh_token)
|
||||
self.access_token = new_token_info.get('access_token')
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response.raise_for_status()
|
||||
file_metadata = response.json()
|
||||
return self._process_file(file_metadata, load_content=load_content)
|
||||
except Exception as refresh_error:
|
||||
raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
|
||||
logging.error(f"HTTP error loading file {file_id}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading file {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
|
||||
self._ensure_valid_token()
|
||||
|
||||
documents: List[Document] = []
|
||||
|
||||
try:
|
||||
url = f"{self.GRAPH_API_BASE}/me/drive/items/{parent_id}/children"
|
||||
params = {'$top': min(100, limit) if limit else 100, '$select': 'id,name,file,folder,createdDateTime,lastModifiedDateTime,size'}
|
||||
if page_token:
|
||||
params['$skipToken'] = page_token
|
||||
|
||||
if search_query:
|
||||
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{search_query}')"
|
||||
response = requests.get(search_url, headers=self._get_headers(), params=params)
|
||||
else:
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
|
||||
items = results.get('value', [])
|
||||
for item in items:
|
||||
if 'folder' in item:
|
||||
doc_metadata = {
|
||||
'file_name': item.get('name', 'Unknown'),
|
||||
'mime_type': 'folder',
|
||||
'size': item.get('size'),
|
||||
'created_time': item.get('createdDateTime'),
|
||||
'modified_time': item.get('lastModifiedDateTime'),
|
||||
'source': 'share_point',
|
||||
'is_folder': True
|
||||
}
|
||||
documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata))
|
||||
else:
|
||||
doc = self._process_file(item, load_content=load_content)
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
|
||||
if limit and len(documents) >= limit:
|
||||
break
|
||||
|
||||
next_link = results.get('@odata.nextLink')
|
||||
if next_link:
|
||||
# Extract skiptoken from the full URL
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
parsed = urlparse(next_link)
|
||||
query_params = parse_qs(parsed.query)
|
||||
skiptoken_list = query_params.get('$skiptoken')
|
||||
if skiptoken_list:
|
||||
self.next_page_token = skiptoken_list[0]
|
||||
else:
|
||||
self.next_page_token = None
|
||||
else:
|
||||
self.next_page_token = None
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing items under parent {parent_id}: {e}")
|
||||
return documents
|
||||
|
||||
def _download_file_content(self, file_id: str) -> Optional[str]:
|
||||
self._ensure_valid_token()
|
||||
|
||||
try:
|
||||
url = f"{self.GRAPH_API_BASE}/me/drive/items/{file_id}/content"
|
||||
response = requests.get(url, headers=self._get_headers())
|
||||
response.raise_for_status()
|
||||
|
||||
content_bytes = response.content
|
||||
|
||||
try:
|
||||
content = content_bytes.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
content = content_bytes.decode('latin-1')
|
||||
except UnicodeDecodeError:
|
||||
logging.error(f"Could not decode file {file_id} as text")
|
||||
return None
|
||||
|
||||
return content
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code in [401, 403]:
|
||||
logging.error(f"Authentication error downloading file {file_id}")
|
||||
try:
|
||||
new_token_info = self.auth.refresh_access_token(self.refresh_token)
|
||||
self.access_token = new_token_info.get('access_token')
|
||||
response = requests.get(url, headers=self._get_headers())
|
||||
response.raise_for_status()
|
||||
content_bytes = response.content
|
||||
try:
|
||||
content = content_bytes.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
content = content_bytes.decode('latin-1')
|
||||
except UnicodeDecodeError:
|
||||
logging.error(f"Could not decode file {file_id} as text")
|
||||
return None
|
||||
return content
|
||||
except Exception as refresh_error:
|
||||
raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
|
||||
logging.error(f"HTTP error downloading file {file_id}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def _download_single_file(self, file_id: str, local_dir: str) -> bool:
|
||||
try:
|
||||
url = f"{self.GRAPH_API_BASE}/me/drive/items/{file_id}"
|
||||
params = {'$select': 'id,name,file'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
metadata = response.json()
|
||||
file_name = metadata.get('name', 'unknown')
|
||||
file_data = metadata.get('file', {})
|
||||
mime_type = file_data.get('mimeType', 'application/octet-stream')
|
||||
|
||||
if mime_type not in self.SUPPORTED_MIME_TYPES:
|
||||
logging.info(f"Skipping unsupported file type: {mime_type}")
|
||||
return False
|
||||
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
full_path = os.path.join(local_dir, file_name)
|
||||
|
||||
download_url = f"{self.GRAPH_API_BASE}/me/drive/items/{file_id}/content"
|
||||
download_response = requests.get(download_url, headers=self._get_headers())
|
||||
download_response.raise_for_status()
|
||||
|
||||
with open(full_path, 'wb') as f:
|
||||
f.write(download_response.content)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Error in _download_single_file: {e}")
|
||||
return False
|
||||
|
||||
def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
|
||||
files_downloaded = 0
|
||||
try:
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
|
||||
url = f"{self.GRAPH_API_BASE}/me/drive/items/{folder_id}/children"
|
||||
params = {'$top': 1000}
|
||||
|
||||
while url:
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
items = results.get('value', [])
|
||||
logging.info(f"Found {len(items)} items in folder {folder_id}")
|
||||
|
||||
for item in items:
|
||||
item_name = item.get('name', 'unknown')
|
||||
item_id = item.get('id')
|
||||
|
||||
if 'folder' in item:
|
||||
if recursive:
|
||||
subfolder_path = os.path.join(local_dir, item_name)
|
||||
os.makedirs(subfolder_path, exist_ok=True)
|
||||
subfolder_files = self._download_folder_recursive(
|
||||
item_id,
|
||||
subfolder_path,
|
||||
recursive
|
||||
)
|
||||
files_downloaded += subfolder_files
|
||||
logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}")
|
||||
else:
|
||||
success = self._download_single_file(item_id, local_dir)
|
||||
if success:
|
||||
files_downloaded += 1
|
||||
logging.info(f"Downloaded file: {item_name}")
|
||||
else:
|
||||
logging.warning(f"Failed to download file: {item_name}")
|
||||
|
||||
url = results.get('@odata.nextLink')
|
||||
|
||||
return files_downloaded
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True)
|
||||
return files_downloaded
|
||||
|
||||
def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
|
||||
try:
|
||||
self._ensure_valid_token()
|
||||
return self._download_folder_recursive(folder_id, local_dir, recursive)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
|
||||
return 0
|
||||
|
||||
def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool:
|
||||
try:
|
||||
self._ensure_valid_token()
|
||||
return self._download_single_file(file_id, local_dir)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {file_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Download files/folders to a local directory.
|
||||
if source_config is None:
|
||||
source_config = {}
|
||||
|
||||
Args:
|
||||
local_dir: Local directory path to download files to
|
||||
source_config: Configuration for what to download
|
||||
config = source_config if source_config else getattr(self, 'config', {})
|
||||
files_downloaded = 0
|
||||
|
||||
Returns:
|
||||
Dictionary containing download results:
|
||||
- files_downloaded: Number of files downloaded
|
||||
- directory_path: Path where files were downloaded
|
||||
- empty_result: Whether no files were downloaded
|
||||
- source_type: Type of connector
|
||||
- config_used: Configuration that was used
|
||||
- error: Error message if download failed (optional)
|
||||
"""
|
||||
pass
|
||||
try:
|
||||
folder_ids = config.get('folder_ids', [])
|
||||
file_ids = config.get('file_ids', [])
|
||||
recursive = config.get('recursive', True)
|
||||
|
||||
if file_ids:
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [file_ids]
|
||||
|
||||
for file_id in file_ids:
|
||||
if self._download_file_to_directory(file_id, local_dir):
|
||||
files_downloaded += 1
|
||||
|
||||
if folder_ids:
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [folder_ids]
|
||||
|
||||
for folder_id in folder_ids:
|
||||
try:
|
||||
url = f"{self.GRAPH_API_BASE}/me/drive/items/{folder_id}"
|
||||
params = {'$select': 'id,name'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
folder_metadata = response.json()
|
||||
folder_name = folder_metadata.get('name', '')
|
||||
folder_path = os.path.join(local_dir, folder_name)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
folder_files = self._download_folder_recursive(
|
||||
folder_id,
|
||||
folder_path,
|
||||
recursive
|
||||
)
|
||||
files_downloaded += folder_files
|
||||
logging.info(f"Downloaded {folder_files} files from folder {folder_name}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
|
||||
|
||||
if not file_ids and not folder_ids:
|
||||
raise ValueError("No folder_ids or file_ids provided for download")
|
||||
|
||||
return {
|
||||
"files_downloaded": files_downloaded,
|
||||
"directory_path": local_dir,
|
||||
"empty_result": files_downloaded == 0,
|
||||
"source_type": "share_point",
|
||||
"config_used": config
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"files_downloaded": files_downloaded,
|
||||
"directory_path": local_dir,
|
||||
"empty_result": True,
|
||||
"source_type": "share_point",
|
||||
"config_used": config,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user