Compare commits
8 Commits
dependabot
...
sharepoint
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c46d8a094 | ||
|
|
065939302b | ||
|
|
5fa87db9e7 | ||
|
|
cc54cea783 | ||
|
|
d9f0072112 | ||
|
|
2b73c0c9a0 | ||
|
|
da62133d21 | ||
|
|
8edb6dcf2a |
@@ -6,4 +6,17 @@ VITE_API_STREAMING=true
|
||||
OPENAI_API_BASE=
|
||||
OPENAI_API_VERSION=
|
||||
AZURE_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
|
||||
#Azure AD Application (client) ID
|
||||
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
|
||||
#Azure AD Application client secret
|
||||
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
|
||||
#Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
#If you are using a Microsoft Entra ID tenant,
|
||||
#configure the AUTHORITY variable as
|
||||
#"https://login.microsoftonline.com/TENANT_GUID"
|
||||
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||
MICROSOFT_AUTHORITY=https://{tenentId}.ciamlogin.com/{tenentId}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
Welcome, contributors! We're excited to announce that DocsGPT is participating in Hacktoberfest. Get involved by submitting meaningful pull requests.
|
||||
|
||||
All Meaningful contributors with accepted PRs that were created for issues with the `hacktoberfest` label (set by our maintainer team: dartpain, siiddhantt, pabik, ManishMadan2882) will receive a cool T-shirt! 🤩.
|
||||
<img width="1331" height="678" alt="hacktoberfest-mocks-preview" src="https://github.com/user-attachments/assets/633f6377-38db-48f5-b519-a8b3855a9eb4" />
|
||||
|
||||
Fill in [this form](https://forms.gle/Npaba4n9Epfyx56S8
|
||||
) after your PR was merged please
|
||||
|
||||
@@ -1,321 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class TodoListTool(Tool):
|
||||
"""Todo List
|
||||
|
||||
Manages todo items for users. Supports creating, viewing, updating, and deleting todos.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this todo list tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated todos
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["todos"]
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of list, create, get, update, complete, delete.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: TodoListTool requires a valid user_id."
|
||||
|
||||
if action_name == "list":
|
||||
return self._list()
|
||||
|
||||
if action_name == "create":
|
||||
return self._create(kwargs.get("title", ""))
|
||||
|
||||
if action_name == "get":
|
||||
return self._get(kwargs.get("todo_id"))
|
||||
|
||||
if action_name == "update":
|
||||
return self._update(
|
||||
kwargs.get("todo_id"),
|
||||
kwargs.get("title", "")
|
||||
)
|
||||
|
||||
if action_name == "complete":
|
||||
return self._complete(kwargs.get("todo_id"))
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete(kwargs.get("todo_id"))
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "list",
|
||||
"description": "List all todos for the user.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "create",
|
||||
"description": "Create a new todo item.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Title of the todo item."
|
||||
}
|
||||
},
|
||||
"required": ["title"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "get",
|
||||
"description": "Get a specific todo by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to retrieve."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "update",
|
||||
"description": "Update a todo's title by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to update."
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "The new title for the todo."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id", "title"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "complete",
|
||||
"description": "Mark a todo as completed.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to mark as completed."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete a specific todo by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to delete."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements."""
|
||||
return {}
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------
|
||||
def _coerce_todo_id(self, value: Optional[Any]) -> Optional[int]:
|
||||
"""Convert todo identifiers to sequential integers."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, int):
|
||||
return value if value > 0 else None
|
||||
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if stripped.isdigit():
|
||||
numeric_value = int(stripped)
|
||||
return numeric_value if numeric_value > 0 else None
|
||||
|
||||
return None
|
||||
|
||||
def _get_next_todo_id(self) -> int:
|
||||
"""Get the next sequential todo_id for this user and tool.
|
||||
|
||||
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
|
||||
With 5-10 todos max, scanning is negligible.
|
||||
"""
|
||||
# Find all todos for this user/tool and get their IDs
|
||||
todos = list(self.collection.find(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"todo_id": 1}
|
||||
))
|
||||
|
||||
# Find the maximum todo_id
|
||||
max_id = 0
|
||||
for todo in todos:
|
||||
todo_id = self._coerce_todo_id(todo.get("todo_id"))
|
||||
if todo_id is not None:
|
||||
max_id = max(max_id, todo_id)
|
||||
|
||||
return max_id + 1
|
||||
|
||||
def _list(self) -> str:
|
||||
"""List all todos for the user."""
|
||||
cursor = self.collection.find({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
todos = list(cursor)
|
||||
|
||||
if not todos:
|
||||
return "No todos found."
|
||||
|
||||
result_lines = ["Todos:"]
|
||||
for doc in todos:
|
||||
todo_id = doc.get("todo_id")
|
||||
title = doc.get("title", "Untitled")
|
||||
status = doc.get("status", "open")
|
||||
|
||||
line = f"[{todo_id}] {title} ({status})"
|
||||
result_lines.append(line)
|
||||
|
||||
return "\n".join(result_lines)
|
||||
|
||||
def _create(self, title: str) -> str:
|
||||
"""Create a new todo item."""
|
||||
title = (title or "").strip()
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
now = datetime.now()
|
||||
todo_id = self._get_next_todo_id()
|
||||
|
||||
doc = {
|
||||
"todo_id": todo_id,
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"title": title,
|
||||
"status": "open",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
self.collection.insert_one(doc)
|
||||
return f"Todo created with ID {todo_id}: {title}"
|
||||
|
||||
def _get(self, todo_id: Optional[Any]) -> str:
|
||||
"""Get a specific todo by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
doc = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"todo_id": parsed_todo_id
|
||||
})
|
||||
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
title = doc.get("title", "Untitled")
|
||||
status = doc.get("status", "open")
|
||||
|
||||
result = f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
||||
|
||||
return result
|
||||
|
||||
def _update(self, todo_id: Optional[Any], title: str) -> str:
|
||||
"""Update a todo's title by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
title = (title or "").strip()
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
result = self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
|
||||
{"$set": {"title": title, "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
return f"Todo {parsed_todo_id} updated to: {title}"
|
||||
|
||||
def _complete(self, todo_id: Optional[Any]) -> str:
|
||||
"""Mark a todo as completed."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
result = self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
|
||||
{"$set": {"status": "completed", "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
return f"Todo {parsed_todo_id} marked as completed."
|
||||
|
||||
def _delete(self, todo_id: Optional[Any]) -> str:
|
||||
"""Delete a specific todo by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
result = self.collection.delete_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"todo_id": parsed_todo_id
|
||||
})
|
||||
|
||||
if result.deleted_count == 0:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
return f"Todo {parsed_todo_id} deleted."
|
||||
@@ -28,7 +28,7 @@ class ToolManager:
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
if tool_name in {"mcp_tool", "notes", "memory", "todo_list"} and user_id:
|
||||
if tool_name in {"mcp_tool", "notes", "memory"} and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
@@ -36,7 +36,7 @@ class ToolManager:
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
if tool_name in {"mcp_tool", "memory", "todo_list"} and user_id:
|
||||
if tool_name in {"mcp_tool", "memory"} and user_id:
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
|
||||
@@ -113,10 +113,14 @@ class ConnectorsCallback(Resource):
|
||||
session_token = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
if provider == "google_drive":
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
else:
|
||||
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.warning(f"Could not get user info: {e}")
|
||||
user_email = 'Connected User'
|
||||
|
||||
@@ -10,7 +10,7 @@ from application.api import api
|
||||
from application.api.user.base import agents_collection, storage
|
||||
from application.api.user.tasks import store_attachment
|
||||
from application.core.settings import settings
|
||||
from application.tts.tts_creator import TTSCreator
|
||||
from application.tts.google_tts import GoogleTTS
|
||||
from application.utils import safe_filename
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ class TextToSpeech(Resource):
|
||||
data = request.get_json()
|
||||
text = data["text"]
|
||||
try:
|
||||
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
|
||||
tts_instance = GoogleTTS()
|
||||
audio_base64, detected_language = tts_instance.text_to_speech(text)
|
||||
return make_response(
|
||||
jsonify(
|
||||
|
||||
@@ -55,6 +55,11 @@ class Settings(BaseSettings):
|
||||
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
)
|
||||
|
||||
# Microsoft Entra ID (Azure AD) integration
|
||||
MICROSOFT_CLIENT_ID: Optional[str] = None # Azure AD Application (client) ID
|
||||
MICROSOFT_CLIENT_SECRET: Optional[str] = None # Azure AD Application client secret
|
||||
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
@@ -130,7 +135,6 @@ class Settings(BaseSettings):
|
||||
# Encryption settings
|
||||
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
||||
|
||||
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
|
||||
ELEVENLABS_API_KEY: Optional[str] = None
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||
from application.parser.connectors.share_point.loader import SharePointLoader
|
||||
|
||||
|
||||
class ConnectorCreator:
|
||||
@@ -12,10 +14,12 @@ class ConnectorCreator:
|
||||
|
||||
connectors = {
|
||||
"google_drive": GoogleDriveLoader,
|
||||
"share_point": SharePointLoader,
|
||||
}
|
||||
|
||||
auth_providers = {
|
||||
"google_drive": GoogleDriveAuth,
|
||||
"share_point": SharePointAuth,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
10
application/parser/connectors/share_point/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Share Point connector package for DocsGPT.
|
||||
|
||||
This module provides authentication and document loading capabilities for Share Point.
|
||||
"""
|
||||
|
||||
from .auth import SharePointAuth
|
||||
from .loader import SharePointLoader
|
||||
|
||||
__all__ = ['SharePointAuth', 'SharePointLoader']
|
||||
91
application/parser/connectors/share_point/auth.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import logging
|
||||
import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from msal import ConfidentialClientApplication
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.base import BaseConnectorAuth
|
||||
|
||||
|
||||
class SharePointAuth(BaseConnectorAuth):
|
||||
"""
|
||||
Handles Microsoft OAuth 2.0 authentication.
|
||||
|
||||
# Documentation:
|
||||
- https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow
|
||||
- https://learn.microsoft.com/en-gb/entra/msal/python/
|
||||
"""
|
||||
|
||||
# Microsoft Graph scopes for SharePoint access
|
||||
SCOPES = [
|
||||
"User.Read",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.client_id = settings.MICROSOFT_CLIENT_ID
|
||||
self.client_secret = settings.MICROSOFT_CLIENT_SECRET
|
||||
|
||||
if not self.client_id or not self.client_secret:
|
||||
raise ValueError(
|
||||
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID and MICROSOFT_CLIENT_SECRET in settings."
|
||||
)
|
||||
|
||||
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
|
||||
self.tenant_id = settings.MICROSOFT_TENANT_ID
|
||||
self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://{self.tenant_id}.ciamlogin.com/{self.tenant_id}")
|
||||
|
||||
self.auth_app = ConfidentialClientApplication(
|
||||
client_id=self.client_id, client_credential=self.client_secret, authority=self.authority
|
||||
)
|
||||
|
||||
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||
return self.auth_app.get_authorization_request_url(
|
||||
scopes=self.SCOPES, state=state, redirect_uri=self.redirect_uri
|
||||
)
|
||||
|
||||
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||
result = self.auth_app.acquire_token_by_authorization_code(
|
||||
code=authorization_code, scopes=self.SCOPES, redirect_uri=self.redirect_uri
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
logging.error(f"Error acquiring token: {result.get('error_description')}")
|
||||
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
|
||||
|
||||
return self.map_token_response(result)
|
||||
|
||||
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
result = self.auth_app.acquire_token_by_refresh_token(refresh_token=refresh_token, scopes=self.SCOPES)
|
||||
|
||||
if "error" in result:
|
||||
logging.error(f"Error acquiring token: {result.get('error_description')}")
|
||||
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
|
||||
|
||||
return self.map_token_response(result)
|
||||
|
||||
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||
if not token_info or "expiry" not in token_info:
|
||||
# If no expiry info, consider token expired to be safe
|
||||
return True
|
||||
|
||||
# Get expiry timestamp and current time
|
||||
expiry_timestamp = token_info["expiry"]
|
||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||
|
||||
# Token is expired if current time is greater than or equal to expiry time
|
||||
return current_timestamp >= expiry_timestamp
|
||||
|
||||
def map_token_response(self, result) -> Dict[str, Any]:
|
||||
return {
|
||||
"access_token": result.get("access_token"),
|
||||
"refresh_token": result.get("refresh_token"),
|
||||
"token_uri": result.get("id_token_claims", {}).get("iss"),
|
||||
"scopes": result.get("scope"),
|
||||
"expiry": result.get("id_token_claims", {}).get("exp"),
|
||||
"user_info": {
|
||||
"name": result.get("id_token_claims", {}).get("name"),
|
||||
"email": result.get("id_token_claims", {}).get("preferred_username"),
|
||||
},
|
||||
"raw_token": result,
|
||||
}
|
||||
44
application/parser/connectors/share_point/loader.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import List, Dict, Any
|
||||
from application.parser.connectors.base import BaseConnectorLoader
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
class SharePointLoader(BaseConnectorLoader):
|
||||
def __init__(self, session_token: str):
|
||||
pass
|
||||
|
||||
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""
|
||||
Load documents from the external knowledge base.
|
||||
|
||||
Args:
|
||||
inputs: Configuration dictionary containing:
|
||||
- file_ids: Optional list of specific file IDs to load
|
||||
- folder_ids: Optional list of folder IDs to browse/download
|
||||
- limit: Maximum number of items to return
|
||||
- list_only: If True, return metadata without content
|
||||
- recursive: Whether to recursively process folders
|
||||
|
||||
Returns:
|
||||
List of Document objects
|
||||
"""
|
||||
pass
|
||||
|
||||
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Download files/folders to a local directory.
|
||||
|
||||
Args:
|
||||
local_dir: Local directory path to download files to
|
||||
source_config: Configuration for what to download
|
||||
|
||||
Returns:
|
||||
Dictionary containing download results:
|
||||
- files_downloaded: Number of files downloaded
|
||||
- directory_path: Path where files were downloaded
|
||||
- empty_result: Whether no files were downloaded
|
||||
- source_type: Type of connector
|
||||
- config_used: Configuration that was used
|
||||
- error: Error message if download failed (optional)
|
||||
"""
|
||||
pass
|
||||
@@ -10,7 +10,6 @@ ebooklib==0.18
|
||||
escodegen==1.0.11
|
||||
esprima==4.0.1
|
||||
esutils==1.0.1
|
||||
elevenlabs==2.17.0
|
||||
Flask==3.1.1
|
||||
faiss-cpu==1.9.0.post1
|
||||
fastmcp==2.11.0
|
||||
@@ -30,7 +29,7 @@ jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
kombu==5.4.2
|
||||
langchain==0.3.20
|
||||
langchain-community==0.4.1
|
||||
langchain-community==0.3.19
|
||||
langchain-core==0.3.59
|
||||
langchain-openai==0.3.16
|
||||
langchain-text-splitters==0.3.8
|
||||
@@ -41,6 +40,7 @@ markupsafe==3.0.2
|
||||
marshmallow==3.26.1
|
||||
mpmath==1.3.0
|
||||
multidict==6.4.3
|
||||
msal==1.34.0
|
||||
mypy-extensions==1.0.0
|
||||
networkx==3.4.2
|
||||
numpy==2.2.1
|
||||
@@ -88,4 +88,4 @@ werkzeug>=3.1.0,<3.1.2
|
||||
yarl==1.20.0
|
||||
markdownify==1.1.0
|
||||
tldextract==5.1.3
|
||||
websockets==14.1
|
||||
websockets==14.1
|
||||
@@ -15,11 +15,10 @@ class ElevenlabsTTS(BaseTTS):
|
||||
|
||||
def text_to_speech(self, text):
|
||||
lang = "en"
|
||||
audio = self.client.text_to_speech.convert(
|
||||
voice_id="nPczCjzI2devNBz1zQrb",
|
||||
model_id="eleven_multilingual_v2",
|
||||
audio = self.client.generate(
|
||||
text=text,
|
||||
output_format="mp3_44100_128"
|
||||
model="eleven_multilingual_v2",
|
||||
voice="Brian",
|
||||
)
|
||||
audio_data = BytesIO()
|
||||
for chunk in audio:
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from application.tts.google_tts import GoogleTTS
|
||||
from application.tts.elevenlabs import ElevenlabsTTS
|
||||
from application.tts.base import BaseTTS
|
||||
|
||||
|
||||
|
||||
class TTSCreator:
|
||||
tts_providers = {
|
||||
"google_tts": GoogleTTS,
|
||||
"elevenlabs": ElevenlabsTTS,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_tts(cls, tts_type, *args, **kwargs)-> BaseTTS:
|
||||
tts_class = cls.tts_providers.get(tts_type.lower())
|
||||
if not tts_class:
|
||||
raise ValueError(f"No tts class found for type {tts_type}")
|
||||
return tts_class(*args, **kwargs)
|
||||
@@ -21,7 +21,7 @@ def get_encoding():
|
||||
|
||||
|
||||
def get_gpt_model() -> str:
|
||||
"""Get GPT model based on provider"""
|
||||
"""Get the appropriate GPT model based on provider"""
|
||||
model_map = {
|
||||
"openai": "gpt-4o-mini",
|
||||
"anthropic": "claude-2",
|
||||
@@ -32,7 +32,16 @@ def get_gpt_model() -> str:
|
||||
|
||||
|
||||
def safe_filename(filename):
|
||||
"""Create safe filename, preserving extension. Handles non-Latin characters."""
|
||||
"""
|
||||
Creates a safe filename that preserves the original extension.
|
||||
Uses secure_filename, but ensures a proper filename is returned even with non-Latin characters.
|
||||
|
||||
Args:
|
||||
filename (str): The original filename
|
||||
|
||||
Returns:
|
||||
str: A safe filename that can be used for storage
|
||||
"""
|
||||
if not filename:
|
||||
return str(uuid.uuid4())
|
||||
_, extension = os.path.splitext(filename)
|
||||
@@ -74,14 +83,8 @@ def count_tokens_docs(docs):
|
||||
return tokens
|
||||
|
||||
|
||||
def get_missing_fields(data, required_fields):
|
||||
"""Check for missing required fields. Returns list of missing field names."""
|
||||
return [field for field in required_fields if field not in data]
|
||||
|
||||
|
||||
def check_required_fields(data, required_fields):
|
||||
"""Validate required fields. Returns Flask 400 response if validation fails, None otherwise."""
|
||||
missing_fields = get_missing_fields(data, required_fields)
|
||||
missing_fields = [field for field in required_fields if field not in data]
|
||||
if missing_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -95,8 +98,7 @@ def check_required_fields(data, required_fields):
|
||||
return None
|
||||
|
||||
|
||||
def get_field_validation_errors(data, required_fields):
|
||||
"""Check for missing and empty fields. Returns dict with 'missing_fields' and 'empty_fields', or None."""
|
||||
def validate_required_fields(data, required_fields):
|
||||
missing_fields = []
|
||||
empty_fields = []
|
||||
|
||||
@@ -105,24 +107,12 @@ def get_field_validation_errors(data, required_fields):
|
||||
missing_fields.append(field)
|
||||
elif not data[field]:
|
||||
empty_fields.append(field)
|
||||
if missing_fields or empty_fields:
|
||||
return {"missing_fields": missing_fields, "empty_fields": empty_fields}
|
||||
return None
|
||||
|
||||
|
||||
def validate_required_fields(data, required_fields):
|
||||
"""Validate required fields (must exist and be non-empty). Returns Flask 400 response if validation fails, None otherwise."""
|
||||
errors_dict = get_field_validation_errors(data, required_fields)
|
||||
if errors_dict:
|
||||
errors = []
|
||||
if errors_dict["missing_fields"]:
|
||||
errors.append(
|
||||
f"Missing required fields: {', '.join(errors_dict['missing_fields'])}"
|
||||
)
|
||||
if errors_dict["empty_fields"]:
|
||||
errors.append(
|
||||
f"Empty values in required fields: {', '.join(errors_dict['empty_fields'])}"
|
||||
)
|
||||
errors = []
|
||||
if missing_fields:
|
||||
errors.append(f"Missing required fields: {', '.join(missing_fields)}")
|
||||
if empty_fields:
|
||||
errors.append(f"Empty values in required fields: {', '.join(empty_fields)}")
|
||||
if errors:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": " | ".join(errors)}), 400
|
||||
)
|
||||
@@ -134,7 +124,10 @@ def get_hash(data):
|
||||
|
||||
|
||||
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
"""Limit chat history to fit within token limit."""
|
||||
"""
|
||||
Limits chat history based on token count.
|
||||
Returns a list of messages that fit within the token limit.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
max_token_limit = (
|
||||
@@ -168,7 +161,7 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
|
||||
|
||||
def validate_function_name(function_name):
|
||||
"""Validate function name matches allowed pattern (alphanumeric, underscore, hyphen)."""
|
||||
"""Validates if a function name matches the allowed pattern."""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#e3e3e3"><path d="M240-80q-33 0-56.5-23.5T160-160v-640q0-33 23.5-56.5T240-880h480q33 0 56.5 23.5T800-800v640q0 33-23.5 56.5T720-80H240Zm0-80h480v-640H240v640Zm88-104 56-56-56-56-56 56 56 56Zm0-160 56-56-56-56-56 56 56 56Zm0-160 56-56-56-56-56 56 56 56Zm120 280h232v-80H448v80Zm0-160h232v-80H448v80Zm0-160h232v-80H448v80ZM240-160v-640 640Z"/></svg>
|
||||
|
Before Width: | Height: | Size: 446 B |
@@ -9,8 +9,7 @@ import userService from './api/services/userService';
|
||||
import Add from './assets/add.svg';
|
||||
import DocsGPT3 from './assets/cute_docsgpt3.svg';
|
||||
import Discord from './assets/discord.svg';
|
||||
import PanelLeftClose from './assets/panel-left-close.svg';
|
||||
import PanelLeftOpen from './assets/panel-left-open.svg';
|
||||
import Expand from './assets/expand.svg';
|
||||
import Github from './assets/git_nav.svg';
|
||||
import Hamburger from './assets/hamburger.svg';
|
||||
import openNewChat from './assets/openNewChat.svg';
|
||||
@@ -303,20 +302,18 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
{
|
||||
<div className="absolute top-3 left-3 z-20 hidden transition-all duration-300 ease-in-out lg:block">
|
||||
<div className="flex items-center gap-3">
|
||||
{!navOpen && (
|
||||
<button
|
||||
onClick={() => {
|
||||
setNavOpen(!navOpen);
|
||||
}}
|
||||
className="transition-transform duration-200 hover:scale-110"
|
||||
>
|
||||
<img
|
||||
src={PanelLeftOpen}
|
||||
alt="Open navigation menu"
|
||||
className="m-auto transition-all duration-300 ease-in-out"
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
onClick={() => {
|
||||
setNavOpen(!navOpen);
|
||||
}}
|
||||
className="transition-transform duration-200 hover:scale-110"
|
||||
>
|
||||
<img
|
||||
src={Expand}
|
||||
alt="Toggle navigation menu"
|
||||
className="m-auto transition-all duration-300 ease-in-out"
|
||||
/>
|
||||
</button>
|
||||
{queries?.length > 0 && (
|
||||
<button
|
||||
onClick={() => {
|
||||
@@ -366,8 +363,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
}}
|
||||
>
|
||||
<img
|
||||
src={navOpen ? PanelLeftClose : PanelLeftOpen}
|
||||
alt={navOpen ? 'Collapse sidebar' : 'Expand sidebar'}
|
||||
src={Expand}
|
||||
alt="Toggle navigation menu"
|
||||
className="m-auto transition-all duration-300 ease-in-out hover:scale-110"
|
||||
/>
|
||||
</button>
|
||||
|
||||
@@ -109,18 +109,18 @@ export default function AgentPreview() {
|
||||
} else setLastQueryReturnedErr(false);
|
||||
}, [queries]);
|
||||
return (
|
||||
<div className="relative h-full w-full">
|
||||
<div className="scrollbar-thin absolute inset-0 bottom-[180px] overflow-hidden px-4 pt-4 [&>div>div]:!w-full [&>div>div]:!max-w-none">
|
||||
<ConversationMessages
|
||||
handleQuestion={handleQuestion}
|
||||
handleQuestionSubmission={handleQuestionSubmission}
|
||||
queries={queries}
|
||||
status={status}
|
||||
showHeroOnEmpty={false}
|
||||
/>
|
||||
</div>
|
||||
<div className="absolute right-0 bottom-0 left-0 flex w-full flex-col gap-4 pb-2">
|
||||
<div className="w-full px-4">
|
||||
<div>
|
||||
<div className="dark:bg-raisin-black flex h-full flex-col items-center justify-between gap-2 overflow-y-hidden">
|
||||
<div className="h-[512px] w-full overflow-y-auto">
|
||||
<ConversationMessages
|
||||
handleQuestion={handleQuestion}
|
||||
handleQuestionSubmission={handleQuestionSubmission}
|
||||
queries={queries}
|
||||
status={status}
|
||||
showHeroOnEmpty={false}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex w-[95%] max-w-[1500px] flex-col items-center gap-4 pb-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
|
||||
<MessageInput
|
||||
onSubmit={(text) => handleQuestionSubmission(text)}
|
||||
loading={status === 'loading'}
|
||||
@@ -128,11 +128,11 @@ export default function AgentPreview() {
|
||||
showToolButton={selectedAgent ? false : true}
|
||||
autoFocus={false}
|
||||
/>
|
||||
<p className="text-gray-4000 dark:text-sonic-silver w-full self-center bg-transparent pt-2 text-center text-xs md:inline">
|
||||
This is a preview of the agent. You can publish it to start using it
|
||||
in conversations.
|
||||
</p>
|
||||
</div>
|
||||
<p className="text-gray-4000 dark:text-sonic-silver w-full bg-transparent text-center text-xs md:inline">
|
||||
This is a preview of the agent. You can publish it to start using it
|
||||
in conversations.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -534,7 +534,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
setHasChanges(isChanged);
|
||||
}, [agent, dispatch, effectiveMode, imageFile, jsonSchemaText]);
|
||||
return (
|
||||
<div className="flex flex-col px-4 pt-4 pb-2 max-[1179px]:min-h-[100dvh] min-[1180px]:h-[100dvh] md:px-12 md:pt-12 md:pb-3">
|
||||
<div className="p-4 md:p-12">
|
||||
<div className="flex items-center gap-3 px-4">
|
||||
<button
|
||||
className="rounded-full border p-3 text-sm text-gray-400 dark:border-0 dark:bg-[#28292D] dark:text-gray-500 dark:hover:bg-[#2E2F34]"
|
||||
@@ -615,9 +615,9 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-3 flex w-full flex-1 grid-cols-5 flex-col gap-10 rounded-[30px] bg-[#F6F6F6] p-5 max-[1179px]:overflow-visible min-[1180px]:grid min-[1180px]:gap-5 min-[1180px]:overflow-hidden dark:bg-[#383838]">
|
||||
<div className="scrollbar-thin col-span-2 flex flex-col gap-5 max-[1179px]:overflow-visible min-[1180px]:max-h-full min-[1180px]:overflow-y-auto min-[1180px]:pr-3">
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<div className="mt-5 flex w-full grid-cols-5 flex-col gap-10 min-[1180px]:grid min-[1180px]:gap-5">
|
||||
<div className="col-span-2 flex flex-col gap-5">
|
||||
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<h2 className="text-lg font-semibold">Meta</h2>
|
||||
<input
|
||||
className="border-silver text-jet dark:bg-raisin-black dark:text-bright-gray dark:placeholder:text-silver mt-3 w-full rounded-3xl border bg-white px-5 py-3 text-sm outline-hidden placeholder:text-gray-400 dark:border-[#7E7E7E]"
|
||||
@@ -650,7 +650,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<h2 className="text-lg font-semibold">Source</h2>
|
||||
<div className="mt-3">
|
||||
<div className="flex flex-wrap items-center gap-1">
|
||||
@@ -744,7 +744,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<div className="flex flex-wrap items-end gap-1">
|
||||
<div className="min-w-20 grow basis-full sm:basis-0">
|
||||
<Prompts
|
||||
@@ -781,7 +781,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<h2 className="text-lg font-semibold">Tools</h2>
|
||||
<div className="mt-3 flex flex-wrap items-center gap-1">
|
||||
<button
|
||||
@@ -823,7 +823,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<h2 className="text-lg font-semibold">Agent type</h2>
|
||||
<div className="mt-3">
|
||||
<Dropdown
|
||||
@@ -848,7 +848,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<button
|
||||
onClick={() =>
|
||||
setIsAdvancedSectionExpanded(!isAdvancedSectionExpanded)
|
||||
@@ -1032,11 +1032,9 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="col-span-3 flex flex-col gap-2 max-[1179px]:h-auto max-[1179px]:px-0 max-[1179px]:py-0 min-[1180px]:h-full min-[1180px]:py-2 dark:text-[#E0E0E0]">
|
||||
<div className="col-span-3 flex flex-col gap-3 rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<h2 className="text-lg font-semibold">Preview</h2>
|
||||
<div className="flex-1 max-[1179px]:overflow-visible min-[1180px]:min-h-0 min-[1180px]:overflow-hidden">
|
||||
<AgentPreviewArea />
|
||||
</div>
|
||||
<AgentPreviewArea />
|
||||
</div>
|
||||
</div>
|
||||
<ConfirmationModal
|
||||
@@ -1073,9 +1071,9 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
function AgentPreviewArea() {
|
||||
const selectedAgent = useSelector(selectSelectedAgent);
|
||||
return (
|
||||
<div className="dark:bg-raisin-black w-full rounded-[30px] border border-[#F6F6F6] bg-white max-[1179px]:h-[600px] min-[1180px]:h-full dark:border-[#7E7E7E]">
|
||||
<div className="dark:bg-raisin-black h-full w-full rounded-[30px] border border-[#F6F6F6] bg-white max-[1180px]:h-192 dark:border-[#7E7E7E]">
|
||||
{selectedAgent?.status === 'published' ? (
|
||||
<div className="flex h-full w-full flex-col overflow-hidden rounded-[30px]">
|
||||
<div className="flex h-full w-full flex-col justify-end overflow-auto rounded-[30px]">
|
||||
<AgentPreview />
|
||||
</div>
|
||||
) : (
|
||||
|
||||
@@ -177,15 +177,13 @@ export default function SharedAgent() {
|
||||
/>
|
||||
</div>
|
||||
<div className="flex w-[95%] max-w-[1500px] flex-col items-center pb-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
|
||||
<div className="w-full px-2">
|
||||
<MessageInput
|
||||
onSubmit={(text) => handleQuestionSubmission(text)}
|
||||
loading={status === 'loading'}
|
||||
showSourceButton={sharedAgent ? false : true}
|
||||
showToolButton={sharedAgent ? false : true}
|
||||
autoFocus={false}
|
||||
/>
|
||||
</div>
|
||||
<MessageInput
|
||||
onSubmit={(text) => handleQuestionSubmission(text)}
|
||||
loading={status === 'loading'}
|
||||
showSourceButton={sharedAgent ? false : true}
|
||||
showToolButton={sharedAgent ? false : true}
|
||||
autoFocus={false}
|
||||
/>
|
||||
<p className="text-gray-4000 dark:text-sonic-silver hidden w-screen self-center bg-transparent py-2 text-center text-xs md:inline md:w-full">
|
||||
{t('tagline')}
|
||||
</p>
|
||||
|
||||
4
frontend/src/assets/expand.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<svg width="27" height="26" viewBox="0 0 27 26" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M4.03371 5.27275L4.1915 20.9162C4.20021 21.7802 4.90766 22.4735 5.77162 22.4648L21.4151 22.307C22.2791 22.2983 22.9724 21.5909 22.9637 20.7269L22.8059 5.0834C22.7972 4.21944 22.0897 3.52612 21.2258 3.53483L5.58228 3.69262C4.71831 3.70134 4.02499 4.40878 4.03371 5.27275Z" stroke="#949494" stroke-width="2.08591" stroke-linejoin="round"/>
|
||||
<path d="M9.42289 22.428L9.23354 3.65585M17.6924 15.0436L15.5856 12.9788L17.6504 10.872M6.29419 22.4596L12.5516 22.3965M6.10484 3.68741L12.3622 3.62429" stroke="#949494" stroke-width="2.08591" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 692 B |
@@ -1,5 +1,5 @@
|
||||
<svg width="113" height="124" viewBox="0 0 113 124" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<circle cx="55.5" cy="71" r="53" fill="#E8E3F3" fill-opacity="0.6"/>
|
||||
<circle cx="55.5" cy="71" r="53" fill="#F1F1F1" fill-opacity="0.5"/>
|
||||
<rect x="-0.599797" y="0.654564" width="43.9445" height="61.5222" rx="4.39444" transform="matrix(-0.999048 0.0436194 0.0436194 0.999048 68.9873 43.3176)" fill="#EEEEEE" stroke="#999999" stroke-width="1.25556"/>
|
||||
<rect x="0.704349" y="-0.540466" width="46.4556" height="64.0333" rx="5.65" transform="matrix(-0.991445 -0.130526 -0.130526 0.991445 96.3673 40.893)" fill="#FAFAFA" stroke="#999999" stroke-width="1.25556"/>
|
||||
<path d="M94.3796 45.7849C94.7417 43.0349 92.8059 40.5122 90.0559 40.1501L55.2011 35.5614C52.4511 35.1994 49.9284 37.1352 49.5663 39.8851L48.3372 49.2212L93.1505 55.121L94.3796 45.7849Z" fill="#EEEEEE"/>
|
||||
|
||||
|
Before Width: | Height: | Size: 2.0 KiB After Width: | Height: | Size: 2.0 KiB |
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#949494" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-panel-left-close-icon lucide-panel-left-close"><rect width="18" height="18" x="3" y="3" rx="2"/><path d="M9 3v18"/><path d="m16 15-3-3 3-3"/></svg>
|
||||
|
Before Width: | Height: | Size: 345 B |
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#949494" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-panel-left-open-icon lucide-panel-left-open"><rect width="18" height="18" x="3" y="3" rx="2"/><path d="M9 3v18"/><path d="m14 9 3 3-3 3"/></svg>
|
||||
|
Before Width: | Height: | Size: 342 B |
16
frontend/src/assets/sharepoint.svg
Normal file
@@ -0,0 +1,16 @@
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
|
||||
<!-- Uploaded to: SVG Repo, www.svgrepo.com, Transformed by: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 0 48 48" id="b" xmlns="http://www.w3.org/2000/svg" fill="#000000" stroke="#000000" stroke-width="3.312">
|
||||
|
||||
<g id="SVGRepo_bgCarrier" stroke-width="0"/>
|
||||
|
||||
<g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
|
||||
<g id="SVGRepo_iconCarrier">
|
||||
|
||||
<defs>
|
||||
|
||||
<style>.c{fill:none;stroke:#000000;stroke-linecap:round;stroke-linejoin:round;}</style>
|
||||
|
||||
</defs>
|
||||
|
After Width: | Height: | Size: 1.2 KiB |
@@ -136,34 +136,33 @@ const Chunks: React.FC<ChunksProps> = ({
|
||||
|
||||
const pathParts = path ? path.split('/') : [];
|
||||
|
||||
const fetchChunks = async () => {
|
||||
const fetchChunks = () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const response = await userService.getDocumentChunks(
|
||||
documentId,
|
||||
page,
|
||||
perPage,
|
||||
token,
|
||||
path,
|
||||
searchTerm,
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to fetch chunks data');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
setPage(data.page);
|
||||
setPerPage(data.per_page);
|
||||
setTotalChunks(data.total);
|
||||
setPaginatedChunks(data.chunks);
|
||||
} catch (error) {
|
||||
setPaginatedChunks([]);
|
||||
console.error(error);
|
||||
} finally {
|
||||
// ✅ always runs, success or failure
|
||||
userService
|
||||
.getDocumentChunks(documentId, page, perPage, token, path, searchTerm)
|
||||
.then((response) => {
|
||||
if (!response.ok) {
|
||||
setLoading(false);
|
||||
setPaginatedChunks([]);
|
||||
throw new Error('Failed to fetch chunks data');
|
||||
}
|
||||
return response.json();
|
||||
})
|
||||
.then((data) => {
|
||||
setPage(data.page);
|
||||
setPerPage(data.per_page);
|
||||
setTotalChunks(data.total);
|
||||
setPaginatedChunks(data.chunks);
|
||||
setLoading(false);
|
||||
})
|
||||
.catch((error) => {
|
||||
setLoading(false);
|
||||
setPaginatedChunks([]);
|
||||
});
|
||||
} catch (e) {
|
||||
setLoading(false);
|
||||
setPaginatedChunks([]);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
13
frontend/src/components/ConnectedStateSkeleton.tsx
Normal file
@@ -0,0 +1,13 @@
|
||||
const ConnectedStateSkeleton = () => (
|
||||
<div className="mb-4">
|
||||
<div className="flex w-full animate-pulse items-center justify-between rounded-[10px] bg-gray-200 px-4 py-2 dark:bg-gray-700">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="h-4 w-4 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
<div className="h-4 w-32 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
<div className="h-4 w-16 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
export default ConnectedStateSkeleton;
|
||||
@@ -150,7 +150,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
{isConnected ? (
|
||||
<div className="mb-4">
|
||||
<div className="flex w-full items-center justify-between rounded-[10px] bg-[#8FDD51] px-4 py-2 text-sm font-medium text-[#212121]">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex max-w-[500px] items-center gap-2">
|
||||
<svg className="h-4 w-4" viewBox="0 0 24 24">
|
||||
<path
|
||||
fill="currentColor"
|
||||
|
||||
@@ -20,9 +20,14 @@ type CopyButtonProps = {
|
||||
const DEFAULT_ICON_SIZE = 'w-4 h-4';
|
||||
const DEFAULT_PADDING = 'p-2';
|
||||
const DEFAULT_COPIED_DURATION = 2000;
|
||||
const DEFAULT_BG_LIGHT = '#FFFFFF';
|
||||
const DEFAULT_BG_DARK = 'transparent';
|
||||
const DEFAULT_HOVER_BG_LIGHT = '#EEEEEE';
|
||||
const DEFAULT_HOVER_BG_DARK = '#464152';
|
||||
|
||||
export default function CopyButton({
|
||||
textToCopy,
|
||||
|
||||
iconSize = DEFAULT_ICON_SIZE,
|
||||
padding = DEFAULT_PADDING,
|
||||
showText = false,
|
||||
@@ -38,8 +43,9 @@ export default function CopyButton({
|
||||
const iconWrapperClasses = clsx(
|
||||
'flex items-center justify-center rounded-full transition-colors duration-150 ease-in-out',
|
||||
padding,
|
||||
`bg-[${DEFAULT_BG_LIGHT}] dark:bg-[${DEFAULT_BG_DARK}]`,
|
||||
{
|
||||
[`bg-[#FFFFFF}] dark:bg-transparent hover:bg-[#EEEEEE] dark:hover:bg-purple-taupe`]:
|
||||
[`hover:bg-[${DEFAULT_HOVER_BG_LIGHT}] dark:hover:bg-[${DEFAULT_HOVER_BG_DARK}]`]:
|
||||
!isCopied,
|
||||
'bg-green-100 dark:bg-green-900 hover:bg-green-100 dark:hover:bg-green-900':
|
||||
isCopied,
|
||||
|
||||
13
frontend/src/components/FileSelectionSkeleton.tsx
Normal file
@@ -0,0 +1,13 @@
|
||||
const FilesSectionSkeleton = () => (
|
||||
<div className="rounded-lg border border-[#EEE6FF78] dark:border-[#6A6A6A]">
|
||||
<div className="p-4">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div className="h-5 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div className="h-8 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
<div className="h-4 w-40 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
export default FilesSectionSkeleton;
|
||||
@@ -7,7 +7,10 @@ import {
|
||||
getSessionToken,
|
||||
setSessionToken,
|
||||
removeSessionToken,
|
||||
validateProviderSession,
|
||||
} from '../utils/providerUtils';
|
||||
import ConnectedStateSkeleton from './ConnectedStateSkeleton';
|
||||
import FilesSectionSkeleton from './FileSelectionSkeleton';
|
||||
|
||||
interface PickerFile {
|
||||
id: string;
|
||||
@@ -50,20 +53,9 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
|
||||
const validateSession = async (sessionToken: string) => {
|
||||
try {
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
const validateResponse = await fetch(
|
||||
`${apiHost}/api/connectors/validate-session`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: 'google_drive',
|
||||
session_token: sessionToken,
|
||||
}),
|
||||
},
|
||||
const validateResponse = await validateProviderSession(
|
||||
token,
|
||||
'google_drive',
|
||||
);
|
||||
|
||||
if (!validateResponse.ok) {
|
||||
@@ -234,30 +226,6 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
onSelectionChange([], []);
|
||||
};
|
||||
|
||||
const ConnectedStateSkeleton = () => (
|
||||
<div className="mb-4">
|
||||
<div className="flex w-full animate-pulse items-center justify-between rounded-[10px] bg-gray-200 px-4 py-2 dark:bg-gray-700">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="h-4 w-4 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
<div className="h-4 w-32 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
<div className="h-4 w-16 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
const FilesSectionSkeleton = () => (
|
||||
<div className="rounded-lg border border-[#EEE6FF78] dark:border-[#6A6A6A]">
|
||||
<div className="p-4">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div className="h-5 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div className="h-8 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
<div className="h-4 w-40 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<div>
|
||||
{isValidating ? (
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { createPortal } from 'react-dom';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
|
||||
@@ -8,7 +6,6 @@ import endpoints from '../api/endpoints';
|
||||
import userService from '../api/services/userService';
|
||||
import AlertIcon from '../assets/alert.svg';
|
||||
import ClipIcon from '../assets/clip.svg';
|
||||
import DragFileUpload from '../assets/DragFileUpload.svg';
|
||||
import ExitIcon from '../assets/exit.svg';
|
||||
import SendArrowIcon from './SendArrowIcon';
|
||||
import SourceIcon from '../assets/source.svg';
|
||||
@@ -20,7 +17,6 @@ import {
|
||||
selectAttachments,
|
||||
updateAttachment,
|
||||
} from '../upload/uploadSlice';
|
||||
import { reorderAttachments } from '../upload/uploadSlice';
|
||||
|
||||
import { ActiveState } from '../models/misc';
|
||||
import {
|
||||
@@ -57,7 +53,6 @@ export default function MessageInput({
|
||||
const [isToolsPopupOpen, setIsToolsPopupOpen] = useState(false);
|
||||
const [uploadModalState, setUploadModalState] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
const [handleDragActive, setHandleDragActive] = useState<boolean>(false);
|
||||
|
||||
const selectedDocs = useSelector(selectSelectedDocs);
|
||||
const token = useSelector(selectToken);
|
||||
@@ -87,134 +82,77 @@ export default function MessageInput({
|
||||
};
|
||||
}, [browserOS]);
|
||||
|
||||
const uploadFiles = useCallback(
|
||||
(files: File[]) => {
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
|
||||
files.forEach((file) => {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
const xhr = new XMLHttpRequest();
|
||||
const uniqueId = crypto.randomUUID();
|
||||
|
||||
const newAttachment = {
|
||||
id: uniqueId,
|
||||
fileName: file.name,
|
||||
progress: 0,
|
||||
status: 'uploading' as const,
|
||||
taskId: '',
|
||||
};
|
||||
|
||||
dispatch(addAttachment(newAttachment));
|
||||
|
||||
xhr.upload.addEventListener('progress', (event) => {
|
||||
if (event.lengthComputable) {
|
||||
const progress = Math.round((event.loaded / event.total) * 100);
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uniqueId,
|
||||
updates: { progress },
|
||||
}),
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
xhr.onload = () => {
|
||||
if (xhr.status === 200) {
|
||||
const response = JSON.parse(xhr.responseText);
|
||||
if (response.task_id) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uniqueId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
xhr.onerror = () => {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: uniqueId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
};
|
||||
|
||||
xhr.open('POST', `${apiHost}${endpoints.USER.STORE_ATTACHMENT}`);
|
||||
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
xhr.send(formData);
|
||||
});
|
||||
},
|
||||
[dispatch, token],
|
||||
);
|
||||
|
||||
const handleFileAttachment = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
if (!e.target.files || e.target.files.length === 0) return;
|
||||
|
||||
const files = Array.from(e.target.files);
|
||||
uploadFiles(files);
|
||||
const file = e.target.files[0];
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
|
||||
// clear input so same file can be selected again
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
const xhr = new XMLHttpRequest();
|
||||
|
||||
const newAttachment = {
|
||||
fileName: file.name,
|
||||
progress: 0,
|
||||
status: 'uploading' as const,
|
||||
taskId: '',
|
||||
};
|
||||
|
||||
dispatch(addAttachment(newAttachment));
|
||||
|
||||
xhr.upload.addEventListener('progress', (event) => {
|
||||
if (event.lengthComputable) {
|
||||
const progress = Math.round((event.loaded / event.total) * 100);
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
taskId: newAttachment.taskId,
|
||||
updates: { progress },
|
||||
}),
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
xhr.onload = () => {
|
||||
if (xhr.status === 200) {
|
||||
const response = JSON.parse(xhr.responseText);
|
||||
if (response.task_id) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
taskId: newAttachment.taskId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
taskId: newAttachment.taskId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
xhr.onerror = () => {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
taskId: newAttachment.taskId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
};
|
||||
|
||||
xhr.open('POST', `${apiHost}${endpoints.USER.STORE_ATTACHMENT}`);
|
||||
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
xhr.send(formData);
|
||||
e.target.value = '';
|
||||
};
|
||||
|
||||
// Drag and drop handler
|
||||
const onDrop = useCallback(
|
||||
(acceptedFiles: File[]) => {
|
||||
uploadFiles(acceptedFiles);
|
||||
setHandleDragActive(false);
|
||||
},
|
||||
[uploadFiles],
|
||||
);
|
||||
|
||||
const { getRootProps, getInputProps } = useDropzone({
|
||||
onDrop,
|
||||
noClick: true,
|
||||
noKeyboard: true,
|
||||
multiple: true,
|
||||
onDragEnter: () => {
|
||||
setHandleDragActive(true);
|
||||
},
|
||||
onDragLeave: () => {
|
||||
setHandleDragActive(false);
|
||||
},
|
||||
maxSize: 25000000,
|
||||
accept: {
|
||||
'application/pdf': ['.pdf'],
|
||||
'text/plain': ['.txt'],
|
||||
'text/x-rst': ['.rst'],
|
||||
'text/x-markdown': ['.md'],
|
||||
'application/zip': ['.zip'],
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
|
||||
['.docx'],
|
||||
'application/json': ['.json'],
|
||||
'text/csv': ['.csv'],
|
||||
'text/html': ['.html'],
|
||||
'application/epub+zip': ['.epub'],
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': [
|
||||
'.xlsx',
|
||||
],
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation':
|
||||
['.pptx'],
|
||||
'image/png': ['.png'],
|
||||
'image/jpeg': ['.jpeg'],
|
||||
'image/jpg': ['.jpg'],
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
const checkTaskStatus = () => {
|
||||
const processingAttachments = attachments.filter(
|
||||
@@ -229,7 +167,7 @@ export default function MessageInput({
|
||||
if (data.status === 'SUCCESS') {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
taskId: attachment.taskId!,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
@@ -241,14 +179,14 @@ export default function MessageInput({
|
||||
} else if (data.status === 'FAILURE') {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
taskId: attachment.taskId!,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
} else if (data.status === 'PROGRESS' && data.result?.current) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
taskId: attachment.taskId!,
|
||||
updates: { progress: data.result.current },
|
||||
}),
|
||||
);
|
||||
@@ -257,7 +195,7 @@ export default function MessageInput({
|
||||
.catch(() => {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
taskId: attachment.taskId!,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
@@ -321,131 +259,90 @@ export default function MessageInput({
|
||||
handleAbort();
|
||||
};
|
||||
|
||||
// Drag state for reordering
|
||||
const [draggingId, setDraggingId] = useState<string | null>(null);
|
||||
|
||||
// no preview object URLs to revoke (preview removed per reviewer request)
|
||||
|
||||
const findIndexById = (id: string) =>
|
||||
attachments.findIndex((a) => a.id === id);
|
||||
|
||||
const handleDragStart = (e: React.DragEvent, id: string) => {
|
||||
setDraggingId(id);
|
||||
try {
|
||||
e.dataTransfer.setData('text/plain', id);
|
||||
e.dataTransfer.effectAllowed = 'move';
|
||||
} catch (err) {
|
||||
// ignore
|
||||
}
|
||||
};
|
||||
|
||||
const handleDragOver = (e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
e.dataTransfer.dropEffect = 'move';
|
||||
};
|
||||
|
||||
const handleDropOn = (e: React.DragEvent, targetId: string) => {
|
||||
e.preventDefault();
|
||||
const sourceId = e.dataTransfer.getData('text/plain');
|
||||
if (!sourceId || sourceId === targetId) return;
|
||||
|
||||
const sourceIndex = findIndexById(sourceId);
|
||||
const destIndex = findIndexById(targetId);
|
||||
if (sourceIndex === -1 || destIndex === -1) return;
|
||||
|
||||
dispatch(reorderAttachments({ sourceIndex, destinationIndex: destIndex }));
|
||||
setDraggingId(null);
|
||||
};
|
||||
|
||||
return (
|
||||
<div {...getRootProps()} className="flex w-full flex-col">
|
||||
<input {...getInputProps()} />
|
||||
<div className="mx-2 flex w-full flex-col">
|
||||
<div className="border-dark-gray bg-lotion dark:border-grey relative flex w-full flex-col rounded-[23px] border dark:bg-transparent">
|
||||
<div className="flex flex-wrap gap-1.5 px-2 py-2 sm:gap-2 sm:px-3">
|
||||
{attachments.map((attachment) => {
|
||||
return (
|
||||
<div
|
||||
key={attachment.id}
|
||||
draggable={true}
|
||||
onDragStart={(e) => handleDragStart(e, attachment.id)}
|
||||
onDragOver={handleDragOver}
|
||||
onDrop={(e) => handleDropOn(e, attachment.id)}
|
||||
className={`group dark:text-bright-gray relative flex items-center rounded-xl bg-[#EFF3F4] px-2 py-1 text-[12px] text-[#5D5D5D] sm:px-3 sm:py-1.5 sm:text-[14px] dark:bg-[#393B3D] ${
|
||||
attachment.status !== 'completed'
|
||||
? 'opacity-70'
|
||||
: 'opacity-100'
|
||||
} ${draggingId === attachment.id ? 'ring-dashed opacity-60 ring-2 ring-purple-200' : ''}`}
|
||||
title={attachment.fileName}
|
||||
>
|
||||
<div className="bg-purple-30 mr-2 flex h-8 w-8 items-center justify-center rounded-md p-1">
|
||||
{attachment.status === 'completed' && (
|
||||
<img
|
||||
src={DocumentationDark}
|
||||
alt="Attachment"
|
||||
className="h-[15px] w-[15px] object-fill"
|
||||
/>
|
||||
)}
|
||||
|
||||
{attachment.status === 'failed' && (
|
||||
<img
|
||||
src={AlertIcon}
|
||||
alt="Failed"
|
||||
className="h-[15px] w-[15px] object-fill"
|
||||
/>
|
||||
)}
|
||||
|
||||
{(attachment.status === 'uploading' ||
|
||||
attachment.status === 'processing') && (
|
||||
<div className="flex h-[15px] w-[15px] items-center justify-center">
|
||||
<svg className="h-[15px] w-[15px]" viewBox="0 0 24 24">
|
||||
<circle
|
||||
className="opacity-0"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r="10"
|
||||
stroke="transparent"
|
||||
strokeWidth="4"
|
||||
fill="none"
|
||||
/>
|
||||
<circle
|
||||
className="text-[#ECECF1]"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r="10"
|
||||
stroke="currentColor"
|
||||
strokeWidth="4"
|
||||
fill="none"
|
||||
strokeDasharray="62.83"
|
||||
strokeDashoffset={
|
||||
62.83 * (1 - attachment.progress / 100)
|
||||
}
|
||||
transform="rotate(-90 12 12)"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<span className="max-w-[120px] truncate font-medium sm:max-w-[150px]">
|
||||
{attachment.fileName}
|
||||
</span>
|
||||
|
||||
<button
|
||||
className="ml-1.5 flex items-center justify-center rounded-full p-1"
|
||||
onClick={() => {
|
||||
dispatch(removeAttachment(attachment.id));
|
||||
}}
|
||||
aria-label={t('conversation.attachments.remove')}
|
||||
>
|
||||
{attachments.map((attachment, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className={`group dark:text-bright-gray relative flex items-center rounded-xl bg-[#EFF3F4] px-2 py-1 text-[12px] text-[#5D5D5D] sm:px-3 sm:py-1.5 sm:text-[14px] dark:bg-[#393B3D] ${
|
||||
attachment.status !== 'completed' ? 'opacity-70' : 'opacity-100'
|
||||
}`}
|
||||
title={attachment.fileName}
|
||||
>
|
||||
<div className="bg-purple-30 mr-2 items-center justify-center rounded-lg p-[5.5px]">
|
||||
{attachment.status === 'completed' && (
|
||||
<img
|
||||
src={ExitIcon}
|
||||
alt={t('conversation.attachments.remove')}
|
||||
className="h-2.5 w-2.5 filter dark:invert"
|
||||
src={DocumentationDark}
|
||||
alt="Attachment"
|
||||
className="h-[15px] w-[15px] object-fill"
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
|
||||
{attachment.status === 'failed' && (
|
||||
<img
|
||||
src={AlertIcon}
|
||||
alt="Failed"
|
||||
className="h-[15px] w-[15px] object-fill"
|
||||
/>
|
||||
)}
|
||||
|
||||
{(attachment.status === 'uploading' ||
|
||||
attachment.status === 'processing') && (
|
||||
<div className="flex h-[15px] w-[15px] items-center justify-center">
|
||||
<svg className="h-[15px] w-[15px]" viewBox="0 0 24 24">
|
||||
<circle
|
||||
className="opacity-0"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r="10"
|
||||
stroke="transparent"
|
||||
strokeWidth="4"
|
||||
fill="none"
|
||||
/>
|
||||
<circle
|
||||
className="text-[#ECECF1]"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r="10"
|
||||
stroke="currentColor"
|
||||
strokeWidth="4"
|
||||
fill="none"
|
||||
strokeDasharray="62.83"
|
||||
strokeDashoffset={
|
||||
62.83 * (1 - attachment.progress / 100)
|
||||
}
|
||||
transform="rotate(-90 12 12)"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
||||
<span className="max-w-[120px] truncate font-medium sm:max-w-[150px]">
|
||||
{attachment.fileName}
|
||||
</span>
|
||||
|
||||
<button
|
||||
className="ml-1.5 flex items-center justify-center rounded-full p-1"
|
||||
onClick={() => {
|
||||
if (attachment.id) {
|
||||
dispatch(removeAttachment(attachment.id));
|
||||
} else if (attachment.taskId) {
|
||||
dispatch(removeAttachment(attachment.taskId));
|
||||
}
|
||||
}}
|
||||
aria-label={t('conversation.attachments.remove')}
|
||||
>
|
||||
<img
|
||||
src={ExitIcon}
|
||||
alt={t('conversation.attachments.remove')}
|
||||
className="h-2.5 w-2.5 filter dark:invert"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<div className="w-full">
|
||||
@@ -527,7 +424,6 @@ export default function MessageInput({
|
||||
<input
|
||||
type="file"
|
||||
className="hidden"
|
||||
multiple
|
||||
onChange={handleFileAttachment}
|
||||
/>
|
||||
</label>
|
||||
@@ -587,20 +483,6 @@ export default function MessageInput({
|
||||
close={() => setUploadModalState('INACTIVE')}
|
||||
/>
|
||||
)}
|
||||
|
||||
{handleDragActive &&
|
||||
createPortal(
|
||||
<div className="dark:bg-gray-alpha/50 pointer-events-none fixed top-0 left-0 z-50 flex size-full flex-col items-center justify-center bg-white/85">
|
||||
<img className="filter dark:invert" src={DragFileUpload} />
|
||||
<span className="text-outer-space dark:text-silver px-2 text-2xl font-bold">
|
||||
{t('modals.uploadDoc.drag.title')}
|
||||
</span>
|
||||
<span className="text-s text-outer-space dark:text-silver w-48 p-2 text-center">
|
||||
{t('modals.uploadDoc.drag.description')}
|
||||
</span>
|
||||
</div>,
|
||||
document.body,
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
175
frontend/src/components/SharePointPicker.tsx
Normal file
@@ -0,0 +1,175 @@
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ConnectorAuth from './ConnectorAuth';
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
import {
|
||||
getSessionToken,
|
||||
setSessionToken,
|
||||
removeSessionToken,
|
||||
validateProviderSession,
|
||||
} from '../utils/providerUtils';
|
||||
import ConnectedStateSkeleton from './ConnectedStateSkeleton';
|
||||
import FilesSectionSkeleton from './FileSelectionSkeleton';
|
||||
|
||||
interface SharePointPickerProps {
|
||||
token: string | null;
|
||||
}
|
||||
|
||||
const SharePointPicker: React.FC<SharePointPickerProps> = ({ token }) => {
|
||||
const { t } = useTranslation();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [userEmail, setUserEmail] = useState<string>('');
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
const [authError, setAuthError] = useState<string>('');
|
||||
const [accessToken, setAccessToken] = useState<string | null>(null);
|
||||
const [isValidating, setIsValidating] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const sessionToken = getSessionToken('share_point');
|
||||
if (sessionToken) {
|
||||
setIsValidating(true);
|
||||
setIsConnected(true); // Optimistically set as connected for skeleton
|
||||
validateSession(sessionToken);
|
||||
}
|
||||
}, [token]);
|
||||
|
||||
const validateSession = async (sessionToken: string) => {
|
||||
try {
|
||||
const validateResponse = await validateProviderSession(
|
||||
token,
|
||||
'share_point',
|
||||
);
|
||||
|
||||
if (!validateResponse.ok) {
|
||||
setIsConnected(false);
|
||||
setAuthError(
|
||||
t('modals.uploadDoc.connectors.sharePoint.sessionExpired'),
|
||||
);
|
||||
setIsValidating(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
const validateData = await validateResponse.json();
|
||||
if (validateData.success) {
|
||||
setUserEmail(
|
||||
validateData.user_email ||
|
||||
t('modals.uploadDoc.connectors.auth.connectedUser'),
|
||||
);
|
||||
setIsConnected(true);
|
||||
setAuthError('');
|
||||
setAccessToken(validateData.access_token || null);
|
||||
setIsValidating(false);
|
||||
|
||||
return true;
|
||||
} else {
|
||||
setIsConnected(false);
|
||||
setAuthError(
|
||||
validateData.error ||
|
||||
t('modals.uploadDoc.connectors.sharePoint.sessionExpiredGeneric'),
|
||||
);
|
||||
setIsValidating(false);
|
||||
return false;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error validating session:', error);
|
||||
setAuthError(t('modals.uploadDoc.connectors.sharePoint.validateFailed'));
|
||||
setIsConnected(false);
|
||||
setIsValidating(false);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
const handleDisconnect = async () => {
|
||||
const sessionToken = getSessionToken('share_point');
|
||||
if (sessionToken) {
|
||||
try {
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
await fetch(`${apiHost}/api/connectors/disconnect`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: 'share_point',
|
||||
session_token: sessionToken,
|
||||
}),
|
||||
});
|
||||
} catch (err) {
|
||||
console.error('Error disconnecting from SharePoint:', err);
|
||||
}
|
||||
}
|
||||
|
||||
removeSessionToken('share_point');
|
||||
setIsConnected(false);
|
||||
setAccessToken(null);
|
||||
setUserEmail('');
|
||||
setAuthError('');
|
||||
};
|
||||
|
||||
const handleOpenPicker = async () => {
|
||||
alert('Feature not supported yet.');
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
{isValidating ? (
|
||||
<>
|
||||
<ConnectedStateSkeleton />
|
||||
<FilesSectionSkeleton />
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<ConnectorAuth
|
||||
provider="share_point"
|
||||
label={t('modals.uploadDoc.connectors.sharePoint.connect')}
|
||||
onSuccess={(data) => {
|
||||
setUserEmail(
|
||||
data.user_email ||
|
||||
t('modals.uploadDoc.connectors.auth.connectedUser'),
|
||||
);
|
||||
setIsConnected(true);
|
||||
setAuthError('');
|
||||
|
||||
if (data.session_token) {
|
||||
setSessionToken('share_point', data.session_token);
|
||||
validateSession(data.session_token);
|
||||
}
|
||||
}}
|
||||
onError={(error) => {
|
||||
setAuthError(error);
|
||||
setIsConnected(false);
|
||||
}}
|
||||
isConnected={isConnected}
|
||||
userEmail={userEmail}
|
||||
onDisconnect={handleDisconnect}
|
||||
errorMessage={authError}
|
||||
/>
|
||||
|
||||
{isConnected && (
|
||||
<div className="rounded-lg border border-[#EEE6FF78] dark:border-[#6A6A6A]">
|
||||
<div className="p-4">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<h3 className="text-sm font-medium">
|
||||
{t('modals.uploadDoc.connectors.sharePoint.selectedFiles')}
|
||||
</h3>
|
||||
<button
|
||||
onClick={() => handleOpenPicker()}
|
||||
className="rounded-md bg-[#A076F6] px-3 py-1 text-sm text-white hover:bg-[#8A5FD4]"
|
||||
disabled={isLoading}
|
||||
>
|
||||
{isLoading
|
||||
? t('modals.uploadDoc.connectors.sharePoint.loading')
|
||||
: t('modals.uploadDoc.connectors.sharePoint.selectFiles')}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default SharePointPicker;
|
||||
@@ -33,7 +33,7 @@ export default function Sidebar({
|
||||
return (
|
||||
<div ref={sidebarRef} className="h-vh relative">
|
||||
<div
|
||||
className={`dark:bg-chinese-black fixed top-0 right-0 z-50 h-full w-64 transform bg-white shadow-xl transition-all duration-300 sm:w-80 ${
|
||||
className={`dark:bg-chinese-black fixed top-0 right-0 z-50 h-full w-72 transform bg-white shadow-xl transition-all duration-300 sm:w-96 ${
|
||||
isOpen ? 'translate-x-[10px]' : 'translate-x-full'
|
||||
} border-l border-[#9ca3af]/10`}
|
||||
>
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import React, { useRef, useEffect, useState, useLayoutEffect } from 'react';
|
||||
import { createPortal } from 'react-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { Doc } from '../models/misc';
|
||||
@@ -108,7 +107,7 @@ export default function SourcesPopup({
|
||||
onClose();
|
||||
};
|
||||
|
||||
const popupContent = (
|
||||
return (
|
||||
<div
|
||||
ref={popupRef}
|
||||
className="bg-lotion dark:bg-charleston-green-2 fixed z-50 flex flex-col rounded-xl shadow-[0px_9px_46px_8px_#0000001F,0px_24px_38px_3px_#00000024,0px_11px_15px_-7px_#00000033]"
|
||||
@@ -219,7 +218,7 @@ export default function SourcesPopup({
|
||||
</>
|
||||
) : (
|
||||
<div className="dark:text-bright-gray p-4 text-center text-gray-500 dark:text-[14px]">
|
||||
{t('conversation.sources.noSourcesAvailable')}
|
||||
{t('noSourcesAvailable')}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -246,6 +245,4 @@ export default function SourcesPopup({
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return createPortal(popupContent, document.body);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import React, { useEffect, useRef, useState, useLayoutEffect } from 'react';
|
||||
import { createPortal } from 'react-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
@@ -134,10 +133,10 @@ export default function ToolsPopup({
|
||||
tool.displayName.toLowerCase().includes(searchTerm.toLowerCase()),
|
||||
);
|
||||
|
||||
const popupContent = (
|
||||
return (
|
||||
<div
|
||||
ref={popupRef}
|
||||
className="border-light-silver bg-lotion dark:border-dim-gray dark:bg-charleston-green-2 fixed z-50 rounded-lg border shadow-[0px_9px_46px_8px_#0000001F,0px_24px_38px_3px_#00000024,0px_11px_15px_-7px_#00000033]"
|
||||
className="border-light-silver bg-lotion dark:border-dim-gray dark:bg-charleston-green-2 fixed z-9999 rounded-lg border shadow-[0px_9px_46px_8px_#0000001F,0px_24px_38px_3px_#00000024,0px_11px_15px_-7px_#00000033]"
|
||||
style={{
|
||||
top: popupPosition.showAbove ? popupPosition.top : undefined,
|
||||
bottom: popupPosition.showAbove
|
||||
@@ -243,6 +242,4 @@ export default function ToolsPopup({
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return createPortal(popupContent, document.body);
|
||||
}
|
||||
|
||||
@@ -44,10 +44,7 @@ export default function UploadToast() {
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className="fixed right-4 bottom-4 z-50 flex max-w-md flex-col gap-2"
|
||||
onMouseDown={(e) => e.stopPropagation()}
|
||||
>
|
||||
<div className="fixed right-4 bottom-4 z-50 flex max-w-md flex-col gap-2">
|
||||
{uploadTasks
|
||||
.filter((task) => !task.dismissed)
|
||||
.map((task) => {
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
|
||||
import SharedAgentCard from '../agents/SharedAgentCard';
|
||||
import DragFileUpload from '../assets/DragFileUpload.svg';
|
||||
import MessageInput from '../components/MessageInput';
|
||||
import { useMediaQuery } from '../hooks';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import {
|
||||
selectConversationId,
|
||||
selectSelectedAgent,
|
||||
selectToken,
|
||||
} from '../preferences/preferenceSlice';
|
||||
import { AppDispatch } from '../store';
|
||||
import Upload from '../upload/Upload';
|
||||
import { handleSendFeedback } from './conversationHandlers';
|
||||
import ConversationMessages from './ConversationMessages';
|
||||
import { FEEDBACK, Query } from './conversationModels';
|
||||
@@ -41,12 +45,53 @@ export default function Conversation() {
|
||||
const selectedAgent = useSelector(selectSelectedAgent);
|
||||
const completedAttachments = useSelector(selectCompletedAttachments);
|
||||
|
||||
const [uploadModalState, setUploadModalState] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
const [files, setFiles] = useState<File[]>([]);
|
||||
const [lastQueryReturnedErr, setLastQueryReturnedErr] =
|
||||
useState<boolean>(false);
|
||||
const [isShareModalOpen, setShareModalState] = useState<boolean>(false);
|
||||
const [handleDragActive, setHandleDragActive] = useState<boolean>(false);
|
||||
|
||||
const fetchStream = useRef<any>(null);
|
||||
|
||||
const onDrop = useCallback((acceptedFiles: File[]) => {
|
||||
setUploadModalState('ACTIVE');
|
||||
setFiles(acceptedFiles);
|
||||
setHandleDragActive(false);
|
||||
}, []);
|
||||
|
||||
const { getRootProps, getInputProps } = useDropzone({
|
||||
onDrop,
|
||||
noClick: true,
|
||||
multiple: true,
|
||||
onDragEnter: () => {
|
||||
setHandleDragActive(true);
|
||||
},
|
||||
onDragLeave: () => {
|
||||
setHandleDragActive(false);
|
||||
},
|
||||
maxSize: 25000000,
|
||||
accept: {
|
||||
'application/pdf': ['.pdf'],
|
||||
'text/plain': ['.txt'],
|
||||
'text/x-rst': ['.rst'],
|
||||
'text/x-markdown': ['.md'],
|
||||
'application/zip': ['.zip'],
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
|
||||
['.docx'],
|
||||
'application/json': ['.json'],
|
||||
'text/csv': ['.csv'],
|
||||
'text/html': ['.html'],
|
||||
'application/epub+zip': ['.epub'],
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': [
|
||||
'.xlsx',
|
||||
],
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation':
|
||||
['.pptx'],
|
||||
},
|
||||
});
|
||||
|
||||
const handleFetchAnswer = useCallback(
|
||||
({ question, index }: { question: string; index?: number }) => {
|
||||
fetchStream.current = dispatch(fetchAnswer({ question, indx: index }));
|
||||
@@ -177,7 +222,14 @@ export default function Conversation() {
|
||||
/>
|
||||
|
||||
<div className="bg-opacity-0 z-3 flex h-auto w-full max-w-[1300px] flex-col items-end self-center rounded-2xl py-1 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
|
||||
<div className="flex w-full items-center rounded-[40px] px-2">
|
||||
<div
|
||||
{...getRootProps()}
|
||||
className="flex w-full items-center rounded-[40px]"
|
||||
>
|
||||
<label htmlFor="file-upload" className="sr-only">
|
||||
{t('modals.uploadDoc.label')}
|
||||
</label>
|
||||
<input {...getInputProps()} id="file-upload" />
|
||||
<MessageInput
|
||||
onSubmit={(text) => {
|
||||
handleQuestionSubmission(text);
|
||||
@@ -192,6 +244,26 @@ export default function Conversation() {
|
||||
{t('tagline')}
|
||||
</p>
|
||||
</div>
|
||||
{handleDragActive && (
|
||||
<div className="bg-opacity-50 dark:bg-gray-alpha pointer-events-none fixed top-0 left-0 z-30 flex size-full flex-col items-center justify-center bg-white">
|
||||
<img className="filter dark:invert" src={DragFileUpload} />
|
||||
<span className="text-outer-space dark:text-silver px-2 text-2xl font-bold">
|
||||
{t('modals.uploadDoc.drag.title')}
|
||||
</span>
|
||||
<span className="text-s text-outer-space dark:text-silver w-48 p-2 text-center">
|
||||
{t('modals.uploadDoc.drag.description')}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
{uploadModalState === 'ACTIVE' && (
|
||||
<Upload
|
||||
receivedFile={files}
|
||||
setModalState={setUploadModalState}
|
||||
isOnboarding={false}
|
||||
renderTab={'file'}
|
||||
close={() => setUploadModalState('INACTIVE')}
|
||||
></Upload>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
}
|
||||
|
||||
.list li:not(:first-child) {
|
||||
margin-top: 0.5em;
|
||||
margin-top: 1em;
|
||||
}
|
||||
|
||||
.list li > .list {
|
||||
margin-top: 0.5em;
|
||||
margin-top: 1em;
|
||||
}
|
||||
|
||||
@@ -86,7 +86,10 @@ const ConversationBubble = forwardRef<
|
||||
// const bubbleRef = useRef<HTMLDivElement | null>(null);
|
||||
const chunks = useSelector(selectChunks);
|
||||
const selectedDocs = useSelector(selectSelectedDocs);
|
||||
const [isLikeHovered, setIsLikeHovered] = useState(false);
|
||||
const [isEditClicked, setIsEditClicked] = useState(false);
|
||||
const [isDislikeHovered, setIsDislikeHovered] = useState(false);
|
||||
const [isQuestionHovered, setIsQuestionHovered] = useState(false);
|
||||
const [editInputBox, setEditInputBox] = useState<string>('');
|
||||
const messageRef = useRef<HTMLDivElement>(null);
|
||||
const [shouldShowToggle, setShouldShowToggle] = useState(false);
|
||||
@@ -112,7 +115,11 @@ const ConversationBubble = forwardRef<
|
||||
let bubble;
|
||||
if (type === 'QUESTION') {
|
||||
bubble = (
|
||||
<div className={`group ${className}`}>
|
||||
<div
|
||||
onMouseEnter={() => setIsQuestionHovered(true)}
|
||||
onMouseLeave={() => setIsQuestionHovered(false)}
|
||||
className={className}
|
||||
>
|
||||
<div className="flex flex-col items-end">
|
||||
{filesAttached && filesAttached.length > 0 && (
|
||||
<div className="mr-12 mb-4 flex flex-wrap justify-end gap-2">
|
||||
@@ -181,7 +188,7 @@ const ConversationBubble = forwardRef<
|
||||
setIsEditClicked(true);
|
||||
setEditInputBox(message ?? '');
|
||||
}}
|
||||
className={`hover:bg-light-silver mt-3 flex h-fit shrink-0 cursor-pointer items-center rounded-full p-2 pt-1.5 pl-1.5 dark:hover:bg-[#35363B] ${isEditClicked ? 'visible' : 'invisible group-hover:visible'}`}
|
||||
className={`hover:bg-light-silver mt-3 flex h-fit shrink-0 cursor-pointer items-center rounded-full p-2 pt-1.5 pl-1.5 dark:hover:bg-[#35363B] ${isQuestionHovered || isEditClicked ? 'visible' : 'invisible'}`}
|
||||
>
|
||||
<img src={Edit} alt="Edit" className="cursor-pointer" />
|
||||
</button>
|
||||
@@ -414,7 +421,7 @@ const ConversationBubble = forwardRef<
|
||||
<Fragment key={index}>
|
||||
{segment.type === 'text' ? (
|
||||
<ReactMarkdown
|
||||
className="fade-in flex flex-col gap-3 leading-normal break-words whitespace-pre-wrap"
|
||||
className="fade-in leading-normal break-words whitespace-pre-wrap"
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[rehypeKatex]}
|
||||
components={{
|
||||
@@ -561,7 +568,13 @@ const ConversationBubble = forwardRef<
|
||||
<>
|
||||
<div className="relative mr-2 flex items-center justify-center">
|
||||
<div>
|
||||
<div className="bg-white-3000 dark:hover:bg-purple-taupe flex items-center justify-center rounded-full p-2 hover:bg-[#EEEEEE] dark:bg-transparent">
|
||||
<div
|
||||
className={`flex items-center justify-center rounded-full p-2 ${
|
||||
isLikeHovered
|
||||
? 'dark:bg-purple-taupe bg-[#EEEEEE]'
|
||||
: 'bg-white-3000 dark:bg-transparent'
|
||||
}`}
|
||||
>
|
||||
<Like
|
||||
className={`${feedback === 'LIKE' ? 'fill-white-3000 stroke-purple-30 dark:fill-transparent' : 'stroke-gray-4000 fill-none'} cursor-pointer`}
|
||||
onClick={() => {
|
||||
@@ -571,6 +584,8 @@ const ConversationBubble = forwardRef<
|
||||
handleFeedback?.('LIKE');
|
||||
}
|
||||
}}
|
||||
onMouseEnter={() => setIsLikeHovered(true)}
|
||||
onMouseLeave={() => setIsLikeHovered(false)}
|
||||
></Like>
|
||||
</div>
|
||||
</div>
|
||||
@@ -578,7 +593,13 @@ const ConversationBubble = forwardRef<
|
||||
|
||||
<div className="relative mr-2 flex items-center justify-center">
|
||||
<div>
|
||||
<div className="bg-white-3000 dark:hover:bg-purple-taupe flex items-center justify-center rounded-full p-2 hover:bg-[#EEEEEE] dark:bg-transparent">
|
||||
<div
|
||||
className={`flex items-center justify-center rounded-full p-2 ${
|
||||
isDislikeHovered
|
||||
? 'dark:bg-purple-taupe bg-[#EEEEEE]'
|
||||
: 'bg-white-3000 dark:bg-transparent'
|
||||
}`}
|
||||
>
|
||||
<Dislike
|
||||
className={`${feedback === 'DISLIKE' ? 'fill-white-3000 stroke-red-2000 dark:fill-transparent' : 'stroke-gray-4000 fill-none'} cursor-pointer`}
|
||||
onClick={() => {
|
||||
@@ -588,6 +609,8 @@ const ConversationBubble = forwardRef<
|
||||
handleFeedback?.('DISLIKE');
|
||||
}
|
||||
}}
|
||||
onMouseEnter={() => setIsDislikeHovered(true)}
|
||||
onMouseLeave={() => setIsDislikeHovered(false)}
|
||||
></Dislike>
|
||||
</div>
|
||||
</div>
|
||||
@@ -635,7 +658,7 @@ function AllSources(sources: AllSourcesProps) {
|
||||
<p className="text-left text-xl">{`${sources.sources.length} ${t('conversation.sources.title')}`}</p>
|
||||
<div className="mx-1 mt-2 h-[0.8px] w-full rounded-full bg-[#C4C4C4]/40 lg:w-[95%]"></div>
|
||||
</div>
|
||||
<div className="scrollbar-thin mt-6 flex h-[90%] w-52 flex-col gap-4 overflow-y-auto pr-3 sm:w-64">
|
||||
<div className="mt-6 flex h-[90%] w-60 flex-col items-center gap-4 overflow-y-auto sm:w-80">
|
||||
{sources.sources.map((source, index) => {
|
||||
const isExternalSource = source.link && source.link !== 'local';
|
||||
return (
|
||||
|
||||
@@ -161,16 +161,14 @@ export const SharedConversation = () => {
|
||||
/>
|
||||
<div className="flex w-full max-w-[1200px] flex-col items-center gap-4 pb-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
|
||||
{apiKey ? (
|
||||
<div className="w-full px-2">
|
||||
<MessageInput
|
||||
onSubmit={(text) => {
|
||||
handleQuestionSubmission(text);
|
||||
}}
|
||||
loading={status === 'loading'}
|
||||
showSourceButton={false}
|
||||
showToolButton={false}
|
||||
/>
|
||||
</div>
|
||||
<MessageInput
|
||||
onSubmit={(text) => {
|
||||
handleQuestionSubmission(text);
|
||||
}}
|
||||
loading={status === 'loading'}
|
||||
showSourceButton={false}
|
||||
showToolButton={false}
|
||||
/>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => navigate('/')}
|
||||
|
||||
@@ -56,7 +56,7 @@ export const fetchAnswer = createAsyncThunk<
|
||||
question,
|
||||
signal,
|
||||
state.preference.token,
|
||||
state.preference.selectedDocs || [],
|
||||
state.preference.selectedDocs!,
|
||||
currentConversationId,
|
||||
state.preference.prompt.id,
|
||||
state.preference.chunks,
|
||||
@@ -163,7 +163,7 @@ export const fetchAnswer = createAsyncThunk<
|
||||
question,
|
||||
signal,
|
||||
state.preference.token,
|
||||
state.preference.selectedDocs || [],
|
||||
state.preference.selectedDocs!,
|
||||
state.conversation.conversationId,
|
||||
state.preference.prompt.id,
|
||||
state.preference.chunks,
|
||||
|
||||
@@ -118,34 +118,18 @@ layer(base);
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
/* Light theme scrollbar */
|
||||
&::-webkit-scrollbar-thumb {
|
||||
background: rgba(215, 215, 215, 1);
|
||||
background: rgba(156, 163, 175, 0.5);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
&::-webkit-scrollbar-thumb:hover {
|
||||
background: rgba(195, 195, 195, 1);
|
||||
background: rgba(156, 163, 175, 0.7);
|
||||
}
|
||||
|
||||
/* Dark theme scrollbar */
|
||||
.dark &::-webkit-scrollbar-thumb {
|
||||
background: rgba(77, 78, 88, 1);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.dark &::-webkit-scrollbar-thumb:hover {
|
||||
background: rgba(97, 98, 108, 1);
|
||||
}
|
||||
|
||||
/* For Firefox - Light theme */
|
||||
/* For Firefox */
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: rgba(215, 215, 215, 1) transparent;
|
||||
|
||||
/* For Firefox - Dark theme */
|
||||
.dark & {
|
||||
scrollbar-color: rgba(77, 78, 88, 1) transparent;
|
||||
}
|
||||
scrollbar-color: rgba(156, 163, 175, 0.5) transparent;
|
||||
}
|
||||
|
||||
@utility table-default {
|
||||
|
||||
@@ -255,8 +255,8 @@
|
||||
"addQuery": "Add Query"
|
||||
},
|
||||
"drag": {
|
||||
"title": "Drop attachments here",
|
||||
"description": "Release to upload your attachments"
|
||||
"title": "Upload a source file",
|
||||
"description": "Drop your file here to add it as a source"
|
||||
},
|
||||
"progress": {
|
||||
"upload": "Upload is in progress",
|
||||
@@ -298,6 +298,10 @@
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Upload from Google Drive"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Upload from SharePoint"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
@@ -327,6 +331,24 @@
|
||||
"remove": "Remove",
|
||||
"folderAlt": "Folder",
|
||||
"fileAlt": "File"
|
||||
},
|
||||
"sharePoint": {
|
||||
"connect": "Connect to SharePoint",
|
||||
"sessionExpired": "Session expired. Please reconnect to SharePoint.",
|
||||
"sessionExpiredGeneric": "Session expired. Please reconnect your account.",
|
||||
"validateFailed": "Failed to validate session. Please reconnect.",
|
||||
"noSession": "No valid session found. Please reconnect to SharePoint.",
|
||||
"noAccessToken": "No access token available. Please reconnect to SharePoint.",
|
||||
"pickerFailed": "Failed to open file picker. Please try again.",
|
||||
"selectedFiles": "Selected Files",
|
||||
"selectFiles": "Select Files",
|
||||
"loading": "Loading...",
|
||||
"noFilesSelected": "No files or folders selected",
|
||||
"folders": "Folders",
|
||||
"files": "Files",
|
||||
"remove": "Remove",
|
||||
"folderAlt": "Folder",
|
||||
"fileAlt": "File"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -421,8 +443,7 @@
|
||||
"title": "Sources",
|
||||
"text": "Choose Your Sources",
|
||||
"link": "Source link",
|
||||
"view_more": "{{count}} more sources",
|
||||
"noSourcesAvailable": "No sources available"
|
||||
"view_more": "{{count}} more sources"
|
||||
},
|
||||
"attachments": {
|
||||
"attach": "Attach",
|
||||
|
||||
@@ -218,8 +218,8 @@
|
||||
"addQuery": "Agregar Consulta"
|
||||
},
|
||||
"drag": {
|
||||
"title": "Suelta los archivos adjuntos aquí",
|
||||
"description": "Suelta para subir tus archivos adjuntos"
|
||||
"title": "Subir archivo fuente",
|
||||
"description": "Arrastra tu archivo aquí para agregarlo como fuente"
|
||||
},
|
||||
"progress": {
|
||||
"upload": "Subida en progreso",
|
||||
@@ -261,6 +261,10 @@
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Subir desde Google Drive"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Subir desde SharePoint"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
@@ -290,6 +294,24 @@
|
||||
"remove": "Eliminar",
|
||||
"folderAlt": "Carpeta",
|
||||
"fileAlt": "Archivo"
|
||||
},
|
||||
"sharePoint": {
|
||||
"connect": "Conectar a SharePoint",
|
||||
"sessionExpired": "Sesión expirada. Por favor, reconecte a SharePoint.",
|
||||
"sessionExpiredGeneric": "Sesión expirada. Por favor, reconecte su cuenta.",
|
||||
"validateFailed": "Error al validar la sesión. Por favor, reconecte.",
|
||||
"noSession": "No se encontró una sesión válida. Por favor, reconecte a SharePoint.",
|
||||
"noAccessToken": "No hay token de acceso disponible. Por favor, reconecte a SharePoint.",
|
||||
"pickerFailed": "Error al abrir el selector de archivos. Por favor, inténtelo de nuevo.",
|
||||
"selectedFiles": "Archivos Seleccionados",
|
||||
"selectFiles": "Seleccionar Archivos",
|
||||
"loading": "Cargando...",
|
||||
"noFilesSelected": "No hay archivos o carpetas seleccionados",
|
||||
"folders": "Carpetas",
|
||||
"files": "Archivos",
|
||||
"remove": "Eliminar",
|
||||
"folderAlt": "Carpeta",
|
||||
"fileAlt": "Archivo"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -384,8 +406,7 @@
|
||||
"title": "Fuentes",
|
||||
"link": "Enlace fuente",
|
||||
"view_more": "Ver {{count}} más fuentes",
|
||||
"text": "Elegir tus fuentes",
|
||||
"noSourcesAvailable": "No hay fuentes disponibles"
|
||||
"text": "Elegir tus fuentes"
|
||||
},
|
||||
"attachments": {
|
||||
"attach": "Adjuntar",
|
||||
|
||||
@@ -218,8 +218,8 @@
|
||||
"addQuery": "クエリを追加"
|
||||
},
|
||||
"drag": {
|
||||
"title": "添付ファイルをここにドロップ",
|
||||
"description": "リリースして添付ファイルをアップロード"
|
||||
"title": "ソースファイルをアップロード",
|
||||
"description": "ファイルをここにドロップしてソースとして追加してください"
|
||||
},
|
||||
"progress": {
|
||||
"upload": "アップロード中",
|
||||
@@ -261,6 +261,10 @@
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Google Driveからアップロード"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "SharePointからアップロード"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
@@ -290,6 +294,24 @@
|
||||
"remove": "削除",
|
||||
"folderAlt": "フォルダ",
|
||||
"fileAlt": "ファイル"
|
||||
},
|
||||
"sharePoint": {
|
||||
"connect": "SharePointに接続",
|
||||
"sessionExpired": "セッションが期限切れです。SharePointに再接続してください。",
|
||||
"sessionExpiredGeneric": "セッションが期限切れです。アカウントに再接続してください。",
|
||||
"validateFailed": "セッションの検証に失敗しました。再接続してください。",
|
||||
"noSession": "有効なセッションが見つかりません。SharePointに再接続してください。",
|
||||
"noAccessToken": "アクセストークンが利用できません。SharePointに再接続してください。",
|
||||
"pickerFailed": "ファイルピッカーを開けませんでした。もう一度お試しください。",
|
||||
"selectedFiles": "選択されたファイル",
|
||||
"selectFiles": "ファイルを選択",
|
||||
"loading": "読み込み中...",
|
||||
"noFilesSelected": "ファイルまたはフォルダが選択されていません",
|
||||
"folders": "フォルダ",
|
||||
"files": "ファイル",
|
||||
"remove": "削除",
|
||||
"folderAlt": "フォルダ",
|
||||
"fileAlt": "ファイル"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -384,8 +406,7 @@
|
||||
"title": "ソース",
|
||||
"text": "ソーステキスト",
|
||||
"link": "ソースリンク",
|
||||
"view_more": "さらに{{count}}個のソース",
|
||||
"noSourcesAvailable": "利用可能なソースがありません"
|
||||
"view_more": "さらに{{count}}個のソース"
|
||||
},
|
||||
"attachments": {
|
||||
"attach": "添付",
|
||||
|
||||
@@ -218,8 +218,8 @@
|
||||
"addQuery": "Добавить запрос"
|
||||
},
|
||||
"drag": {
|
||||
"title": "Перетащите вложения сюда",
|
||||
"description": "Отпустите, чтобы загрузить ваши вложения"
|
||||
"title": "Загрузить исходный файл",
|
||||
"description": "Перетащите файл сюда, чтобы добавить его как источник"
|
||||
},
|
||||
"progress": {
|
||||
"upload": "Идет загрузка",
|
||||
@@ -261,6 +261,10 @@
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Загрузить из Google Drive"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Загрузить из SharePoint"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
@@ -290,6 +294,24 @@
|
||||
"remove": "Удалить",
|
||||
"folderAlt": "Папка",
|
||||
"fileAlt": "Файл"
|
||||
},
|
||||
"sharePoint": {
|
||||
"connect": "Подключиться к SharePoint",
|
||||
"sessionExpired": "Сеанс истек. Пожалуйста, переподключитесь к SharePoint.",
|
||||
"sessionExpiredGeneric": "Сеанс истек. Пожалуйста, переподключите свою учетную запись.",
|
||||
"validateFailed": "Не удалось проверить сеанс. Пожалуйста, переподключитесь.",
|
||||
"noSession": "Действительный сеанс не найден. Пожалуйста, переподключитесь к SharePoint.",
|
||||
"noAccessToken": "Токен доступа недоступен. Пожалуйста, переподключитесь к SharePoint.",
|
||||
"pickerFailed": "Не удалось открыть средство выбора файлов. Пожалуйста, попробуйте еще раз.",
|
||||
"selectedFiles": "Выбранные файлы",
|
||||
"selectFiles": "Выбрать файлы",
|
||||
"loading": "Загрузка...",
|
||||
"noFilesSelected": "Файлы или папки не выбраны",
|
||||
"folders": "Папки",
|
||||
"files": "Файлы",
|
||||
"remove": "Удалить",
|
||||
"folderAlt": "Папка",
|
||||
"fileAlt": "Файл"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -384,8 +406,7 @@
|
||||
"title": "Источники",
|
||||
"text": "Выберите ваши источники",
|
||||
"link": "Ссылка на источник",
|
||||
"view_more": "ещё {{count}} источников",
|
||||
"noSourcesAvailable": "Нет доступных источников"
|
||||
"view_more": "ещё {{count}} источников"
|
||||
},
|
||||
"attachments": {
|
||||
"attach": "Прикрепить",
|
||||
|
||||
@@ -218,8 +218,8 @@
|
||||
"addQuery": "新增查詢"
|
||||
},
|
||||
"drag": {
|
||||
"title": "將附件拖放到此處",
|
||||
"description": "釋放以上傳您的附件"
|
||||
"title": "上傳來源檔案",
|
||||
"description": "將檔案拖放到此處以新增為來源"
|
||||
},
|
||||
"progress": {
|
||||
"upload": "正在上傳",
|
||||
@@ -261,6 +261,10 @@
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "從Google Drive上傳"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "從SharePoint上傳"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
@@ -290,6 +294,24 @@
|
||||
"remove": "移除",
|
||||
"folderAlt": "資料夾",
|
||||
"fileAlt": "檔案"
|
||||
},
|
||||
"sharePoint": {
|
||||
"connect": "連接到 SharePoint",
|
||||
"sessionExpired": "工作階段已過期。請重新連接到 SharePoint。",
|
||||
"sessionExpiredGeneric": "工作階段已過期。請重新連接您的帳戶。",
|
||||
"validateFailed": "驗證工作階段失敗。請重新連接。",
|
||||
"noSession": "未找到有效工作階段。請重新連接到 SharePoint。",
|
||||
"noAccessToken": "存取權杖不可用。請重新連接到 SharePoint。",
|
||||
"pickerFailed": "無法開啟檔案選擇器。請重試。",
|
||||
"selectedFiles": "已選擇的檔案",
|
||||
"selectFiles": "選擇檔案",
|
||||
"loading": "載入中...",
|
||||
"noFilesSelected": "未選擇檔案或資料夾",
|
||||
"folders": "資料夾",
|
||||
"files": "檔案",
|
||||
"remove": "移除",
|
||||
"folderAlt": "資料夾",
|
||||
"fileAlt": "檔案"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -384,8 +406,7 @@
|
||||
"title": "來源",
|
||||
"text": "來源文字",
|
||||
"link": "來源連結",
|
||||
"view_more": "查看更多 {{count}} 個來源",
|
||||
"noSourcesAvailable": "沒有可用的來源"
|
||||
"view_more": "查看更多 {{count}} 個來源"
|
||||
},
|
||||
"attachments": {
|
||||
"attach": "附件",
|
||||
|
||||
@@ -218,8 +218,8 @@
|
||||
"addQuery": "添加查询"
|
||||
},
|
||||
"drag": {
|
||||
"title": "将附件拖放到此处",
|
||||
"description": "释放以上传您的附件"
|
||||
"title": "上传源文件",
|
||||
"description": "将文件拖放到此处以添加为源"
|
||||
},
|
||||
"progress": {
|
||||
"upload": "正在上传",
|
||||
@@ -261,6 +261,10 @@
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "从Google Drive上传"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "从SharePoint上传"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
@@ -290,6 +294,24 @@
|
||||
"remove": "删除",
|
||||
"folderAlt": "文件夹",
|
||||
"fileAlt": "文件"
|
||||
},
|
||||
"sharePoint": {
|
||||
"connect": "连接到 SharePoint",
|
||||
"sessionExpired": "会话已过期。请重新连接到 SharePoint。",
|
||||
"sessionExpiredGeneric": "会话已过期。请重新连接您的账户。",
|
||||
"validateFailed": "验证会话失败。请重新连接。",
|
||||
"noSession": "未找到有效会话。请重新连接到 SharePoint。",
|
||||
"noAccessToken": "访问令牌不可用。请重新连接到 SharePoint。",
|
||||
"pickerFailed": "无法打开文件选择器。请重试。",
|
||||
"selectedFiles": "已选择的文件",
|
||||
"selectFiles": "选择文件",
|
||||
"loading": "加载中...",
|
||||
"noFilesSelected": "未选择文件或文件夹",
|
||||
"folders": "文件夹",
|
||||
"files": "文件",
|
||||
"remove": "删除",
|
||||
"folderAlt": "文件夹",
|
||||
"fileAlt": "文件"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -384,8 +406,7 @@
|
||||
"title": "来源",
|
||||
"text": "来源文本",
|
||||
"link": "来源链接",
|
||||
"view_more": "还有{{count}}个来源",
|
||||
"noSourcesAvailable": "没有可用的来源"
|
||||
"view_more": "还有{{count}}个来源"
|
||||
},
|
||||
"attachments": {
|
||||
"attach": "附件",
|
||||
|
||||
@@ -103,8 +103,8 @@ export const ShareConversationModal = ({
|
||||
};
|
||||
|
||||
return (
|
||||
<WrapperModal close={close} contentClassName="!overflow-visible">
|
||||
<div className="flex w-[600px] max-w-[80vw] flex-col gap-2">
|
||||
<WrapperModal close={close}>
|
||||
<div className="flex max-h-[80vh] w-[600px] max-w-[80vw] flex-col gap-2 overflow-y-auto">
|
||||
<h2 className="text-eerie-black dark:text-chinese-white text-xl font-medium">
|
||||
{t('modals.shareConv.label')}
|
||||
</h2>
|
||||
|
||||
@@ -32,6 +32,7 @@ import { FormField, IngestorConfig, IngestorType } from './types/ingestor';
|
||||
|
||||
import { FilePicker } from '../components/FilePicker';
|
||||
import GoogleDrivePicker from '../components/GoogleDrivePicker';
|
||||
import SharePointPicker from '../components/SharePointPicker';
|
||||
|
||||
import ChevronRight from '../assets/chevron-right.svg';
|
||||
|
||||
@@ -252,6 +253,8 @@ function Upload({
|
||||
token={token}
|
||||
/>
|
||||
);
|
||||
case 'share_point_picker':
|
||||
return <SharePointPicker key={field.name} token={token} />;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import UrlIcon from '../../assets/url.svg';
|
||||
import GithubIcon from '../../assets/github.svg';
|
||||
import RedditIcon from '../../assets/reddit.svg';
|
||||
import DriveIcon from '../../assets/drive.svg';
|
||||
import SharePoint from '../../assets/sharepoint.svg';
|
||||
|
||||
export type IngestorType =
|
||||
| 'crawler'
|
||||
@@ -11,7 +12,8 @@ export type IngestorType =
|
||||
| 'reddit'
|
||||
| 'url'
|
||||
| 'google_drive'
|
||||
| 'local_file';
|
||||
| 'local_file'
|
||||
| 'share_point';
|
||||
|
||||
export interface IngestorConfig {
|
||||
type: IngestorType | null;
|
||||
@@ -33,7 +35,8 @@ export type FieldType =
|
||||
| 'boolean'
|
||||
| 'local_file_picker'
|
||||
| 'remote_file_picker'
|
||||
| 'google_drive_picker';
|
||||
| 'google_drive_picker'
|
||||
| 'share_point_picker';
|
||||
|
||||
export interface FormField {
|
||||
name: string;
|
||||
@@ -147,6 +150,24 @@ export const IngestorFormSchemas: IngestorSchema[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
key: 'share_point',
|
||||
label: 'Share Point',
|
||||
icon: SharePoint,
|
||||
heading: 'Upload from Share Point',
|
||||
validate: () => {
|
||||
const sharePointClientId = import.meta.env.VITE_SHARE_POINT_CLIENT_ID;
|
||||
return !!sharePointClientId;
|
||||
},
|
||||
fields: [
|
||||
{
|
||||
name: 'files',
|
||||
label: 'Select Files from Share Point',
|
||||
type: 'share_point_picker',
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
export const IngestorDefaultConfigs: Record<
|
||||
@@ -175,6 +196,14 @@ export const IngestorDefaultConfigs: Record<
|
||||
},
|
||||
},
|
||||
local_file: { name: '', config: { files: [] } },
|
||||
share_point: {
|
||||
name: '',
|
||||
config: {
|
||||
file_ids: '',
|
||||
folder_ids: '',
|
||||
recursive: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export interface IngestorOption {
|
||||
|
||||
@@ -2,11 +2,11 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
import { RootState } from '../store';
|
||||
|
||||
export interface Attachment {
|
||||
id: string; // Unique identifier for the attachment (required for state management)
|
||||
fileName: string;
|
||||
progress: number;
|
||||
status: 'uploading' | 'processing' | 'completed' | 'failed';
|
||||
taskId: string; // Server-assigned task ID (used for API calls)
|
||||
taskId: string;
|
||||
id?: string;
|
||||
token_count?: number;
|
||||
}
|
||||
|
||||
@@ -47,12 +47,12 @@ export const uploadSlice = createSlice({
|
||||
updateAttachment: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
taskId: string;
|
||||
updates: Partial<Attachment>;
|
||||
}>,
|
||||
) => {
|
||||
const index = state.attachments.findIndex(
|
||||
(att) => att.id === action.payload.id,
|
||||
(att) => att.taskId === action.payload.taskId,
|
||||
);
|
||||
if (index !== -1) {
|
||||
state.attachments[index] = {
|
||||
@@ -63,26 +63,9 @@ export const uploadSlice = createSlice({
|
||||
},
|
||||
removeAttachment: (state, action: PayloadAction<string>) => {
|
||||
state.attachments = state.attachments.filter(
|
||||
(att) => att.id !== action.payload,
|
||||
(att) => att.taskId !== action.payload && att.id !== action.payload,
|
||||
);
|
||||
},
|
||||
// Reorder attachments array by moving item from sourceIndex to destinationIndex
|
||||
reorderAttachments: (
|
||||
state,
|
||||
action: PayloadAction<{ sourceIndex: number; destinationIndex: number }>,
|
||||
) => {
|
||||
const { sourceIndex, destinationIndex } = action.payload;
|
||||
if (
|
||||
sourceIndex < 0 ||
|
||||
destinationIndex < 0 ||
|
||||
sourceIndex >= state.attachments.length ||
|
||||
destinationIndex >= state.attachments.length
|
||||
)
|
||||
return;
|
||||
|
||||
const [moved] = state.attachments.splice(sourceIndex, 1);
|
||||
state.attachments.splice(destinationIndex, 0, moved);
|
||||
},
|
||||
clearAttachments: (state) => {
|
||||
state.attachments = state.attachments.filter(
|
||||
(att) => att.status === 'uploading' || att.status === 'processing',
|
||||
@@ -138,7 +121,6 @@ export const {
|
||||
addAttachment,
|
||||
updateAttachment,
|
||||
removeAttachment,
|
||||
reorderAttachments,
|
||||
clearAttachments,
|
||||
addUploadTask,
|
||||
updateUploadTask,
|
||||
|
||||
@@ -14,3 +14,21 @@ export const setSessionToken = (provider: string, token: string): void => {
|
||||
export const removeSessionToken = (provider: string): void => {
|
||||
localStorage.removeItem(`${provider}_session_token`);
|
||||
};
|
||||
|
||||
export const validateProviderSession = async (
|
||||
token: string | null,
|
||||
provider: string,
|
||||
) => {
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
return await fetch(`${apiHost}/api/connectors/validate-session`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: provider,
|
||||
session_token: getSessionToken(provider),
|
||||
}),
|
||||
});
|
||||
};
|
||||
|
||||
@@ -3,6 +3,7 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.core.settings import settings
|
||||
from tests.conftest import FakeMongoCollection
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -167,13 +168,10 @@ class TestBaseAgentTools:
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
user_tools.insert_one(
|
||||
{"_id": "1", "user": "test_user", "name": "tool1", "status": True}
|
||||
)
|
||||
user_tools.insert_one(
|
||||
{"_id": "2", "user": "test_user", "name": "tool2", "status": True}
|
||||
)
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
|
||||
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
|
||||
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": True},
|
||||
}
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
tools = agent._get_user_tools("test_user")
|
||||
@@ -189,13 +187,10 @@ class TestBaseAgentTools:
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
user_tools.insert_one(
|
||||
{"_id": "1", "user": "test_user", "name": "tool1", "status": True}
|
||||
)
|
||||
user_tools.insert_one(
|
||||
{"_id": "2", "user": "test_user", "name": "tool2", "status": False}
|
||||
)
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
|
||||
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
|
||||
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": False},
|
||||
}
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
tools = agent._get_user_tools("test_user")
|
||||
@@ -214,16 +209,17 @@ class TestBaseAgentTools:
|
||||
tool_id = str(ObjectId())
|
||||
tool_obj_id = ObjectId(tool_id)
|
||||
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"key": "api_key_123",
|
||||
"tools": [tool_id],
|
||||
}
|
||||
)
|
||||
fake_agent_collection = FakeMongoCollection()
|
||||
fake_agent_collection.docs["api_key_123"] = {
|
||||
"key": "api_key_123",
|
||||
"tools": [tool_id],
|
||||
}
|
||||
|
||||
tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tools_collection.insert_one({"_id": tool_obj_id, "name": "api_tool"})
|
||||
fake_tools_collection = FakeMongoCollection()
|
||||
fake_tools_collection.docs[tool_id] = {"_id": tool_obj_id, "name": "api_tool"}
|
||||
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["agents"] = fake_agent_collection
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] = fake_tools_collection
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
tools = agent._get_tools("api_key_123")
|
||||
|
||||
@@ -1,552 +0,0 @@
|
||||
import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAnswerValidation:
|
||||
def test_validate_request_passes_with_required_fields(
|
||||
self, mock_mongo_db, flask_app
|
||||
):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
data = {"question": "What is Python?"}
|
||||
|
||||
result = resource.validate_request(data)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_validate_request_fails_without_question(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
data = {}
|
||||
|
||||
result = resource.validate_request(data)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 400
|
||||
assert "question" in result.json["message"].lower()
|
||||
|
||||
def test_validate_with_conversation_id_required(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
data = {"question": "Test"}
|
||||
|
||||
result = resource.validate_request(data, require_conversation_id=True)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 400
|
||||
assert "conversation_id" in result.json["message"].lower()
|
||||
|
||||
def test_validate_passes_with_all_required_fields(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
data = {"question": "Test", "conversation_id": str(ObjectId())}
|
||||
|
||||
result = resource.validate_request(data, require_conversation_id=True)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestUsageChecking:
|
||||
def test_returns_none_when_no_api_key(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_error_for_invalid_api_key(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "invalid_key_123"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 401
|
||||
assert result.json["success"] is False
|
||||
assert "invalid" in result.json["message"].lower()
|
||||
|
||||
def test_checks_token_limit_when_enabled(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": True,
|
||||
"token_limit": 1000,
|
||||
"limited_request_mode": False,
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_checks_request_limit_when_enabled(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": False,
|
||||
"limited_request_mode": True,
|
||||
"request_limit": 100,
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_uses_default_limits_when_not_specified(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": True,
|
||||
"limited_request_mode": True,
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_exceeds_token_limit(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
token_usage_collection = mock_mongo_db[settings.MONGO_DB_NAME][
|
||||
"token_usage"
|
||||
]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": True,
|
||||
"token_limit": 100,
|
||||
"limited_request_mode": False,
|
||||
}
|
||||
)
|
||||
|
||||
token_usage_collection.insert_one(
|
||||
{
|
||||
"_id": ObjectId(),
|
||||
"api_key": "test_key",
|
||||
"prompt_tokens": 60,
|
||||
"generated_tokens": 50,
|
||||
"timestamp": datetime.datetime.now(),
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 429
|
||||
assert result.json["success"] is False
|
||||
assert "usage limit" in result.json["message"].lower()
|
||||
|
||||
def test_exceeds_request_limit(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
token_usage_collection = mock_mongo_db[settings.MONGO_DB_NAME][
|
||||
"token_usage"
|
||||
]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": False,
|
||||
"limited_request_mode": True,
|
||||
"request_limit": 2,
|
||||
}
|
||||
)
|
||||
|
||||
now = datetime.datetime.now()
|
||||
for i in range(3):
|
||||
token_usage_collection.insert_one(
|
||||
{
|
||||
"_id": ObjectId(),
|
||||
"api_key": "test_key",
|
||||
"prompt_tokens": 10,
|
||||
"generated_tokens": 10,
|
||||
"timestamp": now,
|
||||
}
|
||||
)
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 429
|
||||
assert result.json["success"] is False
|
||||
|
||||
def test_both_limits_disabled_returns_none(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": False,
|
||||
"limited_request_mode": False,
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGPTModelRetrieval:
|
||||
def test_initializes_gpt_model(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
assert hasattr(resource, "gpt_model")
|
||||
assert resource.gpt_model is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConversationServiceIntegration:
|
||||
def test_initializes_conversation_service(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
assert hasattr(resource, "conversation_service")
|
||||
assert resource.conversation_service is not None
|
||||
|
||||
def test_has_access_to_user_logs_collection(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
assert hasattr(resource, "user_logs_collection")
|
||||
assert resource.user_logs_collection is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompleteStreamMethod:
|
||||
def test_streams_answer_chunks(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "Hello "},
|
||||
{"answer": "world!"},
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test question",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
|
||||
answer_chunks = [s for s in stream if '"type": "answer"' in s]
|
||||
assert len(answer_chunks) == 2
|
||||
assert '"answer": "Hello "' in answer_chunks[0]
|
||||
assert '"answer": "world!"' in answer_chunks[1]
|
||||
|
||||
def test_streams_sources(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "Test answer"},
|
||||
{"sources": [{"title": "doc1.txt", "text": "x" * 200}]},
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
|
||||
source_chunks = [s for s in stream if '"type": "source"' in s]
|
||||
assert len(source_chunks) == 1
|
||||
assert '"title": "doc1.txt"' in source_chunks[0]
|
||||
|
||||
def test_handles_error_during_streaming(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.side_effect = Exception("Test error")
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert any('"type": "error"' in s for s in stream)
|
||||
|
||||
def test_saves_conversation_when_enabled(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "Test answer"},
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
with patch.object(
|
||||
resource.conversation_service, "save_conversation"
|
||||
) as mock_save:
|
||||
mock_save.return_value = str(ObjectId())
|
||||
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=True,
|
||||
)
|
||||
)
|
||||
|
||||
mock_save.assert_called_once()
|
||||
|
||||
def test_logs_to_user_logs_collection(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
user_logs = mock_mongo_db[settings.MONGO_DB_NAME]["user_logs"]
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "Test answer"},
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {"retriever": "test"}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="Test question?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key="test_key",
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert user_logs.count_documents({}) == 1
|
||||
log_entry = user_logs.find_one({})
|
||||
assert log_entry["action"] == "stream_answer"
|
||||
assert log_entry["user"] == "user123"
|
||||
assert log_entry["api_key"] == "test_key"
|
||||
assert log_entry["question"] == "Test question?"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestProcessResponseStream:
|
||||
def test_processes_complete_stream(self, mock_mongo_db, flask_app):
|
||||
import json
|
||||
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
conv_id = str(ObjectId())
|
||||
stream = [
|
||||
f'data: {json.dumps({"type": "answer", "answer": "Hello "})}\n\n',
|
||||
f'data: {json.dumps({"type": "answer", "answer": "world"})}\n\n',
|
||||
f'data: {json.dumps({"type": "source", "source": [{"title": "doc1"}]})}\n\n',
|
||||
f'data: {json.dumps({"type": "id", "id": conv_id})}\n\n',
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert result[0] == conv_id
|
||||
assert result[1] == "Hello world"
|
||||
assert result[2] == [{"title": "doc1"}]
|
||||
assert result[5] is None
|
||||
|
||||
def test_handles_stream_error(self, mock_mongo_db, flask_app):
|
||||
import json
|
||||
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
stream = [
|
||||
f'data: {json.dumps({"type": "error", "error": "Test error"})}\n\n',
|
||||
]
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert len(result) == 5
|
||||
assert result[0] is None
|
||||
assert result[4] == "Test error"
|
||||
|
||||
def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
stream = [
|
||||
"data: invalid json\n\n",
|
||||
'data: {"type": "end"}\n\n',
|
||||
]
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestErrorStreamGenerate:
|
||||
def test_generates_error_stream(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
error_stream = list(resource.error_stream_generate("Test error message"))
|
||||
|
||||
assert len(error_stream) == 1
|
||||
assert '"type": "error"' in error_stream[0]
|
||||
assert '"error": "Test error message"' in error_stream[0]
|
||||
@@ -1,242 +0,0 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConversationServiceGet:
|
||||
|
||||
def test_returns_none_when_no_conversation_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
result = service.get_conversation("", "user_123")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_no_user_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
result = service.get_conversation(str(ObjectId()), "")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_conversation_for_owner(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
conversation = {
|
||||
"_id": conv_id,
|
||||
"user": "user_123",
|
||||
"name": "Test Conv",
|
||||
"queries": [],
|
||||
}
|
||||
collection.insert_one(conversation)
|
||||
|
||||
result = service.get_conversation(str(conv_id), "user_123")
|
||||
|
||||
assert result is not None
|
||||
assert result["name"] == "Test Conv"
|
||||
assert result["_id"] == str(conv_id)
|
||||
|
||||
def test_returns_none_for_unauthorized_user(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one(
|
||||
{"_id": conv_id, "user": "owner_123", "name": "Private Conv"}
|
||||
)
|
||||
|
||||
result = service.get_conversation(str(conv_id), "hacker_456")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_converts_objectid_to_string(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one({"_id": conv_id, "user": "user_123", "name": "Test"})
|
||||
|
||||
result = service.get_conversation(str(conv_id), "user_123")
|
||||
|
||||
assert isinstance(result["_id"], str)
|
||||
assert result["_id"] == str(conv_id)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConversationServiceSave:
|
||||
|
||||
def test_raises_error_when_no_user_in_token(self, mock_mongo_db):
|
||||
"""Test validation: user ID required"""
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
mock_llm = Mock()
|
||||
|
||||
with pytest.raises(ValueError, match="User ID not found"):
|
||||
service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="Test?",
|
||||
response="Answer",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={}, # No 'sub' key
|
||||
)
|
||||
|
||||
def test_truncates_long_source_text(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from bson import ObjectId
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.gen.return_value = "Test Summary"
|
||||
|
||||
long_text = "x" * 2000
|
||||
sources = [{"text": long_text, "title": "Doc"}]
|
||||
|
||||
conv_id = service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="Question",
|
||||
response="Response",
|
||||
thought="",
|
||||
sources=sources,
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
saved_conv = collection.find_one({"_id": ObjectId(conv_id)})
|
||||
saved_source_text = saved_conv["queries"][0]["sources"][0]["text"]
|
||||
|
||||
assert len(saved_source_text) == 1000
|
||||
assert saved_source_text == "x" * 1000
|
||||
|
||||
def test_creates_new_conversation_with_summary(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from bson import ObjectId
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.gen.return_value = "Python Basics"
|
||||
|
||||
conv_id = service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="What is Python?",
|
||||
response="Python is a programming language",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
assert conv_id is not None
|
||||
saved_conv = collection.find_one({"_id": ObjectId(conv_id)})
|
||||
assert saved_conv["name"] == "Python Basics"
|
||||
assert saved_conv["user"] == "user_123"
|
||||
assert len(saved_conv["queries"]) == 1
|
||||
assert saved_conv["queries"][0]["prompt"] == "What is Python?"
|
||||
|
||||
def test_appends_to_existing_conversation(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from bson import ObjectId
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
existing_conv_id = ObjectId()
|
||||
collection.insert_one(
|
||||
{
|
||||
"_id": existing_conv_id,
|
||||
"user": "user_123",
|
||||
"name": "Old Conv",
|
||||
"queries": [{"prompt": "Q1", "response": "A1"}],
|
||||
}
|
||||
)
|
||||
|
||||
mock_llm = Mock()
|
||||
|
||||
result = service.save_conversation(
|
||||
conversation_id=str(existing_conv_id),
|
||||
question="Q2",
|
||||
response="A2",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
assert result == str(existing_conv_id)
|
||||
|
||||
def test_prevents_unauthorized_conversation_update(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one({"_id": conv_id, "user": "owner_123", "queries": []})
|
||||
|
||||
mock_llm = Mock()
|
||||
|
||||
with pytest.raises(ValueError, match="not found or unauthorized"):
|
||||
service.save_conversation(
|
||||
conversation_id=str(conv_id),
|
||||
question="Hack",
|
||||
response="Attempt",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={"sub": "hacker_456"},
|
||||
)
|
||||
@@ -1,252 +0,0 @@
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetPromptFunction:
|
||||
|
||||
def test_loads_custom_prompt_from_database(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
from application.core.settings import settings
|
||||
|
||||
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
|
||||
prompt_id = ObjectId()
|
||||
|
||||
prompts_collection.insert_one(
|
||||
{
|
||||
"_id": prompt_id,
|
||||
"content": "Custom prompt from database",
|
||||
"user": "user_123",
|
||||
}
|
||||
)
|
||||
|
||||
result = get_prompt(str(prompt_id), prompts_collection)
|
||||
assert result == "Custom prompt from database"
|
||||
|
||||
def test_raises_error_for_invalid_prompt_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
from application.core.settings import settings
|
||||
|
||||
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid prompt ID"):
|
||||
get_prompt(str(ObjectId()), prompts_collection)
|
||||
|
||||
def test_raises_error_for_malformed_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
from application.core.settings import settings
|
||||
|
||||
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid prompt ID"):
|
||||
get_prompt("not_a_valid_id", prompts_collection)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorInitialization:
|
||||
|
||||
def test_initializes_with_decoded_token(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {
|
||||
"question": "What is Python?",
|
||||
"conversation_id": str(ObjectId()),
|
||||
}
|
||||
decoded_token = {"sub": "user_123", "email": "test@example.com"}
|
||||
|
||||
processor = StreamProcessor(request_data, decoded_token)
|
||||
|
||||
assert processor.data == request_data
|
||||
assert processor.decoded_token == decoded_token
|
||||
assert processor.initial_user_id == "user_123"
|
||||
assert processor.conversation_id == request_data["conversation_id"]
|
||||
|
||||
def test_initializes_without_token(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "Test question"}
|
||||
|
||||
processor = StreamProcessor(request_data, None)
|
||||
|
||||
assert processor.decoded_token is None
|
||||
assert processor.initial_user_id is None
|
||||
assert processor.data == request_data
|
||||
|
||||
def test_initializes_default_attributes(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
processor = StreamProcessor({"question": "Test"}, {"sub": "user_123"})
|
||||
|
||||
assert processor.source == {}
|
||||
assert processor.all_sources == []
|
||||
assert processor.attachments == []
|
||||
assert processor.history == []
|
||||
assert processor.agent_config == {}
|
||||
assert processor.retriever_config == {}
|
||||
assert processor.is_shared_usage is False
|
||||
assert processor.shared_token is None
|
||||
|
||||
def test_extracts_conversation_id_from_request(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
conv_id = str(ObjectId())
|
||||
request_data = {"question": "Test", "conversation_id": conv_id}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
assert processor.conversation_id == conv_id
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorHistoryLoading:
|
||||
|
||||
def test_loads_history_from_existing_conversation(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
conversations_collection = mock_mongo_db[settings.MONGO_DB_NAME][
|
||||
"conversations"
|
||||
]
|
||||
conv_id = ObjectId()
|
||||
|
||||
conversations_collection.insert_one(
|
||||
{
|
||||
"_id": conv_id,
|
||||
"user": "user_123",
|
||||
"name": "Test Conv",
|
||||
"queries": [
|
||||
{"prompt": "What is Python?", "response": "Python is a language"},
|
||||
{"prompt": "Tell me more", "response": "Python is versatile"},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"question": "How to install it?",
|
||||
"conversation_id": str(conv_id),
|
||||
}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
processor._load_conversation_history()
|
||||
|
||||
assert len(processor.history) == 2
|
||||
assert processor.history[0]["prompt"] == "What is Python?"
|
||||
assert processor.history[1]["response"] == "Python is versatile"
|
||||
|
||||
def test_raises_error_for_unauthorized_conversation(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
conversations_collection = mock_mongo_db[settings.MONGO_DB_NAME][
|
||||
"conversations"
|
||||
]
|
||||
conv_id = ObjectId()
|
||||
|
||||
conversations_collection.insert_one(
|
||||
{
|
||||
"_id": conv_id,
|
||||
"user": "owner_123",
|
||||
"name": "Private Conv",
|
||||
"queries": [],
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {"question": "Hack attempt", "conversation_id": str(conv_id)}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "hacker_456"})
|
||||
|
||||
with pytest.raises(ValueError, match="Conversation not found or unauthorized"):
|
||||
processor._load_conversation_history()
|
||||
|
||||
def test_uses_request_history_when_no_conversation_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {
|
||||
"question": "What is Python?",
|
||||
"history": [{"prompt": "Hello", "response": "Hi there!"}],
|
||||
}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
assert processor.conversation_id is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorAgentConfiguration:
|
||||
|
||||
def test_configures_agent_from_valid_api_key(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_api_key_123",
|
||||
"endpoint": "openai",
|
||||
"model": "gpt-4",
|
||||
"prompt_id": "default",
|
||||
"user": "user_123",
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {"question": "Test", "api_key": "test_api_key_123"}
|
||||
|
||||
processor = StreamProcessor(request_data, None)
|
||||
|
||||
try:
|
||||
processor._configure_agent()
|
||||
assert processor.agent_config is not None
|
||||
except Exception as e:
|
||||
assert "Invalid API Key" in str(e)
|
||||
|
||||
def test_uses_default_config_without_api_key(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "Test"}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
processor._configure_agent()
|
||||
|
||||
assert isinstance(processor.agent_config, dict)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorAttachments:
|
||||
|
||||
def test_processes_attachments_from_request(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
attachments_collection = mock_mongo_db[settings.MONGO_DB_NAME]["attachments"]
|
||||
att_id = ObjectId()
|
||||
|
||||
attachments_collection.insert_one(
|
||||
{
|
||||
"_id": att_id,
|
||||
"filename": "document.pdf",
|
||||
"content": "Document content",
|
||||
"user": "user_123",
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {"question": "Analyze this", "attachments": [str(att_id)]}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
assert processor.data.get("attachments") == [str(att_id)]
|
||||
|
||||
def test_handles_empty_attachments(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "Simple question"}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
assert processor.attachments == []
|
||||
assert (
|
||||
"attachments" not in processor.data
|
||||
or processor.data.get("attachments") is None
|
||||
)
|
||||
@@ -1,89 +0,0 @@
|
||||
"""API-specific test fixtures."""
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers():
|
||||
return {"Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_token(monkeypatch, decoded_token):
|
||||
def mock_decorator(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = decoded_token
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
monkeypatch.setattr("application.auth.api_key_required", lambda: mock_decorator)
|
||||
return decoded_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation():
|
||||
return {
|
||||
"_id": ObjectId(),
|
||||
"user": "test_user",
|
||||
"name": "Test Conversation",
|
||||
"queries": [
|
||||
{
|
||||
"prompt": "What is Python?",
|
||||
"response": "Python is a programming language",
|
||||
}
|
||||
],
|
||||
"date": "2025-01-01T00:00:00",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_prompt():
|
||||
return {
|
||||
"_id": ObjectId(),
|
||||
"user": "test_user",
|
||||
"name": "Helpful Assistant",
|
||||
"content": "You are a helpful assistant that provides clear and concise answers.",
|
||||
"type": "custom",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent():
|
||||
return {
|
||||
"_id": ObjectId(),
|
||||
"user": "test_user",
|
||||
"name": "Test Agent",
|
||||
"type": "classic",
|
||||
"endpoint": "openai",
|
||||
"model": "gpt-4",
|
||||
"prompt_id": "default",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_answer_request():
|
||||
return {
|
||||
"question": "What is Python?",
|
||||
"history": [],
|
||||
"conversation_id": None,
|
||||
"prompt_id": "default",
|
||||
"chunks": 2,
|
||||
"token_limit": 1000,
|
||||
"retriever": "classic_rag",
|
||||
"active_docs": "local/test/",
|
||||
"isNoneDoc": False,
|
||||
"save_conversation": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app():
|
||||
from flask import Flask
|
||||
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
@@ -1,311 +0,0 @@
|
||||
import datetime
|
||||
import io
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTimeRangeGenerators:
|
||||
|
||||
def test_generate_minute_range(self):
|
||||
from application.api.user.base import generate_minute_range
|
||||
|
||||
start = datetime.datetime(2024, 1, 1, 10, 0, 0)
|
||||
end = datetime.datetime(2024, 1, 1, 10, 5, 0)
|
||||
|
||||
result = generate_minute_range(start, end)
|
||||
|
||||
assert len(result) == 6
|
||||
assert "2024-01-01 10:00:00" in result
|
||||
assert "2024-01-01 10:05:00" in result
|
||||
assert all(val == 0 for val in result.values())
|
||||
|
||||
def test_generate_hourly_range(self):
|
||||
from application.api.user.base import generate_hourly_range
|
||||
|
||||
start = datetime.datetime(2024, 1, 1, 10, 0, 0)
|
||||
end = datetime.datetime(2024, 1, 1, 15, 0, 0)
|
||||
|
||||
result = generate_hourly_range(start, end)
|
||||
|
||||
assert len(result) == 6
|
||||
assert "2024-01-01 10:00" in result
|
||||
assert "2024-01-01 15:00" in result
|
||||
assert all(val == 0 for val in result.values())
|
||||
|
||||
def test_generate_date_range(self):
|
||||
from application.api.user.base import generate_date_range
|
||||
|
||||
start = datetime.date(2024, 1, 1)
|
||||
end = datetime.date(2024, 1, 5)
|
||||
|
||||
result = generate_date_range(start, end)
|
||||
|
||||
assert len(result) == 5
|
||||
assert "2024-01-01" in result
|
||||
assert "2024-01-05" in result
|
||||
assert all(val == 0 for val in result.values())
|
||||
|
||||
def test_single_minute_range(self):
|
||||
from application.api.user.base import generate_minute_range
|
||||
|
||||
time = datetime.datetime(2024, 1, 1, 10, 30, 0)
|
||||
result = generate_minute_range(time, time)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "2024-01-01 10:30:00" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEnsureUserDoc:
|
||||
|
||||
def test_creates_new_user_with_defaults(self, mock_mongo_db):
|
||||
from application.api.user.base import ensure_user_doc
|
||||
|
||||
user_id = "test_user_123"
|
||||
|
||||
result = ensure_user_doc(user_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["user_id"] == user_id
|
||||
assert "agent_preferences" in result
|
||||
assert result["agent_preferences"]["pinned"] == []
|
||||
assert result["agent_preferences"]["shared_with_me"] == []
|
||||
|
||||
def test_returns_existing_user(self, mock_mongo_db):
|
||||
from application.api.user.base import ensure_user_doc
|
||||
from application.core.settings import settings
|
||||
|
||||
users_collection = mock_mongo_db[settings.MONGO_DB_NAME]["users"]
|
||||
user_id = "existing_user"
|
||||
|
||||
existing_doc = {
|
||||
"user_id": user_id,
|
||||
"agent_preferences": {"pinned": ["agent1"], "shared_with_me": ["agent2"]},
|
||||
}
|
||||
users_collection.insert_one(existing_doc)
|
||||
|
||||
result = ensure_user_doc(user_id)
|
||||
|
||||
assert result["user_id"] == user_id
|
||||
assert result["agent_preferences"]["pinned"] == ["agent1"]
|
||||
assert result["agent_preferences"]["shared_with_me"] == ["agent2"]
|
||||
|
||||
def test_adds_missing_preferences_fields(self, mock_mongo_db):
|
||||
from application.api.user.base import ensure_user_doc
|
||||
from application.core.settings import settings
|
||||
|
||||
users_collection = mock_mongo_db[settings.MONGO_DB_NAME]["users"]
|
||||
user_id = "incomplete_user"
|
||||
|
||||
users_collection.insert_one(
|
||||
{"user_id": user_id, "agent_preferences": {"pinned": ["agent1"]}}
|
||||
)
|
||||
|
||||
result = ensure_user_doc(user_id)
|
||||
|
||||
assert "shared_with_me" in result["agent_preferences"]
|
||||
assert result["agent_preferences"]["shared_with_me"] == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestResolveToolDetails:
|
||||
|
||||
def test_resolves_tool_ids_to_details(self, mock_mongo_db):
|
||||
from application.api.user.base import resolve_tool_details
|
||||
from application.core.settings import settings
|
||||
|
||||
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tool_id1 = ObjectId()
|
||||
tool_id2 = ObjectId()
|
||||
|
||||
user_tools.insert_one(
|
||||
{"_id": tool_id1, "name": "calculator", "displayName": "Calculator Tool"}
|
||||
)
|
||||
user_tools.insert_one(
|
||||
{"_id": tool_id2, "name": "weather", "displayName": "Weather API"}
|
||||
)
|
||||
|
||||
result = resolve_tool_details([str(tool_id1), str(tool_id2)])
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == str(tool_id1)
|
||||
assert result[0]["name"] == "calculator"
|
||||
assert result[0]["display_name"] == "Calculator Tool"
|
||||
assert result[1]["name"] == "weather"
|
||||
|
||||
def test_handles_missing_display_name(self, mock_mongo_db):
|
||||
from application.api.user.base import resolve_tool_details
|
||||
from application.core.settings import settings
|
||||
|
||||
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tool_id = ObjectId()
|
||||
|
||||
user_tools.insert_one({"_id": tool_id, "name": "test_tool"})
|
||||
|
||||
result = resolve_tool_details([str(tool_id)])
|
||||
|
||||
assert result[0]["display_name"] == "test_tool"
|
||||
|
||||
def test_empty_tool_ids_list(self, mock_mongo_db):
|
||||
from application.api.user.base import resolve_tool_details
|
||||
|
||||
result = resolve_tool_details([])
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetVectorStore:
|
||||
|
||||
@patch("application.api.user.base.VectorCreator.create_vectorstore")
|
||||
def test_creates_vector_store(self, mock_create, mock_mongo_db):
|
||||
from application.api.user.base import get_vector_store
|
||||
|
||||
mock_store = Mock()
|
||||
mock_create.return_value = mock_store
|
||||
source_id = "test_source_123"
|
||||
|
||||
result = get_vector_store(source_id)
|
||||
|
||||
assert result == mock_store
|
||||
mock_create.assert_called_once()
|
||||
args, kwargs = mock_create.call_args
|
||||
assert kwargs.get("source_id") == source_id
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandleImageUpload:
|
||||
|
||||
def test_returns_existing_url_when_no_file(self, flask_app):
|
||||
from application.api.user.base import handle_image_upload
|
||||
|
||||
with flask_app.test_request_context():
|
||||
mock_request = Mock()
|
||||
mock_request.files = {}
|
||||
mock_storage = Mock()
|
||||
existing_url = "existing/path/image.jpg"
|
||||
|
||||
url, error = handle_image_upload(
|
||||
mock_request, existing_url, "user123", mock_storage
|
||||
)
|
||||
|
||||
assert url == existing_url
|
||||
assert error is None
|
||||
|
||||
def test_uploads_new_image(self, flask_app):
|
||||
from application.api.user.base import handle_image_upload
|
||||
|
||||
with flask_app.test_request_context():
|
||||
mock_file = FileStorage(
|
||||
stream=io.BytesIO(b"fake image data"), filename="test_image.png"
|
||||
)
|
||||
mock_request = Mock()
|
||||
mock_request.files = {"image": mock_file}
|
||||
mock_storage = Mock()
|
||||
mock_storage.save_file.return_value = {"success": True}
|
||||
|
||||
url, error = handle_image_upload(
|
||||
mock_request, "old_url", "user123", mock_storage
|
||||
)
|
||||
|
||||
assert error is None
|
||||
assert url is not None
|
||||
assert "test_image.png" in url
|
||||
assert "user123" in url
|
||||
mock_storage.save_file.assert_called_once()
|
||||
|
||||
def test_ignores_empty_filename(self, flask_app):
|
||||
from application.api.user.base import handle_image_upload
|
||||
|
||||
with flask_app.test_request_context():
|
||||
mock_file = Mock()
|
||||
mock_file.filename = ""
|
||||
mock_request = Mock()
|
||||
mock_request.files = {"image": mock_file}
|
||||
mock_storage = Mock()
|
||||
existing_url = "existing.jpg"
|
||||
|
||||
url, error = handle_image_upload(
|
||||
mock_request, existing_url, "user123", mock_storage
|
||||
)
|
||||
|
||||
assert url == existing_url
|
||||
assert error is None
|
||||
mock_storage.save_file.assert_not_called()
|
||||
|
||||
def test_handles_upload_error(self, flask_app):
|
||||
from application.api.user.base import handle_image_upload
|
||||
|
||||
with flask_app.app_context():
|
||||
mock_file = FileStorage(stream=io.BytesIO(b"data"), filename="test.png")
|
||||
mock_request = Mock()
|
||||
mock_request.files = {"image": mock_file}
|
||||
mock_storage = Mock()
|
||||
mock_storage.save_file.side_effect = Exception("Storage error")
|
||||
|
||||
url, error = handle_image_upload(
|
||||
mock_request, "old.jpg", "user123", mock_storage
|
||||
)
|
||||
|
||||
assert url is None
|
||||
assert error is not None
|
||||
assert error.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRequireAgentDecorator:
|
||||
|
||||
def test_validates_webhook_token(self, mock_mongo_db, flask_app):
|
||||
from application.api.user.base import require_agent
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
webhook_token = "valid_webhook_token_123"
|
||||
|
||||
agents_collection.insert_one(
|
||||
{"_id": agent_id, "incoming_webhook_token": webhook_token}
|
||||
)
|
||||
|
||||
@require_agent
|
||||
def test_func(webhook_token=None, agent=None, agent_id_str=None):
|
||||
return {"agent_id": agent_id_str}
|
||||
|
||||
result = test_func(webhook_token=webhook_token)
|
||||
|
||||
assert result["agent_id"] == str(agent_id)
|
||||
|
||||
def test_returns_400_for_missing_token(self, mock_mongo_db, flask_app):
|
||||
from application.api.user.base import require_agent
|
||||
|
||||
with flask_app.app_context():
|
||||
|
||||
@require_agent
|
||||
def test_func(webhook_token=None, agent=None, agent_id_str=None):
|
||||
return {"success": True}
|
||||
|
||||
result = test_func()
|
||||
|
||||
assert result.status_code == 400
|
||||
assert result.json["success"] is False
|
||||
assert "missing" in result.json["message"].lower()
|
||||
|
||||
def test_returns_404_for_invalid_token(self, mock_mongo_db, flask_app):
|
||||
from application.api.user.base import require_agent
|
||||
|
||||
with flask_app.app_context():
|
||||
|
||||
@require_agent
|
||||
def test_func(webhook_token=None, agent=None, agent_id_str=None):
|
||||
return {"success": True}
|
||||
|
||||
result = test_func(webhook_token="invalid_token_999")
|
||||
|
||||
assert result.status_code == 404
|
||||
assert result.json["success"] is False
|
||||
assert "not found" in result.json["message"].lower()
|
||||
@@ -1,15 +1,7 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import mongomock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def get_settings():
|
||||
"""Lazy load settings to avoid import-time errors."""
|
||||
from application.core.settings import settings
|
||||
|
||||
return settings
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -43,51 +35,18 @@ def mock_retriever():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mongo_db(monkeypatch):
|
||||
"""Mock MongoDB using mongomock - industry standard MongoDB mocking library."""
|
||||
settings = get_settings()
|
||||
fake_collection = FakeMongoCollection()
|
||||
fake_db = {
|
||||
"agents": fake_collection,
|
||||
"user_tools": fake_collection,
|
||||
"memories": fake_collection,
|
||||
}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_db = mock_client[settings.MONGO_DB_NAME]
|
||||
|
||||
def get_mock_client():
|
||||
return {settings.MONGO_DB_NAME: mock_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", get_mock_client)
|
||||
|
||||
monkeypatch.setattr("application.api.user.base.users_collection", mock_db["users"])
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.user_tools_collection", mock_db["user_tools"]
|
||||
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.agents_collection", mock_db["agents"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.conversations_collection", mock_db["conversations"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.sources_collection", mock_db["sources"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.prompts_collection", mock_db["prompts"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.feedback_collection", mock_db["feedback"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.token_usage_collection", mock_db["token_usage"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.attachments_collection", mock_db["attachments"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.user_logs_collection", mock_db["user_logs"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.shared_conversations_collections",
|
||||
mock_db["shared_conversations"],
|
||||
)
|
||||
|
||||
return get_mock_client()
|
||||
return fake_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -128,6 +87,53 @@ def log_context():
|
||||
return context
|
||||
|
||||
|
||||
class FakeMongoCollection:
|
||||
def __init__(self):
|
||||
self.docs = {}
|
||||
|
||||
def find_one(self, query, projection=None):
|
||||
if "key" in query:
|
||||
return self.docs.get(query["key"])
|
||||
if "_id" in query:
|
||||
return self.docs.get(str(query["_id"]))
|
||||
if "user" in query:
|
||||
for doc in self.docs.values():
|
||||
if doc.get("user") == query["user"]:
|
||||
return doc
|
||||
return None
|
||||
|
||||
def find(self, query, projection=None):
|
||||
results = []
|
||||
if "_id" in query and "$in" in query["_id"]:
|
||||
for doc_id in query["_id"]["$in"]:
|
||||
doc = self.docs.get(str(doc_id))
|
||||
if doc:
|
||||
results.append(doc)
|
||||
elif "user" in query:
|
||||
for doc in self.docs.values():
|
||||
if doc.get("user") == query["user"]:
|
||||
if "status" in query:
|
||||
if doc.get("status") == query["status"]:
|
||||
results.append(doc)
|
||||
else:
|
||||
results.append(doc)
|
||||
return results
|
||||
|
||||
def insert_one(self, doc):
|
||||
doc_id = doc.get("_id", len(self.docs))
|
||||
self.docs[str(doc_id)] = doc
|
||||
return Mock(inserted_id=doc_id)
|
||||
|
||||
def update_one(self, query, update, upsert=False):
|
||||
return Mock(modified_count=1)
|
||||
|
||||
def delete_one(self, query):
|
||||
return Mock(deleted_count=1)
|
||||
|
||||
def delete_many(self, query):
|
||||
return Mock(deleted_count=0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_creator(mock_llm, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
from typing import Any, Dict, Generator
|
||||
from unittest.mock import Mock, patch
|
||||
from typing import Any, Dict, Generator
|
||||
|
||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||
|
||||
|
||||
class TestToolCall:
|
||||
"""Test ToolCall dataclass."""
|
||||
|
||||
def test_tool_call_creation(self):
|
||||
"""Test basic ToolCall creation."""
|
||||
tool_call = ToolCall(
|
||||
id="test_id", name="test_function", arguments={"arg1": "value1"}, index=0
|
||||
id="test_id",
|
||||
name="test_function",
|
||||
arguments={"arg1": "value1"},
|
||||
index=0
|
||||
)
|
||||
assert tool_call.id == "test_id"
|
||||
assert tool_call.name == "test_function"
|
||||
@@ -15,11 +21,12 @@ class TestToolCall:
|
||||
assert tool_call.index == 0
|
||||
|
||||
def test_tool_call_from_dict(self):
|
||||
"""Test ToolCall creation from dictionary."""
|
||||
data = {
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "New York"},
|
||||
"index": 1,
|
||||
"index": 1
|
||||
}
|
||||
tool_call = ToolCall.from_dict(data)
|
||||
assert tool_call.id == "call_123"
|
||||
@@ -28,6 +35,7 @@ class TestToolCall:
|
||||
assert tool_call.index == 1
|
||||
|
||||
def test_tool_call_from_dict_missing_fields(self):
|
||||
"""Test ToolCall creation with missing fields."""
|
||||
data = {"name": "test_func"}
|
||||
tool_call = ToolCall.from_dict(data)
|
||||
assert tool_call.id == ""
|
||||
@@ -37,13 +45,16 @@ class TestToolCall:
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
"""Test LLMResponse dataclass."""
|
||||
|
||||
def test_llm_response_creation(self):
|
||||
"""Test basic LLMResponse creation."""
|
||||
tool_calls = [ToolCall(id="1", name="func", arguments={})]
|
||||
response = LLMResponse(
|
||||
content="Hello",
|
||||
tool_calls=tool_calls,
|
||||
finish_reason="tool_calls",
|
||||
raw_response={"test": "data"},
|
||||
raw_response={"test": "data"}
|
||||
)
|
||||
assert response.content == "Hello"
|
||||
assert len(response.tool_calls) == 1
|
||||
@@ -51,43 +62,55 @@ class TestLLMResponse:
|
||||
assert response.raw_response == {"test": "data"}
|
||||
|
||||
def test_requires_tool_call_true(self):
|
||||
"""Test requires_tool_call property when tool calls are needed."""
|
||||
tool_calls = [ToolCall(id="1", name="func", arguments={})]
|
||||
response = LLMResponse(
|
||||
content="",
|
||||
tool_calls=tool_calls,
|
||||
finish_reason="tool_calls",
|
||||
raw_response={},
|
||||
raw_response={}
|
||||
)
|
||||
assert response.requires_tool_call is True
|
||||
|
||||
def test_requires_tool_call_false_no_tools(self):
|
||||
"""Test requires_tool_call property when no tool calls."""
|
||||
response = LLMResponse(
|
||||
content="Hello", tool_calls=[], finish_reason="stop", raw_response={}
|
||||
content="Hello",
|
||||
tool_calls=[],
|
||||
finish_reason="stop",
|
||||
raw_response={}
|
||||
)
|
||||
assert response.requires_tool_call is False
|
||||
|
||||
def test_requires_tool_call_false_wrong_finish_reason(self):
|
||||
"""Test requires_tool_call property with tools but wrong finish reason."""
|
||||
tool_calls = [ToolCall(id="1", name="func", arguments={})]
|
||||
response = LLMResponse(
|
||||
content="Hello",
|
||||
tool_calls=tool_calls,
|
||||
finish_reason="stop",
|
||||
raw_response={},
|
||||
raw_response={}
|
||||
)
|
||||
assert response.requires_tool_call is False
|
||||
|
||||
|
||||
class ConcreteHandler(LLMHandler):
|
||||
"""Concrete implementation for testing abstract base class."""
|
||||
|
||||
def parse_response(self, response: Any) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content=str(response),
|
||||
tool_calls=[],
|
||||
finish_reason="stop",
|
||||
raw_response=response,
|
||||
raw_response=response
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||
return {"role": "tool", "content": str(result), "tool_call_id": tool_call.id}
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": str(result),
|
||||
"tool_call_id": tool_call.id
|
||||
}
|
||||
|
||||
def _iterate_stream(self, response: Any) -> Generator:
|
||||
for chunk in response:
|
||||
@@ -95,119 +118,114 @@ class ConcreteHandler(LLMHandler):
|
||||
|
||||
|
||||
class TestLLMHandler:
|
||||
"""Test LLMHandler base class."""
|
||||
|
||||
def test_handler_initialization(self):
|
||||
"""Test handler initialization."""
|
||||
handler = ConcreteHandler()
|
||||
assert handler.llm_calls == []
|
||||
assert handler.tool_calls == []
|
||||
|
||||
def test_prepare_messages_no_attachments(self):
|
||||
"""Test prepare_messages with no attachments."""
|
||||
handler = ConcreteHandler()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
|
||||
mock_agent = Mock()
|
||||
result = handler.prepare_messages(mock_agent, messages, None)
|
||||
assert result == messages
|
||||
|
||||
def test_prepare_messages_with_supported_attachments(self):
|
||||
"""Test prepare_messages with supported attachments."""
|
||||
handler = ConcreteHandler()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
attachments = [{"mime_type": "image/png", "path": "/test.png"}]
|
||||
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
|
||||
mock_agent.llm.prepare_messages_with_attachments.return_value = messages
|
||||
|
||||
|
||||
result = handler.prepare_messages(mock_agent, messages, attachments)
|
||||
mock_agent.llm.prepare_messages_with_attachments.assert_called_once_with(
|
||||
messages, attachments
|
||||
)
|
||||
assert result == messages
|
||||
|
||||
@patch("application.llm.handlers.base.logger")
|
||||
@patch('application.llm.handlers.base.logger')
|
||||
def test_prepare_messages_with_unsupported_attachments(self, mock_logger):
|
||||
"""Test prepare_messages with unsupported attachments."""
|
||||
handler = ConcreteHandler()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
attachments = [{"mime_type": "text/plain", "path": "/test.txt"}]
|
||||
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
|
||||
|
||||
with patch.object(
|
||||
handler, "_append_unsupported_attachments", return_value=messages
|
||||
) as mock_append:
|
||||
|
||||
with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append:
|
||||
result = handler.prepare_messages(mock_agent, messages, attachments)
|
||||
mock_append.assert_called_once_with(messages, attachments)
|
||||
assert result == messages
|
||||
|
||||
def test_prepare_messages_mixed_attachments(self):
|
||||
"""Test prepare_messages with both supported and unsupported attachments."""
|
||||
handler = ConcreteHandler()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
attachments = [
|
||||
{"mime_type": "image/png", "path": "/test.png"},
|
||||
{"mime_type": "text/plain", "path": "/test.txt"},
|
||||
{"mime_type": "text/plain", "path": "/test.txt"}
|
||||
]
|
||||
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
|
||||
mock_agent.llm.prepare_messages_with_attachments.return_value = messages
|
||||
|
||||
with patch.object(
|
||||
handler, "_append_unsupported_attachments", return_value=messages
|
||||
) as mock_append:
|
||||
|
||||
with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append:
|
||||
result = handler.prepare_messages(mock_agent, messages, attachments)
|
||||
|
||||
|
||||
# Should call both methods
|
||||
mock_agent.llm.prepare_messages_with_attachments.assert_called_once()
|
||||
mock_append.assert_called_once()
|
||||
assert result == messages
|
||||
|
||||
def test_process_message_flow_non_streaming(self):
|
||||
"""Test process_message_flow for non-streaming."""
|
||||
handler = ConcreteHandler()
|
||||
mock_agent = Mock()
|
||||
initial_response = "test response"
|
||||
tools_dict = {}
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
with patch.object(
|
||||
handler, "prepare_messages", return_value=messages
|
||||
) as mock_prepare:
|
||||
with patch.object(
|
||||
handler, "handle_non_streaming", return_value="final"
|
||||
) as mock_handle:
|
||||
|
||||
with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare:
|
||||
with patch.object(handler, 'handle_non_streaming', return_value="final") as mock_handle:
|
||||
result = handler.process_message_flow(
|
||||
mock_agent, initial_response, tools_dict, messages, stream=False
|
||||
)
|
||||
|
||||
|
||||
mock_prepare.assert_called_once_with(mock_agent, messages, None)
|
||||
mock_handle.assert_called_once_with(
|
||||
mock_agent, initial_response, tools_dict, messages
|
||||
)
|
||||
mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages)
|
||||
assert result == "final"
|
||||
|
||||
def test_process_message_flow_streaming(self):
|
||||
"""Test process_message_flow for streaming."""
|
||||
handler = ConcreteHandler()
|
||||
mock_agent = Mock()
|
||||
initial_response = "test response"
|
||||
tools_dict = {}
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
|
||||
def mock_generator():
|
||||
yield "chunk1"
|
||||
yield "chunk2"
|
||||
|
||||
with patch.object(
|
||||
handler, "prepare_messages", return_value=messages
|
||||
) as mock_prepare:
|
||||
with patch.object(
|
||||
handler, "handle_streaming", return_value=mock_generator()
|
||||
) as mock_handle:
|
||||
|
||||
with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare:
|
||||
with patch.object(handler, 'handle_streaming', return_value=mock_generator()) as mock_handle:
|
||||
result = handler.process_message_flow(
|
||||
mock_agent, initial_response, tools_dict, messages, stream=True
|
||||
)
|
||||
|
||||
|
||||
mock_prepare.assert_called_once_with(mock_agent, messages, None)
|
||||
mock_handle.assert_called_once_with(
|
||||
mock_agent, initial_response, tools_dict, messages
|
||||
)
|
||||
|
||||
mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages)
|
||||
|
||||
# Verify it's a generator
|
||||
chunks = list(result)
|
||||
assert chunks == ["chunk1", "chunk2"]
|
||||
@@ -1,4 +1,3 @@
|
||||
pytest>=8.0.0
|
||||
pytest-cov>=4.1.0
|
||||
coverage>=7.4.0
|
||||
mongomock>=4.3.0
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
import pytest
|
||||
from application.agents.tools.todo_list import TodoListTool
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class FakeCursor(list):
|
||||
def sort(self, key, direction):
|
||||
reverse = direction == -1
|
||||
sorted_list = sorted(self, key=lambda d: d.get(key, 0), reverse=reverse)
|
||||
return FakeCursor(sorted_list)
|
||||
|
||||
def limit(self, count):
|
||||
return FakeCursor(self[:count])
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if not self:
|
||||
raise StopIteration
|
||||
return self.pop(0)
|
||||
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self):
|
||||
self.docs = {}
|
||||
|
||||
def create_index(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def insert_one(self, doc):
|
||||
key = (doc["user_id"], doc["tool_id"], int(doc["todo_id"]))
|
||||
self.docs[key] = doc
|
||||
return type("res", (), {"inserted_id": key})
|
||||
|
||||
def find_one(self, query):
|
||||
key = (query.get("user_id"), query.get("tool_id"), int(query.get("todo_id")))
|
||||
return self.docs.get(key)
|
||||
|
||||
def find(self, query):
|
||||
user_id = query.get("user_id")
|
||||
tool_id = query.get("tool_id")
|
||||
filtered = [
|
||||
doc for (uid, tid, _), doc in self.docs.items()
|
||||
if uid == user_id and tid == tool_id
|
||||
]
|
||||
return FakeCursor(filtered)
|
||||
|
||||
def update_one(self, query, update, upsert=False):
|
||||
key = (query.get("user_id"), query.get("tool_id"), int(query.get("todo_id")))
|
||||
if key in self.docs:
|
||||
self.docs[key].update(update.get("$set", {}))
|
||||
return type("res", (), {"matched_count": 1})
|
||||
elif upsert:
|
||||
new_doc = {**query, **update.get("$set", {})}
|
||||
self.docs[key] = new_doc
|
||||
return type("res", (), {"matched_count": 1})
|
||||
else:
|
||||
return type("res", (), {"matched_count": 0})
|
||||
|
||||
def delete_one(self, query):
|
||||
key = (query.get("user_id"), query.get("tool_id"), int(query.get("todo_id")))
|
||||
if key in self.docs:
|
||||
del self.docs[key]
|
||||
return type("res", (), {"deleted_count": 1})
|
||||
return type("res", (), {"deleted_count": 0})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def todo_tool(monkeypatch) -> TodoListTool:
|
||||
"""Provides a TodoListTool with a fake MongoDB backend."""
|
||||
fake_collection = FakeCollection()
|
||||
fake_client = {settings.MONGO_DB_NAME: {"todos": fake_collection}}
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
return TodoListTool({"tool_id": "test_tool"}, user_id="test_user")
|
||||
|
||||
|
||||
def test_create_and_get(todo_tool: TodoListTool):
|
||||
res = todo_tool.execute_action("todo_create", title="Write tests", description="Write pytest cases")
|
||||
assert res["status_code"] == 201
|
||||
todo_id = res["todo_id"]
|
||||
|
||||
get_res = todo_tool.execute_action("todo_get", todo_id=todo_id)
|
||||
assert get_res["status_code"] == 200
|
||||
assert get_res["todo"]["title"] == "Write tests"
|
||||
assert get_res["todo"]["description"] == "Write pytest cases"
|
||||
|
||||
|
||||
def test_get_all_todos(todo_tool: TodoListTool):
|
||||
todo_tool.execute_action("todo_create", title="Task 1")
|
||||
todo_tool.execute_action("todo_create", title="Task 2")
|
||||
|
||||
list_res = todo_tool.execute_action("todo_list")
|
||||
assert list_res["status_code"] == 200
|
||||
titles = [todo["title"] for todo in list_res["todos"]]
|
||||
assert "Task 1" in titles
|
||||
assert "Task 2" in titles
|
||||
|
||||
|
||||
def test_update_todo(todo_tool: TodoListTool):
|
||||
create_res = todo_tool.execute_action("todo_create", title="Initial Title")
|
||||
todo_id = create_res["todo_id"]
|
||||
|
||||
update_res = todo_tool.execute_action("todo_update", todo_id=todo_id, updates={"title": "Updated Title", "status": "done"})
|
||||
assert update_res["status_code"] == 200
|
||||
|
||||
get_res = todo_tool.execute_action("todo_get", todo_id=todo_id)
|
||||
assert get_res["todo"]["title"] == "Updated Title"
|
||||
assert get_res["todo"]["status"] == "done"
|
||||
|
||||
|
||||
def test_delete_todo(todo_tool: TodoListTool):
|
||||
create_res = todo_tool.execute_action("todo_create", title="To Delete")
|
||||
todo_id = create_res["todo_id"]
|
||||
|
||||
delete_res = todo_tool.execute_action("todo_delete", todo_id=todo_id)
|
||||
assert delete_res["status_code"] == 200
|
||||
|
||||
get_res = todo_tool.execute_action("todo_get", todo_id=todo_id)
|
||||
assert get_res["status_code"] == 404
|
||||
|
||||
|
||||
def test_isolation_and_default_tool_id(monkeypatch):
|
||||
"""Ensure todos are isolated by tool_id and user_id."""
|
||||
fake_collection = FakeCollection()
|
||||
fake_client = {settings.MONGO_DB_NAME: {"todos": fake_collection}}
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Same user, different tool_id
|
||||
tool1 = TodoListTool({"tool_id": "tool_1"}, user_id="u1")
|
||||
tool2 = TodoListTool({"tool_id": "tool_2"}, user_id="u1")
|
||||
|
||||
r1_create = tool1.execute_action("todo_create", title="from tool 1")
|
||||
r2_create = tool2.execute_action("todo_create", title="from tool 2")
|
||||
|
||||
r1 = tool1.execute_action("todo_get", todo_id=r1_create["todo_id"])
|
||||
r2 = tool2.execute_action("todo_get", todo_id=r2_create["todo_id"])
|
||||
|
||||
assert r1["status_code"] == 200
|
||||
assert r1["todo"]["title"] == "from tool 1"
|
||||
|
||||
assert r2["status_code"] == 200
|
||||
assert r2["todo"]["title"] == "from tool 2"
|
||||
|
||||
# Same user, no tool_id → should default to same value
|
||||
t3 = TodoListTool({}, user_id="default_user")
|
||||
t4 = TodoListTool({}, user_id="default_user")
|
||||
|
||||
assert t3.tool_id == "default_default_user"
|
||||
assert t4.tool_id == "default_default_user"
|
||||
|
||||
create_res = t3.execute_action("todo_create", title="shared default")
|
||||
r = t4.execute_action("todo_get", todo_id=create_res["todo_id"])
|
||||
|
||||
assert r["status_code"] == 200
|
||||
assert r["todo"]["title"] == "shared default"
|
||||
@@ -16,25 +16,12 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
|
||||
class DummyClient:
|
||||
def __init__(self, api_key):
|
||||
created["api_key"] = api_key
|
||||
self.convert_calls = []
|
||||
self.generate_calls = []
|
||||
|
||||
class TextToSpeech:
|
||||
def __init__(self, outer):
|
||||
self._outer = outer
|
||||
|
||||
def convert(self, *, voice_id, model_id, text, output_format):
|
||||
self._outer.convert_calls.append(
|
||||
{
|
||||
"voice_id": voice_id,
|
||||
"model_id": model_id,
|
||||
"text": text,
|
||||
"output_format": output_format,
|
||||
}
|
||||
)
|
||||
yield b"chunk-one"
|
||||
yield b"chunk-two"
|
||||
|
||||
self.text_to_speech = TextToSpeech(self)
|
||||
def generate(self, *, text, model, voice):
|
||||
self.generate_calls.append({"text": text, "model": model, "voice": voice})
|
||||
yield b"chunk-one"
|
||||
yield b"chunk-two"
|
||||
|
||||
client_module = ModuleType("elevenlabs.client")
|
||||
client_module.ElevenLabs = DummyClient
|
||||
@@ -48,13 +35,8 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
|
||||
audio_base64, lang = tts.text_to_speech("Speak")
|
||||
|
||||
assert created["api_key"] == "api-key"
|
||||
assert tts.client.convert_calls == [
|
||||
{
|
||||
"voice_id": "nPczCjzI2devNBz1zQrb",
|
||||
"model_id": "eleven_multilingual_v2",
|
||||
"text": "Speak",
|
||||
"output_format": "mp3_44100_128",
|
||||
}
|
||||
assert tts.client.generate_calls == [
|
||||
{"text": "Speak", "model": "eleven_multilingual_v2", "voice": "Brian"}
|
||||
]
|
||||
assert lang == "en"
|
||||
assert base64.b64decode(audio_base64.encode()) == b"chunk-onechunk-two"
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from application.tts.tts_creator import TTSCreator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tts_creator():
|
||||
return TTSCreator()
|
||||
|
||||
|
||||
def test_create_google_tts(tts_creator):
|
||||
# Patch the provider registry so the factory calls our mock class
|
||||
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
|
||||
mock_google_tts = TTSCreator.tts_providers["google_tts"]
|
||||
instance = MagicMock()
|
||||
mock_google_tts.return_value = instance
|
||||
|
||||
result = tts_creator.create_tts("google_tts", "arg1", key="value")
|
||||
|
||||
mock_google_tts.assert_called_once_with("arg1", key="value")
|
||||
assert result == instance
|
||||
|
||||
|
||||
def test_create_elevenlabs_tts(tts_creator):
|
||||
# Patch the provider registry so the factory calls our mock class
|
||||
with patch.dict(TTSCreator.tts_providers, {"elevenlabs": MagicMock()}):
|
||||
mock_elevenlabs_tts = TTSCreator.tts_providers["elevenlabs"]
|
||||
instance = MagicMock()
|
||||
mock_elevenlabs_tts.return_value = instance
|
||||
|
||||
result = tts_creator.create_tts("elevenlabs", "voice", lang="en")
|
||||
|
||||
mock_elevenlabs_tts.assert_called_once_with("voice", lang="en")
|
||||
assert result == instance
|
||||
|
||||
|
||||
def test_invalid_tts_type(tts_creator):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
tts_creator.create_tts("unknown_tts")
|
||||
assert "No tts class found" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_tts_type_case_insensitivity(tts_creator):
|
||||
# Patch the provider registry to ensure case-insensitive lookup hits our mock
|
||||
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
|
||||
mock_google_tts = TTSCreator.tts_providers["google_tts"]
|
||||
instance = MagicMock()
|
||||
mock_google_tts.return_value = instance
|
||||
|
||||
result = tts_creator.create_tts("GoOgLe_TtS")
|
||||
|
||||
mock_google_tts.assert_called_once_with()
|
||||
assert result == instance
|
||||
|
||||
|
||||
def test_tts_providers_integrity(tts_creator):
|
||||
providers = tts_creator.tts_providers
|
||||
assert "google_tts" in providers
|
||||
assert "elevenlabs" in providers
|
||||
assert callable(providers["google_tts"])
|
||||
assert callable(providers["elevenlabs"])
|
||||