diff --git a/.env-template b/.env-template index 8b53112c..13575fc3 100644 --- a/.env-template +++ b/.env-template @@ -12,4 +12,17 @@ EMBEDDINGS_KEY= OPENAI_API_BASE= OPENAI_API_VERSION= AZURE_DEPLOYMENT_NAME= -AZURE_EMBEDDINGS_DEPLOYMENT_NAME= \ No newline at end of file +AZURE_EMBEDDINGS_DEPLOYMENT_NAME= + +#Azure AD Application (client) ID +MICROSOFT_CLIENT_ID=your-azure-ad-client-id +#Azure AD Application client secret +MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret +#Azure AD Tenant ID (or 'common' for multi-tenant) +MICROSOFT_TENANT_ID=your-azure-ad-tenant-id +#If you are using a Microsoft Entra ID tenant, +#configure the AUTHORITY variable as +#"https://login.microsoftonline.com/TENANT_GUID" +#or "https://login.microsoftonline.com/contoso.onmicrosoft.com". +#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app. +MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId} diff --git a/application/agents/react_agent.py b/application/agents/react_agent.py index 116fa4aa..92be75f6 100644 --- a/application/agents/react_agent.py +++ b/application/agents/react_agent.py @@ -235,4 +235,4 @@ class ReActAgent(BaseAgent): ) except Exception as e: logger.error(f"Error extracting content: {e}") - return "".join(collected) + return "".join(collected) \ No newline at end of file diff --git a/application/api/connector/routes.py b/application/api/connector/routes.py index 91fc3f0b..913e5349 100644 --- a/application/api/connector/routes.py +++ b/application/api/connector/routes.py @@ -146,20 +146,19 @@ class ConnectorsCallback(Resource): session_token = str(uuid.uuid4()) try: - credentials = auth.create_credentials_from_token_info(token_info) - service = auth.build_drive_service(credentials) - user_info = service.about().get(fields="user").execute() - user_email = user_info.get('user', {}).get('emailAddress', 'Connected User') + if provider == "google_drive": + credentials = auth.create_credentials_from_token_info(token_info) + service = auth.build_drive_service(credentials) + user_info = service.about().get(fields="user").execute() + user_email = user_info.get('user', {}).get('emailAddress', 'Connected User') + else: + user_email = token_info.get('user_info', {}).get('email', 'Connected User') + except Exception as e: current_app.logger.warning(f"Could not get user info: {e}") user_email = 'Connected User' - sanitized_token_info = { - "access_token": token_info.get("access_token"), - "refresh_token": token_info.get("refresh_token"), - "token_uri": token_info.get("token_uri"), - "expiry": token_info.get("expiry") - } + sanitized_token_info = auth.sanitize_token_info(token_info) sessions_collection.find_one_and_update( {"_id": ObjectId(state_object_id), "provider": provider}, @@ -201,12 +200,12 @@ class ConnectorsCallback(Resource): @connectors_ns.route("/api/connectors/files") class ConnectorFiles(Resource): @api.expect(api.model("ConnectorFilesModel", { - "provider": fields.String(required=True), - "session_token": fields.String(required=True), - "folder_id": fields.String(required=False), - "limit": fields.Integer(required=False), + "provider": fields.String(required=True), + "session_token": fields.String(required=True), + "folder_id": fields.String(required=False), + "limit": fields.Integer(required=False), "page_token": fields.String(required=False), - "search_query": fields.String(required=False) + "search_query": fields.String(required=False), })) @api.doc(description="List files from a connector provider (supports pagination and search)") def post(self): @@ -214,11 +213,8 @@ class ConnectorFiles(Resource): data = request.get_json() provider = data.get('provider') session_token = data.get('session_token') - folder_id = data.get('folder_id') limit = data.get('limit', 10) - page_token = data.get('page_token') - search_query = data.get('search_query') - + if not provider or not session_token: return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400) @@ -231,15 +227,12 @@ class ConnectorFiles(Resource): return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401) loader = ConnectorCreator.create_connector(provider, session_token) + + generic_keys = {'provider', 'session_token'} input_config = { - 'limit': limit, - 'list_only': True, - 'session_token': session_token, - 'folder_id': folder_id, - 'page_token': page_token + k: v for k, v in data.items() if k not in generic_keys } - if search_query: - input_config['search_query'] = search_query + input_config['list_only'] = True documents = loader.load_data(input_config) @@ -306,12 +299,7 @@ class ConnectorValidateSession(Resource): if is_expired and token_info.get('refresh_token'): try: refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token')) - sanitized_token_info = { - "access_token": refreshed_token_info.get("access_token"), - "refresh_token": refreshed_token_info.get("refresh_token"), - "token_uri": refreshed_token_info.get("token_uri"), - "expiry": refreshed_token_info.get("expiry") - } + sanitized_token_info = auth.sanitize_token_info(refreshed_token_info) sessions_collection.update_one( {"session_token": session_token}, {"$set": {"token_info": sanitized_token_info}} @@ -328,12 +316,18 @@ class ConnectorValidateSession(Resource): "error": "Session token has expired. Please reconnect." }), 401) - return make_response(jsonify({ + _base_fields = {"access_token", "refresh_token", "token_uri", "expiry"} + provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields} + + response_data = { "success": True, "expired": False, "user_email": session.get('user_email', 'Connected User'), - "access_token": token_info.get('access_token') - }), 200) + "access_token": token_info.get('access_token'), + **provider_extras, + } + + return make_response(jsonify(response_data), 200) except Exception as e: current_app.logger.error(f"Error validating connector session: {e}", exc_info=True) return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500) diff --git a/application/api/user/agents/routes.py b/application/api/user/agents/routes.py index 64fa7bed..8f313a7e 100644 --- a/application/api/user/agents/routes.py +++ b/application/api/user/agents/routes.py @@ -1412,4 +1412,4 @@ class RemoveSharedAgent(Resource): current_app.logger.error(f"Error removing shared agent: {err}") return make_response( jsonify({"success": False, "message": "Server error"}), 500 - ) + ) \ No newline at end of file diff --git a/application/core/settings.py b/application/core/settings.py index 5c424074..5cdf7f09 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -65,8 +65,14 @@ class Settings(BaseSettings): "http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp) ) + # Microsoft Entra ID (Azure AD) integration + MICROSOFT_CLIENT_ID: Optional[str] = None # Azure AD Application (client) ID + MICROSOFT_CLIENT_SECRET: Optional[str] = None # Azure AD Application client secret + MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant) + MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}" + # GitHub source - GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access + GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access # LLM Cache CACHE_REDIS_URL: str = "redis://localhost:6379/2" diff --git a/application/parser/connectors/base.py b/application/parser/connectors/base.py index dfb6de87..b9b7f78f 100644 --- a/application/parser/connectors/base.py +++ b/application/parser/connectors/base.py @@ -62,15 +62,26 @@ class BaseConnectorAuth(ABC): def is_token_expired(self, token_info: Dict[str, Any]) -> bool: """ Check if a token is expired. - + Args: token_info: Token information dictionary - + Returns: True if token is expired, False otherwise """ pass + def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]: + """Extract the fields safe to persist in the session store. + """ + return { + "access_token": token_info.get("access_token"), + "refresh_token": token_info.get("refresh_token"), + "token_uri": token_info.get("token_uri"), + "expiry": token_info.get("expiry"), + **extra_fields, + } + class BaseConnectorLoader(ABC): """ diff --git a/application/parser/connectors/connector_creator.py b/application/parser/connectors/connector_creator.py index bf4456ca..609e6407 100644 --- a/application/parser/connectors/connector_creator.py +++ b/application/parser/connectors/connector_creator.py @@ -1,5 +1,7 @@ from application.parser.connectors.google_drive.loader import GoogleDriveLoader from application.parser.connectors.google_drive.auth import GoogleDriveAuth +from application.parser.connectors.share_point.auth import SharePointAuth +from application.parser.connectors.share_point.loader import SharePointLoader class ConnectorCreator: @@ -12,10 +14,12 @@ class ConnectorCreator: connectors = { "google_drive": GoogleDriveLoader, + "share_point": SharePointLoader, } auth_providers = { "google_drive": GoogleDriveAuth, + "share_point": SharePointAuth, } @classmethod diff --git a/application/parser/connectors/google_drive/auth.py b/application/parser/connectors/google_drive/auth.py index f5fbe056..e24368c9 100644 --- a/application/parser/connectors/google_drive/auth.py +++ b/application/parser/connectors/google_drive/auth.py @@ -232,10 +232,6 @@ class GoogleDriveAuth(BaseConnectorAuth): if missing_fields: raise ValueError(f"Missing required token fields: {missing_fields}") - if 'client_id' not in token_info: - token_info['client_id'] = settings.GOOGLE_CLIENT_ID - if 'client_secret' not in token_info: - token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET if 'token_uri' not in token_info: token_info['token_uri'] = 'https://oauth2.googleapis.com/token' diff --git a/application/parser/connectors/google_drive/loader.py b/application/parser/connectors/google_drive/loader.py index c96a08be..9605517c 100644 --- a/application/parser/connectors/google_drive/loader.py +++ b/application/parser/connectors/google_drive/loader.py @@ -327,15 +327,10 @@ class GoogleDriveLoader(BaseConnectorLoader): content_bytes = file_io.getvalue() try: - content = content_bytes.decode('utf-8') + return content_bytes.decode('utf-8') except UnicodeDecodeError: - try: - content = content_bytes.decode('latin-1') - except UnicodeDecodeError: - logging.error(f"Could not decode file {file_id} as text") - return None - - return content + logging.error(f"Could not decode file {file_id} as text") + return None except HttpError as e: logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}") diff --git a/application/parser/connectors/share_point/__init__.py b/application/parser/connectors/share_point/__init__.py new file mode 100644 index 00000000..b83bb56f --- /dev/null +++ b/application/parser/connectors/share_point/__init__.py @@ -0,0 +1,10 @@ +""" +Share Point connector package for DocsGPT. + +This module provides authentication and document loading capabilities for Share Point. +""" + +from .auth import SharePointAuth +from .loader import SharePointLoader + +__all__ = ['SharePointAuth', 'SharePointLoader'] \ No newline at end of file diff --git a/application/parser/connectors/share_point/auth.py b/application/parser/connectors/share_point/auth.py new file mode 100644 index 00000000..1da894b1 --- /dev/null +++ b/application/parser/connectors/share_point/auth.py @@ -0,0 +1,152 @@ +import datetime +import logging +from typing import Optional, Dict, Any + +from msal import ConfidentialClientApplication + +from application.core.settings import settings +from application.parser.connectors.base import BaseConnectorAuth + +logger = logging.getLogger(__name__) + + +class SharePointAuth(BaseConnectorAuth): + """ + Handles Microsoft OAuth 2.0 authentication for SharePoint/OneDrive. + + Note: Files.Read scope allows access to files the user has granted access to, + similar to Google Drive's drive.file scope. + """ + + SCOPES = [ + "Files.Read", + "Sites.Read.All", + "User.Read", + ] + + def __init__(self): + self.client_id = settings.MICROSOFT_CLIENT_ID + self.client_secret = settings.MICROSOFT_CLIENT_SECRET + + if not self.client_id: + raise ValueError( + "Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID in settings." + ) + + if not self.client_secret: + raise ValueError( + "Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_SECRET in settings." + ) + + self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI + self.tenant_id = settings.MICROSOFT_TENANT_ID + self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://login.microsoftonline.com/{self.tenant_id}") + + self.auth_app = ConfidentialClientApplication( + client_id=self.client_id, + client_credential=self.client_secret, + authority=self.authority + ) + + def get_authorization_url(self, state: Optional[str] = None) -> str: + return self.auth_app.get_authorization_request_url( + scopes=self.SCOPES, state=state, redirect_uri=self.redirect_uri + ) + + def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]: + result = self.auth_app.acquire_token_by_authorization_code( + code=authorization_code, + scopes=self.SCOPES, + redirect_uri=self.redirect_uri + ) + + if "error" in result: + logger.error("Token exchange failed: %s", result.get("error_description")) + raise ValueError(f"Error acquiring token: {result.get('error_description')}") + + return self.map_token_response(result) + + def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]: + result = self.auth_app.acquire_token_by_refresh_token(refresh_token=refresh_token, scopes=self.SCOPES) + + if "error" in result: + logger.error("Token refresh failed: %s", result.get("error_description")) + raise ValueError(f"Error refreshing token: {result.get('error_description')}") + + return self.map_token_response(result) + + def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]: + try: + from application.core.mongo_db import MongoDB + from application.core.settings import settings + + mongo = MongoDB.get_client() + db = mongo[settings.MONGO_DB_NAME] + + sessions_collection = db["connector_sessions"] + session = sessions_collection.find_one({"session_token": session_token}) + + if not session: + raise ValueError(f"Invalid session token: {session_token}") + + if "token_info" not in session: + raise ValueError("Session missing token information") + + token_info = session["token_info"] + if not token_info: + raise ValueError("Invalid token information") + + required_fields = ["access_token", "refresh_token"] + missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)] + if missing_fields: + raise ValueError(f"Missing required token fields: {missing_fields}") + + if 'token_uri' not in token_info: + token_info['token_uri'] = f"https://login.microsoftonline.com/{settings.MICROSOFT_TENANT_ID}/oauth2/v2.0/token" + + return token_info + + except Exception as e: + logger.error("Failed to retrieve token from session: %s", e) + raise ValueError(f"Failed to retrieve SharePoint token information: {str(e)}") + + def is_token_expired(self, token_info: Dict[str, Any]) -> bool: + if not token_info: + return True + + expiry_timestamp = token_info.get("expiry") + + if expiry_timestamp is None: + return True + + current_timestamp = int(datetime.datetime.now().timestamp()) + return (expiry_timestamp - current_timestamp) < 60 + + def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]: + return super().sanitize_token_info( + token_info, + allows_shared_content=token_info.get("allows_shared_content", False), + **extra_fields, + ) + + PERSONAL_ACCOUNT_TENANT_ID = "9188040d-6c67-4c5b-b112-36a304b66dad" + + def _allows_shared_content(self, id_token_claims: Dict[str, Any]) -> bool: + """Return True when the account is a work/school tenant that can access SharePoint shared content.""" + tid = id_token_claims.get("tid", "") + return bool(tid) and tid != self.PERSONAL_ACCOUNT_TENANT_ID + + def map_token_response(self, result) -> Dict[str, Any]: + claims = result.get("id_token_claims", {}) + return { + "access_token": result.get("access_token"), + "refresh_token": result.get("refresh_token"), + "token_uri": claims.get("iss"), + "scopes": result.get("scope"), + "expiry": claims.get("exp"), + "allows_shared_content": self._allows_shared_content(claims), + "user_info": { + "name": claims.get("name"), + "email": claims.get("preferred_username"), + }, + } diff --git a/application/parser/connectors/share_point/loader.py b/application/parser/connectors/share_point/loader.py new file mode 100644 index 00000000..191e3e54 --- /dev/null +++ b/application/parser/connectors/share_point/loader.py @@ -0,0 +1,649 @@ +""" +SharePoint/OneDrive loader for DocsGPT. +Loads documents from SharePoint/OneDrive using Microsoft Graph API. +""" + +import functools +import logging +import os +from typing import List, Dict, Any, Optional, Tuple +from urllib.parse import quote + +import requests + +from application.parser.connectors.base import BaseConnectorLoader +from application.parser.connectors.share_point.auth import SharePointAuth +from application.parser.schema.base import Document + + +def _retry_on_auth_failure(func): + """Retry once after refreshing the access token on 401/403 responses.""" + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except requests.exceptions.HTTPError as e: + if e.response is not None and e.response.status_code in (401, 403): + logging.info(f"Auth failure in {func.__name__}, refreshing token and retrying") + try: + new_token_info = self.auth.refresh_access_token(self.refresh_token) + self.access_token = new_token_info.get('access_token') + except Exception as refresh_error: + raise ValueError( + f"Authentication failed and could not be refreshed: {refresh_error}" + ) from e + return func(self, *args, **kwargs) + raise + return wrapper + + +class SharePointLoader(BaseConnectorLoader): + + SUPPORTED_MIME_TYPES = { + 'application/pdf': '.pdf', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx', + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx', + 'application/msword': '.doc', + 'application/vnd.ms-powerpoint': '.ppt', + 'application/vnd.ms-excel': '.xls', + 'text/plain': '.txt', + 'text/csv': '.csv', + 'text/html': '.html', + 'text/markdown': '.md', + 'text/x-rst': '.rst', + 'application/json': '.json', + 'application/epub+zip': '.epub', + 'application/rtf': '.rtf', + 'image/jpeg': '.jpg', + 'image/png': '.png', + } + + EXTENSION_TO_MIME = {v: k for k, v in SUPPORTED_MIME_TYPES.items()} + + GRAPH_API_BASE = "https://graph.microsoft.com/v1.0" + + def __init__(self, session_token: str): + self.auth = SharePointAuth() + self.session_token = session_token + + token_info = self.auth.get_token_info_from_session(session_token) + self.access_token = token_info.get('access_token') + self.refresh_token = token_info.get('refresh_token') + self.allows_shared_content = token_info.get('allows_shared_content', False) + + if not self.access_token: + raise ValueError("No access token found in session") + + self.next_page_token = None + + def _get_headers(self) -> Dict[str, str]: + return { + 'Authorization': f'Bearer {self.access_token}', + 'Accept': 'application/json' + } + + def _ensure_valid_token(self): + if not self.access_token: + raise ValueError("No access token available") + + token_info = {'access_token': self.access_token, 'expiry': None} + if self.auth.is_token_expired(token_info): + logging.info("Token expired, attempting refresh") + try: + new_token_info = self.auth.refresh_access_token(self.refresh_token) + self.access_token = new_token_info.get('access_token') + except Exception: + raise ValueError("Failed to refresh access token") + + def _get_item_url(self, item_ref: str) -> str: + if ':' in item_ref: + drive_id, item_id = item_ref.split(':', 1) + return f"{self.GRAPH_API_BASE}/drives/{drive_id}/items/{item_id}" + return f"{self.GRAPH_API_BASE}/me/drive/items/{item_ref}" + + def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]: + try: + drive_item_id = file_metadata.get('id') + file_name = file_metadata.get('name', 'Unknown') + file_data = file_metadata.get('file', {}) + mime_type = file_data.get('mimeType', 'application/octet-stream') + + if mime_type not in self.SUPPORTED_MIME_TYPES: + logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}") + return None + + doc_metadata = { + 'file_name': file_name, + 'mime_type': mime_type, + 'size': file_metadata.get('size'), + 'created_time': file_metadata.get('createdDateTime'), + 'modified_time': file_metadata.get('lastModifiedDateTime'), + 'source': 'share_point' + } + + if not load_content: + return Document( + text="", + doc_id=drive_item_id, + extra_info=doc_metadata + ) + + content = self._download_file_content(drive_item_id) + if content is None: + logging.warning(f"Could not load content for file {file_name} ({drive_item_id})") + return None + + return Document( + text=content, + doc_id=drive_item_id, + extra_info=doc_metadata + ) + + except Exception as e: + logging.error(f"Error processing file: {e}") + return None + + def load_data(self, inputs: Dict[str, Any]) -> List[Document]: + try: + documents: List[Document] = [] + + folder_id = inputs.get('folder_id') + file_ids = inputs.get('file_ids', []) + limit = inputs.get('limit', 100) + list_only = inputs.get('list_only', False) + load_content = not list_only + page_token = inputs.get('page_token') + search_query = inputs.get('search_query') + self.next_page_token = None + + shared = inputs.get('shared', False) + + if file_ids: + for file_id in file_ids: + try: + doc = self._load_file_by_id(file_id, load_content=load_content) + if doc: + if not search_query or ( + search_query.lower() in doc.extra_info.get('file_name', '').lower() + ): + documents.append(doc) + except Exception as e: + logging.error(f"Error loading file {file_id}: {e}") + continue + elif shared: + if not self.allows_shared_content: + logging.warning("Shared content is only available for work/school Microsoft accounts") + return [] + documents = self._list_shared_items( + limit=limit, + load_content=load_content, + page_token=page_token, + search_query=search_query + ) + else: + parent_id = folder_id if folder_id else 'root' + documents = self._list_items_in_parent( + parent_id, + limit=limit, + load_content=load_content, + page_token=page_token, + search_query=search_query + ) + + logging.info(f"Loaded {len(documents)} documents from SharePoint/OneDrive") + return documents + + except Exception as e: + logging.error(f"Error loading data from SharePoint/OneDrive: {e}", exc_info=True) + raise + + @_retry_on_auth_failure + def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]: + self._ensure_valid_token() + + try: + url = self._get_item_url(file_id) + params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'} + response = requests.get(url, headers=self._get_headers(), params=params) + response.raise_for_status() + + file_metadata = response.json() + return self._process_file(file_metadata, load_content=load_content) + + except requests.exceptions.HTTPError: + raise + except Exception as e: + logging.error(f"Error loading file {file_id}: {e}") + return None + + @_retry_on_auth_failure + def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]: + self._ensure_valid_token() + + documents: List[Document] = [] + + try: + url = f"{self._get_item_url(parent_id)}/children" + params = {'$top': min(100, limit) if limit else 100, '$select': 'id,name,file,folder,createdDateTime,lastModifiedDateTime,size'} + if page_token: + params['$skipToken'] = page_token + + if search_query: + encoded_query = quote(search_query, safe='') + if ':' in parent_id: + drive_id = parent_id.split(':', 1)[0] + search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')" + else: + search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')" + response = requests.get(search_url, headers=self._get_headers(), params=params) + else: + response = requests.get(url, headers=self._get_headers(), params=params) + + response.raise_for_status() + + results = response.json() + + items = results.get('value', []) + for item in items: + if 'folder' in item: + doc_metadata = { + 'file_name': item.get('name', 'Unknown'), + 'mime_type': 'folder', + 'size': item.get('size'), + 'created_time': item.get('createdDateTime'), + 'modified_time': item.get('lastModifiedDateTime'), + 'source': 'share_point', + 'is_folder': True + } + documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata)) + else: + doc = self._process_file(item, load_content=load_content) + if doc: + documents.append(doc) + + if limit and len(documents) >= limit: + break + + next_link = results.get('@odata.nextLink') + if next_link: + from urllib.parse import urlparse, parse_qs + parsed = urlparse(next_link) + query_params = parse_qs(parsed.query) + skiptoken_list = query_params.get('$skiptoken') + if skiptoken_list: + self.next_page_token = skiptoken_list[0] + else: + self.next_page_token = None + else: + self.next_page_token = None + return documents + + except Exception as e: + logging.error(f"Error listing items under parent {parent_id}: {e}") + return documents + + + + + def _resolve_mime_type(self, resource: Dict[str, Any]) -> Tuple[str, bool]: + """Resolve mime type from resource, falling back to file extension.""" + file_data = resource.get('file', {}) + mime_type = file_data.get('mimeType') if file_data else None + + if mime_type and mime_type in self.SUPPORTED_MIME_TYPES: + return mime_type, True + + name = resource.get('name', '') + ext = os.path.splitext(name)[1].lower() + if ext in self.EXTENSION_TO_MIME: + return self.EXTENSION_TO_MIME[ext], True + + return mime_type or 'application/octet-stream', False + + def _get_user_drive_web_url(self) -> Optional[str]: + """Fetch the current user's OneDrive web URL for KQL path exclusion.""" + try: + response = requests.get( + f"{self.GRAPH_API_BASE}/me/drive", + headers=self._get_headers(), + params={'$select': 'webUrl'} + ) + response.raise_for_status() + return response.json().get('webUrl') + except Exception as e: + logging.warning(f"Could not fetch user drive web URL: {e}") + return None + + def _build_shared_kql_query(self, search_query: Optional[str], user_drive_url: Optional[str]) -> str: + """Build KQL query string that excludes the user's own drive items.""" + base_query = search_query if search_query else "*" + if user_drive_url: + return f'{base_query} AND -path:"{user_drive_url}"' + return base_query + + def _list_shared_items(self, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]: + """Fetch shared drive items using Microsoft Graph Search API with local offset paging. + + We always fetch up to a fixed maximum number of hits from Graph (single request), + then page through that array locally using `page_token` as a simple integer offset. + This avoids relying on buggy or inconsistent remote `from`/`size` semantics. + """ + self._ensure_valid_token() + documents: List[Document] = [] + + try: + user_drive_url = self._get_user_drive_web_url() + query_text = self._build_shared_kql_query(search_query, user_drive_url) + + url = f"{self.GRAPH_API_BASE}/search/query" + page_size = 500 # maximum number of hits we care about for selection + + body = { + "requests": [ + { + "entityTypes": ["driveItem"], + "query": {"queryString": query_text}, + "from": 0, + "size": page_size, + } + ] + } + + headers = self._get_headers() + headers["Content-Type"] = "application/json" + response = requests.post(url, headers=headers, json=body) + response.raise_for_status() + results = response.json() + + search_response = results.get("value", []) + if not search_response: + logging.warning("Search API returned empty value array") + self.next_page_token = None + return documents + + hits_containers = search_response[0].get("hitsContainers", []) + if not hits_containers: + logging.warning("Search API returned no hitsContainers") + self.next_page_token = None + return documents + + container = hits_containers[0] + total = container.get("total", 0) + raw_hits = container.get("hits", []) + + # Deduplicate by effective item ID (driveId:itemId) to avoid the same + # resource appearing multiple times across the result set. + deduped_hits = [] + seen_ids = set() + for hit in raw_hits: + resource = hit.get("resource", {}) + item_id = resource.get("id") + drive_id = resource.get("parentReference", {}).get("driveId") + effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id + if not effective_id or effective_id in seen_ids: + continue + seen_ids.add(effective_id) + deduped_hits.append(hit) + + hits = deduped_hits + logging.info( + f"Search API returned {total} total results, {len(raw_hits)} raw hits, {len(hits)} unique hits in this batch" + ) + try: + offset = int(page_token) if page_token is not None else 0 + except (TypeError, ValueError): + logging.warning( + f"Invalid page_token '{page_token}' for shared items search, defaulting to 0" + ) + offset = 0 + + if offset < 0: + offset = 0 + if offset >= len(hits): + self.next_page_token = None + return documents + + end_index = offset + limit if limit else len(hits) + end_index = min(end_index, len(hits)) + + for hit in hits[offset:end_index]: + resource = hit.get("resource", {}) + item_name = resource.get("name", "Unknown") + item_id = resource.get("id") + drive_id = resource.get("parentReference", {}).get("driveId") + + effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id + + is_folder = "folder" in resource + + if is_folder: + doc_metadata = { + "file_name": item_name, + "mime_type": "folder", + "size": resource.get("size"), + "created_time": resource.get("createdDateTime"), + "modified_time": resource.get("lastModifiedDateTime"), + "source": "share_point", + "is_folder": True, + } + documents.append( + Document(text="", doc_id=effective_id, extra_info=doc_metadata) + ) + else: + mime_type, supported = self._resolve_mime_type(resource) + if not supported: + logging.info( + f"Skipping unsupported shared file: {item_name} (mime: {mime_type})" + ) + continue + + doc_metadata = { + "file_name": item_name, + "mime_type": mime_type, + "size": resource.get("size"), + "created_time": resource.get("createdDateTime"), + "modified_time": resource.get("lastModifiedDateTime"), + "source": "share_point", + } + + content = "" + if load_content: + content = self._download_file_content(effective_id) or "" + + documents.append( + Document(text=content, doc_id=effective_id, extra_info=doc_metadata) + ) + + if limit and end_index < len(hits): + self.next_page_token = str(end_index) + else: + self.next_page_token = None + + return documents + + except Exception as e: + logging.error(f"Error listing shared items via search API: {e}", exc_info=True) + return documents + + @_retry_on_auth_failure + def _download_file_content(self, file_id: str) -> Optional[str]: + self._ensure_valid_token() + + try: + url = f"{self._get_item_url(file_id)}/content" + response = requests.get(url, headers=self._get_headers()) + response.raise_for_status() + + try: + return response.content.decode('utf-8') + except UnicodeDecodeError: + logging.error(f"Could not decode file {file_id} as text") + return None + + except requests.exceptions.HTTPError: + raise + except Exception as e: + logging.error(f"Error downloading file {file_id}: {e}") + return None + + def _download_single_file(self, file_id: str, local_dir: str) -> bool: + try: + url = self._get_item_url(file_id) + params = {'$select': 'id,name,file'} + response = requests.get(url, headers=self._get_headers(), params=params) + response.raise_for_status() + + metadata = response.json() + file_name = metadata.get('name', 'unknown') + file_data = metadata.get('file', {}) + mime_type = file_data.get('mimeType', 'application/octet-stream') + + if mime_type not in self.SUPPORTED_MIME_TYPES: + logging.info(f"Skipping unsupported file type: {mime_type}") + return False + + os.makedirs(local_dir, exist_ok=True) + full_path = os.path.join(local_dir, file_name) + + download_url = f"{self._get_item_url(file_id)}/content" + download_response = requests.get(download_url, headers=self._get_headers()) + download_response.raise_for_status() + + with open(full_path, 'wb') as f: + f.write(download_response.content) + + return True + except Exception as e: + logging.error(f"Error in _download_single_file: {e}") + return False + + def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int: + files_downloaded = 0 + try: + os.makedirs(local_dir, exist_ok=True) + + url = f"{self._get_item_url(folder_id)}/children" + params = {'$top': 1000} + + while url: + response = requests.get(url, headers=self._get_headers(), params=params) + response.raise_for_status() + + results = response.json() + items = results.get('value', []) + logging.info(f"Found {len(items)} items in folder {folder_id}") + + for item in items: + item_name = item.get('name', 'unknown') + item_id = item.get('id') + + if 'folder' in item: + if recursive: + subfolder_path = os.path.join(local_dir, item_name) + os.makedirs(subfolder_path, exist_ok=True) + subfolder_files = self._download_folder_recursive( + item_id, + subfolder_path, + recursive + ) + files_downloaded += subfolder_files + logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}") + else: + success = self._download_single_file(item_id, local_dir) + if success: + files_downloaded += 1 + logging.info(f"Downloaded file: {item_name}") + else: + logging.warning(f"Failed to download file: {item_name}") + + url = results.get('@odata.nextLink') + + return files_downloaded + + except Exception as e: + logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True) + return files_downloaded + + def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int: + try: + self._ensure_valid_token() + return self._download_folder_recursive(folder_id, local_dir, recursive) + except Exception as e: + logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True) + return 0 + + def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool: + try: + self._ensure_valid_token() + return self._download_single_file(file_id, local_dir) + except Exception as e: + logging.error(f"Error downloading file {file_id}: {e}", exc_info=True) + return False + + def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]: + if source_config is None: + source_config = {} + + config = source_config if source_config else getattr(self, 'config', {}) + files_downloaded = 0 + + try: + folder_ids = config.get('folder_ids', []) + file_ids = config.get('file_ids', []) + recursive = config.get('recursive', True) + + if file_ids: + if isinstance(file_ids, str): + file_ids = [file_ids] + + for file_id in file_ids: + if self._download_file_to_directory(file_id, local_dir): + files_downloaded += 1 + + if folder_ids: + if isinstance(folder_ids, str): + folder_ids = [folder_ids] + + for folder_id in folder_ids: + try: + url = self._get_item_url(folder_id) + params = {'$select': 'id,name'} + response = requests.get(url, headers=self._get_headers(), params=params) + response.raise_for_status() + + folder_metadata = response.json() + folder_name = folder_metadata.get('name', '') + folder_path = os.path.join(local_dir, folder_name) + os.makedirs(folder_path, exist_ok=True) + + folder_files = self._download_folder_recursive( + folder_id, + folder_path, + recursive + ) + files_downloaded += folder_files + logging.info(f"Downloaded {folder_files} files from folder {folder_name}") + except Exception as e: + logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True) + + if not file_ids and not folder_ids: + raise ValueError("No folder_ids or file_ids provided for download") + + return { + "files_downloaded": files_downloaded, + "directory_path": local_dir, + "empty_result": files_downloaded == 0, + "source_type": "share_point", + "config_used": config + } + + except Exception as e: + return { + "files_downloaded": files_downloaded, + "directory_path": local_dir, + "empty_result": True, + "source_type": "share_point", + "config_used": config, + "error": str(e) + } diff --git a/application/requirements.txt b/application/requirements.txt index e99b8614..9565d40b 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -47,6 +47,7 @@ markupsafe==3.0.3 marshmallow>=3.18.0,<5.0.0 mpmath==1.3.0 multidict==6.7.0 +msal==1.34.0 mypy-extensions==1.1.0 networkx==3.6.1 numpy==2.4.0 @@ -95,4 +96,4 @@ werkzeug>=3.1.0 yarl==1.22.0 markdownify==1.2.2 tldextract==5.3.0 -websockets==15.0.1 +websockets==15.0.1 \ No newline at end of file diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 649ef13c..2766d0b5 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -8103,6 +8103,7 @@ "https://github.com/sponsors/katex" ], "license": "MIT", + "license": "MIT", "dependencies": { "commander": "^8.3.0" }, diff --git a/frontend/src/agents/AgentCard.tsx b/frontend/src/agents/AgentCard.tsx index f4c88c9b..6df11e16 100644 --- a/frontend/src/agents/AgentCard.tsx +++ b/frontend/src/agents/AgentCard.tsx @@ -320,4 +320,4 @@ export default function AgentCard({ /> ); -} +} \ No newline at end of file diff --git a/frontend/src/agents/AgentsList.tsx b/frontend/src/agents/AgentsList.tsx index f550a993..e9aa1bcd 100644 --- a/frontend/src/agents/AgentsList.tsx +++ b/frontend/src/agents/AgentsList.tsx @@ -603,4 +603,4 @@ function AgentSection({ ); -} +} \ No newline at end of file diff --git a/frontend/src/agents/NewAgent.tsx b/frontend/src/agents/NewAgent.tsx index 7ff752bd..d5d28647 100644 --- a/frontend/src/agents/NewAgent.tsx +++ b/frontend/src/agents/NewAgent.tsx @@ -1405,4 +1405,4 @@ function AddPromptModal({ handleAddPrompt={handleAddPrompt} /> ); -} +} \ No newline at end of file diff --git a/frontend/src/agents/agents.config.ts b/frontend/src/agents/agents.config.ts index 3569600e..35761e2e 100644 --- a/frontend/src/agents/agents.config.ts +++ b/frontend/src/agents/agents.config.ts @@ -41,4 +41,4 @@ export const agentSectionsConfig = [ selectData: selectSharedAgents, updateAction: setSharedAgents, }, -]; +]; \ No newline at end of file diff --git a/frontend/src/agents/index.tsx b/frontend/src/agents/index.tsx index 988345ad..64481985 100644 --- a/frontend/src/agents/index.tsx +++ b/frontend/src/agents/index.tsx @@ -18,4 +18,4 @@ export default function Agents() { } /> ); -} +} \ No newline at end of file diff --git a/frontend/src/assets/sharepoint.svg b/frontend/src/assets/sharepoint.svg new file mode 100644 index 00000000..9a332f8e --- /dev/null +++ b/frontend/src/assets/sharepoint.svg @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/frontend/src/components/ConnectedStateSkeleton.tsx b/frontend/src/components/ConnectedStateSkeleton.tsx new file mode 100644 index 00000000..f871fc9b --- /dev/null +++ b/frontend/src/components/ConnectedStateSkeleton.tsx @@ -0,0 +1,13 @@ +const ConnectedStateSkeleton = () => ( +
+
+
+
+
+
+
+
+
+); + +export default ConnectedStateSkeleton; diff --git a/frontend/src/components/ConnectorAuth.tsx b/frontend/src/components/ConnectorAuth.tsx index a60d293c..6411df90 100644 --- a/frontend/src/components/ConnectorAuth.tsx +++ b/frontend/src/components/ConnectorAuth.tsx @@ -150,7 +150,7 @@ const ConnectorAuth: React.FC = ({ {isConnected ? (
-
+
= ({ displayName: 'Drive', rootName: 'My Drive', }, + share_point: { + displayName: 'SharePoint', + rootName: 'My Files', + }, } as const; const getProviderConfig = (provider: string) => { @@ -88,9 +92,14 @@ export const FilePicker: React.FC = ({ const [authError, setAuthError] = useState(''); const [isConnected, setIsConnected] = useState(false); const [userEmail, setUserEmail] = useState(''); + const [allowsSharedContent, setAllowsSharedContent] = useState(false); + const [activeTab, setActiveTab] = useState<'my_files' | 'shared'>( + 'my_files', + ); const scrollContainerRef = useRef(null); const searchTimeoutRef = useRef | null>(null); + const abortControllerRef = useRef(null); const isFolder = (file: CloudFile) => { return ( @@ -106,7 +115,13 @@ export const FilePicker: React.FC = ({ folderId: string | null, pageToken?: string, searchQuery = '', + shared = false, ) => { + // Cancel any in-flight request so stale responses never overwrite new state + abortControllerRef.current?.abort(); + const controller = new AbortController(); + abortControllerRef.current = controller; + setIsLoading(true); const apiHost = import.meta.env.VITE_API_HOST; @@ -115,20 +130,23 @@ export const FilePicker: React.FC = ({ } try { + const body: Record = { + provider: provider, + session_token: sessionToken, + folder_id: folderId, + limit: 10, + page_token: pageToken, + search_query: searchQuery, + shared: shared, + }; const response = await fetch(`${apiHost}/api/connectors/files`, { method: 'POST', headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}`, }, - body: JSON.stringify({ - provider: provider, - session_token: sessionToken, - folder_id: folderId, - limit: 10, - page_token: pageToken, - search_query: searchQuery, - }), + body: JSON.stringify(body), + signal: controller.signal, }); const data = await response.json(); @@ -145,12 +163,15 @@ export const FilePicker: React.FC = ({ } } } catch (err) { + if ((err as Error).name === 'AbortError') return; console.error('Error loading files:', err); if (!pageToken) { setFiles([]); } } finally { - setIsLoading(false); + if (!controller.signal.aborted) { + setIsLoading(false); + } } }, [token, provider], @@ -192,11 +213,17 @@ export const FilePicker: React.FC = ({ setUserEmail(validateData.user_email || 'Connected User'); setIsConnected(true); setAuthError(''); + if (provider === 'share_point') { + setAllowsSharedContent( + validateData.allows_shared_content ?? false, + ); + } setFiles([]); setNextPageToken(null); setHasMoreFiles(false); setCurrentFolderId(null); + setActiveTab('my_files'); setFolderPath([ { id: null, @@ -238,6 +265,7 @@ export const FilePicker: React.FC = ({ currentFolderId, nextPageToken, searchQuery, + activeTab === 'shared' && !currentFolderId, ); } } @@ -249,6 +277,7 @@ export const FilePicker: React.FC = ({ searchQuery, provider, loadCloudFiles, + activeTab, ]); useEffect(() => { @@ -264,6 +293,7 @@ export const FilePicker: React.FC = ({ if (searchTimeoutRef.current) { clearTimeout(searchTimeoutRef.current); } + abortControllerRef.current?.abort(); }; }, []); @@ -277,7 +307,13 @@ export const FilePicker: React.FC = ({ searchTimeoutRef.current = setTimeout(() => { const sessionToken = getSessionToken(provider); if (sessionToken) { - loadCloudFiles(sessionToken, currentFolderId, undefined, query); + loadCloudFiles( + sessionToken, + currentFolderId, + undefined, + query, + activeTab === 'shared' && !currentFolderId, + ); } }, 300); }; @@ -295,7 +331,7 @@ export const FilePicker: React.FC = ({ const sessionToken = getSessionToken(provider); if (sessionToken) { - loadCloudFiles(sessionToken, folderId, undefined, ''); + loadCloudFiles(sessionToken, folderId, undefined, '', false); } }; @@ -311,10 +347,41 @@ export const FilePicker: React.FC = ({ const sessionToken = getSessionToken(provider); if (sessionToken) { - loadCloudFiles(sessionToken, newFolderId, undefined, ''); + loadCloudFiles( + sessionToken, + newFolderId, + undefined, + '', + activeTab === 'shared' && !newFolderId, + ); } }; + const handleTabChange = (tab: 'my_files' | 'shared') => { + if (tab === activeTab) return; + setActiveTab(tab); + setFiles([]); + setNextPageToken(null); + setHasMoreFiles(false); + setCurrentFolderId(null); + setSearchQuery(''); + setFolderPath([ + { + id: null, + name: + tab === 'shared' + ? 'Shared' + : getProviderConfig(provider).rootName, + }, + ]); + const sessionToken = getSessionToken(provider); + if (sessionToken) { + loadCloudFiles(sessionToken, null, undefined, '', tab === 'shared'); + } + }; + + + const handleFileSelect = (fileId: string, isFolder: boolean) => { if (isFolder) { const newSelectedFolders = selectedFolders.includes(fileId) @@ -346,7 +413,7 @@ export const FilePicker: React.FC = ({ if (data.session_token) { setSessionToken(provider, data.session_token); - loadCloudFiles(data.session_token, null); + validateAndLoadFiles(); } }} onError={(error) => { @@ -379,6 +446,8 @@ export const FilePicker: React.FC = ({ removeSessionToken(provider); setIsConnected(false); + setAllowsSharedContent(false); + setActiveTab('my_files'); setFiles([]); setSelectedFiles([]); onSelectionChange([]); @@ -390,9 +459,32 @@ export const FilePicker: React.FC = ({ /> {isConnected && ( -
+
- {/* Breadcrumb navigation */} + {provider === 'share_point' && allowsSharedContent && ( +
+ + +
+ )}
{folderPath.map((path, index) => ( @@ -439,7 +531,7 @@ export const FilePicker: React.FC = ({
-
+
= ({ - {files.map((file, index) => ( - { - if (isFolder(file)) { - handleFolderClick(file.id, file.name); - } else { - handleFileSelect(file.id, false); - } - }} - > - -
{ - e.stopPropagation(); - handleFileSelect(file.id, isFolder(file)); + {isLoading && files.length === 0 + ? Array.from({ length: 5 }).map((_, i) => ( + + +
+ + +
+ + +
+ + +
+ + + )) + : files.map((file, index) => ( + { + if (isFolder(file)) { + handleFolderClick(file.id, file.name); + } else { + handleFileSelect(file.id, false); + } }} > - {(isFolder(file) - ? selectedFolders - : selectedFiles - ).includes(file.id) && ( - Selected - )} -
-
- -
-
- {isFolder(file) -
- {file.name} -
-
- - {formatDate(file.modifiedTime)} - - - {file.size ? formatBytes(file.size) : '-'} - - - ))} + +
{ + e.stopPropagation(); + handleFileSelect(file.id, isFolder(file)); + }} + > + {(isFolder(file) + ? selectedFolders + : selectedFiles + ).includes(file.id) && ( + Selected + )} +
+
+ +
+
+ {isFolder(file) +
+ {file.name} +
+
+ + {formatDate(file.modifiedTime)} + + + {file.size ? formatBytes(file.size) : '-'} + + + ))} + {isLoading && files.length > 0 && + Array.from({ length: 3 }).map((_, i) => ( + + +
+ + +
+ + +
+ + +
+ + + ))} - - {isLoading && ( -
-
-
- Loading more files... -
-
- )} } diff --git a/frontend/src/components/FileSelectionSkeleton.tsx b/frontend/src/components/FileSelectionSkeleton.tsx new file mode 100644 index 00000000..1aa45769 --- /dev/null +++ b/frontend/src/components/FileSelectionSkeleton.tsx @@ -0,0 +1,13 @@ +const FilesSectionSkeleton = () => ( +
+
+
+
+
+
+
+
+
+); + +export default FilesSectionSkeleton; diff --git a/frontend/src/components/GoogleDrivePicker.tsx b/frontend/src/components/GoogleDrivePicker.tsx index be7f178e..bb3d240b 100644 --- a/frontend/src/components/GoogleDrivePicker.tsx +++ b/frontend/src/components/GoogleDrivePicker.tsx @@ -7,7 +7,10 @@ import { getSessionToken, setSessionToken, removeSessionToken, + validateProviderSession, } from '../utils/providerUtils'; +import ConnectedStateSkeleton from './ConnectedStateSkeleton'; +import FilesSectionSkeleton from './FileSelectionSkeleton'; interface PickerFile { id: string; @@ -50,20 +53,9 @@ const GoogleDrivePicker: React.FC = ({ const validateSession = async (sessionToken: string) => { try { - const apiHost = import.meta.env.VITE_API_HOST; - const validateResponse = await fetch( - `${apiHost}/api/connectors/validate-session`, - { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify({ - provider: 'google_drive', - session_token: sessionToken, - }), - }, + const validateResponse = await validateProviderSession( + token, + 'google_drive', ); if (!validateResponse.ok) { @@ -234,30 +226,6 @@ const GoogleDrivePicker: React.FC = ({ onSelectionChange([], []); }; - const ConnectedStateSkeleton = () => ( -
-
-
-
-
-
-
-
-
- ); - - const FilesSectionSkeleton = () => ( -
-
-
-
-
-
-
-
-
- ); - return (
{isValidating ? ( diff --git a/frontend/src/components/Notification.tsx b/frontend/src/components/Notification.tsx index d83f7223..0d188989 100644 --- a/frontend/src/components/Notification.tsx +++ b/frontend/src/components/Notification.tsx @@ -132,4 +132,4 @@ export default function Notification({ ); -} +} \ No newline at end of file diff --git a/frontend/src/locale/de.json b/frontend/src/locale/de.json index 4ff7a2a6..9f43f9ee 100644 --- a/frontend/src/locale/de.json +++ b/frontend/src/locale/de.json @@ -671,7 +671,10 @@ "itemsSelected": "{{count}} ausgewählt", "name": "Name", "lastModified": "Zuletzt geändert", - "size": "Größe" + "size": "Größe", + "myFiles": "Meine Dateien", + "sharedWithMe": "Mit mir geteilt", + "loadingMore": "Weitere Dateien laden..." }, "actionButtons": { "openNewChat": "Neuen Chat öffnen", diff --git a/frontend/src/locale/en.json b/frontend/src/locale/en.json index 9f40d2ab..982c5a29 100644 --- a/frontend/src/locale/en.json +++ b/frontend/src/locale/en.json @@ -321,6 +321,10 @@ "s3": { "label": "Amazon S3", "heading": "Add content from Amazon S3" + }, + "share_point": { + "label": "SharePoint", + "heading": "Upload from SharePoint" } }, "connectors": { @@ -350,6 +354,24 @@ "remove": "Remove", "folderAlt": "Folder", "fileAlt": "File" + }, + "sharePoint": { + "connect": "Connect to SharePoint", + "sessionExpired": "Session expired. Please reconnect to SharePoint.", + "sessionExpiredGeneric": "Session expired. Please reconnect your account.", + "validateFailed": "Failed to validate session. Please reconnect.", + "noSession": "No valid session found. Please reconnect to SharePoint.", + "noAccessToken": "No access token available. Please reconnect to SharePoint.", + "pickerFailed": "Failed to open file picker. Please try again.", + "selectedFiles": "Selected Files", + "selectFiles": "Select Files", + "loading": "Loading...", + "noFilesSelected": "No files or folders selected", + "folders": "Folders", + "files": "Files", + "remove": "Remove", + "folderAlt": "Folder", + "fileAlt": "File" } } }, @@ -672,7 +694,10 @@ "itemsSelected": "{{count}} selected", "name": "Name", "lastModified": "Last Modified", - "size": "Size" + "size": "Size", + "myFiles": "My Files", + "sharedWithMe": "Shared with Me", + "loadingMore": "Loading more files..." }, "actionButtons": { "openNewChat": "Open New Chat", diff --git a/frontend/src/locale/es.json b/frontend/src/locale/es.json index e8006404..f365dd33 100644 --- a/frontend/src/locale/es.json +++ b/frontend/src/locale/es.json @@ -321,6 +321,10 @@ "s3": { "label": "Amazon S3", "heading": "Agregar contenido desde Amazon S3" + }, + "share_point": { + "label": "SharePoint", + "heading": "Subir desde SharePoint" } }, "connectors": { @@ -350,6 +354,24 @@ "remove": "Eliminar", "folderAlt": "Carpeta", "fileAlt": "Archivo" + }, + "sharePoint": { + "connect": "Conectar a SharePoint", + "sessionExpired": "Sesión expirada. Por favor, reconecte a SharePoint.", + "sessionExpiredGeneric": "Sesión expirada. Por favor, reconecte su cuenta.", + "validateFailed": "Error al validar la sesión. Por favor, reconecte.", + "noSession": "No se encontró una sesión válida. Por favor, reconecte a SharePoint.", + "noAccessToken": "No hay token de acceso disponible. Por favor, reconecte a SharePoint.", + "pickerFailed": "Error al abrir el selector de archivos. Por favor, inténtelo de nuevo.", + "selectedFiles": "Archivos Seleccionados", + "selectFiles": "Seleccionar Archivos", + "loading": "Cargando...", + "noFilesSelected": "No hay archivos o carpetas seleccionados", + "folders": "Carpetas", + "files": "Archivos", + "remove": "Eliminar", + "folderAlt": "Carpeta", + "fileAlt": "Archivo" } } }, @@ -671,7 +693,10 @@ "itemsSelected": "{{count}} seleccionados", "name": "Nombre", "lastModified": "Última modificación", - "size": "Tamaño" + "size": "Tamaño", + "myFiles": "Mis archivos", + "sharedWithMe": "Compartido conmigo", + "loadingMore": "Cargando más archivos..." }, "actionButtons": { "openNewChat": "Abrir nuevo chat", diff --git a/frontend/src/locale/jp.json b/frontend/src/locale/jp.json index 75fac241..79eb3313 100644 --- a/frontend/src/locale/jp.json +++ b/frontend/src/locale/jp.json @@ -321,6 +321,10 @@ "s3": { "label": "Amazon S3", "heading": "Amazon S3からコンテンツを追加" + }, + "share_point": { + "label": "SharePoint", + "heading": "SharePointからアップロード" } }, "connectors": { @@ -350,6 +354,24 @@ "remove": "削除", "folderAlt": "フォルダ", "fileAlt": "ファイル" + }, + "sharePoint": { + "connect": "SharePointに接続", + "sessionExpired": "セッションが期限切れです。SharePointに再接続してください。", + "sessionExpiredGeneric": "セッションが期限切れです。アカウントに再接続してください。", + "validateFailed": "セッションの検証に失敗しました。再接続してください。", + "noSession": "有効なセッションが見つかりません。SharePointに再接続してください。", + "noAccessToken": "アクセストークンが利用できません。SharePointに再接続してください。", + "pickerFailed": "ファイルピッカーを開けませんでした。もう一度お試しください。", + "selectedFiles": "選択されたファイル", + "selectFiles": "ファイルを選択", + "loading": "読み込み中...", + "noFilesSelected": "ファイルまたはフォルダが選択されていません", + "folders": "フォルダ", + "files": "ファイル", + "remove": "削除", + "folderAlt": "フォルダ", + "fileAlt": "ファイル" } } }, @@ -671,7 +693,10 @@ "itemsSelected": "{{count}} 件選択済み", "name": "名前", "lastModified": "最終更新日", - "size": "サイズ" + "size": "サイズ", + "myFiles": "マイファイル", + "sharedWithMe": "共有アイテム", + "loadingMore": "さらに読み込み中..." }, "actionButtons": { "openNewChat": "新しいチャットを開く", diff --git a/frontend/src/locale/ru.json b/frontend/src/locale/ru.json index ada52aa7..1fdb7601 100644 --- a/frontend/src/locale/ru.json +++ b/frontend/src/locale/ru.json @@ -321,6 +321,10 @@ "s3": { "label": "Amazon S3", "heading": "Добавить контент из Amazon S3" + }, + "share_point": { + "label": "SharePoint", + "heading": "Загрузить из SharePoint" } }, "connectors": { @@ -350,6 +354,24 @@ "remove": "Удалить", "folderAlt": "Папка", "fileAlt": "Файл" + }, + "sharePoint": { + "connect": "Подключиться к SharePoint", + "sessionExpired": "Сеанс истек. Пожалуйста, переподключитесь к SharePoint.", + "sessionExpiredGeneric": "Сеанс истек. Пожалуйста, переподключите свою учетную запись.", + "validateFailed": "Не удалось проверить сеанс. Пожалуйста, переподключитесь.", + "noSession": "Действительный сеанс не найден. Пожалуйста, переподключитесь к SharePoint.", + "noAccessToken": "Токен доступа недоступен. Пожалуйста, переподключитесь к SharePoint.", + "pickerFailed": "Не удалось открыть средство выбора файлов. Пожалуйста, попробуйте еще раз.", + "selectedFiles": "Выбранные файлы", + "selectFiles": "Выбрать файлы", + "loading": "Загрузка...", + "noFilesSelected": "Файлы или папки не выбраны", + "folders": "Папки", + "files": "Файлы", + "remove": "Удалить", + "folderAlt": "Папка", + "fileAlt": "Файл" } } }, @@ -671,7 +693,10 @@ "itemsSelected": "{{count}} выбрано", "name": "Имя", "lastModified": "Последнее изменение", - "size": "Размер" + "size": "Размер", + "myFiles": "Мои файлы", + "sharedWithMe": "Доступные мне", + "loadingMore": "Загрузка файлов..." }, "actionButtons": { "openNewChat": "Открыть новый чат", diff --git a/frontend/src/locale/zh-TW.json b/frontend/src/locale/zh-TW.json index 247dd5f0..8a537e23 100644 --- a/frontend/src/locale/zh-TW.json +++ b/frontend/src/locale/zh-TW.json @@ -321,6 +321,10 @@ "s3": { "label": "Amazon S3", "heading": "從Amazon S3新增內容" + }, + "share_point": { + "label": "SharePoint", + "heading": "從SharePoint上傳" } }, "connectors": { @@ -350,6 +354,24 @@ "remove": "移除", "folderAlt": "資料夾", "fileAlt": "檔案" + }, + "sharePoint": { + "connect": "連接到 SharePoint", + "sessionExpired": "工作階段已過期。請重新連接到 SharePoint。", + "sessionExpiredGeneric": "工作階段已過期。請重新連接您的帳戶。", + "validateFailed": "驗證工作階段失敗。請重新連接。", + "noSession": "未找到有效工作階段。請重新連接到 SharePoint。", + "noAccessToken": "存取權杖不可用。請重新連接到 SharePoint。", + "pickerFailed": "無法開啟檔案選擇器。請重試。", + "selectedFiles": "已選擇的檔案", + "selectFiles": "選擇檔案", + "loading": "載入中...", + "noFilesSelected": "未選擇檔案或資料夾", + "folders": "資料夾", + "files": "檔案", + "remove": "移除", + "folderAlt": "資料夾", + "fileAlt": "檔案" } } }, @@ -671,7 +693,10 @@ "itemsSelected": "已選擇 {{count}} 項", "name": "名稱", "lastModified": "最後修改", - "size": "大小" + "size": "大小", + "myFiles": "我的檔案", + "sharedWithMe": "與我共用", + "loadingMore": "載入更多檔案..." }, "actionButtons": { "openNewChat": "開啟新聊天", diff --git a/frontend/src/locale/zh.json b/frontend/src/locale/zh.json index 305a898e..991ff944 100644 --- a/frontend/src/locale/zh.json +++ b/frontend/src/locale/zh.json @@ -321,6 +321,10 @@ "s3": { "label": "Amazon S3", "heading": "从Amazon S3添加内容" + }, + "share_point": { + "label": "SharePoint", + "heading": "从SharePoint上传" } }, "connectors": { @@ -350,6 +354,24 @@ "remove": "删除", "folderAlt": "文件夹", "fileAlt": "文件" + }, + "sharePoint": { + "connect": "连接到 SharePoint", + "sessionExpired": "会话已过期。请重新连接到 SharePoint。", + "sessionExpiredGeneric": "会话已过期。请重新连接您的账户。", + "validateFailed": "验证会话失败。请重新连接。", + "noSession": "未找到有效会话。请重新连接到 SharePoint。", + "noAccessToken": "访问令牌不可用。请重新连接到 SharePoint。", + "pickerFailed": "无法打开文件选择器。请重试。", + "selectedFiles": "已选择的文件", + "selectFiles": "选择文件", + "loading": "加载中...", + "noFilesSelected": "未选择文件或文件夹", + "folders": "文件夹", + "files": "文件", + "remove": "删除", + "folderAlt": "文件夹", + "fileAlt": "文件" } } }, @@ -671,7 +693,10 @@ "itemsSelected": "已选择 {{count}} 项", "name": "名称", "lastModified": "最后修改", - "size": "大小" + "size": "大小", + "myFiles": "我的文件", + "sharedWithMe": "与我共享", + "loadingMore": "加载更多文件..." }, "actionButtons": { "openNewChat": "打开新聊天", diff --git a/frontend/src/store.ts b/frontend/src/store.ts index 571a367c..0bfb51c6 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -78,4 +78,4 @@ export default store; // TODO : use https://redux-toolkit.js.org/tutorials/typescript#define-typed-hooks everywere instead of direct useDispatch -// TODO : streamline async state management +// TODO : streamline async state management \ No newline at end of file diff --git a/frontend/src/upload/Upload.tsx b/frontend/src/upload/Upload.tsx index 33cadd06..4430b1dc 100644 --- a/frontend/src/upload/Upload.tsx +++ b/frontend/src/upload/Upload.tsx @@ -252,6 +252,23 @@ function Upload({ token={token} /> ); + case 'share_point_picker': + return ( + { + setSelectedFiles(selectedFileIds); + setSelectedFolders(selectedFolderIds); + }} + provider="share_point" + token={token} + initialSelectedFiles={selectedFiles} + initialSelectedFolders={selectedFolders} + /> + ); default: return null; } @@ -534,6 +551,9 @@ function Upload({ const hasGoogleDrivePicker = schema.some( (field: FormField) => field.type === 'google_drive_picker', ); + const hasSharePointPicker = schema.some( + (field: FormField) => field.type === 'share_point_picker', + ); let configData: Record = { ...ingestor.config }; @@ -541,7 +561,7 @@ function Upload({ files.forEach((file) => { formData.append('file', file); }); - } else if (hasRemoteFilePicker || hasGoogleDrivePicker) { + } else if (hasRemoteFilePicker || hasGoogleDrivePicker || hasSharePointPicker) { const sessionToken = getSessionToken(ingestor.type as string); configData = { provider: ingestor.type as string, @@ -717,12 +737,15 @@ function Upload({ const hasGoogleDrivePicker = schema.some( (field: FormField) => field.type === 'google_drive_picker', ); + const hasSharePointPicker = schema.some( + (field: FormField) => field.type === 'share_point_picker', + ); if (hasLocalFilePicker) { if (files.length === 0) { return true; } - } else if (hasRemoteFilePicker || hasGoogleDrivePicker) { + } else if (hasRemoteFilePicker || hasGoogleDrivePicker || hasSharePointPicker) { if (selectedFiles.length === 0 && selectedFolders.length === 0) { return true; } diff --git a/frontend/src/upload/types/ingestor.ts b/frontend/src/upload/types/ingestor.ts index 4e853262..b4440d1a 100644 --- a/frontend/src/upload/types/ingestor.ts +++ b/frontend/src/upload/types/ingestor.ts @@ -5,6 +5,7 @@ import GithubIcon from '../../assets/github.svg'; import RedditIcon from '../../assets/reddit.svg'; import DriveIcon from '../../assets/drive.svg'; import S3Icon from '../../assets/s3.svg'; +import SharePoint from '../../assets/sharepoint.svg'; export type IngestorType = | 'crawler' @@ -13,7 +14,8 @@ export type IngestorType = | 'url' | 'google_drive' | 'local_file' - | 's3'; + | 's3' + | 'share_point'; export interface IngestorConfig { type: IngestorType | null; @@ -35,7 +37,8 @@ export type FieldType = | 'boolean' | 'local_file_picker' | 'remote_file_picker' - | 'google_drive_picker'; + | 'google_drive_picker' + | 'share_point_picker'; export interface FormField { name: string; @@ -193,6 +196,24 @@ export const IngestorFormSchemas: IngestorSchema[] = [ }, ], }, + { + key: 'share_point', + label: 'Share Point', + icon: SharePoint, + heading: 'Upload from Share Point', + validate: () => { + const sharePointClientId = import.meta.env.VITE_SHARE_POINT_CLIENT_ID; + return !!sharePointClientId; + }, + fields: [ + { + name: 'files', + label: 'Select Files from Share Point', + type: 'share_point_picker', + required: true, + }, + ], + }, ]; export const IngestorDefaultConfigs: Record< @@ -232,6 +253,14 @@ export const IngestorDefaultConfigs: Record< endpoint_url: '', }, }, + share_point: { + name: '', + config: { + file_ids: '', + folder_ids: '', + recursive: true, + }, + }, }; export interface IngestorOption { diff --git a/frontend/src/utils/providerUtils.ts b/frontend/src/utils/providerUtils.ts index 25236ad2..01f25c5c 100644 --- a/frontend/src/utils/providerUtils.ts +++ b/frontend/src/utils/providerUtils.ts @@ -14,3 +14,21 @@ export const setSessionToken = (provider: string, token: string): void => { export const removeSessionToken = (provider: string): void => { localStorage.removeItem(`${provider}_session_token`); }; + +export const validateProviderSession = async ( + token: string | null, + provider: string, +) => { + const apiHost = import.meta.env.VITE_API_HOST; + return await fetch(`${apiHost}/api/connectors/validate-session`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + provider: provider, + session_token: getSessionToken(provider), + }), + }); +}; diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/parser/remote/test_share_point_loader.py b/tests/parser/remote/test_share_point_loader.py new file mode 100644 index 00000000..0433d04f --- /dev/null +++ b/tests/parser/remote/test_share_point_loader.py @@ -0,0 +1,200 @@ +"""Tests for SharePoint loader.""" + +from unittest.mock import patch, MagicMock + +from application.parser.connectors.share_point.loader import SharePointLoader + + +def make_response(json_data=None, status_code=200, raise_error=None): + resp = MagicMock() + resp.status_code = status_code + resp.json.return_value = json_data + resp.content = b"test content" + if raise_error is not None: + resp.raise_for_status.side_effect = raise_error + else: + resp.raise_for_status.return_value = None + return resp + + +class TestSharePointLoaderProcessFile: + """Test _process_file method.""" + + def test_size_retrieved_from_root_level(self): + """Should retrieve size from root of file_metadata, not nested file object.""" + loader = SharePointLoader.__new__(SharePointLoader) + + file_metadata = { + "id": "test-id", + "name": "test.txt", + "createdDateTime": "2024-01-01T00:00:00Z", + "lastModifiedDateTime": "2024-01-01T00:00:00Z", + "size": 1024, + "file": { + "mimeType": "text/plain" + } + } + + doc = loader._process_file(file_metadata, load_content=False) + + assert doc is not None + assert doc.extra_info["size"] == 1024 + assert doc.extra_info["file_name"] == "test.txt" + assert doc.extra_info["mime_type"] == "text/plain" + + def test_size_null_when_missing(self): + """Should return None when size field is missing.""" + loader = SharePointLoader.__new__(SharePointLoader) + + file_metadata = { + "id": "test-id", + "name": "test.txt", + "createdDateTime": "2024-01-01T00:00:00Z", + "lastModifiedDateTime": "2024-01-01T00:00:00Z", + "file": { + "mimeType": "text/plain" + } + } + + doc = loader._process_file(file_metadata, load_content=False) + + assert doc is not None + assert doc.extra_info["size"] is None + + +class TestSharePointLoaderLoadFileById: + """Test _load_file_by_id method.""" + + @patch("application.parser.connectors.share_point.loader.requests.get") + @patch("application.parser.connectors.share_point.loader.SharePointAuth.get_token_info_from_session") + @patch("application.parser.connectors.share_point.loader.SharePointAuth.__init__", return_value=None) + @patch("application.parser.connectors.share_point.loader.SharePointLoader._ensure_valid_token") + def test_load_file_by_id_includes_size_in_select(self, mock_ensure_token, mock_auth_init, mock_get_token, mock_get): + """Should include size field in $select parameter.""" + mock_get_token.return_value = { + "access_token": "test-token", + "refresh_token": "test-refresh" + } + mock_get.return_value = make_response({ + "id": "test-id", + "name": "test.txt", + "createdDateTime": "2024-01-01T00:00:00Z", + "lastModifiedDateTime": "2024-01-01T00:00:00Z", + "size": 2048, + "file": { + "mimeType": "text/plain" + } + }) + + loader = SharePointLoader("test-session") + doc = loader._load_file_by_id("test-id", load_content=False) + + assert doc is not None + assert doc.extra_info["size"] == 2048 + + call_args = mock_get.call_args + params = call_args[1]["params"] + assert "size" in params["$select"] + + @patch("application.parser.connectors.share_point.loader.requests.get") + @patch("application.parser.connectors.share_point.loader.SharePointAuth.get_token_info_from_session") + @patch("application.parser.connectors.share_point.loader.SharePointAuth.__init__", return_value=None) + @patch("application.parser.connectors.share_point.loader.SharePointLoader._ensure_valid_token") + def test_load_file_by_id_returns_document_with_size(self, mock_ensure_token, mock_auth_init, mock_get_token, mock_get): + """Should return document with size from API response.""" + mock_get_token.return_value = { + "access_token": "test-token", + "refresh_token": "test-refresh" + } + mock_get.return_value = make_response({ + "id": "test-id", + "name": "document.pdf", + "createdDateTime": "2024-01-01T00:00:00Z", + "lastModifiedDateTime": "2024-06-15T10:30:00Z", + "size": 56789, + "file": { + "mimeType": "application/pdf" + } + }) + + loader = SharePointLoader("test-session") + doc = loader._load_file_by_id("test-id", load_content=False) + + assert doc is not None + assert doc.doc_id == "test-id" + assert doc.extra_info["file_name"] == "document.pdf" + assert doc.extra_info["mime_type"] == "application/pdf" + assert doc.extra_info["size"] == 56789 + assert doc.extra_info["created_time"] == "2024-01-01T00:00:00Z" + assert doc.extra_info["modified_time"] == "2024-06-15T10:30:00Z" + assert doc.extra_info["source"] == "share_point" + + +class TestSharePointLoaderListItems: + """Test _list_items_in_parent method.""" + + @patch("application.parser.connectors.share_point.loader.requests.get") + @patch("application.parser.connectors.share_point.loader.SharePointAuth.get_token_info_from_session") + @patch("application.parser.connectors.share_point.loader.SharePointAuth.__init__", return_value=None) + @patch("application.parser.connectors.share_point.loader.SharePointLoader._ensure_valid_token") + def test_list_items_includes_size_in_select(self, mock_ensure_token, mock_auth_init, mock_get_token, mock_get): + """Should include size field in $select parameter when listing items.""" + mock_get_token.return_value = { + "access_token": "test-token", + "refresh_token": "test-refresh" + } + mock_get.return_value = make_response({ + "value": [ + { + "id": "file-1", + "name": "file1.txt", + "createdDateTime": "2024-01-01T00:00:00Z", + "lastModifiedDateTime": "2024-01-01T00:00:00Z", + "size": 12345, + "file": { + "mimeType": "text/plain" + } + } + ] + }) + + loader = SharePointLoader("test-session") + docs = loader._list_items_in_parent("parent-id", limit=10, load_content=False) + + assert len(docs) == 1 + assert docs[0].extra_info["size"] == 12345 + + call_args = mock_get.call_args + params = call_args[1]["params"] + assert "size" in params["$select"] + + @patch("application.parser.connectors.share_point.loader.requests.get") + @patch("application.parser.connectors.share_point.loader.SharePointAuth.get_token_info_from_session") + @patch("application.parser.connectors.share_point.loader.SharePointAuth.__init__", return_value=None) + @patch("application.parser.connectors.share_point.loader.SharePointLoader._ensure_valid_token") + def test_list_items_folders_include_size(self, mock_ensure_token, mock_auth_init, mock_get_token, mock_get): + """Should include size for folders as well.""" + mock_get_token.return_value = { + "access_token": "test-token", + "refresh_token": "test-refresh" + } + mock_get.return_value = make_response({ + "value": [ + { + "id": "folder-1", + "name": "MyFolder", + "createdDateTime": "2024-01-01T00:00:00Z", + "lastModifiedDateTime": "2024-01-01T00:00:00Z", + "size": 0, + "folder": {} + } + ] + }) + + loader = SharePointLoader("test-session") + docs = loader._list_items_in_parent("parent-id", limit=10, load_content=False) + + assert len(docs) == 1 + assert docs[0].extra_info["is_folder"] is True + assert docs[0].extra_info["size"] == 0 +