Compare commits

..

8 Commits

Author SHA1 Message Date
Manish Madan
7c46d8a094 Merge pull request #2029 from abfeb8/main
feat: add Microsoft Entra ID integration
2025-10-15 13:38:48 +05:30
Abhishek Malviya
065939302b Merge pull request #1 from abfeb8/feature/auth-fe-impl
feat: add SharePoint integration with session validation and UI components
2025-10-10 15:51:44 +05:30
Abhishek Malviya
5fa87db9e7 Merge branch 'arc53:main' into main 2025-10-10 15:44:40 +05:30
Abhishek Malviya
cc54cea783 feat: add SharePoint integration with session validation and UI components 2025-10-10 15:15:38 +05:30
Abhishek Malviya
d9f0072112 refactor: remove MICROSOFT_REDIRECT_URI and update SharePointAuth to use CONNECTOR_REDIRECT_BASE_URI 2025-10-09 10:36:12 +05:30
Abhishek Malviya
2b73c0c9a0 feat: add init for Share Point connector module 2025-10-08 10:34:38 +05:30
Abhishek Malviya
da62133d21 Merge branch 'main' into main 2025-10-08 09:40:34 +05:30
Abhishek Malviya
8edb6dcf2a feat: add Microsoft Entra ID integration
- Updated .env-template and settings.py for Microsoft Entra ID configuration.
- Enhanced ConnectorsCallback to support SharePoint authentication.
- Introduced SharePointAuth and SharePointLoader classes.
- Added required dependencies in requirements.txt.
2025-10-07 15:23:32 +05:30
70 changed files with 1151 additions and 2699 deletions

View File

@@ -6,4 +6,17 @@ VITE_API_STREAMING=true
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
#Azure AD Application (client) ID
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
#Azure AD Application client secret
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
#Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
#If you are using a Microsoft Entra ID tenant,
#configure the AUTHORITY variable as
#"https://login.microsoftonline.com/TENANT_GUID"
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
MICROSOFT_AUTHORITY=https://{tenentId}.ciamlogin.com/{tenentId}

View File

@@ -3,7 +3,6 @@
Welcome, contributors! We're excited to announce that DocsGPT is participating in Hacktoberfest. Get involved by submitting meaningful pull requests.
All Meaningful contributors with accepted PRs that were created for issues with the `hacktoberfest` label (set by our maintainer team: dartpain, siiddhantt, pabik, ManishMadan2882) will receive a cool T-shirt! 🤩.
<img width="1331" height="678" alt="hacktoberfest-mocks-preview" src="https://github.com/user-attachments/assets/633f6377-38db-48f5-b519-a8b3855a9eb4" />
Fill in [this form](https://forms.gle/Npaba4n9Epfyx56S8
) after your PR was merged please

View File

@@ -1,321 +0,0 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
import uuid
from .base import Tool
from application.core.mongo_db import MongoDB
from application.core.settings import settings
class TodoListTool(Tool):
"""Todo List
Manages todo items for users. Supports creating, viewing, updating, and deleting todos.
"""
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
"""Initialize the tool.
Args:
tool_config: Optional tool configuration. Should include:
- tool_id: Unique identifier for this todo list tool instance (from user_tools._id)
This ensures each user's tool configuration has isolated todos
user_id: The authenticated user's id (should come from decoded_token["sub"]).
"""
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
# Fallback for backward compatibility or testing
self.tool_id = f"default_{user_id}"
else:
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["todos"]
# -----------------------------
# Action implementations
# -----------------------------
def execute_action(self, action_name: str, **kwargs: Any) -> str:
"""Execute an action by name.
Args:
action_name: One of list, create, get, update, complete, delete.
**kwargs: Parameters for the action.
Returns:
A human-readable string result.
"""
if not self.user_id:
return "Error: TodoListTool requires a valid user_id."
if action_name == "list":
return self._list()
if action_name == "create":
return self._create(kwargs.get("title", ""))
if action_name == "get":
return self._get(kwargs.get("todo_id"))
if action_name == "update":
return self._update(
kwargs.get("todo_id"),
kwargs.get("title", "")
)
if action_name == "complete":
return self._complete(kwargs.get("todo_id"))
if action_name == "delete":
return self._delete(kwargs.get("todo_id"))
return f"Unknown action: {action_name}"
def get_actions_metadata(self) -> List[Dict[str, Any]]:
"""Return JSON metadata describing supported actions for tool schemas."""
return [
{
"name": "list",
"description": "List all todos for the user.",
"parameters": {"type": "object", "properties": {}},
},
{
"name": "create",
"description": "Create a new todo item.",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "Title of the todo item."
}
},
"required": ["title"],
},
},
{
"name": "get",
"description": "Get a specific todo by ID.",
"parameters": {
"type": "object",
"properties": {
"todo_id": {
"type": "integer",
"description": "The ID of the todo to retrieve."
}
},
"required": ["todo_id"],
},
},
{
"name": "update",
"description": "Update a todo's title by ID.",
"parameters": {
"type": "object",
"properties": {
"todo_id": {
"type": "integer",
"description": "The ID of the todo to update."
},
"title": {
"type": "string",
"description": "The new title for the todo."
}
},
"required": ["todo_id", "title"],
},
},
{
"name": "complete",
"description": "Mark a todo as completed.",
"parameters": {
"type": "object",
"properties": {
"todo_id": {
"type": "integer",
"description": "The ID of the todo to mark as completed."
}
},
"required": ["todo_id"],
},
},
{
"name": "delete",
"description": "Delete a specific todo by ID.",
"parameters": {
"type": "object",
"properties": {
"todo_id": {
"type": "integer",
"description": "The ID of the todo to delete."
}
},
"required": ["todo_id"],
},
},
]
def get_config_requirements(self) -> Dict[str, Any]:
"""Return configuration requirements."""
return {}
# -----------------------------
# Internal helpers
# -----------------------------
def _coerce_todo_id(self, value: Optional[Any]) -> Optional[int]:
"""Convert todo identifiers to sequential integers."""
if value is None:
return None
if isinstance(value, int):
return value if value > 0 else None
if isinstance(value, str):
stripped = value.strip()
if stripped.isdigit():
numeric_value = int(stripped)
return numeric_value if numeric_value > 0 else None
return None
def _get_next_todo_id(self) -> int:
"""Get the next sequential todo_id for this user and tool.
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
With 5-10 todos max, scanning is negligible.
"""
# Find all todos for this user/tool and get their IDs
todos = list(self.collection.find(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"todo_id": 1}
))
# Find the maximum todo_id
max_id = 0
for todo in todos:
todo_id = self._coerce_todo_id(todo.get("todo_id"))
if todo_id is not None:
max_id = max(max_id, todo_id)
return max_id + 1
def _list(self) -> str:
"""List all todos for the user."""
cursor = self.collection.find({"user_id": self.user_id, "tool_id": self.tool_id})
todos = list(cursor)
if not todos:
return "No todos found."
result_lines = ["Todos:"]
for doc in todos:
todo_id = doc.get("todo_id")
title = doc.get("title", "Untitled")
status = doc.get("status", "open")
line = f"[{todo_id}] {title} ({status})"
result_lines.append(line)
return "\n".join(result_lines)
def _create(self, title: str) -> str:
"""Create a new todo item."""
title = (title or "").strip()
if not title:
return "Error: Title is required."
now = datetime.now()
todo_id = self._get_next_todo_id()
doc = {
"todo_id": todo_id,
"user_id": self.user_id,
"tool_id": self.tool_id,
"title": title,
"status": "open",
"created_at": now,
"updated_at": now,
}
self.collection.insert_one(doc)
return f"Todo created with ID {todo_id}: {title}"
def _get(self, todo_id: Optional[Any]) -> str:
"""Get a specific todo by ID."""
parsed_todo_id = self._coerce_todo_id(todo_id)
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
doc = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"todo_id": parsed_todo_id
})
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
title = doc.get("title", "Untitled")
status = doc.get("status", "open")
result = f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
return result
def _update(self, todo_id: Optional[Any], title: str) -> str:
"""Update a todo's title by ID."""
parsed_todo_id = self._coerce_todo_id(todo_id)
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
title = (title or "").strip()
if not title:
return "Error: Title is required."
result = self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
{"$set": {"title": title, "updated_at": datetime.now()}}
)
if result.matched_count == 0:
return f"Error: Todo with ID {parsed_todo_id} not found."
return f"Todo {parsed_todo_id} updated to: {title}"
def _complete(self, todo_id: Optional[Any]) -> str:
"""Mark a todo as completed."""
parsed_todo_id = self._coerce_todo_id(todo_id)
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
result = self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
{"$set": {"status": "completed", "updated_at": datetime.now()}}
)
if result.matched_count == 0:
return f"Error: Todo with ID {parsed_todo_id} not found."
return f"Todo {parsed_todo_id} marked as completed."
def _delete(self, todo_id: Optional[Any]) -> str:
"""Delete a specific todo by ID."""
parsed_todo_id = self._coerce_todo_id(todo_id)
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
result = self.collection.delete_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"todo_id": parsed_todo_id
})
if result.deleted_count == 0:
return f"Error: Todo with ID {parsed_todo_id} not found."
return f"Todo {parsed_todo_id} deleted."

View File

@@ -28,7 +28,7 @@ class ToolManager:
module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
if tool_name in {"mcp_tool", "notes", "memory", "todo_list"} and user_id:
if tool_name in {"mcp_tool", "notes", "memory"} and user_id:
return obj(tool_config, user_id)
else:
return obj(tool_config)
@@ -36,7 +36,7 @@ class ToolManager:
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
if tool_name in {"mcp_tool", "memory", "todo_list"} and user_id:
if tool_name in {"mcp_tool", "memory"} and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)

View File

@@ -113,10 +113,14 @@ class ConnectorsCallback(Resource):
session_token = str(uuid.uuid4())
try:
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
if provider == "google_drive":
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
else:
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
except Exception as e:
current_app.logger.warning(f"Could not get user info: {e}")
user_email = 'Connected User'

View File

@@ -10,7 +10,7 @@ from application.api import api
from application.api.user.base import agents_collection, storage
from application.api.user.tasks import store_attachment
from application.core.settings import settings
from application.tts.tts_creator import TTSCreator
from application.tts.google_tts import GoogleTTS
from application.utils import safe_filename
@@ -133,7 +133,7 @@ class TextToSpeech(Resource):
data = request.get_json()
text = data["text"]
try:
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
tts_instance = GoogleTTS()
audio_base64, detected_language = tts_instance.text_to_speech(text)
return make_response(
jsonify(

View File

@@ -55,6 +55,11 @@ class Settings(BaseSettings):
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
)
# Microsoft Entra ID (Azure AD) integration
MICROSOFT_CLIENT_ID: Optional[str] = None # Azure AD Application (client) ID
MICROSOFT_CLIENT_SECRET: Optional[str] = None # Azure AD Application client secret
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
# GitHub source
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
@@ -130,7 +135,6 @@ class Settings(BaseSettings):
# Encryption settings
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
ELEVENLABS_API_KEY: Optional[str] = None
path = Path(__file__).parent.parent.absolute()

View File

@@ -1,5 +1,7 @@
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
from application.parser.connectors.share_point.auth import SharePointAuth
from application.parser.connectors.share_point.loader import SharePointLoader
class ConnectorCreator:
@@ -12,10 +14,12 @@ class ConnectorCreator:
connectors = {
"google_drive": GoogleDriveLoader,
"share_point": SharePointLoader,
}
auth_providers = {
"google_drive": GoogleDriveAuth,
"share_point": SharePointAuth,
}
@classmethod

View File

@@ -0,0 +1,10 @@
"""
Share Point connector package for DocsGPT.
This module provides authentication and document loading capabilities for Share Point.
"""
from .auth import SharePointAuth
from .loader import SharePointLoader
__all__ = ['SharePointAuth', 'SharePointLoader']

View File

@@ -0,0 +1,91 @@
import logging
import datetime
from typing import Optional, Dict, Any
from msal import ConfidentialClientApplication
from application.core.settings import settings
from application.parser.connectors.base import BaseConnectorAuth
class SharePointAuth(BaseConnectorAuth):
"""
Handles Microsoft OAuth 2.0 authentication.
# Documentation:
- https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow
- https://learn.microsoft.com/en-gb/entra/msal/python/
"""
# Microsoft Graph scopes for SharePoint access
SCOPES = [
"User.Read",
]
def __init__(self):
self.client_id = settings.MICROSOFT_CLIENT_ID
self.client_secret = settings.MICROSOFT_CLIENT_SECRET
if not self.client_id or not self.client_secret:
raise ValueError(
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID and MICROSOFT_CLIENT_SECRET in settings."
)
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
self.tenant_id = settings.MICROSOFT_TENANT_ID
self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://{self.tenant_id}.ciamlogin.com/{self.tenant_id}")
self.auth_app = ConfidentialClientApplication(
client_id=self.client_id, client_credential=self.client_secret, authority=self.authority
)
def get_authorization_url(self, state: Optional[str] = None) -> str:
return self.auth_app.get_authorization_request_url(
scopes=self.SCOPES, state=state, redirect_uri=self.redirect_uri
)
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
result = self.auth_app.acquire_token_by_authorization_code(
code=authorization_code, scopes=self.SCOPES, redirect_uri=self.redirect_uri
)
if "error" in result:
logging.error(f"Error acquiring token: {result.get('error_description')}")
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
return self.map_token_response(result)
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
result = self.auth_app.acquire_token_by_refresh_token(refresh_token=refresh_token, scopes=self.SCOPES)
if "error" in result:
logging.error(f"Error acquiring token: {result.get('error_description')}")
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
return self.map_token_response(result)
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
if not token_info or "expiry" not in token_info:
# If no expiry info, consider token expired to be safe
return True
# Get expiry timestamp and current time
expiry_timestamp = token_info["expiry"]
current_timestamp = int(datetime.datetime.now().timestamp())
# Token is expired if current time is greater than or equal to expiry time
return current_timestamp >= expiry_timestamp
def map_token_response(self, result) -> Dict[str, Any]:
return {
"access_token": result.get("access_token"),
"refresh_token": result.get("refresh_token"),
"token_uri": result.get("id_token_claims", {}).get("iss"),
"scopes": result.get("scope"),
"expiry": result.get("id_token_claims", {}).get("exp"),
"user_info": {
"name": result.get("id_token_claims", {}).get("name"),
"email": result.get("id_token_claims", {}).get("preferred_username"),
},
"raw_token": result,
}

View File

@@ -0,0 +1,44 @@
from typing import List, Dict, Any
from application.parser.connectors.base import BaseConnectorLoader
from application.parser.schema.base import Document
class SharePointLoader(BaseConnectorLoader):
def __init__(self, session_token: str):
pass
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
"""
Load documents from the external knowledge base.
Args:
inputs: Configuration dictionary containing:
- file_ids: Optional list of specific file IDs to load
- folder_ids: Optional list of folder IDs to browse/download
- limit: Maximum number of items to return
- list_only: If True, return metadata without content
- recursive: Whether to recursively process folders
Returns:
List of Document objects
"""
pass
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
"""
Download files/folders to a local directory.
Args:
local_dir: Local directory path to download files to
source_config: Configuration for what to download
Returns:
Dictionary containing download results:
- files_downloaded: Number of files downloaded
- directory_path: Path where files were downloaded
- empty_result: Whether no files were downloaded
- source_type: Type of connector
- config_used: Configuration that was used
- error: Error message if download failed (optional)
"""
pass

View File

@@ -10,7 +10,6 @@ ebooklib==0.18
escodegen==1.0.11
esprima==4.0.1
esutils==1.0.1
elevenlabs==2.17.0
Flask==3.1.1
faiss-cpu==1.9.0.post1
fastmcp==2.11.0
@@ -30,7 +29,7 @@ jsonpatch==1.33
jsonpointer==3.0.0
kombu==5.4.2
langchain==0.3.20
langchain-community==0.4.1
langchain-community==0.3.19
langchain-core==0.3.59
langchain-openai==0.3.16
langchain-text-splitters==0.3.8
@@ -41,6 +40,7 @@ markupsafe==3.0.2
marshmallow==3.26.1
mpmath==1.3.0
multidict==6.4.3
msal==1.34.0
mypy-extensions==1.0.0
networkx==3.4.2
numpy==2.2.1
@@ -88,4 +88,4 @@ werkzeug>=3.1.0,<3.1.2
yarl==1.20.0
markdownify==1.1.0
tldextract==5.1.3
websockets==14.1
websockets==14.1

View File

@@ -15,11 +15,10 @@ class ElevenlabsTTS(BaseTTS):
def text_to_speech(self, text):
lang = "en"
audio = self.client.text_to_speech.convert(
voice_id="nPczCjzI2devNBz1zQrb",
model_id="eleven_multilingual_v2",
audio = self.client.generate(
text=text,
output_format="mp3_44100_128"
model="eleven_multilingual_v2",
voice="Brian",
)
audio_data = BytesIO()
for chunk in audio:

View File

@@ -1,18 +0,0 @@
from application.tts.google_tts import GoogleTTS
from application.tts.elevenlabs import ElevenlabsTTS
from application.tts.base import BaseTTS
class TTSCreator:
tts_providers = {
"google_tts": GoogleTTS,
"elevenlabs": ElevenlabsTTS,
}
@classmethod
def create_tts(cls, tts_type, *args, **kwargs)-> BaseTTS:
tts_class = cls.tts_providers.get(tts_type.lower())
if not tts_class:
raise ValueError(f"No tts class found for type {tts_type}")
return tts_class(*args, **kwargs)

View File

@@ -21,7 +21,7 @@ def get_encoding():
def get_gpt_model() -> str:
"""Get GPT model based on provider"""
"""Get the appropriate GPT model based on provider"""
model_map = {
"openai": "gpt-4o-mini",
"anthropic": "claude-2",
@@ -32,7 +32,16 @@ def get_gpt_model() -> str:
def safe_filename(filename):
"""Create safe filename, preserving extension. Handles non-Latin characters."""
"""
Creates a safe filename that preserves the original extension.
Uses secure_filename, but ensures a proper filename is returned even with non-Latin characters.
Args:
filename (str): The original filename
Returns:
str: A safe filename that can be used for storage
"""
if not filename:
return str(uuid.uuid4())
_, extension = os.path.splitext(filename)
@@ -74,14 +83,8 @@ def count_tokens_docs(docs):
return tokens
def get_missing_fields(data, required_fields):
"""Check for missing required fields. Returns list of missing field names."""
return [field for field in required_fields if field not in data]
def check_required_fields(data, required_fields):
"""Validate required fields. Returns Flask 400 response if validation fails, None otherwise."""
missing_fields = get_missing_fields(data, required_fields)
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
return make_response(
jsonify(
@@ -95,8 +98,7 @@ def check_required_fields(data, required_fields):
return None
def get_field_validation_errors(data, required_fields):
"""Check for missing and empty fields. Returns dict with 'missing_fields' and 'empty_fields', or None."""
def validate_required_fields(data, required_fields):
missing_fields = []
empty_fields = []
@@ -105,24 +107,12 @@ def get_field_validation_errors(data, required_fields):
missing_fields.append(field)
elif not data[field]:
empty_fields.append(field)
if missing_fields or empty_fields:
return {"missing_fields": missing_fields, "empty_fields": empty_fields}
return None
def validate_required_fields(data, required_fields):
"""Validate required fields (must exist and be non-empty). Returns Flask 400 response if validation fails, None otherwise."""
errors_dict = get_field_validation_errors(data, required_fields)
if errors_dict:
errors = []
if errors_dict["missing_fields"]:
errors.append(
f"Missing required fields: {', '.join(errors_dict['missing_fields'])}"
)
if errors_dict["empty_fields"]:
errors.append(
f"Empty values in required fields: {', '.join(errors_dict['empty_fields'])}"
)
errors = []
if missing_fields:
errors.append(f"Missing required fields: {', '.join(missing_fields)}")
if empty_fields:
errors.append(f"Empty values in required fields: {', '.join(empty_fields)}")
if errors:
return make_response(
jsonify({"success": False, "message": " | ".join(errors)}), 400
)
@@ -134,7 +124,10 @@ def get_hash(data):
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
"""Limit chat history to fit within token limit."""
"""
Limits chat history based on token count.
Returns a list of messages that fit within the token limit.
"""
from application.core.settings import settings
max_token_limit = (
@@ -168,7 +161,7 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
def validate_function_name(function_name):
"""Validate function name matches allowed pattern (alphanumeric, underscore, hyphen)."""
"""Validates if a function name matches the allowed pattern."""
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
return False
return True

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#e3e3e3"><path d="M240-80q-33 0-56.5-23.5T160-160v-640q0-33 23.5-56.5T240-880h480q33 0 56.5 23.5T800-800v640q0 33-23.5 56.5T720-80H240Zm0-80h480v-640H240v640Zm88-104 56-56-56-56-56 56 56 56Zm0-160 56-56-56-56-56 56 56 56Zm0-160 56-56-56-56-56 56 56 56Zm120 280h232v-80H448v80Zm0-160h232v-80H448v80Zm0-160h232v-80H448v80ZM240-160v-640 640Z"/></svg>

Before

Width:  |  Height:  |  Size: 446 B

View File

@@ -9,8 +9,7 @@ import userService from './api/services/userService';
import Add from './assets/add.svg';
import DocsGPT3 from './assets/cute_docsgpt3.svg';
import Discord from './assets/discord.svg';
import PanelLeftClose from './assets/panel-left-close.svg';
import PanelLeftOpen from './assets/panel-left-open.svg';
import Expand from './assets/expand.svg';
import Github from './assets/git_nav.svg';
import Hamburger from './assets/hamburger.svg';
import openNewChat from './assets/openNewChat.svg';
@@ -303,20 +302,18 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
{
<div className="absolute top-3 left-3 z-20 hidden transition-all duration-300 ease-in-out lg:block">
<div className="flex items-center gap-3">
{!navOpen && (
<button
onClick={() => {
setNavOpen(!navOpen);
}}
className="transition-transform duration-200 hover:scale-110"
>
<img
src={PanelLeftOpen}
alt="Open navigation menu"
className="m-auto transition-all duration-300 ease-in-out"
/>
</button>
)}
<button
onClick={() => {
setNavOpen(!navOpen);
}}
className="transition-transform duration-200 hover:scale-110"
>
<img
src={Expand}
alt="Toggle navigation menu"
className="m-auto transition-all duration-300 ease-in-out"
/>
</button>
{queries?.length > 0 && (
<button
onClick={() => {
@@ -366,8 +363,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
}}
>
<img
src={navOpen ? PanelLeftClose : PanelLeftOpen}
alt={navOpen ? 'Collapse sidebar' : 'Expand sidebar'}
src={Expand}
alt="Toggle navigation menu"
className="m-auto transition-all duration-300 ease-in-out hover:scale-110"
/>
</button>

View File

@@ -109,18 +109,18 @@ export default function AgentPreview() {
} else setLastQueryReturnedErr(false);
}, [queries]);
return (
<div className="relative h-full w-full">
<div className="scrollbar-thin absolute inset-0 bottom-[180px] overflow-hidden px-4 pt-4 [&>div>div]:!w-full [&>div>div]:!max-w-none">
<ConversationMessages
handleQuestion={handleQuestion}
handleQuestionSubmission={handleQuestionSubmission}
queries={queries}
status={status}
showHeroOnEmpty={false}
/>
</div>
<div className="absolute right-0 bottom-0 left-0 flex w-full flex-col gap-4 pb-2">
<div className="w-full px-4">
<div>
<div className="dark:bg-raisin-black flex h-full flex-col items-center justify-between gap-2 overflow-y-hidden">
<div className="h-[512px] w-full overflow-y-auto">
<ConversationMessages
handleQuestion={handleQuestion}
handleQuestionSubmission={handleQuestionSubmission}
queries={queries}
status={status}
showHeroOnEmpty={false}
/>
</div>
<div className="flex w-[95%] max-w-[1500px] flex-col items-center gap-4 pb-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
<MessageInput
onSubmit={(text) => handleQuestionSubmission(text)}
loading={status === 'loading'}
@@ -128,11 +128,11 @@ export default function AgentPreview() {
showToolButton={selectedAgent ? false : true}
autoFocus={false}
/>
<p className="text-gray-4000 dark:text-sonic-silver w-full self-center bg-transparent pt-2 text-center text-xs md:inline">
This is a preview of the agent. You can publish it to start using it
in conversations.
</p>
</div>
<p className="text-gray-4000 dark:text-sonic-silver w-full bg-transparent text-center text-xs md:inline">
This is a preview of the agent. You can publish it to start using it
in conversations.
</p>
</div>
</div>
);

View File

@@ -534,7 +534,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
setHasChanges(isChanged);
}, [agent, dispatch, effectiveMode, imageFile, jsonSchemaText]);
return (
<div className="flex flex-col px-4 pt-4 pb-2 max-[1179px]:min-h-[100dvh] min-[1180px]:h-[100dvh] md:px-12 md:pt-12 md:pb-3">
<div className="p-4 md:p-12">
<div className="flex items-center gap-3 px-4">
<button
className="rounded-full border p-3 text-sm text-gray-400 dark:border-0 dark:bg-[#28292D] dark:text-gray-500 dark:hover:bg-[#2E2F34]"
@@ -615,9 +615,9 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
</button>
</div>
</div>
<div className="mt-3 flex w-full flex-1 grid-cols-5 flex-col gap-10 rounded-[30px] bg-[#F6F6F6] p-5 max-[1179px]:overflow-visible min-[1180px]:grid min-[1180px]:gap-5 min-[1180px]:overflow-hidden dark:bg-[#383838]">
<div className="scrollbar-thin col-span-2 flex flex-col gap-5 max-[1179px]:overflow-visible min-[1180px]:max-h-full min-[1180px]:overflow-y-auto min-[1180px]:pr-3">
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
<div className="mt-5 flex w-full grid-cols-5 flex-col gap-10 min-[1180px]:grid min-[1180px]:gap-5">
<div className="col-span-2 flex flex-col gap-5">
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
<h2 className="text-lg font-semibold">Meta</h2>
<input
className="border-silver text-jet dark:bg-raisin-black dark:text-bright-gray dark:placeholder:text-silver mt-3 w-full rounded-3xl border bg-white px-5 py-3 text-sm outline-hidden placeholder:text-gray-400 dark:border-[#7E7E7E]"
@@ -650,7 +650,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
/>
</div>
</div>
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
<h2 className="text-lg font-semibold">Source</h2>
<div className="mt-3">
<div className="flex flex-wrap items-center gap-1">
@@ -744,7 +744,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
</div>
</div>
</div>
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
<div className="flex flex-wrap items-end gap-1">
<div className="min-w-20 grow basis-full sm:basis-0">
<Prompts
@@ -781,7 +781,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
</button>
</div>
</div>
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
<h2 className="text-lg font-semibold">Tools</h2>
<div className="mt-3 flex flex-wrap items-center gap-1">
<button
@@ -823,7 +823,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
/>
</div>
</div>
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
<h2 className="text-lg font-semibold">Agent type</h2>
<div className="mt-3">
<Dropdown
@@ -848,7 +848,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
/>
</div>
</div>
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
<button
onClick={() =>
setIsAdvancedSectionExpanded(!isAdvancedSectionExpanded)
@@ -1032,11 +1032,9 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
)}
</div>
</div>
<div className="col-span-3 flex flex-col gap-2 max-[1179px]:h-auto max-[1179px]:px-0 max-[1179px]:py-0 min-[1180px]:h-full min-[1180px]:py-2 dark:text-[#E0E0E0]">
<div className="col-span-3 flex flex-col gap-3 rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
<h2 className="text-lg font-semibold">Preview</h2>
<div className="flex-1 max-[1179px]:overflow-visible min-[1180px]:min-h-0 min-[1180px]:overflow-hidden">
<AgentPreviewArea />
</div>
<AgentPreviewArea />
</div>
</div>
<ConfirmationModal
@@ -1073,9 +1071,9 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
function AgentPreviewArea() {
const selectedAgent = useSelector(selectSelectedAgent);
return (
<div className="dark:bg-raisin-black w-full rounded-[30px] border border-[#F6F6F6] bg-white max-[1179px]:h-[600px] min-[1180px]:h-full dark:border-[#7E7E7E]">
<div className="dark:bg-raisin-black h-full w-full rounded-[30px] border border-[#F6F6F6] bg-white max-[1180px]:h-192 dark:border-[#7E7E7E]">
{selectedAgent?.status === 'published' ? (
<div className="flex h-full w-full flex-col overflow-hidden rounded-[30px]">
<div className="flex h-full w-full flex-col justify-end overflow-auto rounded-[30px]">
<AgentPreview />
</div>
) : (

View File

@@ -177,15 +177,13 @@ export default function SharedAgent() {
/>
</div>
<div className="flex w-[95%] max-w-[1500px] flex-col items-center pb-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
<div className="w-full px-2">
<MessageInput
onSubmit={(text) => handleQuestionSubmission(text)}
loading={status === 'loading'}
showSourceButton={sharedAgent ? false : true}
showToolButton={sharedAgent ? false : true}
autoFocus={false}
/>
</div>
<MessageInput
onSubmit={(text) => handleQuestionSubmission(text)}
loading={status === 'loading'}
showSourceButton={sharedAgent ? false : true}
showToolButton={sharedAgent ? false : true}
autoFocus={false}
/>
<p className="text-gray-4000 dark:text-sonic-silver hidden w-screen self-center bg-transparent py-2 text-center text-xs md:inline md:w-full">
{t('tagline')}
</p>

View File

@@ -0,0 +1,4 @@
<svg width="27" height="26" viewBox="0 0 27 26" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M4.03371 5.27275L4.1915 20.9162C4.20021 21.7802 4.90766 22.4735 5.77162 22.4648L21.4151 22.307C22.2791 22.2983 22.9724 21.5909 22.9637 20.7269L22.8059 5.0834C22.7972 4.21944 22.0897 3.52612 21.2258 3.53483L5.58228 3.69262C4.71831 3.70134 4.02499 4.40878 4.03371 5.27275Z" stroke="#949494" stroke-width="2.08591" stroke-linejoin="round"/>
<path d="M9.42289 22.428L9.23354 3.65585M17.6924 15.0436L15.5856 12.9788L17.6504 10.872M6.29419 22.4596L12.5516 22.3965M6.10484 3.68741L12.3622 3.62429" stroke="#949494" stroke-width="2.08591" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 692 B

View File

@@ -1,5 +1,5 @@
<svg width="113" height="124" viewBox="0 0 113 124" fill="none" xmlns="http://www.w3.org/2000/svg">
<circle cx="55.5" cy="71" r="53" fill="#E8E3F3" fill-opacity="0.6"/>
<circle cx="55.5" cy="71" r="53" fill="#F1F1F1" fill-opacity="0.5"/>
<rect x="-0.599797" y="0.654564" width="43.9445" height="61.5222" rx="4.39444" transform="matrix(-0.999048 0.0436194 0.0436194 0.999048 68.9873 43.3176)" fill="#EEEEEE" stroke="#999999" stroke-width="1.25556"/>
<rect x="0.704349" y="-0.540466" width="46.4556" height="64.0333" rx="5.65" transform="matrix(-0.991445 -0.130526 -0.130526 0.991445 96.3673 40.893)" fill="#FAFAFA" stroke="#999999" stroke-width="1.25556"/>
<path d="M94.3796 45.7849C94.7417 43.0349 92.8059 40.5122 90.0559 40.1501L55.2011 35.5614C52.4511 35.1994 49.9284 37.1352 49.5663 39.8851L48.3372 49.2212L93.1505 55.121L94.3796 45.7849Z" fill="#EEEEEE"/>

Before

Width:  |  Height:  |  Size: 2.0 KiB

After

Width:  |  Height:  |  Size: 2.0 KiB

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#949494" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-panel-left-close-icon lucide-panel-left-close"><rect width="18" height="18" x="3" y="3" rx="2"/><path d="M9 3v18"/><path d="m16 15-3-3 3-3"/></svg>

Before

Width:  |  Height:  |  Size: 345 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#949494" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-panel-left-open-icon lucide-panel-left-open"><rect width="18" height="18" x="3" y="3" rx="2"/><path d="M9 3v18"/><path d="m14 9 3 3-3 3"/></svg>

Before

Width:  |  Height:  |  Size: 342 B

View File

@@ -0,0 +1,16 @@
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<!-- Uploaded to: SVG Repo, www.svgrepo.com, Transformed by: SVG Repo Mixer Tools -->
<svg width="800px" height="800px" viewBox="0 0 48 48" id="b" xmlns="http://www.w3.org/2000/svg" fill="#000000" stroke="#000000" stroke-width="3.312">
<g id="SVGRepo_bgCarrier" stroke-width="0"/>
<g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"/>
<g id="SVGRepo_iconCarrier">
<defs>
<style>.c{fill:none;stroke:#000000;stroke-linecap:round;stroke-linejoin:round;}</style>
</defs>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@@ -136,34 +136,33 @@ const Chunks: React.FC<ChunksProps> = ({
const pathParts = path ? path.split('/') : [];
const fetchChunks = async () => {
const fetchChunks = () => {
setLoading(true);
try {
const response = await userService.getDocumentChunks(
documentId,
page,
perPage,
token,
path,
searchTerm,
);
if (!response.ok) {
throw new Error('Failed to fetch chunks data');
}
const data = await response.json();
setPage(data.page);
setPerPage(data.per_page);
setTotalChunks(data.total);
setPaginatedChunks(data.chunks);
} catch (error) {
setPaginatedChunks([]);
console.error(error);
} finally {
// ✅ always runs, success or failure
userService
.getDocumentChunks(documentId, page, perPage, token, path, searchTerm)
.then((response) => {
if (!response.ok) {
setLoading(false);
setPaginatedChunks([]);
throw new Error('Failed to fetch chunks data');
}
return response.json();
})
.then((data) => {
setPage(data.page);
setPerPage(data.per_page);
setTotalChunks(data.total);
setPaginatedChunks(data.chunks);
setLoading(false);
})
.catch((error) => {
setLoading(false);
setPaginatedChunks([]);
});
} catch (e) {
setLoading(false);
setPaginatedChunks([]);
}
};

View File

@@ -0,0 +1,13 @@
const ConnectedStateSkeleton = () => (
<div className="mb-4">
<div className="flex w-full animate-pulse items-center justify-between rounded-[10px] bg-gray-200 px-4 py-2 dark:bg-gray-700">
<div className="flex items-center gap-2">
<div className="h-4 w-4 rounded bg-gray-300 dark:bg-gray-600"></div>
<div className="h-4 w-32 rounded bg-gray-300 dark:bg-gray-600"></div>
</div>
<div className="h-4 w-16 rounded bg-gray-300 dark:bg-gray-600"></div>
</div>
</div>
);
export default ConnectedStateSkeleton;

View File

@@ -150,7 +150,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
{isConnected ? (
<div className="mb-4">
<div className="flex w-full items-center justify-between rounded-[10px] bg-[#8FDD51] px-4 py-2 text-sm font-medium text-[#212121]">
<div className="flex items-center gap-2">
<div className="flex max-w-[500px] items-center gap-2">
<svg className="h-4 w-4" viewBox="0 0 24 24">
<path
fill="currentColor"

View File

@@ -20,9 +20,14 @@ type CopyButtonProps = {
const DEFAULT_ICON_SIZE = 'w-4 h-4';
const DEFAULT_PADDING = 'p-2';
const DEFAULT_COPIED_DURATION = 2000;
const DEFAULT_BG_LIGHT = '#FFFFFF';
const DEFAULT_BG_DARK = 'transparent';
const DEFAULT_HOVER_BG_LIGHT = '#EEEEEE';
const DEFAULT_HOVER_BG_DARK = '#464152';
export default function CopyButton({
textToCopy,
iconSize = DEFAULT_ICON_SIZE,
padding = DEFAULT_PADDING,
showText = false,
@@ -38,8 +43,9 @@ export default function CopyButton({
const iconWrapperClasses = clsx(
'flex items-center justify-center rounded-full transition-colors duration-150 ease-in-out',
padding,
`bg-[${DEFAULT_BG_LIGHT}] dark:bg-[${DEFAULT_BG_DARK}]`,
{
[`bg-[#FFFFFF}] dark:bg-transparent hover:bg-[#EEEEEE] dark:hover:bg-purple-taupe`]:
[`hover:bg-[${DEFAULT_HOVER_BG_LIGHT}] dark:hover:bg-[${DEFAULT_HOVER_BG_DARK}]`]:
!isCopied,
'bg-green-100 dark:bg-green-900 hover:bg-green-100 dark:hover:bg-green-900':
isCopied,

View File

@@ -0,0 +1,13 @@
const FilesSectionSkeleton = () => (
<div className="rounded-lg border border-[#EEE6FF78] dark:border-[#6A6A6A]">
<div className="p-4">
<div className="mb-4 flex items-center justify-between">
<div className="h-5 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
<div className="h-8 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
</div>
<div className="h-4 w-40 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
</div>
</div>
);
export default FilesSectionSkeleton;

View File

@@ -7,7 +7,10 @@ import {
getSessionToken,
setSessionToken,
removeSessionToken,
validateProviderSession,
} from '../utils/providerUtils';
import ConnectedStateSkeleton from './ConnectedStateSkeleton';
import FilesSectionSkeleton from './FileSelectionSkeleton';
interface PickerFile {
id: string;
@@ -50,20 +53,9 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
const validateSession = async (sessionToken: string) => {
try {
const apiHost = import.meta.env.VITE_API_HOST;
const validateResponse = await fetch(
`${apiHost}/api/connectors/validate-session`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`,
},
body: JSON.stringify({
provider: 'google_drive',
session_token: sessionToken,
}),
},
const validateResponse = await validateProviderSession(
token,
'google_drive',
);
if (!validateResponse.ok) {
@@ -234,30 +226,6 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
onSelectionChange([], []);
};
const ConnectedStateSkeleton = () => (
<div className="mb-4">
<div className="flex w-full animate-pulse items-center justify-between rounded-[10px] bg-gray-200 px-4 py-2 dark:bg-gray-700">
<div className="flex items-center gap-2">
<div className="h-4 w-4 rounded bg-gray-300 dark:bg-gray-600"></div>
<div className="h-4 w-32 rounded bg-gray-300 dark:bg-gray-600"></div>
</div>
<div className="h-4 w-16 rounded bg-gray-300 dark:bg-gray-600"></div>
</div>
</div>
);
const FilesSectionSkeleton = () => (
<div className="rounded-lg border border-[#EEE6FF78] dark:border-[#6A6A6A]">
<div className="p-4">
<div className="mb-4 flex items-center justify-between">
<div className="h-5 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
<div className="h-8 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
</div>
<div className="h-4 w-40 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
</div>
</div>
);
return (
<div>
{isValidating ? (

View File

@@ -1,6 +1,4 @@
import { useCallback, useEffect, useRef, useState } from 'react';
import { createPortal } from 'react-dom';
import { useDropzone } from 'react-dropzone';
import { useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useDispatch, useSelector } from 'react-redux';
@@ -8,7 +6,6 @@ import endpoints from '../api/endpoints';
import userService from '../api/services/userService';
import AlertIcon from '../assets/alert.svg';
import ClipIcon from '../assets/clip.svg';
import DragFileUpload from '../assets/DragFileUpload.svg';
import ExitIcon from '../assets/exit.svg';
import SendArrowIcon from './SendArrowIcon';
import SourceIcon from '../assets/source.svg';
@@ -20,7 +17,6 @@ import {
selectAttachments,
updateAttachment,
} from '../upload/uploadSlice';
import { reorderAttachments } from '../upload/uploadSlice';
import { ActiveState } from '../models/misc';
import {
@@ -57,7 +53,6 @@ export default function MessageInput({
const [isToolsPopupOpen, setIsToolsPopupOpen] = useState(false);
const [uploadModalState, setUploadModalState] =
useState<ActiveState>('INACTIVE');
const [handleDragActive, setHandleDragActive] = useState<boolean>(false);
const selectedDocs = useSelector(selectSelectedDocs);
const token = useSelector(selectToken);
@@ -87,134 +82,77 @@ export default function MessageInput({
};
}, [browserOS]);
const uploadFiles = useCallback(
(files: File[]) => {
const apiHost = import.meta.env.VITE_API_HOST;
files.forEach((file) => {
const formData = new FormData();
formData.append('file', file);
const xhr = new XMLHttpRequest();
const uniqueId = crypto.randomUUID();
const newAttachment = {
id: uniqueId,
fileName: file.name,
progress: 0,
status: 'uploading' as const,
taskId: '',
};
dispatch(addAttachment(newAttachment));
xhr.upload.addEventListener('progress', (event) => {
if (event.lengthComputable) {
const progress = Math.round((event.loaded / event.total) * 100);
dispatch(
updateAttachment({
id: uniqueId,
updates: { progress },
}),
);
}
});
xhr.onload = () => {
if (xhr.status === 200) {
const response = JSON.parse(xhr.responseText);
if (response.task_id) {
dispatch(
updateAttachment({
id: uniqueId,
updates: {
taskId: response.task_id,
status: 'processing',
progress: 10,
},
}),
);
}
} else {
dispatch(
updateAttachment({
id: uniqueId,
updates: { status: 'failed' },
}),
);
}
};
xhr.onerror = () => {
dispatch(
updateAttachment({
id: uniqueId,
updates: { status: 'failed' },
}),
);
};
xhr.open('POST', `${apiHost}${endpoints.USER.STORE_ATTACHMENT}`);
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
xhr.send(formData);
});
},
[dispatch, token],
);
const handleFileAttachment = (e: React.ChangeEvent<HTMLInputElement>) => {
if (!e.target.files || e.target.files.length === 0) return;
const files = Array.from(e.target.files);
uploadFiles(files);
const file = e.target.files[0];
const formData = new FormData();
formData.append('file', file);
// clear input so same file can be selected again
const apiHost = import.meta.env.VITE_API_HOST;
const xhr = new XMLHttpRequest();
const newAttachment = {
fileName: file.name,
progress: 0,
status: 'uploading' as const,
taskId: '',
};
dispatch(addAttachment(newAttachment));
xhr.upload.addEventListener('progress', (event) => {
if (event.lengthComputable) {
const progress = Math.round((event.loaded / event.total) * 100);
dispatch(
updateAttachment({
taskId: newAttachment.taskId,
updates: { progress },
}),
);
}
});
xhr.onload = () => {
if (xhr.status === 200) {
const response = JSON.parse(xhr.responseText);
if (response.task_id) {
dispatch(
updateAttachment({
taskId: newAttachment.taskId,
updates: {
taskId: response.task_id,
status: 'processing',
progress: 10,
},
}),
);
}
} else {
dispatch(
updateAttachment({
taskId: newAttachment.taskId,
updates: { status: 'failed' },
}),
);
}
};
xhr.onerror = () => {
dispatch(
updateAttachment({
taskId: newAttachment.taskId,
updates: { status: 'failed' },
}),
);
};
xhr.open('POST', `${apiHost}${endpoints.USER.STORE_ATTACHMENT}`);
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
xhr.send(formData);
e.target.value = '';
};
// Drag and drop handler
const onDrop = useCallback(
(acceptedFiles: File[]) => {
uploadFiles(acceptedFiles);
setHandleDragActive(false);
},
[uploadFiles],
);
const { getRootProps, getInputProps } = useDropzone({
onDrop,
noClick: true,
noKeyboard: true,
multiple: true,
onDragEnter: () => {
setHandleDragActive(true);
},
onDragLeave: () => {
setHandleDragActive(false);
},
maxSize: 25000000,
accept: {
'application/pdf': ['.pdf'],
'text/plain': ['.txt'],
'text/x-rst': ['.rst'],
'text/x-markdown': ['.md'],
'application/zip': ['.zip'],
'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
['.docx'],
'application/json': ['.json'],
'text/csv': ['.csv'],
'text/html': ['.html'],
'application/epub+zip': ['.epub'],
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': [
'.xlsx',
],
'application/vnd.openxmlformats-officedocument.presentationml.presentation':
['.pptx'],
'image/png': ['.png'],
'image/jpeg': ['.jpeg'],
'image/jpg': ['.jpg'],
},
});
useEffect(() => {
const checkTaskStatus = () => {
const processingAttachments = attachments.filter(
@@ -229,7 +167,7 @@ export default function MessageInput({
if (data.status === 'SUCCESS') {
dispatch(
updateAttachment({
id: attachment.id,
taskId: attachment.taskId!,
updates: {
status: 'completed',
progress: 100,
@@ -241,14 +179,14 @@ export default function MessageInput({
} else if (data.status === 'FAILURE') {
dispatch(
updateAttachment({
id: attachment.id,
taskId: attachment.taskId!,
updates: { status: 'failed' },
}),
);
} else if (data.status === 'PROGRESS' && data.result?.current) {
dispatch(
updateAttachment({
id: attachment.id,
taskId: attachment.taskId!,
updates: { progress: data.result.current },
}),
);
@@ -257,7 +195,7 @@ export default function MessageInput({
.catch(() => {
dispatch(
updateAttachment({
id: attachment.id,
taskId: attachment.taskId!,
updates: { status: 'failed' },
}),
);
@@ -321,131 +259,90 @@ export default function MessageInput({
handleAbort();
};
// Drag state for reordering
const [draggingId, setDraggingId] = useState<string | null>(null);
// no preview object URLs to revoke (preview removed per reviewer request)
const findIndexById = (id: string) =>
attachments.findIndex((a) => a.id === id);
const handleDragStart = (e: React.DragEvent, id: string) => {
setDraggingId(id);
try {
e.dataTransfer.setData('text/plain', id);
e.dataTransfer.effectAllowed = 'move';
} catch (err) {
// ignore
}
};
const handleDragOver = (e: React.DragEvent) => {
e.preventDefault();
e.dataTransfer.dropEffect = 'move';
};
const handleDropOn = (e: React.DragEvent, targetId: string) => {
e.preventDefault();
const sourceId = e.dataTransfer.getData('text/plain');
if (!sourceId || sourceId === targetId) return;
const sourceIndex = findIndexById(sourceId);
const destIndex = findIndexById(targetId);
if (sourceIndex === -1 || destIndex === -1) return;
dispatch(reorderAttachments({ sourceIndex, destinationIndex: destIndex }));
setDraggingId(null);
};
return (
<div {...getRootProps()} className="flex w-full flex-col">
<input {...getInputProps()} />
<div className="mx-2 flex w-full flex-col">
<div className="border-dark-gray bg-lotion dark:border-grey relative flex w-full flex-col rounded-[23px] border dark:bg-transparent">
<div className="flex flex-wrap gap-1.5 px-2 py-2 sm:gap-2 sm:px-3">
{attachments.map((attachment) => {
return (
<div
key={attachment.id}
draggable={true}
onDragStart={(e) => handleDragStart(e, attachment.id)}
onDragOver={handleDragOver}
onDrop={(e) => handleDropOn(e, attachment.id)}
className={`group dark:text-bright-gray relative flex items-center rounded-xl bg-[#EFF3F4] px-2 py-1 text-[12px] text-[#5D5D5D] sm:px-3 sm:py-1.5 sm:text-[14px] dark:bg-[#393B3D] ${
attachment.status !== 'completed'
? 'opacity-70'
: 'opacity-100'
} ${draggingId === attachment.id ? 'ring-dashed opacity-60 ring-2 ring-purple-200' : ''}`}
title={attachment.fileName}
>
<div className="bg-purple-30 mr-2 flex h-8 w-8 items-center justify-center rounded-md p-1">
{attachment.status === 'completed' && (
<img
src={DocumentationDark}
alt="Attachment"
className="h-[15px] w-[15px] object-fill"
/>
)}
{attachment.status === 'failed' && (
<img
src={AlertIcon}
alt="Failed"
className="h-[15px] w-[15px] object-fill"
/>
)}
{(attachment.status === 'uploading' ||
attachment.status === 'processing') && (
<div className="flex h-[15px] w-[15px] items-center justify-center">
<svg className="h-[15px] w-[15px]" viewBox="0 0 24 24">
<circle
className="opacity-0"
cx="12"
cy="12"
r="10"
stroke="transparent"
strokeWidth="4"
fill="none"
/>
<circle
className="text-[#ECECF1]"
cx="12"
cy="12"
r="10"
stroke="currentColor"
strokeWidth="4"
fill="none"
strokeDasharray="62.83"
strokeDashoffset={
62.83 * (1 - attachment.progress / 100)
}
transform="rotate(-90 12 12)"
/>
</svg>
</div>
)}
</div>
<span className="max-w-[120px] truncate font-medium sm:max-w-[150px]">
{attachment.fileName}
</span>
<button
className="ml-1.5 flex items-center justify-center rounded-full p-1"
onClick={() => {
dispatch(removeAttachment(attachment.id));
}}
aria-label={t('conversation.attachments.remove')}
>
{attachments.map((attachment, index) => (
<div
key={index}
className={`group dark:text-bright-gray relative flex items-center rounded-xl bg-[#EFF3F4] px-2 py-1 text-[12px] text-[#5D5D5D] sm:px-3 sm:py-1.5 sm:text-[14px] dark:bg-[#393B3D] ${
attachment.status !== 'completed' ? 'opacity-70' : 'opacity-100'
}`}
title={attachment.fileName}
>
<div className="bg-purple-30 mr-2 items-center justify-center rounded-lg p-[5.5px]">
{attachment.status === 'completed' && (
<img
src={ExitIcon}
alt={t('conversation.attachments.remove')}
className="h-2.5 w-2.5 filter dark:invert"
src={DocumentationDark}
alt="Attachment"
className="h-[15px] w-[15px] object-fill"
/>
</button>
)}
{attachment.status === 'failed' && (
<img
src={AlertIcon}
alt="Failed"
className="h-[15px] w-[15px] object-fill"
/>
)}
{(attachment.status === 'uploading' ||
attachment.status === 'processing') && (
<div className="flex h-[15px] w-[15px] items-center justify-center">
<svg className="h-[15px] w-[15px]" viewBox="0 0 24 24">
<circle
className="opacity-0"
cx="12"
cy="12"
r="10"
stroke="transparent"
strokeWidth="4"
fill="none"
/>
<circle
className="text-[#ECECF1]"
cx="12"
cy="12"
r="10"
stroke="currentColor"
strokeWidth="4"
fill="none"
strokeDasharray="62.83"
strokeDashoffset={
62.83 * (1 - attachment.progress / 100)
}
transform="rotate(-90 12 12)"
/>
</svg>
</div>
)}
</div>
);
})}
<span className="max-w-[120px] truncate font-medium sm:max-w-[150px]">
{attachment.fileName}
</span>
<button
className="ml-1.5 flex items-center justify-center rounded-full p-1"
onClick={() => {
if (attachment.id) {
dispatch(removeAttachment(attachment.id));
} else if (attachment.taskId) {
dispatch(removeAttachment(attachment.taskId));
}
}}
aria-label={t('conversation.attachments.remove')}
>
<img
src={ExitIcon}
alt={t('conversation.attachments.remove')}
className="h-2.5 w-2.5 filter dark:invert"
/>
</button>
</div>
))}
</div>
<div className="w-full">
@@ -527,7 +424,6 @@ export default function MessageInput({
<input
type="file"
className="hidden"
multiple
onChange={handleFileAttachment}
/>
</label>
@@ -587,20 +483,6 @@ export default function MessageInput({
close={() => setUploadModalState('INACTIVE')}
/>
)}
{handleDragActive &&
createPortal(
<div className="dark:bg-gray-alpha/50 pointer-events-none fixed top-0 left-0 z-50 flex size-full flex-col items-center justify-center bg-white/85">
<img className="filter dark:invert" src={DragFileUpload} />
<span className="text-outer-space dark:text-silver px-2 text-2xl font-bold">
{t('modals.uploadDoc.drag.title')}
</span>
<span className="text-s text-outer-space dark:text-silver w-48 p-2 text-center">
{t('modals.uploadDoc.drag.description')}
</span>
</div>,
document.body,
)}
</div>
);
}

View File

@@ -0,0 +1,175 @@
import { useTranslation } from 'react-i18next';
import ConnectorAuth from './ConnectorAuth';
import { useEffect, useState } from 'react';
import {
getSessionToken,
setSessionToken,
removeSessionToken,
validateProviderSession,
} from '../utils/providerUtils';
import ConnectedStateSkeleton from './ConnectedStateSkeleton';
import FilesSectionSkeleton from './FileSelectionSkeleton';
interface SharePointPickerProps {
token: string | null;
}
const SharePointPicker: React.FC<SharePointPickerProps> = ({ token }) => {
const { t } = useTranslation();
const [isLoading, setIsLoading] = useState(false);
const [userEmail, setUserEmail] = useState<string>('');
const [isConnected, setIsConnected] = useState(false);
const [authError, setAuthError] = useState<string>('');
const [accessToken, setAccessToken] = useState<string | null>(null);
const [isValidating, setIsValidating] = useState(false);
useEffect(() => {
const sessionToken = getSessionToken('share_point');
if (sessionToken) {
setIsValidating(true);
setIsConnected(true); // Optimistically set as connected for skeleton
validateSession(sessionToken);
}
}, [token]);
const validateSession = async (sessionToken: string) => {
try {
const validateResponse = await validateProviderSession(
token,
'share_point',
);
if (!validateResponse.ok) {
setIsConnected(false);
setAuthError(
t('modals.uploadDoc.connectors.sharePoint.sessionExpired'),
);
setIsValidating(false);
return false;
}
const validateData = await validateResponse.json();
if (validateData.success) {
setUserEmail(
validateData.user_email ||
t('modals.uploadDoc.connectors.auth.connectedUser'),
);
setIsConnected(true);
setAuthError('');
setAccessToken(validateData.access_token || null);
setIsValidating(false);
return true;
} else {
setIsConnected(false);
setAuthError(
validateData.error ||
t('modals.uploadDoc.connectors.sharePoint.sessionExpiredGeneric'),
);
setIsValidating(false);
return false;
}
} catch (error) {
console.error('Error validating session:', error);
setAuthError(t('modals.uploadDoc.connectors.sharePoint.validateFailed'));
setIsConnected(false);
setIsValidating(false);
return false;
}
};
const handleDisconnect = async () => {
const sessionToken = getSessionToken('share_point');
if (sessionToken) {
try {
const apiHost = import.meta.env.VITE_API_HOST;
await fetch(`${apiHost}/api/connectors/disconnect`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`,
},
body: JSON.stringify({
provider: 'share_point',
session_token: sessionToken,
}),
});
} catch (err) {
console.error('Error disconnecting from SharePoint:', err);
}
}
removeSessionToken('share_point');
setIsConnected(false);
setAccessToken(null);
setUserEmail('');
setAuthError('');
};
const handleOpenPicker = async () => {
alert('Feature not supported yet.');
};
return (
<div>
{isValidating ? (
<>
<ConnectedStateSkeleton />
<FilesSectionSkeleton />
</>
) : (
<>
<ConnectorAuth
provider="share_point"
label={t('modals.uploadDoc.connectors.sharePoint.connect')}
onSuccess={(data) => {
setUserEmail(
data.user_email ||
t('modals.uploadDoc.connectors.auth.connectedUser'),
);
setIsConnected(true);
setAuthError('');
if (data.session_token) {
setSessionToken('share_point', data.session_token);
validateSession(data.session_token);
}
}}
onError={(error) => {
setAuthError(error);
setIsConnected(false);
}}
isConnected={isConnected}
userEmail={userEmail}
onDisconnect={handleDisconnect}
errorMessage={authError}
/>
{isConnected && (
<div className="rounded-lg border border-[#EEE6FF78] dark:border-[#6A6A6A]">
<div className="p-4">
<div className="mb-4 flex items-center justify-between">
<h3 className="text-sm font-medium">
{t('modals.uploadDoc.connectors.sharePoint.selectedFiles')}
</h3>
<button
onClick={() => handleOpenPicker()}
className="rounded-md bg-[#A076F6] px-3 py-1 text-sm text-white hover:bg-[#8A5FD4]"
disabled={isLoading}
>
{isLoading
? t('modals.uploadDoc.connectors.sharePoint.loading')
: t('modals.uploadDoc.connectors.sharePoint.selectFiles')}
</button>
</div>
</div>
</div>
)}
</>
)}
</div>
);
};
export default SharePointPicker;

View File

@@ -33,7 +33,7 @@ export default function Sidebar({
return (
<div ref={sidebarRef} className="h-vh relative">
<div
className={`dark:bg-chinese-black fixed top-0 right-0 z-50 h-full w-64 transform bg-white shadow-xl transition-all duration-300 sm:w-80 ${
className={`dark:bg-chinese-black fixed top-0 right-0 z-50 h-full w-72 transform bg-white shadow-xl transition-all duration-300 sm:w-96 ${
isOpen ? 'translate-x-[10px]' : 'translate-x-full'
} border-l border-[#9ca3af]/10`}
>

View File

@@ -1,5 +1,4 @@
import React, { useRef, useEffect, useState, useLayoutEffect } from 'react';
import { createPortal } from 'react-dom';
import { useTranslation } from 'react-i18next';
import { useDispatch, useSelector } from 'react-redux';
import { Doc } from '../models/misc';
@@ -108,7 +107,7 @@ export default function SourcesPopup({
onClose();
};
const popupContent = (
return (
<div
ref={popupRef}
className="bg-lotion dark:bg-charleston-green-2 fixed z-50 flex flex-col rounded-xl shadow-[0px_9px_46px_8px_#0000001F,0px_24px_38px_3px_#00000024,0px_11px_15px_-7px_#00000033]"
@@ -219,7 +218,7 @@ export default function SourcesPopup({
</>
) : (
<div className="dark:text-bright-gray p-4 text-center text-gray-500 dark:text-[14px]">
{t('conversation.sources.noSourcesAvailable')}
{t('noSourcesAvailable')}
</div>
)}
</div>
@@ -246,6 +245,4 @@ export default function SourcesPopup({
</div>
</div>
);
return createPortal(popupContent, document.body);
}

View File

@@ -1,5 +1,4 @@
import React, { useEffect, useRef, useState, useLayoutEffect } from 'react';
import { createPortal } from 'react-dom';
import { useTranslation } from 'react-i18next';
import { useSelector } from 'react-redux';
import { selectToken } from '../preferences/preferenceSlice';
@@ -134,10 +133,10 @@ export default function ToolsPopup({
tool.displayName.toLowerCase().includes(searchTerm.toLowerCase()),
);
const popupContent = (
return (
<div
ref={popupRef}
className="border-light-silver bg-lotion dark:border-dim-gray dark:bg-charleston-green-2 fixed z-50 rounded-lg border shadow-[0px_9px_46px_8px_#0000001F,0px_24px_38px_3px_#00000024,0px_11px_15px_-7px_#00000033]"
className="border-light-silver bg-lotion dark:border-dim-gray dark:bg-charleston-green-2 fixed z-9999 rounded-lg border shadow-[0px_9px_46px_8px_#0000001F,0px_24px_38px_3px_#00000024,0px_11px_15px_-7px_#00000033]"
style={{
top: popupPosition.showAbove ? popupPosition.top : undefined,
bottom: popupPosition.showAbove
@@ -243,6 +242,4 @@ export default function ToolsPopup({
</div>
</div>
);
return createPortal(popupContent, document.body);
}

View File

@@ -44,10 +44,7 @@ export default function UploadToast() {
};
return (
<div
className="fixed right-4 bottom-4 z-50 flex max-w-md flex-col gap-2"
onMouseDown={(e) => e.stopPropagation()}
>
<div className="fixed right-4 bottom-4 z-50 flex max-w-md flex-col gap-2">
{uploadTasks
.filter((task) => !task.dismissed)
.map((task) => {

View File

@@ -1,16 +1,20 @@
import { useCallback, useEffect, useRef, useState } from 'react';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { useDispatch, useSelector } from 'react-redux';
import SharedAgentCard from '../agents/SharedAgentCard';
import DragFileUpload from '../assets/DragFileUpload.svg';
import MessageInput from '../components/MessageInput';
import { useMediaQuery } from '../hooks';
import { ActiveState } from '../models/misc';
import {
selectConversationId,
selectSelectedAgent,
selectToken,
} from '../preferences/preferenceSlice';
import { AppDispatch } from '../store';
import Upload from '../upload/Upload';
import { handleSendFeedback } from './conversationHandlers';
import ConversationMessages from './ConversationMessages';
import { FEEDBACK, Query } from './conversationModels';
@@ -41,12 +45,53 @@ export default function Conversation() {
const selectedAgent = useSelector(selectSelectedAgent);
const completedAttachments = useSelector(selectCompletedAttachments);
const [uploadModalState, setUploadModalState] =
useState<ActiveState>('INACTIVE');
const [files, setFiles] = useState<File[]>([]);
const [lastQueryReturnedErr, setLastQueryReturnedErr] =
useState<boolean>(false);
const [isShareModalOpen, setShareModalState] = useState<boolean>(false);
const [handleDragActive, setHandleDragActive] = useState<boolean>(false);
const fetchStream = useRef<any>(null);
const onDrop = useCallback((acceptedFiles: File[]) => {
setUploadModalState('ACTIVE');
setFiles(acceptedFiles);
setHandleDragActive(false);
}, []);
const { getRootProps, getInputProps } = useDropzone({
onDrop,
noClick: true,
multiple: true,
onDragEnter: () => {
setHandleDragActive(true);
},
onDragLeave: () => {
setHandleDragActive(false);
},
maxSize: 25000000,
accept: {
'application/pdf': ['.pdf'],
'text/plain': ['.txt'],
'text/x-rst': ['.rst'],
'text/x-markdown': ['.md'],
'application/zip': ['.zip'],
'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
['.docx'],
'application/json': ['.json'],
'text/csv': ['.csv'],
'text/html': ['.html'],
'application/epub+zip': ['.epub'],
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': [
'.xlsx',
],
'application/vnd.openxmlformats-officedocument.presentationml.presentation':
['.pptx'],
},
});
const handleFetchAnswer = useCallback(
({ question, index }: { question: string; index?: number }) => {
fetchStream.current = dispatch(fetchAnswer({ question, indx: index }));
@@ -177,7 +222,14 @@ export default function Conversation() {
/>
<div className="bg-opacity-0 z-3 flex h-auto w-full max-w-[1300px] flex-col items-end self-center rounded-2xl py-1 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
<div className="flex w-full items-center rounded-[40px] px-2">
<div
{...getRootProps()}
className="flex w-full items-center rounded-[40px]"
>
<label htmlFor="file-upload" className="sr-only">
{t('modals.uploadDoc.label')}
</label>
<input {...getInputProps()} id="file-upload" />
<MessageInput
onSubmit={(text) => {
handleQuestionSubmission(text);
@@ -192,6 +244,26 @@ export default function Conversation() {
{t('tagline')}
</p>
</div>
{handleDragActive && (
<div className="bg-opacity-50 dark:bg-gray-alpha pointer-events-none fixed top-0 left-0 z-30 flex size-full flex-col items-center justify-center bg-white">
<img className="filter dark:invert" src={DragFileUpload} />
<span className="text-outer-space dark:text-silver px-2 text-2xl font-bold">
{t('modals.uploadDoc.drag.title')}
</span>
<span className="text-s text-outer-space dark:text-silver w-48 p-2 text-center">
{t('modals.uploadDoc.drag.description')}
</span>
</div>
)}
{uploadModalState === 'ACTIVE' && (
<Upload
receivedFile={files}
setModalState={setUploadModalState}
isOnboarding={false}
renderTab={'file'}
close={() => setUploadModalState('INACTIVE')}
></Upload>
)}
</div>
);
}

View File

@@ -3,9 +3,9 @@
}
.list li:not(:first-child) {
margin-top: 0.5em;
margin-top: 1em;
}
.list li > .list {
margin-top: 0.5em;
margin-top: 1em;
}

View File

@@ -86,7 +86,10 @@ const ConversationBubble = forwardRef<
// const bubbleRef = useRef<HTMLDivElement | null>(null);
const chunks = useSelector(selectChunks);
const selectedDocs = useSelector(selectSelectedDocs);
const [isLikeHovered, setIsLikeHovered] = useState(false);
const [isEditClicked, setIsEditClicked] = useState(false);
const [isDislikeHovered, setIsDislikeHovered] = useState(false);
const [isQuestionHovered, setIsQuestionHovered] = useState(false);
const [editInputBox, setEditInputBox] = useState<string>('');
const messageRef = useRef<HTMLDivElement>(null);
const [shouldShowToggle, setShouldShowToggle] = useState(false);
@@ -112,7 +115,11 @@ const ConversationBubble = forwardRef<
let bubble;
if (type === 'QUESTION') {
bubble = (
<div className={`group ${className}`}>
<div
onMouseEnter={() => setIsQuestionHovered(true)}
onMouseLeave={() => setIsQuestionHovered(false)}
className={className}
>
<div className="flex flex-col items-end">
{filesAttached && filesAttached.length > 0 && (
<div className="mr-12 mb-4 flex flex-wrap justify-end gap-2">
@@ -181,7 +188,7 @@ const ConversationBubble = forwardRef<
setIsEditClicked(true);
setEditInputBox(message ?? '');
}}
className={`hover:bg-light-silver mt-3 flex h-fit shrink-0 cursor-pointer items-center rounded-full p-2 pt-1.5 pl-1.5 dark:hover:bg-[#35363B] ${isEditClicked ? 'visible' : 'invisible group-hover:visible'}`}
className={`hover:bg-light-silver mt-3 flex h-fit shrink-0 cursor-pointer items-center rounded-full p-2 pt-1.5 pl-1.5 dark:hover:bg-[#35363B] ${isQuestionHovered || isEditClicked ? 'visible' : 'invisible'}`}
>
<img src={Edit} alt="Edit" className="cursor-pointer" />
</button>
@@ -414,7 +421,7 @@ const ConversationBubble = forwardRef<
<Fragment key={index}>
{segment.type === 'text' ? (
<ReactMarkdown
className="fade-in flex flex-col gap-3 leading-normal break-words whitespace-pre-wrap"
className="fade-in leading-normal break-words whitespace-pre-wrap"
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[rehypeKatex]}
components={{
@@ -561,7 +568,13 @@ const ConversationBubble = forwardRef<
<>
<div className="relative mr-2 flex items-center justify-center">
<div>
<div className="bg-white-3000 dark:hover:bg-purple-taupe flex items-center justify-center rounded-full p-2 hover:bg-[#EEEEEE] dark:bg-transparent">
<div
className={`flex items-center justify-center rounded-full p-2 ${
isLikeHovered
? 'dark:bg-purple-taupe bg-[#EEEEEE]'
: 'bg-white-3000 dark:bg-transparent'
}`}
>
<Like
className={`${feedback === 'LIKE' ? 'fill-white-3000 stroke-purple-30 dark:fill-transparent' : 'stroke-gray-4000 fill-none'} cursor-pointer`}
onClick={() => {
@@ -571,6 +584,8 @@ const ConversationBubble = forwardRef<
handleFeedback?.('LIKE');
}
}}
onMouseEnter={() => setIsLikeHovered(true)}
onMouseLeave={() => setIsLikeHovered(false)}
></Like>
</div>
</div>
@@ -578,7 +593,13 @@ const ConversationBubble = forwardRef<
<div className="relative mr-2 flex items-center justify-center">
<div>
<div className="bg-white-3000 dark:hover:bg-purple-taupe flex items-center justify-center rounded-full p-2 hover:bg-[#EEEEEE] dark:bg-transparent">
<div
className={`flex items-center justify-center rounded-full p-2 ${
isDislikeHovered
? 'dark:bg-purple-taupe bg-[#EEEEEE]'
: 'bg-white-3000 dark:bg-transparent'
}`}
>
<Dislike
className={`${feedback === 'DISLIKE' ? 'fill-white-3000 stroke-red-2000 dark:fill-transparent' : 'stroke-gray-4000 fill-none'} cursor-pointer`}
onClick={() => {
@@ -588,6 +609,8 @@ const ConversationBubble = forwardRef<
handleFeedback?.('DISLIKE');
}
}}
onMouseEnter={() => setIsDislikeHovered(true)}
onMouseLeave={() => setIsDislikeHovered(false)}
></Dislike>
</div>
</div>
@@ -635,7 +658,7 @@ function AllSources(sources: AllSourcesProps) {
<p className="text-left text-xl">{`${sources.sources.length} ${t('conversation.sources.title')}`}</p>
<div className="mx-1 mt-2 h-[0.8px] w-full rounded-full bg-[#C4C4C4]/40 lg:w-[95%]"></div>
</div>
<div className="scrollbar-thin mt-6 flex h-[90%] w-52 flex-col gap-4 overflow-y-auto pr-3 sm:w-64">
<div className="mt-6 flex h-[90%] w-60 flex-col items-center gap-4 overflow-y-auto sm:w-80">
{sources.sources.map((source, index) => {
const isExternalSource = source.link && source.link !== 'local';
return (

View File

@@ -161,16 +161,14 @@ export const SharedConversation = () => {
/>
<div className="flex w-full max-w-[1200px] flex-col items-center gap-4 pb-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
{apiKey ? (
<div className="w-full px-2">
<MessageInput
onSubmit={(text) => {
handleQuestionSubmission(text);
}}
loading={status === 'loading'}
showSourceButton={false}
showToolButton={false}
/>
</div>
<MessageInput
onSubmit={(text) => {
handleQuestionSubmission(text);
}}
loading={status === 'loading'}
showSourceButton={false}
showToolButton={false}
/>
) : (
<button
onClick={() => navigate('/')}

View File

@@ -56,7 +56,7 @@ export const fetchAnswer = createAsyncThunk<
question,
signal,
state.preference.token,
state.preference.selectedDocs || [],
state.preference.selectedDocs!,
currentConversationId,
state.preference.prompt.id,
state.preference.chunks,
@@ -163,7 +163,7 @@ export const fetchAnswer = createAsyncThunk<
question,
signal,
state.preference.token,
state.preference.selectedDocs || [],
state.preference.selectedDocs!,
state.conversation.conversationId,
state.preference.prompt.id,
state.preference.chunks,

View File

@@ -118,34 +118,18 @@ layer(base);
background: transparent;
}
/* Light theme scrollbar */
&::-webkit-scrollbar-thumb {
background: rgba(215, 215, 215, 1);
background: rgba(156, 163, 175, 0.5);
border-radius: 3px;
}
&::-webkit-scrollbar-thumb:hover {
background: rgba(195, 195, 195, 1);
background: rgba(156, 163, 175, 0.7);
}
/* Dark theme scrollbar */
.dark &::-webkit-scrollbar-thumb {
background: rgba(77, 78, 88, 1);
border-radius: 3px;
}
.dark &::-webkit-scrollbar-thumb:hover {
background: rgba(97, 98, 108, 1);
}
/* For Firefox - Light theme */
/* For Firefox */
scrollbar-width: thin;
scrollbar-color: rgba(215, 215, 215, 1) transparent;
/* For Firefox - Dark theme */
.dark & {
scrollbar-color: rgba(77, 78, 88, 1) transparent;
}
scrollbar-color: rgba(156, 163, 175, 0.5) transparent;
}
@utility table-default {

View File

@@ -255,8 +255,8 @@
"addQuery": "Add Query"
},
"drag": {
"title": "Drop attachments here",
"description": "Release to upload your attachments"
"title": "Upload a source file",
"description": "Drop your file here to add it as a source"
},
"progress": {
"upload": "Upload is in progress",
@@ -298,6 +298,10 @@
"google_drive": {
"label": "Google Drive",
"heading": "Upload from Google Drive"
},
"share_point": {
"label": "SharePoint",
"heading": "Upload from SharePoint"
}
},
"connectors": {
@@ -327,6 +331,24 @@
"remove": "Remove",
"folderAlt": "Folder",
"fileAlt": "File"
},
"sharePoint": {
"connect": "Connect to SharePoint",
"sessionExpired": "Session expired. Please reconnect to SharePoint.",
"sessionExpiredGeneric": "Session expired. Please reconnect your account.",
"validateFailed": "Failed to validate session. Please reconnect.",
"noSession": "No valid session found. Please reconnect to SharePoint.",
"noAccessToken": "No access token available. Please reconnect to SharePoint.",
"pickerFailed": "Failed to open file picker. Please try again.",
"selectedFiles": "Selected Files",
"selectFiles": "Select Files",
"loading": "Loading...",
"noFilesSelected": "No files or folders selected",
"folders": "Folders",
"files": "Files",
"remove": "Remove",
"folderAlt": "Folder",
"fileAlt": "File"
}
}
},
@@ -421,8 +443,7 @@
"title": "Sources",
"text": "Choose Your Sources",
"link": "Source link",
"view_more": "{{count}} more sources",
"noSourcesAvailable": "No sources available"
"view_more": "{{count}} more sources"
},
"attachments": {
"attach": "Attach",

View File

@@ -218,8 +218,8 @@
"addQuery": "Agregar Consulta"
},
"drag": {
"title": "Suelta los archivos adjuntos aquí",
"description": "Suelta para subir tus archivos adjuntos"
"title": "Subir archivo fuente",
"description": "Arrastra tu archivo aquí para agregarlo como fuente"
},
"progress": {
"upload": "Subida en progreso",
@@ -261,6 +261,10 @@
"google_drive": {
"label": "Google Drive",
"heading": "Subir desde Google Drive"
},
"share_point": {
"label": "SharePoint",
"heading": "Subir desde SharePoint"
}
},
"connectors": {
@@ -290,6 +294,24 @@
"remove": "Eliminar",
"folderAlt": "Carpeta",
"fileAlt": "Archivo"
},
"sharePoint": {
"connect": "Conectar a SharePoint",
"sessionExpired": "Sesión expirada. Por favor, reconecte a SharePoint.",
"sessionExpiredGeneric": "Sesión expirada. Por favor, reconecte su cuenta.",
"validateFailed": "Error al validar la sesión. Por favor, reconecte.",
"noSession": "No se encontró una sesión válida. Por favor, reconecte a SharePoint.",
"noAccessToken": "No hay token de acceso disponible. Por favor, reconecte a SharePoint.",
"pickerFailed": "Error al abrir el selector de archivos. Por favor, inténtelo de nuevo.",
"selectedFiles": "Archivos Seleccionados",
"selectFiles": "Seleccionar Archivos",
"loading": "Cargando...",
"noFilesSelected": "No hay archivos o carpetas seleccionados",
"folders": "Carpetas",
"files": "Archivos",
"remove": "Eliminar",
"folderAlt": "Carpeta",
"fileAlt": "Archivo"
}
}
},
@@ -384,8 +406,7 @@
"title": "Fuentes",
"link": "Enlace fuente",
"view_more": "Ver {{count}} más fuentes",
"text": "Elegir tus fuentes",
"noSourcesAvailable": "No hay fuentes disponibles"
"text": "Elegir tus fuentes"
},
"attachments": {
"attach": "Adjuntar",

View File

@@ -218,8 +218,8 @@
"addQuery": "クエリを追加"
},
"drag": {
"title": "添付ファイルをここにドロップ",
"description": "リリースして添付ファイルをアップロード"
"title": "ソースファイルをアップロード",
"description": "ファイルをここにドロップしてソースとして追加してください"
},
"progress": {
"upload": "アップロード中",
@@ -261,6 +261,10 @@
"google_drive": {
"label": "Google Drive",
"heading": "Google Driveからアップロード"
},
"share_point": {
"label": "SharePoint",
"heading": "SharePointからアップロード"
}
},
"connectors": {
@@ -290,6 +294,24 @@
"remove": "削除",
"folderAlt": "フォルダ",
"fileAlt": "ファイル"
},
"sharePoint": {
"connect": "SharePointに接続",
"sessionExpired": "セッションが期限切れです。SharePointに再接続してください。",
"sessionExpiredGeneric": "セッションが期限切れです。アカウントに再接続してください。",
"validateFailed": "セッションの検証に失敗しました。再接続してください。",
"noSession": "有効なセッションが見つかりません。SharePointに再接続してください。",
"noAccessToken": "アクセストークンが利用できません。SharePointに再接続してください。",
"pickerFailed": "ファイルピッカーを開けませんでした。もう一度お試しください。",
"selectedFiles": "選択されたファイル",
"selectFiles": "ファイルを選択",
"loading": "読み込み中...",
"noFilesSelected": "ファイルまたはフォルダが選択されていません",
"folders": "フォルダ",
"files": "ファイル",
"remove": "削除",
"folderAlt": "フォルダ",
"fileAlt": "ファイル"
}
}
},
@@ -384,8 +406,7 @@
"title": "ソース",
"text": "ソーステキスト",
"link": "ソースリンク",
"view_more": "さらに{{count}}個のソース",
"noSourcesAvailable": "利用可能なソースがありません"
"view_more": "さらに{{count}}個のソース"
},
"attachments": {
"attach": "添付",

View File

@@ -218,8 +218,8 @@
"addQuery": "Добавить запрос"
},
"drag": {
"title": "Перетащите вложения сюда",
"description": "Отпустите, чтобы загрузить ваши вложения"
"title": "Загрузить исходный файл",
"description": "Перетащите файл сюда, чтобы добавить его как источник"
},
"progress": {
"upload": "Идет загрузка",
@@ -261,6 +261,10 @@
"google_drive": {
"label": "Google Drive",
"heading": "Загрузить из Google Drive"
},
"share_point": {
"label": "SharePoint",
"heading": "Загрузить из SharePoint"
}
},
"connectors": {
@@ -290,6 +294,24 @@
"remove": "Удалить",
"folderAlt": "Папка",
"fileAlt": "Файл"
},
"sharePoint": {
"connect": "Подключиться к SharePoint",
"sessionExpired": "Сеанс истек. Пожалуйста, переподключитесь к SharePoint.",
"sessionExpiredGeneric": "Сеанс истек. Пожалуйста, переподключите свою учетную запись.",
"validateFailed": "Не удалось проверить сеанс. Пожалуйста, переподключитесь.",
"noSession": "Действительный сеанс не найден. Пожалуйста, переподключитесь к SharePoint.",
"noAccessToken": "Токен доступа недоступен. Пожалуйста, переподключитесь к SharePoint.",
"pickerFailed": "Не удалось открыть средство выбора файлов. Пожалуйста, попробуйте еще раз.",
"selectedFiles": "Выбранные файлы",
"selectFiles": "Выбрать файлы",
"loading": "Загрузка...",
"noFilesSelected": "Файлы или папки не выбраны",
"folders": "Папки",
"files": "Файлы",
"remove": "Удалить",
"folderAlt": "Папка",
"fileAlt": "Файл"
}
}
},
@@ -384,8 +406,7 @@
"title": "Источники",
"text": "Выберите ваши источники",
"link": "Ссылка на источник",
"view_more": "ещё {{count}} источников",
"noSourcesAvailable": "Нет доступных источников"
"view_more": "ещё {{count}} источников"
},
"attachments": {
"attach": "Прикрепить",

View File

@@ -218,8 +218,8 @@
"addQuery": "新增查詢"
},
"drag": {
"title": "將附件拖放到此處",
"description": "釋放以上傳您的附件"
"title": "上傳來源檔案",
"description": "將檔案拖放到此處以新增為來源"
},
"progress": {
"upload": "正在上傳",
@@ -261,6 +261,10 @@
"google_drive": {
"label": "Google Drive",
"heading": "從Google Drive上傳"
},
"share_point": {
"label": "SharePoint",
"heading": "從SharePoint上傳"
}
},
"connectors": {
@@ -290,6 +294,24 @@
"remove": "移除",
"folderAlt": "資料夾",
"fileAlt": "檔案"
},
"sharePoint": {
"connect": "連接到 SharePoint",
"sessionExpired": "工作階段已過期。請重新連接到 SharePoint。",
"sessionExpiredGeneric": "工作階段已過期。請重新連接您的帳戶。",
"validateFailed": "驗證工作階段失敗。請重新連接。",
"noSession": "未找到有效工作階段。請重新連接到 SharePoint。",
"noAccessToken": "存取權杖不可用。請重新連接到 SharePoint。",
"pickerFailed": "無法開啟檔案選擇器。請重試。",
"selectedFiles": "已選擇的檔案",
"selectFiles": "選擇檔案",
"loading": "載入中...",
"noFilesSelected": "未選擇檔案或資料夾",
"folders": "資料夾",
"files": "檔案",
"remove": "移除",
"folderAlt": "資料夾",
"fileAlt": "檔案"
}
}
},
@@ -384,8 +406,7 @@
"title": "來源",
"text": "來源文字",
"link": "來源連結",
"view_more": "查看更多 {{count}} 個來源",
"noSourcesAvailable": "沒有可用的來源"
"view_more": "查看更多 {{count}} 個來源"
},
"attachments": {
"attach": "附件",

View File

@@ -218,8 +218,8 @@
"addQuery": "添加查询"
},
"drag": {
"title": "将附件拖放到此处",
"description": "释放以上传您的附件"
"title": "上传源文件",
"description": "将文件拖放到此处以添加为源"
},
"progress": {
"upload": "正在上传",
@@ -261,6 +261,10 @@
"google_drive": {
"label": "Google Drive",
"heading": "从Google Drive上传"
},
"share_point": {
"label": "SharePoint",
"heading": "从SharePoint上传"
}
},
"connectors": {
@@ -290,6 +294,24 @@
"remove": "删除",
"folderAlt": "文件夹",
"fileAlt": "文件"
},
"sharePoint": {
"connect": "连接到 SharePoint",
"sessionExpired": "会话已过期。请重新连接到 SharePoint。",
"sessionExpiredGeneric": "会话已过期。请重新连接您的账户。",
"validateFailed": "验证会话失败。请重新连接。",
"noSession": "未找到有效会话。请重新连接到 SharePoint。",
"noAccessToken": "访问令牌不可用。请重新连接到 SharePoint。",
"pickerFailed": "无法打开文件选择器。请重试。",
"selectedFiles": "已选择的文件",
"selectFiles": "选择文件",
"loading": "加载中...",
"noFilesSelected": "未选择文件或文件夹",
"folders": "文件夹",
"files": "文件",
"remove": "删除",
"folderAlt": "文件夹",
"fileAlt": "文件"
}
}
},
@@ -384,8 +406,7 @@
"title": "来源",
"text": "来源文本",
"link": "来源链接",
"view_more": "还有{{count}}个来源",
"noSourcesAvailable": "没有可用的来源"
"view_more": "还有{{count}}个来源"
},
"attachments": {
"attach": "附件",

View File

@@ -103,8 +103,8 @@ export const ShareConversationModal = ({
};
return (
<WrapperModal close={close} contentClassName="!overflow-visible">
<div className="flex w-[600px] max-w-[80vw] flex-col gap-2">
<WrapperModal close={close}>
<div className="flex max-h-[80vh] w-[600px] max-w-[80vw] flex-col gap-2 overflow-y-auto">
<h2 className="text-eerie-black dark:text-chinese-white text-xl font-medium">
{t('modals.shareConv.label')}
</h2>

View File

@@ -32,6 +32,7 @@ import { FormField, IngestorConfig, IngestorType } from './types/ingestor';
import { FilePicker } from '../components/FilePicker';
import GoogleDrivePicker from '../components/GoogleDrivePicker';
import SharePointPicker from '../components/SharePointPicker';
import ChevronRight from '../assets/chevron-right.svg';
@@ -252,6 +253,8 @@ function Upload({
token={token}
/>
);
case 'share_point_picker':
return <SharePointPicker key={field.name} token={token} />;
default:
return null;
}

View File

@@ -4,6 +4,7 @@ import UrlIcon from '../../assets/url.svg';
import GithubIcon from '../../assets/github.svg';
import RedditIcon from '../../assets/reddit.svg';
import DriveIcon from '../../assets/drive.svg';
import SharePoint from '../../assets/sharepoint.svg';
export type IngestorType =
| 'crawler'
@@ -11,7 +12,8 @@ export type IngestorType =
| 'reddit'
| 'url'
| 'google_drive'
| 'local_file';
| 'local_file'
| 'share_point';
export interface IngestorConfig {
type: IngestorType | null;
@@ -33,7 +35,8 @@ export type FieldType =
| 'boolean'
| 'local_file_picker'
| 'remote_file_picker'
| 'google_drive_picker';
| 'google_drive_picker'
| 'share_point_picker';
export interface FormField {
name: string;
@@ -147,6 +150,24 @@ export const IngestorFormSchemas: IngestorSchema[] = [
},
],
},
{
key: 'share_point',
label: 'Share Point',
icon: SharePoint,
heading: 'Upload from Share Point',
validate: () => {
const sharePointClientId = import.meta.env.VITE_SHARE_POINT_CLIENT_ID;
return !!sharePointClientId;
},
fields: [
{
name: 'files',
label: 'Select Files from Share Point',
type: 'share_point_picker',
required: true,
},
],
},
];
export const IngestorDefaultConfigs: Record<
@@ -175,6 +196,14 @@ export const IngestorDefaultConfigs: Record<
},
},
local_file: { name: '', config: { files: [] } },
share_point: {
name: '',
config: {
file_ids: '',
folder_ids: '',
recursive: true,
},
},
};
export interface IngestorOption {

View File

@@ -2,11 +2,11 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from '../store';
export interface Attachment {
id: string; // Unique identifier for the attachment (required for state management)
fileName: string;
progress: number;
status: 'uploading' | 'processing' | 'completed' | 'failed';
taskId: string; // Server-assigned task ID (used for API calls)
taskId: string;
id?: string;
token_count?: number;
}
@@ -47,12 +47,12 @@ export const uploadSlice = createSlice({
updateAttachment: (
state,
action: PayloadAction<{
id: string;
taskId: string;
updates: Partial<Attachment>;
}>,
) => {
const index = state.attachments.findIndex(
(att) => att.id === action.payload.id,
(att) => att.taskId === action.payload.taskId,
);
if (index !== -1) {
state.attachments[index] = {
@@ -63,26 +63,9 @@ export const uploadSlice = createSlice({
},
removeAttachment: (state, action: PayloadAction<string>) => {
state.attachments = state.attachments.filter(
(att) => att.id !== action.payload,
(att) => att.taskId !== action.payload && att.id !== action.payload,
);
},
// Reorder attachments array by moving item from sourceIndex to destinationIndex
reorderAttachments: (
state,
action: PayloadAction<{ sourceIndex: number; destinationIndex: number }>,
) => {
const { sourceIndex, destinationIndex } = action.payload;
if (
sourceIndex < 0 ||
destinationIndex < 0 ||
sourceIndex >= state.attachments.length ||
destinationIndex >= state.attachments.length
)
return;
const [moved] = state.attachments.splice(sourceIndex, 1);
state.attachments.splice(destinationIndex, 0, moved);
},
clearAttachments: (state) => {
state.attachments = state.attachments.filter(
(att) => att.status === 'uploading' || att.status === 'processing',
@@ -138,7 +121,6 @@ export const {
addAttachment,
updateAttachment,
removeAttachment,
reorderAttachments,
clearAttachments,
addUploadTask,
updateUploadTask,

View File

@@ -14,3 +14,21 @@ export const setSessionToken = (provider: string, token: string): void => {
export const removeSessionToken = (provider: string): void => {
localStorage.removeItem(`${provider}_session_token`);
};
export const validateProviderSession = async (
token: string | null,
provider: string,
) => {
const apiHost = import.meta.env.VITE_API_HOST;
return await fetch(`${apiHost}/api/connectors/validate-session`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`,
},
body: JSON.stringify({
provider: provider,
session_token: getSessionToken(provider),
}),
});
};

View File

@@ -3,6 +3,7 @@ from unittest.mock import Mock
import pytest
from application.agents.classic_agent import ClassicAgent
from application.core.settings import settings
from tests.conftest import FakeMongoCollection
@pytest.mark.unit
@@ -167,13 +168,10 @@ class TestBaseAgentTools:
mock_llm_creator,
mock_llm_handler_creator,
):
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
user_tools.insert_one(
{"_id": "1", "user": "test_user", "name": "tool1", "status": True}
)
user_tools.insert_one(
{"_id": "2", "user": "test_user", "name": "tool2", "status": True}
)
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": True},
}
agent = ClassicAgent(**agent_base_params)
tools = agent._get_user_tools("test_user")
@@ -189,13 +187,10 @@ class TestBaseAgentTools:
mock_llm_creator,
mock_llm_handler_creator,
):
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
user_tools.insert_one(
{"_id": "1", "user": "test_user", "name": "tool1", "status": True}
)
user_tools.insert_one(
{"_id": "2", "user": "test_user", "name": "tool2", "status": False}
)
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": False},
}
agent = ClassicAgent(**agent_base_params)
tools = agent._get_user_tools("test_user")
@@ -214,16 +209,17 @@ class TestBaseAgentTools:
tool_id = str(ObjectId())
tool_obj_id = ObjectId(tool_id)
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agents_collection.insert_one(
{
"key": "api_key_123",
"tools": [tool_id],
}
)
fake_agent_collection = FakeMongoCollection()
fake_agent_collection.docs["api_key_123"] = {
"key": "api_key_123",
"tools": [tool_id],
}
tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
tools_collection.insert_one({"_id": tool_obj_id, "name": "api_tool"})
fake_tools_collection = FakeMongoCollection()
fake_tools_collection.docs[tool_id] = {"_id": tool_obj_id, "name": "api_tool"}
mock_mongo_db[settings.MONGO_DB_NAME]["agents"] = fake_agent_collection
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] = fake_tools_collection
agent = ClassicAgent(**agent_base_params)
tools = agent._get_tools("api_key_123")

View File

View File

@@ -1,552 +0,0 @@
import datetime
from unittest.mock import MagicMock, patch
import pytest
from bson import ObjectId
@pytest.mark.unit
class TestBaseAnswerValidation:
def test_validate_request_passes_with_required_fields(
self, mock_mongo_db, flask_app
):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
data = {"question": "What is Python?"}
result = resource.validate_request(data)
assert result is None
def test_validate_request_fails_without_question(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
data = {}
result = resource.validate_request(data)
assert result is not None
assert result.status_code == 400
assert "question" in result.json["message"].lower()
def test_validate_with_conversation_id_required(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
data = {"question": "Test"}
result = resource.validate_request(data, require_conversation_id=True)
assert result is not None
assert result.status_code == 400
assert "conversation_id" in result.json["message"].lower()
def test_validate_passes_with_all_required_fields(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
data = {"question": "Test", "conversation_id": str(ObjectId())}
result = resource.validate_request(data, require_conversation_id=True)
assert result is None
@pytest.mark.unit
class TestUsageChecking:
def test_returns_none_when_no_api_key(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
agent_config = {}
result = resource.check_usage(agent_config)
assert result is None
def test_returns_error_for_invalid_api_key(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
agent_config = {"user_api_key": "invalid_key_123"}
result = resource.check_usage(agent_config)
assert result is not None
assert result.status_code == 401
assert result.json["success"] is False
assert "invalid" in result.json["message"].lower()
def test_checks_token_limit_when_enabled(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_key",
"limited_token_mode": True,
"token_limit": 1000,
"limited_request_mode": False,
}
)
resource = BaseAnswerResource()
agent_config = {"user_api_key": "test_key"}
result = resource.check_usage(agent_config)
assert result is None
def test_checks_request_limit_when_enabled(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_key",
"limited_token_mode": False,
"limited_request_mode": True,
"request_limit": 100,
}
)
resource = BaseAnswerResource()
agent_config = {"user_api_key": "test_key"}
result = resource.check_usage(agent_config)
assert result is None
def test_uses_default_limits_when_not_specified(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_key",
"limited_token_mode": True,
"limited_request_mode": True,
}
)
resource = BaseAnswerResource()
agent_config = {"user_api_key": "test_key"}
result = resource.check_usage(agent_config)
assert result is None
def test_exceeds_token_limit(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
token_usage_collection = mock_mongo_db[settings.MONGO_DB_NAME][
"token_usage"
]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_key",
"limited_token_mode": True,
"token_limit": 100,
"limited_request_mode": False,
}
)
token_usage_collection.insert_one(
{
"_id": ObjectId(),
"api_key": "test_key",
"prompt_tokens": 60,
"generated_tokens": 50,
"timestamp": datetime.datetime.now(),
}
)
resource = BaseAnswerResource()
agent_config = {"user_api_key": "test_key"}
result = resource.check_usage(agent_config)
assert result is not None
assert result.status_code == 429
assert result.json["success"] is False
assert "usage limit" in result.json["message"].lower()
def test_exceeds_request_limit(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
token_usage_collection = mock_mongo_db[settings.MONGO_DB_NAME][
"token_usage"
]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_key",
"limited_token_mode": False,
"limited_request_mode": True,
"request_limit": 2,
}
)
now = datetime.datetime.now()
for i in range(3):
token_usage_collection.insert_one(
{
"_id": ObjectId(),
"api_key": "test_key",
"prompt_tokens": 10,
"generated_tokens": 10,
"timestamp": now,
}
)
resource = BaseAnswerResource()
agent_config = {"user_api_key": "test_key"}
result = resource.check_usage(agent_config)
assert result is not None
assert result.status_code == 429
assert result.json["success"] is False
def test_both_limits_disabled_returns_none(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_key",
"limited_token_mode": False,
"limited_request_mode": False,
}
)
resource = BaseAnswerResource()
agent_config = {"user_api_key": "test_key"}
result = resource.check_usage(agent_config)
assert result is None
@pytest.mark.unit
class TestGPTModelRetrieval:
def test_initializes_gpt_model(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
assert hasattr(resource, "gpt_model")
assert resource.gpt_model is not None
@pytest.mark.unit
class TestConversationServiceIntegration:
def test_initializes_conversation_service(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
assert hasattr(resource, "conversation_service")
assert resource.conversation_service is not None
def test_has_access_to_user_logs_collection(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
assert hasattr(resource, "user_logs_collection")
assert resource.user_logs_collection is not None
@pytest.mark.unit
class TestCompleteStreamMethod:
def test_streams_answer_chunks(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"answer": "Hello "},
{"answer": "world!"},
]
)
mock_retriever = MagicMock()
mock_retriever.get_params.return_value = {}
decoded_token = {"sub": "user123"}
stream = list(
resource.complete_stream(
question="Test question",
agent=mock_agent,
retriever=mock_retriever,
conversation_id=None,
user_api_key=None,
decoded_token=decoded_token,
should_save_conversation=False,
)
)
answer_chunks = [s for s in stream if '"type": "answer"' in s]
assert len(answer_chunks) == 2
assert '"answer": "Hello "' in answer_chunks[0]
assert '"answer": "world!"' in answer_chunks[1]
def test_streams_sources(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"answer": "Test answer"},
{"sources": [{"title": "doc1.txt", "text": "x" * 200}]},
]
)
mock_retriever = MagicMock()
mock_retriever.get_params.return_value = {}
decoded_token = {"sub": "user123"}
stream = list(
resource.complete_stream(
question="Test?",
agent=mock_agent,
retriever=mock_retriever,
conversation_id=None,
user_api_key=None,
decoded_token=decoded_token,
should_save_conversation=False,
)
)
source_chunks = [s for s in stream if '"type": "source"' in s]
assert len(source_chunks) == 1
assert '"title": "doc1.txt"' in source_chunks[0]
def test_handles_error_during_streaming(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.side_effect = Exception("Test error")
mock_retriever = MagicMock()
mock_retriever.get_params.return_value = {}
decoded_token = {"sub": "user123"}
stream = list(
resource.complete_stream(
question="Test?",
agent=mock_agent,
retriever=mock_retriever,
conversation_id=None,
user_api_key=None,
decoded_token=decoded_token,
should_save_conversation=False,
)
)
assert any('"type": "error"' in s for s in stream)
def test_saves_conversation_when_enabled(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"answer": "Test answer"},
]
)
mock_retriever = MagicMock()
mock_retriever.get_params.return_value = {}
decoded_token = {"sub": "user123"}
with patch.object(
resource.conversation_service, "save_conversation"
) as mock_save:
mock_save.return_value = str(ObjectId())
list(
resource.complete_stream(
question="Test?",
agent=mock_agent,
retriever=mock_retriever,
conversation_id=None,
user_api_key=None,
decoded_token=decoded_token,
should_save_conversation=True,
)
)
mock_save.assert_called_once()
def test_logs_to_user_logs_collection(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
from application.core.settings import settings
with flask_app.app_context():
resource = BaseAnswerResource()
user_logs = mock_mongo_db[settings.MONGO_DB_NAME]["user_logs"]
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[
{"answer": "Test answer"},
]
)
mock_retriever = MagicMock()
mock_retriever.get_params.return_value = {"retriever": "test"}
decoded_token = {"sub": "user123"}
list(
resource.complete_stream(
question="Test question?",
agent=mock_agent,
retriever=mock_retriever,
conversation_id=None,
user_api_key="test_key",
decoded_token=decoded_token,
should_save_conversation=False,
)
)
assert user_logs.count_documents({}) == 1
log_entry = user_logs.find_one({})
assert log_entry["action"] == "stream_answer"
assert log_entry["user"] == "user123"
assert log_entry["api_key"] == "test_key"
assert log_entry["question"] == "Test question?"
@pytest.mark.unit
class TestProcessResponseStream:
def test_processes_complete_stream(self, mock_mongo_db, flask_app):
import json
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
conv_id = str(ObjectId())
stream = [
f'data: {json.dumps({"type": "answer", "answer": "Hello "})}\n\n',
f'data: {json.dumps({"type": "answer", "answer": "world"})}\n\n',
f'data: {json.dumps({"type": "source", "source": [{"title": "doc1"}]})}\n\n',
f'data: {json.dumps({"type": "id", "id": conv_id})}\n\n',
f'data: {json.dumps({"type": "end"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[0] == conv_id
assert result[1] == "Hello world"
assert result[2] == [{"title": "doc1"}]
assert result[5] is None
def test_handles_stream_error(self, mock_mongo_db, flask_app):
import json
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
stream = [
f'data: {json.dumps({"type": "error", "error": "Test error"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert len(result) == 5
assert result[0] is None
assert result[4] == "Test error"
def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
stream = [
"data: invalid json\n\n",
'data: {"type": "end"}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result is not None
@pytest.mark.unit
class TestErrorStreamGenerate:
def test_generates_error_stream(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
error_stream = list(resource.error_stream_generate("Test error message"))
assert len(error_stream) == 1
assert '"type": "error"' in error_stream[0]
assert '"error": "Test error message"' in error_stream[0]

View File

@@ -1,242 +0,0 @@
from unittest.mock import Mock
import pytest
from bson import ObjectId
@pytest.mark.unit
class TestConversationServiceGet:
def test_returns_none_when_no_conversation_id(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
service = ConversationService()
result = service.get_conversation("", "user_123")
assert result is None
def test_returns_none_when_no_user_id(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
service = ConversationService()
result = service.get_conversation(str(ObjectId()), "")
assert result is None
def test_returns_conversation_for_owner(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
conversation = {
"_id": conv_id,
"user": "user_123",
"name": "Test Conv",
"queries": [],
}
collection.insert_one(conversation)
result = service.get_conversation(str(conv_id), "user_123")
assert result is not None
assert result["name"] == "Test Conv"
assert result["_id"] == str(conv_id)
def test_returns_none_for_unauthorized_user(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one(
{"_id": conv_id, "user": "owner_123", "name": "Private Conv"}
)
result = service.get_conversation(str(conv_id), "hacker_456")
assert result is None
def test_converts_objectid_to_string(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one({"_id": conv_id, "user": "user_123", "name": "Test"})
result = service.get_conversation(str(conv_id), "user_123")
assert isinstance(result["_id"], str)
assert result["_id"] == str(conv_id)
@pytest.mark.unit
class TestConversationServiceSave:
def test_raises_error_when_no_user_in_token(self, mock_mongo_db):
"""Test validation: user ID required"""
from application.api.answer.services.conversation_service import (
ConversationService,
)
service = ConversationService()
mock_llm = Mock()
with pytest.raises(ValueError, match="User ID not found"):
service.save_conversation(
conversation_id=None,
question="Test?",
response="Answer",
thought="",
sources=[],
tool_calls=[],
llm=mock_llm,
gpt_model="gpt-4",
decoded_token={}, # No 'sub' key
)
def test_truncates_long_source_text(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
from bson import ObjectId
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
mock_llm = Mock()
mock_llm.gen.return_value = "Test Summary"
long_text = "x" * 2000
sources = [{"text": long_text, "title": "Doc"}]
conv_id = service.save_conversation(
conversation_id=None,
question="Question",
response="Response",
thought="",
sources=sources,
tool_calls=[],
llm=mock_llm,
gpt_model="gpt-4",
decoded_token={"sub": "user_123"},
)
saved_conv = collection.find_one({"_id": ObjectId(conv_id)})
saved_source_text = saved_conv["queries"][0]["sources"][0]["text"]
assert len(saved_source_text) == 1000
assert saved_source_text == "x" * 1000
def test_creates_new_conversation_with_summary(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
from bson import ObjectId
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
mock_llm = Mock()
mock_llm.gen.return_value = "Python Basics"
conv_id = service.save_conversation(
conversation_id=None,
question="What is Python?",
response="Python is a programming language",
thought="",
sources=[],
tool_calls=[],
llm=mock_llm,
gpt_model="gpt-4",
decoded_token={"sub": "user_123"},
)
assert conv_id is not None
saved_conv = collection.find_one({"_id": ObjectId(conv_id)})
assert saved_conv["name"] == "Python Basics"
assert saved_conv["user"] == "user_123"
assert len(saved_conv["queries"]) == 1
assert saved_conv["queries"][0]["prompt"] == "What is Python?"
def test_appends_to_existing_conversation(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
from bson import ObjectId
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
existing_conv_id = ObjectId()
collection.insert_one(
{
"_id": existing_conv_id,
"user": "user_123",
"name": "Old Conv",
"queries": [{"prompt": "Q1", "response": "A1"}],
}
)
mock_llm = Mock()
result = service.save_conversation(
conversation_id=str(existing_conv_id),
question="Q2",
response="A2",
thought="",
sources=[],
tool_calls=[],
llm=mock_llm,
gpt_model="gpt-4",
decoded_token={"sub": "user_123"},
)
assert result == str(existing_conv_id)
def test_prevents_unauthorized_conversation_update(self, mock_mongo_db):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.core.settings import settings
service = ConversationService()
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
conv_id = ObjectId()
collection.insert_one({"_id": conv_id, "user": "owner_123", "queries": []})
mock_llm = Mock()
with pytest.raises(ValueError, match="not found or unauthorized"):
service.save_conversation(
conversation_id=str(conv_id),
question="Hack",
response="Attempt",
thought="",
sources=[],
tool_calls=[],
llm=mock_llm,
gpt_model="gpt-4",
decoded_token={"sub": "hacker_456"},
)

View File

@@ -1,252 +0,0 @@
import pytest
from bson import ObjectId
@pytest.mark.unit
class TestGetPromptFunction:
def test_loads_custom_prompt_from_database(self, mock_mongo_db):
from application.api.answer.services.stream_processor import get_prompt
from application.core.settings import settings
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
prompt_id = ObjectId()
prompts_collection.insert_one(
{
"_id": prompt_id,
"content": "Custom prompt from database",
"user": "user_123",
}
)
result = get_prompt(str(prompt_id), prompts_collection)
assert result == "Custom prompt from database"
def test_raises_error_for_invalid_prompt_id(self, mock_mongo_db):
from application.api.answer.services.stream_processor import get_prompt
from application.core.settings import settings
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
with pytest.raises(ValueError, match="Invalid prompt ID"):
get_prompt(str(ObjectId()), prompts_collection)
def test_raises_error_for_malformed_id(self, mock_mongo_db):
from application.api.answer.services.stream_processor import get_prompt
from application.core.settings import settings
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
with pytest.raises(ValueError, match="Invalid prompt ID"):
get_prompt("not_a_valid_id", prompts_collection)
@pytest.mark.unit
class TestStreamProcessorInitialization:
def test_initializes_with_decoded_token(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
request_data = {
"question": "What is Python?",
"conversation_id": str(ObjectId()),
}
decoded_token = {"sub": "user_123", "email": "test@example.com"}
processor = StreamProcessor(request_data, decoded_token)
assert processor.data == request_data
assert processor.decoded_token == decoded_token
assert processor.initial_user_id == "user_123"
assert processor.conversation_id == request_data["conversation_id"]
def test_initializes_without_token(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
request_data = {"question": "Test question"}
processor = StreamProcessor(request_data, None)
assert processor.decoded_token is None
assert processor.initial_user_id is None
assert processor.data == request_data
def test_initializes_default_attributes(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
processor = StreamProcessor({"question": "Test"}, {"sub": "user_123"})
assert processor.source == {}
assert processor.all_sources == []
assert processor.attachments == []
assert processor.history == []
assert processor.agent_config == {}
assert processor.retriever_config == {}
assert processor.is_shared_usage is False
assert processor.shared_token is None
def test_extracts_conversation_id_from_request(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
conv_id = str(ObjectId())
request_data = {"question": "Test", "conversation_id": conv_id}
processor = StreamProcessor(request_data, {"sub": "user_123"})
assert processor.conversation_id == conv_id
@pytest.mark.unit
class TestStreamProcessorHistoryLoading:
def test_loads_history_from_existing_conversation(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
conversations_collection = mock_mongo_db[settings.MONGO_DB_NAME][
"conversations"
]
conv_id = ObjectId()
conversations_collection.insert_one(
{
"_id": conv_id,
"user": "user_123",
"name": "Test Conv",
"queries": [
{"prompt": "What is Python?", "response": "Python is a language"},
{"prompt": "Tell me more", "response": "Python is versatile"},
],
}
)
request_data = {
"question": "How to install it?",
"conversation_id": str(conv_id),
}
processor = StreamProcessor(request_data, {"sub": "user_123"})
processor._load_conversation_history()
assert len(processor.history) == 2
assert processor.history[0]["prompt"] == "What is Python?"
assert processor.history[1]["response"] == "Python is versatile"
def test_raises_error_for_unauthorized_conversation(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
conversations_collection = mock_mongo_db[settings.MONGO_DB_NAME][
"conversations"
]
conv_id = ObjectId()
conversations_collection.insert_one(
{
"_id": conv_id,
"user": "owner_123",
"name": "Private Conv",
"queries": [],
}
)
request_data = {"question": "Hack attempt", "conversation_id": str(conv_id)}
processor = StreamProcessor(request_data, {"sub": "hacker_456"})
with pytest.raises(ValueError, match="Conversation not found or unauthorized"):
processor._load_conversation_history()
def test_uses_request_history_when_no_conversation_id(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
request_data = {
"question": "What is Python?",
"history": [{"prompt": "Hello", "response": "Hi there!"}],
}
processor = StreamProcessor(request_data, {"sub": "user_123"})
assert processor.conversation_id is None
@pytest.mark.unit
class TestStreamProcessorAgentConfiguration:
def test_configures_agent_from_valid_api_key(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "test_api_key_123",
"endpoint": "openai",
"model": "gpt-4",
"prompt_id": "default",
"user": "user_123",
}
)
request_data = {"question": "Test", "api_key": "test_api_key_123"}
processor = StreamProcessor(request_data, None)
try:
processor._configure_agent()
assert processor.agent_config is not None
except Exception as e:
assert "Invalid API Key" in str(e)
def test_uses_default_config_without_api_key(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
request_data = {"question": "Test"}
processor = StreamProcessor(request_data, {"sub": "user_123"})
processor._configure_agent()
assert isinstance(processor.agent_config, dict)
@pytest.mark.unit
class TestStreamProcessorAttachments:
def test_processes_attachments_from_request(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
attachments_collection = mock_mongo_db[settings.MONGO_DB_NAME]["attachments"]
att_id = ObjectId()
attachments_collection.insert_one(
{
"_id": att_id,
"filename": "document.pdf",
"content": "Document content",
"user": "user_123",
}
)
request_data = {"question": "Analyze this", "attachments": [str(att_id)]}
processor = StreamProcessor(request_data, {"sub": "user_123"})
assert processor.data.get("attachments") == [str(att_id)]
def test_handles_empty_attachments(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
request_data = {"question": "Simple question"}
processor = StreamProcessor(request_data, {"sub": "user_123"})
assert processor.attachments == []
assert (
"attachments" not in processor.data
or processor.data.get("attachments") is None
)

View File

@@ -1,89 +0,0 @@
"""API-specific test fixtures."""
import pytest
from bson import ObjectId
@pytest.fixture
def auth_headers():
return {"Authorization": "Bearer test_token"}
@pytest.fixture
def mock_request_token(monkeypatch, decoded_token):
def mock_decorator(f):
def wrapper(*args, **kwargs):
from flask import request
request.decoded_token = decoded_token
return f(*args, **kwargs)
return wrapper
monkeypatch.setattr("application.auth.api_key_required", lambda: mock_decorator)
return decoded_token
@pytest.fixture
def sample_conversation():
return {
"_id": ObjectId(),
"user": "test_user",
"name": "Test Conversation",
"queries": [
{
"prompt": "What is Python?",
"response": "Python is a programming language",
}
],
"date": "2025-01-01T00:00:00",
}
@pytest.fixture
def sample_prompt():
return {
"_id": ObjectId(),
"user": "test_user",
"name": "Helpful Assistant",
"content": "You are a helpful assistant that provides clear and concise answers.",
"type": "custom",
}
@pytest.fixture
def sample_agent():
return {
"_id": ObjectId(),
"user": "test_user",
"name": "Test Agent",
"type": "classic",
"endpoint": "openai",
"model": "gpt-4",
"prompt_id": "default",
"status": "active",
}
@pytest.fixture
def sample_answer_request():
return {
"question": "What is Python?",
"history": [],
"conversation_id": None,
"prompt_id": "default",
"chunks": 2,
"token_limit": 1000,
"retriever": "classic_rag",
"active_docs": "local/test/",
"isNoneDoc": False,
"save_conversation": True,
}
@pytest.fixture
def flask_app():
from flask import Flask
app = Flask(__name__)
return app

View File

@@ -1,311 +0,0 @@
import datetime
import io
from unittest.mock import Mock, patch
import pytest
from bson import ObjectId
from werkzeug.datastructures import FileStorage
@pytest.mark.unit
class TestTimeRangeGenerators:
def test_generate_minute_range(self):
from application.api.user.base import generate_minute_range
start = datetime.datetime(2024, 1, 1, 10, 0, 0)
end = datetime.datetime(2024, 1, 1, 10, 5, 0)
result = generate_minute_range(start, end)
assert len(result) == 6
assert "2024-01-01 10:00:00" in result
assert "2024-01-01 10:05:00" in result
assert all(val == 0 for val in result.values())
def test_generate_hourly_range(self):
from application.api.user.base import generate_hourly_range
start = datetime.datetime(2024, 1, 1, 10, 0, 0)
end = datetime.datetime(2024, 1, 1, 15, 0, 0)
result = generate_hourly_range(start, end)
assert len(result) == 6
assert "2024-01-01 10:00" in result
assert "2024-01-01 15:00" in result
assert all(val == 0 for val in result.values())
def test_generate_date_range(self):
from application.api.user.base import generate_date_range
start = datetime.date(2024, 1, 1)
end = datetime.date(2024, 1, 5)
result = generate_date_range(start, end)
assert len(result) == 5
assert "2024-01-01" in result
assert "2024-01-05" in result
assert all(val == 0 for val in result.values())
def test_single_minute_range(self):
from application.api.user.base import generate_minute_range
time = datetime.datetime(2024, 1, 1, 10, 30, 0)
result = generate_minute_range(time, time)
assert len(result) == 1
assert "2024-01-01 10:30:00" in result
@pytest.mark.unit
class TestEnsureUserDoc:
def test_creates_new_user_with_defaults(self, mock_mongo_db):
from application.api.user.base import ensure_user_doc
user_id = "test_user_123"
result = ensure_user_doc(user_id)
assert result is not None
assert result["user_id"] == user_id
assert "agent_preferences" in result
assert result["agent_preferences"]["pinned"] == []
assert result["agent_preferences"]["shared_with_me"] == []
def test_returns_existing_user(self, mock_mongo_db):
from application.api.user.base import ensure_user_doc
from application.core.settings import settings
users_collection = mock_mongo_db[settings.MONGO_DB_NAME]["users"]
user_id = "existing_user"
existing_doc = {
"user_id": user_id,
"agent_preferences": {"pinned": ["agent1"], "shared_with_me": ["agent2"]},
}
users_collection.insert_one(existing_doc)
result = ensure_user_doc(user_id)
assert result["user_id"] == user_id
assert result["agent_preferences"]["pinned"] == ["agent1"]
assert result["agent_preferences"]["shared_with_me"] == ["agent2"]
def test_adds_missing_preferences_fields(self, mock_mongo_db):
from application.api.user.base import ensure_user_doc
from application.core.settings import settings
users_collection = mock_mongo_db[settings.MONGO_DB_NAME]["users"]
user_id = "incomplete_user"
users_collection.insert_one(
{"user_id": user_id, "agent_preferences": {"pinned": ["agent1"]}}
)
result = ensure_user_doc(user_id)
assert "shared_with_me" in result["agent_preferences"]
assert result["agent_preferences"]["shared_with_me"] == []
@pytest.mark.unit
class TestResolveToolDetails:
def test_resolves_tool_ids_to_details(self, mock_mongo_db):
from application.api.user.base import resolve_tool_details
from application.core.settings import settings
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
tool_id1 = ObjectId()
tool_id2 = ObjectId()
user_tools.insert_one(
{"_id": tool_id1, "name": "calculator", "displayName": "Calculator Tool"}
)
user_tools.insert_one(
{"_id": tool_id2, "name": "weather", "displayName": "Weather API"}
)
result = resolve_tool_details([str(tool_id1), str(tool_id2)])
assert len(result) == 2
assert result[0]["id"] == str(tool_id1)
assert result[0]["name"] == "calculator"
assert result[0]["display_name"] == "Calculator Tool"
assert result[1]["name"] == "weather"
def test_handles_missing_display_name(self, mock_mongo_db):
from application.api.user.base import resolve_tool_details
from application.core.settings import settings
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
tool_id = ObjectId()
user_tools.insert_one({"_id": tool_id, "name": "test_tool"})
result = resolve_tool_details([str(tool_id)])
assert result[0]["display_name"] == "test_tool"
def test_empty_tool_ids_list(self, mock_mongo_db):
from application.api.user.base import resolve_tool_details
result = resolve_tool_details([])
assert result == []
@pytest.mark.unit
class TestGetVectorStore:
@patch("application.api.user.base.VectorCreator.create_vectorstore")
def test_creates_vector_store(self, mock_create, mock_mongo_db):
from application.api.user.base import get_vector_store
mock_store = Mock()
mock_create.return_value = mock_store
source_id = "test_source_123"
result = get_vector_store(source_id)
assert result == mock_store
mock_create.assert_called_once()
args, kwargs = mock_create.call_args
assert kwargs.get("source_id") == source_id
@pytest.mark.unit
class TestHandleImageUpload:
def test_returns_existing_url_when_no_file(self, flask_app):
from application.api.user.base import handle_image_upload
with flask_app.test_request_context():
mock_request = Mock()
mock_request.files = {}
mock_storage = Mock()
existing_url = "existing/path/image.jpg"
url, error = handle_image_upload(
mock_request, existing_url, "user123", mock_storage
)
assert url == existing_url
assert error is None
def test_uploads_new_image(self, flask_app):
from application.api.user.base import handle_image_upload
with flask_app.test_request_context():
mock_file = FileStorage(
stream=io.BytesIO(b"fake image data"), filename="test_image.png"
)
mock_request = Mock()
mock_request.files = {"image": mock_file}
mock_storage = Mock()
mock_storage.save_file.return_value = {"success": True}
url, error = handle_image_upload(
mock_request, "old_url", "user123", mock_storage
)
assert error is None
assert url is not None
assert "test_image.png" in url
assert "user123" in url
mock_storage.save_file.assert_called_once()
def test_ignores_empty_filename(self, flask_app):
from application.api.user.base import handle_image_upload
with flask_app.test_request_context():
mock_file = Mock()
mock_file.filename = ""
mock_request = Mock()
mock_request.files = {"image": mock_file}
mock_storage = Mock()
existing_url = "existing.jpg"
url, error = handle_image_upload(
mock_request, existing_url, "user123", mock_storage
)
assert url == existing_url
assert error is None
mock_storage.save_file.assert_not_called()
def test_handles_upload_error(self, flask_app):
from application.api.user.base import handle_image_upload
with flask_app.app_context():
mock_file = FileStorage(stream=io.BytesIO(b"data"), filename="test.png")
mock_request = Mock()
mock_request.files = {"image": mock_file}
mock_storage = Mock()
mock_storage.save_file.side_effect = Exception("Storage error")
url, error = handle_image_upload(
mock_request, "old.jpg", "user123", mock_storage
)
assert url is None
assert error is not None
assert error.status_code == 400
@pytest.mark.unit
class TestRequireAgentDecorator:
def test_validates_webhook_token(self, mock_mongo_db, flask_app):
from application.api.user.base import require_agent
from application.core.settings import settings
with flask_app.app_context():
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
agent_id = ObjectId()
webhook_token = "valid_webhook_token_123"
agents_collection.insert_one(
{"_id": agent_id, "incoming_webhook_token": webhook_token}
)
@require_agent
def test_func(webhook_token=None, agent=None, agent_id_str=None):
return {"agent_id": agent_id_str}
result = test_func(webhook_token=webhook_token)
assert result["agent_id"] == str(agent_id)
def test_returns_400_for_missing_token(self, mock_mongo_db, flask_app):
from application.api.user.base import require_agent
with flask_app.app_context():
@require_agent
def test_func(webhook_token=None, agent=None, agent_id_str=None):
return {"success": True}
result = test_func()
assert result.status_code == 400
assert result.json["success"] is False
assert "missing" in result.json["message"].lower()
def test_returns_404_for_invalid_token(self, mock_mongo_db, flask_app):
from application.api.user.base import require_agent
with flask_app.app_context():
@require_agent
def test_func(webhook_token=None, agent=None, agent_id_str=None):
return {"success": True}
result = test_func(webhook_token="invalid_token_999")
assert result.status_code == 404
assert result.json["success"] is False
assert "not found" in result.json["message"].lower()

View File

@@ -1,15 +1,7 @@
from unittest.mock import Mock
import mongomock
import pytest
def get_settings():
"""Lazy load settings to avoid import-time errors."""
from application.core.settings import settings
return settings
from application.core.settings import settings
@pytest.fixture
@@ -43,51 +35,18 @@ def mock_retriever():
@pytest.fixture
def mock_mongo_db(monkeypatch):
"""Mock MongoDB using mongomock - industry standard MongoDB mocking library."""
settings = get_settings()
fake_collection = FakeMongoCollection()
fake_db = {
"agents": fake_collection,
"user_tools": fake_collection,
"memories": fake_collection,
}
fake_client = {settings.MONGO_DB_NAME: fake_db}
mock_client = mongomock.MongoClient()
mock_db = mock_client[settings.MONGO_DB_NAME]
def get_mock_client():
return {settings.MONGO_DB_NAME: mock_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", get_mock_client)
monkeypatch.setattr("application.api.user.base.users_collection", mock_db["users"])
monkeypatch.setattr(
"application.api.user.base.user_tools_collection", mock_db["user_tools"]
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
)
monkeypatch.setattr(
"application.api.user.base.agents_collection", mock_db["agents"]
)
monkeypatch.setattr(
"application.api.user.base.conversations_collection", mock_db["conversations"]
)
monkeypatch.setattr(
"application.api.user.base.sources_collection", mock_db["sources"]
)
monkeypatch.setattr(
"application.api.user.base.prompts_collection", mock_db["prompts"]
)
monkeypatch.setattr(
"application.api.user.base.feedback_collection", mock_db["feedback"]
)
monkeypatch.setattr(
"application.api.user.base.token_usage_collection", mock_db["token_usage"]
)
monkeypatch.setattr(
"application.api.user.base.attachments_collection", mock_db["attachments"]
)
monkeypatch.setattr(
"application.api.user.base.user_logs_collection", mock_db["user_logs"]
)
monkeypatch.setattr(
"application.api.user.base.shared_conversations_collections",
mock_db["shared_conversations"],
)
return get_mock_client()
return fake_client
@pytest.fixture
@@ -128,6 +87,53 @@ def log_context():
return context
class FakeMongoCollection:
def __init__(self):
self.docs = {}
def find_one(self, query, projection=None):
if "key" in query:
return self.docs.get(query["key"])
if "_id" in query:
return self.docs.get(str(query["_id"]))
if "user" in query:
for doc in self.docs.values():
if doc.get("user") == query["user"]:
return doc
return None
def find(self, query, projection=None):
results = []
if "_id" in query and "$in" in query["_id"]:
for doc_id in query["_id"]["$in"]:
doc = self.docs.get(str(doc_id))
if doc:
results.append(doc)
elif "user" in query:
for doc in self.docs.values():
if doc.get("user") == query["user"]:
if "status" in query:
if doc.get("status") == query["status"]:
results.append(doc)
else:
results.append(doc)
return results
def insert_one(self, doc):
doc_id = doc.get("_id", len(self.docs))
self.docs[str(doc_id)] = doc
return Mock(inserted_id=doc_id)
def update_one(self, query, update, upsert=False):
return Mock(modified_count=1)
def delete_one(self, query):
return Mock(deleted_count=1)
def delete_many(self, query):
return Mock(deleted_count=0)
@pytest.fixture
def mock_llm_creator(mock_llm, monkeypatch):
monkeypatch.setattr(

View File

@@ -1,13 +1,19 @@
from typing import Any, Dict, Generator
from unittest.mock import Mock, patch
from typing import Any, Dict, Generator
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
class TestToolCall:
"""Test ToolCall dataclass."""
def test_tool_call_creation(self):
"""Test basic ToolCall creation."""
tool_call = ToolCall(
id="test_id", name="test_function", arguments={"arg1": "value1"}, index=0
id="test_id",
name="test_function",
arguments={"arg1": "value1"},
index=0
)
assert tool_call.id == "test_id"
assert tool_call.name == "test_function"
@@ -15,11 +21,12 @@ class TestToolCall:
assert tool_call.index == 0
def test_tool_call_from_dict(self):
"""Test ToolCall creation from dictionary."""
data = {
"id": "call_123",
"name": "get_weather",
"arguments": {"location": "New York"},
"index": 1,
"index": 1
}
tool_call = ToolCall.from_dict(data)
assert tool_call.id == "call_123"
@@ -28,6 +35,7 @@ class TestToolCall:
assert tool_call.index == 1
def test_tool_call_from_dict_missing_fields(self):
"""Test ToolCall creation with missing fields."""
data = {"name": "test_func"}
tool_call = ToolCall.from_dict(data)
assert tool_call.id == ""
@@ -37,13 +45,16 @@ class TestToolCall:
class TestLLMResponse:
"""Test LLMResponse dataclass."""
def test_llm_response_creation(self):
"""Test basic LLMResponse creation."""
tool_calls = [ToolCall(id="1", name="func", arguments={})]
response = LLMResponse(
content="Hello",
tool_calls=tool_calls,
finish_reason="tool_calls",
raw_response={"test": "data"},
raw_response={"test": "data"}
)
assert response.content == "Hello"
assert len(response.tool_calls) == 1
@@ -51,43 +62,55 @@ class TestLLMResponse:
assert response.raw_response == {"test": "data"}
def test_requires_tool_call_true(self):
"""Test requires_tool_call property when tool calls are needed."""
tool_calls = [ToolCall(id="1", name="func", arguments={})]
response = LLMResponse(
content="",
tool_calls=tool_calls,
finish_reason="tool_calls",
raw_response={},
raw_response={}
)
assert response.requires_tool_call is True
def test_requires_tool_call_false_no_tools(self):
"""Test requires_tool_call property when no tool calls."""
response = LLMResponse(
content="Hello", tool_calls=[], finish_reason="stop", raw_response={}
content="Hello",
tool_calls=[],
finish_reason="stop",
raw_response={}
)
assert response.requires_tool_call is False
def test_requires_tool_call_false_wrong_finish_reason(self):
"""Test requires_tool_call property with tools but wrong finish reason."""
tool_calls = [ToolCall(id="1", name="func", arguments={})]
response = LLMResponse(
content="Hello",
tool_calls=tool_calls,
finish_reason="stop",
raw_response={},
raw_response={}
)
assert response.requires_tool_call is False
class ConcreteHandler(LLMHandler):
"""Concrete implementation for testing abstract base class."""
def parse_response(self, response: Any) -> LLMResponse:
return LLMResponse(
content=str(response),
tool_calls=[],
finish_reason="stop",
raw_response=response,
raw_response=response
)
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
return {"role": "tool", "content": str(result), "tool_call_id": tool_call.id}
return {
"role": "tool",
"content": str(result),
"tool_call_id": tool_call.id
}
def _iterate_stream(self, response: Any) -> Generator:
for chunk in response:
@@ -95,119 +118,114 @@ class ConcreteHandler(LLMHandler):
class TestLLMHandler:
"""Test LLMHandler base class."""
def test_handler_initialization(self):
"""Test handler initialization."""
handler = ConcreteHandler()
assert handler.llm_calls == []
assert handler.tool_calls == []
def test_prepare_messages_no_attachments(self):
"""Test prepare_messages with no attachments."""
handler = ConcreteHandler()
messages = [{"role": "user", "content": "Hello"}]
mock_agent = Mock()
result = handler.prepare_messages(mock_agent, messages, None)
assert result == messages
def test_prepare_messages_with_supported_attachments(self):
"""Test prepare_messages with supported attachments."""
handler = ConcreteHandler()
messages = [{"role": "user", "content": "Hello"}]
attachments = [{"mime_type": "image/png", "path": "/test.png"}]
mock_agent = Mock()
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
mock_agent.llm.prepare_messages_with_attachments.return_value = messages
result = handler.prepare_messages(mock_agent, messages, attachments)
mock_agent.llm.prepare_messages_with_attachments.assert_called_once_with(
messages, attachments
)
assert result == messages
@patch("application.llm.handlers.base.logger")
@patch('application.llm.handlers.base.logger')
def test_prepare_messages_with_unsupported_attachments(self, mock_logger):
"""Test prepare_messages with unsupported attachments."""
handler = ConcreteHandler()
messages = [{"role": "user", "content": "Hello"}]
attachments = [{"mime_type": "text/plain", "path": "/test.txt"}]
mock_agent = Mock()
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
with patch.object(
handler, "_append_unsupported_attachments", return_value=messages
) as mock_append:
with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append:
result = handler.prepare_messages(mock_agent, messages, attachments)
mock_append.assert_called_once_with(messages, attachments)
assert result == messages
def test_prepare_messages_mixed_attachments(self):
"""Test prepare_messages with both supported and unsupported attachments."""
handler = ConcreteHandler()
messages = [{"role": "user", "content": "Hello"}]
attachments = [
{"mime_type": "image/png", "path": "/test.png"},
{"mime_type": "text/plain", "path": "/test.txt"},
{"mime_type": "text/plain", "path": "/test.txt"}
]
mock_agent = Mock()
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
mock_agent.llm.prepare_messages_with_attachments.return_value = messages
with patch.object(
handler, "_append_unsupported_attachments", return_value=messages
) as mock_append:
with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append:
result = handler.prepare_messages(mock_agent, messages, attachments)
# Should call both methods
mock_agent.llm.prepare_messages_with_attachments.assert_called_once()
mock_append.assert_called_once()
assert result == messages
def test_process_message_flow_non_streaming(self):
"""Test process_message_flow for non-streaming."""
handler = ConcreteHandler()
mock_agent = Mock()
initial_response = "test response"
tools_dict = {}
messages = [{"role": "user", "content": "Hello"}]
with patch.object(
handler, "prepare_messages", return_value=messages
) as mock_prepare:
with patch.object(
handler, "handle_non_streaming", return_value="final"
) as mock_handle:
with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare:
with patch.object(handler, 'handle_non_streaming', return_value="final") as mock_handle:
result = handler.process_message_flow(
mock_agent, initial_response, tools_dict, messages, stream=False
)
mock_prepare.assert_called_once_with(mock_agent, messages, None)
mock_handle.assert_called_once_with(
mock_agent, initial_response, tools_dict, messages
)
mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages)
assert result == "final"
def test_process_message_flow_streaming(self):
"""Test process_message_flow for streaming."""
handler = ConcreteHandler()
mock_agent = Mock()
initial_response = "test response"
tools_dict = {}
messages = [{"role": "user", "content": "Hello"}]
def mock_generator():
yield "chunk1"
yield "chunk2"
with patch.object(
handler, "prepare_messages", return_value=messages
) as mock_prepare:
with patch.object(
handler, "handle_streaming", return_value=mock_generator()
) as mock_handle:
with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare:
with patch.object(handler, 'handle_streaming', return_value=mock_generator()) as mock_handle:
result = handler.process_message_flow(
mock_agent, initial_response, tools_dict, messages, stream=True
)
mock_prepare.assert_called_once_with(mock_agent, messages, None)
mock_handle.assert_called_once_with(
mock_agent, initial_response, tools_dict, messages
)
mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages)
# Verify it's a generator
chunks = list(result)
assert chunks == ["chunk1", "chunk2"]

View File

@@ -1,4 +1,3 @@
pytest>=8.0.0
pytest-cov>=4.1.0
coverage>=7.4.0
mongomock>=4.3.0

View File

@@ -1,156 +0,0 @@
import pytest
from application.agents.tools.todo_list import TodoListTool
from application.core.settings import settings
class FakeCursor(list):
def sort(self, key, direction):
reverse = direction == -1
sorted_list = sorted(self, key=lambda d: d.get(key, 0), reverse=reverse)
return FakeCursor(sorted_list)
def limit(self, count):
return FakeCursor(self[:count])
def __iter__(self):
return self
def __next__(self):
if not self:
raise StopIteration
return self.pop(0)
class FakeCollection:
def __init__(self):
self.docs = {}
def create_index(self, *args, **kwargs):
pass
def insert_one(self, doc):
key = (doc["user_id"], doc["tool_id"], int(doc["todo_id"]))
self.docs[key] = doc
return type("res", (), {"inserted_id": key})
def find_one(self, query):
key = (query.get("user_id"), query.get("tool_id"), int(query.get("todo_id")))
return self.docs.get(key)
def find(self, query):
user_id = query.get("user_id")
tool_id = query.get("tool_id")
filtered = [
doc for (uid, tid, _), doc in self.docs.items()
if uid == user_id and tid == tool_id
]
return FakeCursor(filtered)
def update_one(self, query, update, upsert=False):
key = (query.get("user_id"), query.get("tool_id"), int(query.get("todo_id")))
if key in self.docs:
self.docs[key].update(update.get("$set", {}))
return type("res", (), {"matched_count": 1})
elif upsert:
new_doc = {**query, **update.get("$set", {})}
self.docs[key] = new_doc
return type("res", (), {"matched_count": 1})
else:
return type("res", (), {"matched_count": 0})
def delete_one(self, query):
key = (query.get("user_id"), query.get("tool_id"), int(query.get("todo_id")))
if key in self.docs:
del self.docs[key]
return type("res", (), {"deleted_count": 1})
return type("res", (), {"deleted_count": 0})
@pytest.fixture
def todo_tool(monkeypatch) -> TodoListTool:
"""Provides a TodoListTool with a fake MongoDB backend."""
fake_collection = FakeCollection()
fake_client = {settings.MONGO_DB_NAME: {"todos": fake_collection}}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
return TodoListTool({"tool_id": "test_tool"}, user_id="test_user")
def test_create_and_get(todo_tool: TodoListTool):
res = todo_tool.execute_action("todo_create", title="Write tests", description="Write pytest cases")
assert res["status_code"] == 201
todo_id = res["todo_id"]
get_res = todo_tool.execute_action("todo_get", todo_id=todo_id)
assert get_res["status_code"] == 200
assert get_res["todo"]["title"] == "Write tests"
assert get_res["todo"]["description"] == "Write pytest cases"
def test_get_all_todos(todo_tool: TodoListTool):
todo_tool.execute_action("todo_create", title="Task 1")
todo_tool.execute_action("todo_create", title="Task 2")
list_res = todo_tool.execute_action("todo_list")
assert list_res["status_code"] == 200
titles = [todo["title"] for todo in list_res["todos"]]
assert "Task 1" in titles
assert "Task 2" in titles
def test_update_todo(todo_tool: TodoListTool):
create_res = todo_tool.execute_action("todo_create", title="Initial Title")
todo_id = create_res["todo_id"]
update_res = todo_tool.execute_action("todo_update", todo_id=todo_id, updates={"title": "Updated Title", "status": "done"})
assert update_res["status_code"] == 200
get_res = todo_tool.execute_action("todo_get", todo_id=todo_id)
assert get_res["todo"]["title"] == "Updated Title"
assert get_res["todo"]["status"] == "done"
def test_delete_todo(todo_tool: TodoListTool):
create_res = todo_tool.execute_action("todo_create", title="To Delete")
todo_id = create_res["todo_id"]
delete_res = todo_tool.execute_action("todo_delete", todo_id=todo_id)
assert delete_res["status_code"] == 200
get_res = todo_tool.execute_action("todo_get", todo_id=todo_id)
assert get_res["status_code"] == 404
def test_isolation_and_default_tool_id(monkeypatch):
"""Ensure todos are isolated by tool_id and user_id."""
fake_collection = FakeCollection()
fake_client = {settings.MONGO_DB_NAME: {"todos": fake_collection}}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
# Same user, different tool_id
tool1 = TodoListTool({"tool_id": "tool_1"}, user_id="u1")
tool2 = TodoListTool({"tool_id": "tool_2"}, user_id="u1")
r1_create = tool1.execute_action("todo_create", title="from tool 1")
r2_create = tool2.execute_action("todo_create", title="from tool 2")
r1 = tool1.execute_action("todo_get", todo_id=r1_create["todo_id"])
r2 = tool2.execute_action("todo_get", todo_id=r2_create["todo_id"])
assert r1["status_code"] == 200
assert r1["todo"]["title"] == "from tool 1"
assert r2["status_code"] == 200
assert r2["todo"]["title"] == "from tool 2"
# Same user, no tool_id → should default to same value
t3 = TodoListTool({}, user_id="default_user")
t4 = TodoListTool({}, user_id="default_user")
assert t3.tool_id == "default_default_user"
assert t4.tool_id == "default_default_user"
create_res = t3.execute_action("todo_create", title="shared default")
r = t4.execute_action("todo_get", todo_id=create_res["todo_id"])
assert r["status_code"] == 200
assert r["todo"]["title"] == "shared default"

View File

@@ -16,25 +16,12 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
class DummyClient:
def __init__(self, api_key):
created["api_key"] = api_key
self.convert_calls = []
self.generate_calls = []
class TextToSpeech:
def __init__(self, outer):
self._outer = outer
def convert(self, *, voice_id, model_id, text, output_format):
self._outer.convert_calls.append(
{
"voice_id": voice_id,
"model_id": model_id,
"text": text,
"output_format": output_format,
}
)
yield b"chunk-one"
yield b"chunk-two"
self.text_to_speech = TextToSpeech(self)
def generate(self, *, text, model, voice):
self.generate_calls.append({"text": text, "model": model, "voice": voice})
yield b"chunk-one"
yield b"chunk-two"
client_module = ModuleType("elevenlabs.client")
client_module.ElevenLabs = DummyClient
@@ -48,13 +35,8 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
audio_base64, lang = tts.text_to_speech("Speak")
assert created["api_key"] == "api-key"
assert tts.client.convert_calls == [
{
"voice_id": "nPczCjzI2devNBz1zQrb",
"model_id": "eleven_multilingual_v2",
"text": "Speak",
"output_format": "mp3_44100_128",
}
assert tts.client.generate_calls == [
{"text": "Speak", "model": "eleven_multilingual_v2", "voice": "Brian"}
]
assert lang == "en"
assert base64.b64decode(audio_base64.encode()) == b"chunk-onechunk-two"

View File

@@ -1,61 +0,0 @@
import pytest
from unittest.mock import patch, MagicMock
from application.tts.tts_creator import TTSCreator
@pytest.fixture
def tts_creator():
return TTSCreator()
def test_create_google_tts(tts_creator):
# Patch the provider registry so the factory calls our mock class
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
mock_google_tts = TTSCreator.tts_providers["google_tts"]
instance = MagicMock()
mock_google_tts.return_value = instance
result = tts_creator.create_tts("google_tts", "arg1", key="value")
mock_google_tts.assert_called_once_with("arg1", key="value")
assert result == instance
def test_create_elevenlabs_tts(tts_creator):
# Patch the provider registry so the factory calls our mock class
with patch.dict(TTSCreator.tts_providers, {"elevenlabs": MagicMock()}):
mock_elevenlabs_tts = TTSCreator.tts_providers["elevenlabs"]
instance = MagicMock()
mock_elevenlabs_tts.return_value = instance
result = tts_creator.create_tts("elevenlabs", "voice", lang="en")
mock_elevenlabs_tts.assert_called_once_with("voice", lang="en")
assert result == instance
def test_invalid_tts_type(tts_creator):
with pytest.raises(ValueError) as excinfo:
tts_creator.create_tts("unknown_tts")
assert "No tts class found" in str(excinfo.value)
def test_tts_type_case_insensitivity(tts_creator):
# Patch the provider registry to ensure case-insensitive lookup hits our mock
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
mock_google_tts = TTSCreator.tts_providers["google_tts"]
instance = MagicMock()
mock_google_tts.return_value = instance
result = tts_creator.create_tts("GoOgLe_TtS")
mock_google_tts.assert_called_once_with()
assert result == instance
def test_tts_providers_integrity(tts_creator):
providers = tts_creator.tts_providers
assert "google_tts" in providers
assert "elevenlabs" in providers
assert callable(providers["google_tts"])
assert callable(providers["elevenlabs"])