diff --git a/application/parser/connectors/share_point/auth.py b/application/parser/connectors/share_point/auth.py index 85120666..980dbd79 100644 --- a/application/parser/connectors/share_point/auth.py +++ b/application/parser/connectors/share_point/auth.py @@ -10,15 +10,14 @@ from application.parser.connectors.base import BaseConnectorAuth class SharePointAuth(BaseConnectorAuth): """ - Handles Microsoft OAuth 2.0 authentication. + Handles Microsoft OAuth 2.0 authentication for SharePoint/OneDrive. - # Documentation: - - https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow - - https://learn.microsoft.com/en-gb/entra/msal/python/ + Note: Files.Read scope allows access to files the user has granted access to, + similar to Google Drive's drive.file scope. """ - # Microsoft Graph scopes for SharePoint access SCOPES = [ + "Files.Read", "User.Read", ] @@ -26,17 +25,26 @@ class SharePointAuth(BaseConnectorAuth): self.client_id = settings.MICROSOFT_CLIENT_ID self.client_secret = settings.MICROSOFT_CLIENT_SECRET - if not self.client_id or not self.client_secret: + if not self.client_id: raise ValueError( - "Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID and MICROSOFT_CLIENT_SECRET in settings." + "Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID in settings." + ) + + if not self.client_secret: + raise ValueError( + "Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_SECRET in settings." ) self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI self.tenant_id = settings.MICROSOFT_TENANT_ID - self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://{self.tenant_id}.ciamlogin.com/{self.tenant_id}") + self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://login.microsoftonline.com/{self.tenant_id}") + + logging.info(f"SharePointAuth initialized with: client_id={self.client_id[:8]}, tenant_id={self.tenant_id}, redirect_uri={self.redirect_uri}, authority={self.authority}") self.auth_app = ConfidentialClientApplication( - client_id=self.client_id, client_credential=self.client_secret, authority=self.authority + client_id=self.client_id, + client_credential=self.client_secret, + authority=self.authority ) def get_authorization_url(self, state: Optional[str] = None) -> str: @@ -45,36 +53,94 @@ class SharePointAuth(BaseConnectorAuth): ) def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]: + logging.info(f"Exchanging authorization code for token with scopes: {self.SCOPES}") + logging.info(f"Redirect URI: {self.redirect_uri}") + result = self.auth_app.acquire_token_by_authorization_code( - code=authorization_code, scopes=self.SCOPES, redirect_uri=self.redirect_uri + code=authorization_code, + scopes=self.SCOPES, + redirect_uri=self.redirect_uri ) if "error" in result: - logging.error(f"Error acquiring token: {result.get('error_description')}") - raise ValueError(f"Error acquiring token: {result.get('error_description')}") + error_msg = f"Error acquiring token: {result.get('error_description')}" + logging.error(f"{error_msg} - Full result: {result}") + raise ValueError(error_msg) + logging.info(f"Token acquired successfully") return self.map_token_response(result) def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]: + logging.info(f"Refreshing access token") result = self.auth_app.acquire_token_by_refresh_token(refresh_token=refresh_token, scopes=self.SCOPES) if "error" in result: - logging.error(f"Error acquiring token: {result.get('error_description')}") + logging.error(f"Error refreshing token: {result.get('error_description')} - Full result: {result}") raise ValueError(f"Error acquiring token: {result.get('error_description')}") + logging.info(f"Token refreshed successfully") return self.map_token_response(result) + def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]: + try: + from application.core.mongo_db import MongoDB + from application.core.settings import settings + + mongo = MongoDB.get_client() + db = mongo[settings.MONGO_DB_NAME] + + sessions_collection = db["connector_sessions"] + session = sessions_collection.find_one({"session_token": session_token}) + + if not session: + raise ValueError(f"Invalid session token: {session_token}") + + if "token_info" not in session: + raise ValueError("Session missing token information") + + token_info = session["token_info"] + if not token_info: + raise ValueError("Invalid token information") + + required_fields = ["access_token", "refresh_token"] + missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)] + if missing_fields: + raise ValueError(f"Missing required token fields: {missing_fields}") + + if 'client_id' not in token_info: + token_info['client_id'] = settings.MICROSOFT_CLIENT_ID + if 'tenant_id' not in token_info: + token_info['tenant_id'] = settings.MICROSOFT_TENANT_ID + if 'client_secret' not in token_info: + token_info['client_secret'] = settings.MICROSOFT_CLIENT_SECRET + if 'token_uri' not in token_info: + token_info['token_uri'] = f"https://login.microsoftonline.com/{settings.MICROSOFT_TENANT_ID}/oauth2/v2.0/token" + + logging.info(f"Retrieved token from session. Expiry: {token_info.get('expiry')}") + return token_info + + except Exception as e: + raise ValueError(f"Failed to retrieve SharePoint token information: {str(e)}") + def is_token_expired(self, token_info: Dict[str, Any]) -> bool: - if not token_info or "expiry" not in token_info: - # If no expiry info, consider token expired to be safe + if not token_info: return True - # Get expiry timestamp and current time - expiry_timestamp = token_info["expiry"] - current_timestamp = int(datetime.datetime.now().timestamp()) + expiry_timestamp = token_info.get("expiry") - # Token is expired if current time is greater than or equal to expiry time - return current_timestamp >= expiry_timestamp + if expiry_timestamp is None: + logging.warning("Token expiry is None, treating as expired") + return True + + current_timestamp = int(datetime.datetime.now().timestamp()) + expires_in = expiry_timestamp - current_timestamp + + if expires_in < 60: + logging.info(f"Token expires in {expires_in} seconds, treating as expired") + return True + + logging.debug(f"Token not expired. Expires in {expires_in} seconds") + return False def map_token_response(self, result) -> Dict[str, Any]: return { diff --git a/application/parser/connectors/share_point/loader.py b/application/parser/connectors/share_point/loader.py index ea081afe..753477da 100644 --- a/application/parser/connectors/share_point/loader.py +++ b/application/parser/connectors/share_point/loader.py @@ -1,44 +1,457 @@ -from typing import List, Dict, Any +""" +SharePoint/OneDrive loader for DocsGPT. +Loads documents from SharePoint/OneDrive using Microsoft Graph API. +""" + +import logging +import os +from typing import List, Dict, Any, Optional + +import requests + from application.parser.connectors.base import BaseConnectorLoader +from application.parser.connectors.share_point.auth import SharePointAuth from application.parser.schema.base import Document class SharePointLoader(BaseConnectorLoader): + + SUPPORTED_MIME_TYPES = { + 'application/pdf': '.pdf', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx', + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx', + 'application/msword': '.doc', + 'application/vnd.ms-powerpoint': '.ppt', + 'application/vnd.ms-excel': '.xls', + 'text/plain': '.txt', + 'text/csv': '.csv', + 'text/html': '.html', + 'text/markdown': '.md', + 'text/x-rst': '.rst', + 'application/json': '.json', + 'application/epub+zip': '.epub', + 'application/rtf': '.rtf', + 'image/jpeg': '.jpg', + 'image/jpg': '.jpg', + 'image/png': '.png', + } + + GRAPH_API_BASE = "https://graph.microsoft.com/v1.0" + def __init__(self, session_token: str): - pass + self.auth = SharePointAuth() + self.session_token = session_token + + token_info = self.auth.get_token_info_from_session(session_token) + self.access_token = token_info.get('access_token') + self.refresh_token = token_info.get('refresh_token') + + if not self.access_token: + raise ValueError("No access token found in session") + + self.next_page_token = None + + def _get_headers(self) -> Dict[str, str]: + return { + 'Authorization': f'Bearer {self.access_token}', + 'Accept': 'application/json' + } + + def _ensure_valid_token(self): + if not self.access_token: + raise ValueError("No access token available") + + token_info = {'access_token': self.access_token, 'expiry': None} + if self.auth.is_token_expired(token_info): + logging.info("Token expired, attempting refresh") + try: + new_token_info = self.auth.refresh_access_token(self.refresh_token) + self.access_token = new_token_info.get('access_token') + logging.info("Token refreshed successfully") + except Exception as e: + logging.error(f"Failed to refresh token: {e}") + raise ValueError("Failed to refresh access token") + + def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]: + try: + drive_item_id = file_metadata.get('id') + file_name = file_metadata.get('name', 'Unknown') + file_data = file_metadata.get('file', {}) + mime_type = file_data.get('mimeType', 'application/octet-stream') + + if mime_type not in self.SUPPORTED_MIME_TYPES: + logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}") + return None + + doc_metadata = { + 'file_name': file_name, + 'mime_type': mime_type, + 'size': file_metadata.get('size'), + 'created_time': file_metadata.get('createdDateTime'), + 'modified_time': file_metadata.get('lastModifiedDateTime'), + 'source': 'share_point' + } + + if not load_content: + return Document( + text="", + doc_id=drive_item_id, + extra_info=doc_metadata + ) + + content = self._download_file_content(drive_item_id) + if content is None: + logging.warning(f"Could not load content for file {file_name} ({drive_item_id})") + return None + + return Document( + text=content, + doc_id=drive_item_id, + extra_info=doc_metadata + ) + + except Exception as e: + logging.error(f"Error processing file: {e}") + return None def load_data(self, inputs: Dict[str, Any]) -> List[Document]: - """ - Load documents from the external knowledge base. + try: + documents: List[Document] = [] - Args: - inputs: Configuration dictionary containing: - - file_ids: Optional list of specific file IDs to load - - folder_ids: Optional list of folder IDs to browse/download - - limit: Maximum number of items to return - - list_only: If True, return metadata without content - - recursive: Whether to recursively process folders + folder_id = inputs.get('folder_id') + file_ids = inputs.get('file_ids', []) + limit = inputs.get('limit', 100) + list_only = inputs.get('list_only', False) + load_content = not list_only + page_token = inputs.get('page_token') + search_query = inputs.get('search_query') + self.next_page_token = None - Returns: - List of Document objects - """ - pass + if file_ids: + for file_id in file_ids: + try: + doc = self._load_file_by_id(file_id, load_content=load_content) + if doc: + if not search_query or ( + search_query.lower() in doc.extra_info.get('file_name', '').lower() + ): + documents.append(doc) + except Exception as e: + logging.error(f"Error loading file {file_id}: {e}") + continue + else: + parent_id = folder_id if folder_id else 'root' + documents = self._list_items_in_parent( + parent_id, + limit=limit, + load_content=load_content, + page_token=page_token, + search_query=search_query + ) + + logging.info(f"Loaded {len(documents)} documents from SharePoint/OneDrive") + return documents + + except Exception as e: + logging.error(f"Error loading data from SharePoint/OneDrive: {e}", exc_info=True) + raise + + def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]: + self._ensure_valid_token() + + try: + url = f"{self.GRAPH_API_BASE}/me/drive/items/{file_id}" + params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'} + response = requests.get(url, headers=self._get_headers(), params=params) + response.raise_for_status() + + file_metadata = response.json() + return self._process_file(file_metadata, load_content=load_content) + + except requests.exceptions.HTTPError as e: + if e.response.status_code in [401, 403]: + logging.error(f"Authentication error loading file {file_id}") + try: + new_token_info = self.auth.refresh_access_token(self.refresh_token) + self.access_token = new_token_info.get('access_token') + response = requests.get(url, headers=self._get_headers(), params=params) + response.raise_for_status() + file_metadata = response.json() + return self._process_file(file_metadata, load_content=load_content) + except Exception as refresh_error: + raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}") + logging.error(f"HTTP error loading file {file_id}: {e}") + return None + except Exception as e: + logging.error(f"Error loading file {file_id}: {e}") + return None + + def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]: + self._ensure_valid_token() + + documents: List[Document] = [] + + try: + url = f"{self.GRAPH_API_BASE}/me/drive/items/{parent_id}/children" + params = {'$top': min(100, limit) if limit else 100, '$select': 'id,name,file,folder,createdDateTime,lastModifiedDateTime,size'} + if page_token: + params['$skipToken'] = page_token + + if search_query: + search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{search_query}')" + response = requests.get(search_url, headers=self._get_headers(), params=params) + else: + response = requests.get(url, headers=self._get_headers(), params=params) + + response.raise_for_status() + + results = response.json() + + items = results.get('value', []) + for item in items: + if 'folder' in item: + doc_metadata = { + 'file_name': item.get('name', 'Unknown'), + 'mime_type': 'folder', + 'size': item.get('size'), + 'created_time': item.get('createdDateTime'), + 'modified_time': item.get('lastModifiedDateTime'), + 'source': 'share_point', + 'is_folder': True + } + documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata)) + else: + doc = self._process_file(item, load_content=load_content) + if doc: + documents.append(doc) + + if limit and len(documents) >= limit: + break + + next_link = results.get('@odata.nextLink') + if next_link: + # Extract skiptoken from the full URL + from urllib.parse import urlparse, parse_qs + parsed = urlparse(next_link) + query_params = parse_qs(parsed.query) + skiptoken_list = query_params.get('$skiptoken') + if skiptoken_list: + self.next_page_token = skiptoken_list[0] + else: + self.next_page_token = None + else: + self.next_page_token = None + return documents + + except Exception as e: + logging.error(f"Error listing items under parent {parent_id}: {e}") + return documents + + def _download_file_content(self, file_id: str) -> Optional[str]: + self._ensure_valid_token() + + try: + url = f"{self.GRAPH_API_BASE}/me/drive/items/{file_id}/content" + response = requests.get(url, headers=self._get_headers()) + response.raise_for_status() + + content_bytes = response.content + + try: + content = content_bytes.decode('utf-8') + except UnicodeDecodeError: + try: + content = content_bytes.decode('latin-1') + except UnicodeDecodeError: + logging.error(f"Could not decode file {file_id} as text") + return None + + return content + + except requests.exceptions.HTTPError as e: + if e.response.status_code in [401, 403]: + logging.error(f"Authentication error downloading file {file_id}") + try: + new_token_info = self.auth.refresh_access_token(self.refresh_token) + self.access_token = new_token_info.get('access_token') + response = requests.get(url, headers=self._get_headers()) + response.raise_for_status() + content_bytes = response.content + try: + content = content_bytes.decode('utf-8') + except UnicodeDecodeError: + try: + content = content_bytes.decode('latin-1') + except UnicodeDecodeError: + logging.error(f"Could not decode file {file_id} as text") + return None + return content + except Exception as refresh_error: + raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}") + logging.error(f"HTTP error downloading file {file_id}: {e}") + return None + except Exception as e: + logging.error(f"Error downloading file {file_id}: {e}") + return None + + def _download_single_file(self, file_id: str, local_dir: str) -> bool: + try: + url = f"{self.GRAPH_API_BASE}/me/drive/items/{file_id}" + params = {'$select': 'id,name,file'} + response = requests.get(url, headers=self._get_headers(), params=params) + response.raise_for_status() + + metadata = response.json() + file_name = metadata.get('name', 'unknown') + file_data = metadata.get('file', {}) + mime_type = file_data.get('mimeType', 'application/octet-stream') + + if mime_type not in self.SUPPORTED_MIME_TYPES: + logging.info(f"Skipping unsupported file type: {mime_type}") + return False + + os.makedirs(local_dir, exist_ok=True) + full_path = os.path.join(local_dir, file_name) + + download_url = f"{self.GRAPH_API_BASE}/me/drive/items/{file_id}/content" + download_response = requests.get(download_url, headers=self._get_headers()) + download_response.raise_for_status() + + with open(full_path, 'wb') as f: + f.write(download_response.content) + + return True + except Exception as e: + logging.error(f"Error in _download_single_file: {e}") + return False + + def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int: + files_downloaded = 0 + try: + os.makedirs(local_dir, exist_ok=True) + + url = f"{self.GRAPH_API_BASE}/me/drive/items/{folder_id}/children" + params = {'$top': 1000} + + while url: + response = requests.get(url, headers=self._get_headers(), params=params) + response.raise_for_status() + + results = response.json() + items = results.get('value', []) + logging.info(f"Found {len(items)} items in folder {folder_id}") + + for item in items: + item_name = item.get('name', 'unknown') + item_id = item.get('id') + + if 'folder' in item: + if recursive: + subfolder_path = os.path.join(local_dir, item_name) + os.makedirs(subfolder_path, exist_ok=True) + subfolder_files = self._download_folder_recursive( + item_id, + subfolder_path, + recursive + ) + files_downloaded += subfolder_files + logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}") + else: + success = self._download_single_file(item_id, local_dir) + if success: + files_downloaded += 1 + logging.info(f"Downloaded file: {item_name}") + else: + logging.warning(f"Failed to download file: {item_name}") + + url = results.get('@odata.nextLink') + + return files_downloaded + + except Exception as e: + logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True) + return files_downloaded + + def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int: + try: + self._ensure_valid_token() + return self._download_folder_recursive(folder_id, local_dir, recursive) + except Exception as e: + logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True) + return 0 + + def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool: + try: + self._ensure_valid_token() + return self._download_single_file(file_id, local_dir) + except Exception as e: + logging.error(f"Error downloading file {file_id}: {e}", exc_info=True) + return False def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]: - """ - Download files/folders to a local directory. + if source_config is None: + source_config = {} - Args: - local_dir: Local directory path to download files to - source_config: Configuration for what to download + config = source_config if source_config else getattr(self, 'config', {}) + files_downloaded = 0 - Returns: - Dictionary containing download results: - - files_downloaded: Number of files downloaded - - directory_path: Path where files were downloaded - - empty_result: Whether no files were downloaded - - source_type: Type of connector - - config_used: Configuration that was used - - error: Error message if download failed (optional) - """ - pass + try: + folder_ids = config.get('folder_ids', []) + file_ids = config.get('file_ids', []) + recursive = config.get('recursive', True) + + if file_ids: + if isinstance(file_ids, str): + file_ids = [file_ids] + + for file_id in file_ids: + if self._download_file_to_directory(file_id, local_dir): + files_downloaded += 1 + + if folder_ids: + if isinstance(folder_ids, str): + folder_ids = [folder_ids] + + for folder_id in folder_ids: + try: + url = f"{self.GRAPH_API_BASE}/me/drive/items/{folder_id}" + params = {'$select': 'id,name'} + response = requests.get(url, headers=self._get_headers(), params=params) + response.raise_for_status() + + folder_metadata = response.json() + folder_name = folder_metadata.get('name', '') + folder_path = os.path.join(local_dir, folder_name) + os.makedirs(folder_path, exist_ok=True) + + folder_files = self._download_folder_recursive( + folder_id, + folder_path, + recursive + ) + files_downloaded += folder_files + logging.info(f"Downloaded {folder_files} files from folder {folder_name}") + except Exception as e: + logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True) + + if not file_ids and not folder_ids: + raise ValueError("No folder_ids or file_ids provided for download") + + return { + "files_downloaded": files_downloaded, + "directory_path": local_dir, + "empty_result": files_downloaded == 0, + "source_type": "share_point", + "config_used": config + } + + except Exception as e: + return { + "files_downloaded": files_downloaded, + "directory_path": local_dir, + "empty_result": True, + "source_type": "share_point", + "config_used": config, + "error": str(e) + } diff --git a/frontend/src/components/FilePicker.tsx b/frontend/src/components/FilePicker.tsx index e7c4b476..e31edf79 100644 --- a/frontend/src/components/FilePicker.tsx +++ b/frontend/src/components/FilePicker.tsx @@ -56,6 +56,10 @@ export const FilePicker: React.FC = ({ displayName: 'Drive', rootName: 'My Drive', }, + share_point: { + displayName: 'SharePoint', + rootName: 'My Files', + }, } as const; const getProviderConfig = (provider: string) => { diff --git a/frontend/src/components/MessageInput.tsx b/frontend/src/components/MessageInput.tsx index 19d7afbb..4fabb699 100644 --- a/frontend/src/components/MessageInput.tsx +++ b/frontend/src/components/MessageInput.tsx @@ -11,7 +11,6 @@ import ClipIcon from '../assets/clip.svg'; import DragFileUpload from '../assets/DragFileUpload.svg'; import ExitIcon from '../assets/exit.svg'; import SendArrowIcon from './SendArrowIcon'; -import SendArrowIcon from './SendArrowIcon'; import SourceIcon from '../assets/source.svg'; import DocumentationDark from '../assets/documentation-dark.svg'; import ToolIcon from '../assets/tool.svg'; diff --git a/frontend/src/upload/Upload.tsx b/frontend/src/upload/Upload.tsx index c5f840f3..aef973d8 100644 --- a/frontend/src/upload/Upload.tsx +++ b/frontend/src/upload/Upload.tsx @@ -254,7 +254,22 @@ function Upload({ /> ); case 'share_point_picker': - return ; + return ( + { + setSelectedFiles(selectedFileIds); + setSelectedFolders(selectedFolderIds); + }} + provider="share_point" + token={token} + initialSelectedFiles={selectedFiles} + initialSelectedFolders={selectedFolders} + /> + ); default: return null; } @@ -537,6 +552,9 @@ function Upload({ const hasGoogleDrivePicker = schema.some( (field: FormField) => field.type === 'google_drive_picker', ); + const hasSharePointPicker = schema.some( + (field: FormField) => field.type === 'share_point_picker', + ); let configData: Record = { ...ingestor.config }; @@ -544,7 +562,7 @@ function Upload({ files.forEach((file) => { formData.append('file', file); }); - } else if (hasRemoteFilePicker || hasGoogleDrivePicker) { + } else if (hasRemoteFilePicker || hasGoogleDrivePicker || hasSharePointPicker) { const sessionToken = getSessionToken(ingestor.type as string); configData = { provider: ingestor.type as string, @@ -720,12 +738,15 @@ function Upload({ const hasGoogleDrivePicker = schema.some( (field: FormField) => field.type === 'google_drive_picker', ); + const hasSharePointPicker = schema.some( + (field: FormField) => field.type === 'share_point_picker', + ); if (hasLocalFilePicker) { if (files.length === 0) { return true; } - } else if (hasRemoteFilePicker || hasGoogleDrivePicker) { + } else if (hasRemoteFilePicker || hasGoogleDrivePicker || hasSharePointPicker) { if (selectedFiles.length === 0 && selectedFolders.length === 0) { return true; } diff --git a/tests/parser/remote/test_share_point_loader.py b/tests/parser/remote/test_share_point_loader.py new file mode 100644 index 00000000..042130ee --- /dev/null +++ b/tests/parser/remote/test_share_point_loader.py @@ -0,0 +1,197 @@ +"""Tests for SharePoint loader.""" + +import pytest +from unittest.mock import patch, MagicMock + +from application.parser.connectors.share_point.loader import SharePointLoader + + +def make_response(json_data=None, status_code=200, raise_error=None): + resp = MagicMock() + resp.status_code = status_code + resp.json.return_value = json_data + resp.content = b"test content" + if raise_error is not None: + resp.raise_for_status.side_effect = raise_error + else: + resp.raise_for_status.return_value = None + return resp + + +class TestSharePointLoaderProcessFile: + """Test _process_file method.""" + + def test_size_retrieved_from_root_level(self): + """Should retrieve size from root of file_metadata, not nested file object.""" + loader = SharePointLoader.__new__(SharePointLoader) + + file_metadata = { + "id": "test-id", + "name": "test.txt", + "createdDateTime": "2024-01-01T00:00:00Z", + "lastModifiedDateTime": "2024-01-01T00:00:00Z", + "size": 1024, + "file": { + "mimeType": "text/plain" + } + } + + doc = loader._process_file(file_metadata, load_content=False) + + assert doc is not None + assert doc.extra_info["size"] == 1024 + assert doc.extra_info["file_name"] == "test.txt" + assert doc.extra_info["mime_type"] == "text/plain" + + def test_size_null_when_missing(self): + """Should return None when size field is missing.""" + loader = SharePointLoader.__new__(SharePointLoader) + + file_metadata = { + "id": "test-id", + "name": "test.txt", + "createdDateTime": "2024-01-01T00:00:00Z", + "lastModifiedDateTime": "2024-01-01T00:00:00Z", + "file": { + "mimeType": "text/plain" + } + } + + doc = loader._process_file(file_metadata, load_content=False) + + assert doc is not None + assert doc.extra_info["size"] is None + + +class TestSharePointLoaderLoadFileById: + """Test _load_file_by_id method.""" + + @patch("application.parser.connectors.share_point.loader.requests.get") + @patch("application.parser.connectors.share_point.loader.SharePointAuth.get_token_info_from_session") + @patch("application.parser.connectors.share_point.loader.SharePointLoader._ensure_valid_token") + def test_load_file_by_id_includes_size_in_select(self, mock_ensure_token, mock_get_token, mock_get): + """Should include size field in $select parameter.""" + mock_get_token.return_value = { + "access_token": "test-token", + "refresh_token": "test-refresh" + } + mock_get.return_value = make_response({ + "id": "test-id", + "name": "test.txt", + "createdDateTime": "2024-01-01T00:00:00Z", + "lastModifiedDateTime": "2024-01-01T00:00:00Z", + "size": 2048, + "file": { + "mimeType": "text/plain" + } + }) + + loader = SharePointLoader("test-session") + doc = loader._load_file_by_id("test-id", load_content=False) + + assert doc is not None + assert doc.extra_info["size"] == 2048 + + call_args = mock_get.call_args + params = call_args[1]["params"] + assert "size" in params["$select"] + + @patch("application.parser.connectors.share_point.loader.requests.get") + @patch("application.parser.connectors.share_point.loader.SharePointAuth.get_token_info_from_session") + @patch("application.parser.connectors.share_point.loader.SharePointLoader._ensure_valid_token") + def test_load_file_by_id_returns_document_with_size(self, mock_ensure_token, mock_get_token, mock_get): + """Should return document with size from API response.""" + mock_get_token.return_value = { + "access_token": "test-token", + "refresh_token": "test-refresh" + } + mock_get.return_value = make_response({ + "id": "test-id", + "name": "document.pdf", + "createdDateTime": "2024-01-01T00:00:00Z", + "lastModifiedDateTime": "2024-06-15T10:30:00Z", + "size": 56789, + "file": { + "mimeType": "application/pdf" + } + }) + + loader = SharePointLoader("test-session") + doc = loader._load_file_by_id("test-id", load_content=False) + + assert doc is not None + assert doc.doc_id == "test-id" + assert doc.extra_info["file_name"] == "document.pdf" + assert doc.extra_info["mime_type"] == "application/pdf" + assert doc.extra_info["size"] == 56789 + assert doc.extra_info["created_time"] == "2024-01-01T00:00:00Z" + assert doc.extra_info["modified_time"] == "2024-06-15T10:30:00Z" + assert doc.extra_info["source"] == "share_point" + + +class TestSharePointLoaderListItems: + """Test _list_items_in_parent method.""" + + @patch("application.parser.connectors.share_point.loader.requests.get") + @patch("application.parser.connectors.share_point.loader.SharePointAuth.get_token_info_from_session") + @patch("application.parser.connectors.share_point.loader.SharePointLoader._ensure_valid_token") + def test_list_items_includes_size_in_select(self, mock_ensure_token, mock_get_token, mock_get): + """Should include size field in $select parameter when listing items.""" + mock_get_token.return_value = { + "access_token": "test-token", + "refresh_token": "test-refresh" + } + mock_get.return_value = make_response({ + "value": [ + { + "id": "file-1", + "name": "file1.txt", + "createdDateTime": "2024-01-01T00:00:00Z", + "lastModifiedDateTime": "2024-01-01T00:00:00Z", + "size": 12345, + "file": { + "mimeType": "text/plain" + } + } + ] + }) + + loader = SharePointLoader("test-session") + docs = loader._list_items_in_parent("parent-id", limit=10, load_content=False) + + assert len(docs) == 1 + assert docs[0].extra_info["size"] == 12345 + + call_args = mock_get.call_args + params = call_args[1]["params"] + assert "size" in params["$select"] + + @patch("application.parser.connectors.share_point.loader.requests.get") + @patch("application.parser.connectors.share_point.loader.SharePointAuth.get_token_info_from_session") + @patch("application.parser.connectors.share_point.loader.SharePointLoader._ensure_valid_token") + def test_list_items_folders_include_size(self, mock_ensure_token, mock_get_token, mock_get): + """Should include size for folders as well.""" + mock_get_token.return_value = { + "access_token": "test-token", + "refresh_token": "test-refresh" + } + mock_get.return_value = make_response({ + "value": [ + { + "id": "folder-1", + "name": "MyFolder", + "createdDateTime": "2024-01-01T00:00:00Z", + "lastModifiedDateTime": "2024-01-01T00:00:00Z", + "size": 0, + "folder": {} + } + ] + }) + + loader = SharePointLoader("test-session") + docs = loader._list_items_in_parent("parent-id", limit=10, load_content=False) + + assert len(docs) == 1 + assert docs[0].extra_info["is_folder"] is True + assert docs[0].extra_info["size"] == 0 +