mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-30 00:53:14 +00:00
Compare commits
10 Commits
sharepoint
...
hacktoberf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b2383b074 | ||
|
|
e4e9910575 | ||
|
|
f448e4a615 | ||
|
|
c4e8daf50e | ||
|
|
5aa4ec1b9f | ||
|
|
125ce0aad3 | ||
|
|
ababc9ae04 | ||
|
|
62ac90746e | ||
|
|
096f6d91a2 | ||
|
|
d28ef6b094 |
@@ -6,17 +6,4 @@ VITE_API_STREAMING=true
|
||||
OPENAI_API_BASE=
|
||||
OPENAI_API_VERSION=
|
||||
AZURE_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
|
||||
#Azure AD Application (client) ID
|
||||
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
|
||||
#Azure AD Application client secret
|
||||
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
|
||||
#Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
#If you are using a Microsoft Entra ID tenant,
|
||||
#configure the AUTHORITY variable as
|
||||
#"https://login.microsoftonline.com/TENANT_GUID"
|
||||
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||
MICROSOFT_AUTHORITY=https://{tenentId}.ciamlogin.com/{tenentId}
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
@@ -3,6 +3,7 @@
|
||||
Welcome, contributors! We're excited to announce that DocsGPT is participating in Hacktoberfest. Get involved by submitting meaningful pull requests.
|
||||
|
||||
All Meaningful contributors with accepted PRs that were created for issues with the `hacktoberfest` label (set by our maintainer team: dartpain, siiddhantt, pabik, ManishMadan2882) will receive a cool T-shirt! 🤩.
|
||||
<img width="1331" height="678" alt="hacktoberfest-mocks-preview" src="https://github.com/user-attachments/assets/633f6377-38db-48f5-b519-a8b3855a9eb4" />
|
||||
|
||||
Fill in [this form](https://forms.gle/Npaba4n9Epfyx56S8
|
||||
) after your PR was merged please
|
||||
|
||||
@@ -113,14 +113,10 @@ class ConnectorsCallback(Resource):
|
||||
session_token = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
if provider == "google_drive":
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
else:
|
||||
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
|
||||
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
except Exception as e:
|
||||
current_app.logger.warning(f"Could not get user info: {e}")
|
||||
user_email = 'Connected User'
|
||||
|
||||
@@ -10,7 +10,7 @@ from application.api import api
|
||||
from application.api.user.base import agents_collection, storage
|
||||
from application.api.user.tasks import store_attachment
|
||||
from application.core.settings import settings
|
||||
from application.tts.google_tts import GoogleTTS
|
||||
from application.tts.tts_creator import TTSCreator
|
||||
from application.utils import safe_filename
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ class TextToSpeech(Resource):
|
||||
data = request.get_json()
|
||||
text = data["text"]
|
||||
try:
|
||||
tts_instance = GoogleTTS()
|
||||
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
|
||||
audio_base64, detected_language = tts_instance.text_to_speech(text)
|
||||
return make_response(
|
||||
jsonify(
|
||||
|
||||
@@ -55,11 +55,6 @@ class Settings(BaseSettings):
|
||||
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
)
|
||||
|
||||
# Microsoft Entra ID (Azure AD) integration
|
||||
MICROSOFT_CLIENT_ID: Optional[str] = None # Azure AD Application (client) ID
|
||||
MICROSOFT_CLIENT_SECRET: Optional[str] = None # Azure AD Application client secret
|
||||
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
@@ -135,6 +130,7 @@ class Settings(BaseSettings):
|
||||
# Encryption settings
|
||||
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
||||
|
||||
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
|
||||
ELEVENLABS_API_KEY: Optional[str] = None
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||
from application.parser.connectors.share_point.loader import SharePointLoader
|
||||
|
||||
|
||||
class ConnectorCreator:
|
||||
@@ -14,12 +12,10 @@ class ConnectorCreator:
|
||||
|
||||
connectors = {
|
||||
"google_drive": GoogleDriveLoader,
|
||||
"share_point": SharePointLoader,
|
||||
}
|
||||
|
||||
auth_providers = {
|
||||
"google_drive": GoogleDriveAuth,
|
||||
"share_point": SharePointAuth,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
"""
|
||||
Share Point connector package for DocsGPT.
|
||||
|
||||
This module provides authentication and document loading capabilities for Share Point.
|
||||
"""
|
||||
|
||||
from .auth import SharePointAuth
|
||||
from .loader import SharePointLoader
|
||||
|
||||
__all__ = ['SharePointAuth', 'SharePointLoader']
|
||||
@@ -1,91 +0,0 @@
|
||||
import logging
|
||||
import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from msal import ConfidentialClientApplication
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.base import BaseConnectorAuth
|
||||
|
||||
|
||||
class SharePointAuth(BaseConnectorAuth):
|
||||
"""
|
||||
Handles Microsoft OAuth 2.0 authentication.
|
||||
|
||||
# Documentation:
|
||||
- https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow
|
||||
- https://learn.microsoft.com/en-gb/entra/msal/python/
|
||||
"""
|
||||
|
||||
# Microsoft Graph scopes for SharePoint access
|
||||
SCOPES = [
|
||||
"User.Read",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.client_id = settings.MICROSOFT_CLIENT_ID
|
||||
self.client_secret = settings.MICROSOFT_CLIENT_SECRET
|
||||
|
||||
if not self.client_id or not self.client_secret:
|
||||
raise ValueError(
|
||||
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID and MICROSOFT_CLIENT_SECRET in settings."
|
||||
)
|
||||
|
||||
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
|
||||
self.tenant_id = settings.MICROSOFT_TENANT_ID
|
||||
self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://{self.tenant_id}.ciamlogin.com/{self.tenant_id}")
|
||||
|
||||
self.auth_app = ConfidentialClientApplication(
|
||||
client_id=self.client_id, client_credential=self.client_secret, authority=self.authority
|
||||
)
|
||||
|
||||
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||
return self.auth_app.get_authorization_request_url(
|
||||
scopes=self.SCOPES, state=state, redirect_uri=self.redirect_uri
|
||||
)
|
||||
|
||||
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||
result = self.auth_app.acquire_token_by_authorization_code(
|
||||
code=authorization_code, scopes=self.SCOPES, redirect_uri=self.redirect_uri
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
logging.error(f"Error acquiring token: {result.get('error_description')}")
|
||||
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
|
||||
|
||||
return self.map_token_response(result)
|
||||
|
||||
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
result = self.auth_app.acquire_token_by_refresh_token(refresh_token=refresh_token, scopes=self.SCOPES)
|
||||
|
||||
if "error" in result:
|
||||
logging.error(f"Error acquiring token: {result.get('error_description')}")
|
||||
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
|
||||
|
||||
return self.map_token_response(result)
|
||||
|
||||
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||
if not token_info or "expiry" not in token_info:
|
||||
# If no expiry info, consider token expired to be safe
|
||||
return True
|
||||
|
||||
# Get expiry timestamp and current time
|
||||
expiry_timestamp = token_info["expiry"]
|
||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||
|
||||
# Token is expired if current time is greater than or equal to expiry time
|
||||
return current_timestamp >= expiry_timestamp
|
||||
|
||||
def map_token_response(self, result) -> Dict[str, Any]:
|
||||
return {
|
||||
"access_token": result.get("access_token"),
|
||||
"refresh_token": result.get("refresh_token"),
|
||||
"token_uri": result.get("id_token_claims", {}).get("iss"),
|
||||
"scopes": result.get("scope"),
|
||||
"expiry": result.get("id_token_claims", {}).get("exp"),
|
||||
"user_info": {
|
||||
"name": result.get("id_token_claims", {}).get("name"),
|
||||
"email": result.get("id_token_claims", {}).get("preferred_username"),
|
||||
},
|
||||
"raw_token": result,
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
from typing import List, Dict, Any
|
||||
from application.parser.connectors.base import BaseConnectorLoader
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
class SharePointLoader(BaseConnectorLoader):
|
||||
def __init__(self, session_token: str):
|
||||
pass
|
||||
|
||||
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""
|
||||
Load documents from the external knowledge base.
|
||||
|
||||
Args:
|
||||
inputs: Configuration dictionary containing:
|
||||
- file_ids: Optional list of specific file IDs to load
|
||||
- folder_ids: Optional list of folder IDs to browse/download
|
||||
- limit: Maximum number of items to return
|
||||
- list_only: If True, return metadata without content
|
||||
- recursive: Whether to recursively process folders
|
||||
|
||||
Returns:
|
||||
List of Document objects
|
||||
"""
|
||||
pass
|
||||
|
||||
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Download files/folders to a local directory.
|
||||
|
||||
Args:
|
||||
local_dir: Local directory path to download files to
|
||||
source_config: Configuration for what to download
|
||||
|
||||
Returns:
|
||||
Dictionary containing download results:
|
||||
- files_downloaded: Number of files downloaded
|
||||
- directory_path: Path where files were downloaded
|
||||
- empty_result: Whether no files were downloaded
|
||||
- source_type: Type of connector
|
||||
- config_used: Configuration that was used
|
||||
- error: Error message if download failed (optional)
|
||||
"""
|
||||
pass
|
||||
@@ -10,6 +10,7 @@ ebooklib==0.18
|
||||
escodegen==1.0.11
|
||||
esprima==4.0.1
|
||||
esutils==1.0.1
|
||||
elevenlabs==2.17.0
|
||||
Flask==3.1.1
|
||||
faiss-cpu==1.9.0.post1
|
||||
fastmcp==2.11.0
|
||||
@@ -40,7 +41,6 @@ markupsafe==3.0.2
|
||||
marshmallow==3.26.1
|
||||
mpmath==1.3.0
|
||||
multidict==6.4.3
|
||||
msal==1.34.0
|
||||
mypy-extensions==1.0.0
|
||||
networkx==3.4.2
|
||||
numpy==2.2.1
|
||||
@@ -88,4 +88,4 @@ werkzeug>=3.1.0,<3.1.2
|
||||
yarl==1.20.0
|
||||
markdownify==1.1.0
|
||||
tldextract==5.1.3
|
||||
websockets==14.1
|
||||
websockets==14.1
|
||||
|
||||
@@ -15,10 +15,11 @@ class ElevenlabsTTS(BaseTTS):
|
||||
|
||||
def text_to_speech(self, text):
|
||||
lang = "en"
|
||||
audio = self.client.generate(
|
||||
audio = self.client.text_to_speech.convert(
|
||||
voice_id="nPczCjzI2devNBz1zQrb",
|
||||
model_id="eleven_multilingual_v2",
|
||||
text=text,
|
||||
model="eleven_multilingual_v2",
|
||||
voice="Brian",
|
||||
output_format="mp3_44100_128"
|
||||
)
|
||||
audio_data = BytesIO()
|
||||
for chunk in audio:
|
||||
|
||||
18
application/tts/tts_creator.py
Normal file
18
application/tts/tts_creator.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from application.tts.google_tts import GoogleTTS
|
||||
from application.tts.elevenlabs import ElevenlabsTTS
|
||||
from application.tts.base import BaseTTS
|
||||
|
||||
|
||||
|
||||
class TTSCreator:
|
||||
tts_providers = {
|
||||
"google_tts": GoogleTTS,
|
||||
"elevenlabs": ElevenlabsTTS,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_tts(cls, tts_type, *args, **kwargs)-> BaseTTS:
|
||||
tts_class = cls.tts_providers.get(tts_type.lower())
|
||||
if not tts_class:
|
||||
raise ValueError(f"No tts class found for type {tts_type}")
|
||||
return tts_class(*args, **kwargs)
|
||||
@@ -21,7 +21,7 @@ def get_encoding():
|
||||
|
||||
|
||||
def get_gpt_model() -> str:
|
||||
"""Get the appropriate GPT model based on provider"""
|
||||
"""Get GPT model based on provider"""
|
||||
model_map = {
|
||||
"openai": "gpt-4o-mini",
|
||||
"anthropic": "claude-2",
|
||||
@@ -32,16 +32,7 @@ def get_gpt_model() -> str:
|
||||
|
||||
|
||||
def safe_filename(filename):
|
||||
"""
|
||||
Creates a safe filename that preserves the original extension.
|
||||
Uses secure_filename, but ensures a proper filename is returned even with non-Latin characters.
|
||||
|
||||
Args:
|
||||
filename (str): The original filename
|
||||
|
||||
Returns:
|
||||
str: A safe filename that can be used for storage
|
||||
"""
|
||||
"""Create safe filename, preserving extension. Handles non-Latin characters."""
|
||||
if not filename:
|
||||
return str(uuid.uuid4())
|
||||
_, extension = os.path.splitext(filename)
|
||||
@@ -83,8 +74,14 @@ def count_tokens_docs(docs):
|
||||
return tokens
|
||||
|
||||
|
||||
def get_missing_fields(data, required_fields):
|
||||
"""Check for missing required fields. Returns list of missing field names."""
|
||||
return [field for field in required_fields if field not in data]
|
||||
|
||||
|
||||
def check_required_fields(data, required_fields):
|
||||
missing_fields = [field for field in required_fields if field not in data]
|
||||
"""Validate required fields. Returns Flask 400 response if validation fails, None otherwise."""
|
||||
missing_fields = get_missing_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -98,7 +95,8 @@ def check_required_fields(data, required_fields):
|
||||
return None
|
||||
|
||||
|
||||
def validate_required_fields(data, required_fields):
|
||||
def get_field_validation_errors(data, required_fields):
|
||||
"""Check for missing and empty fields. Returns dict with 'missing_fields' and 'empty_fields', or None."""
|
||||
missing_fields = []
|
||||
empty_fields = []
|
||||
|
||||
@@ -107,12 +105,24 @@ def validate_required_fields(data, required_fields):
|
||||
missing_fields.append(field)
|
||||
elif not data[field]:
|
||||
empty_fields.append(field)
|
||||
errors = []
|
||||
if missing_fields:
|
||||
errors.append(f"Missing required fields: {', '.join(missing_fields)}")
|
||||
if empty_fields:
|
||||
errors.append(f"Empty values in required fields: {', '.join(empty_fields)}")
|
||||
if errors:
|
||||
if missing_fields or empty_fields:
|
||||
return {"missing_fields": missing_fields, "empty_fields": empty_fields}
|
||||
return None
|
||||
|
||||
|
||||
def validate_required_fields(data, required_fields):
|
||||
"""Validate required fields (must exist and be non-empty). Returns Flask 400 response if validation fails, None otherwise."""
|
||||
errors_dict = get_field_validation_errors(data, required_fields)
|
||||
if errors_dict:
|
||||
errors = []
|
||||
if errors_dict["missing_fields"]:
|
||||
errors.append(
|
||||
f"Missing required fields: {', '.join(errors_dict['missing_fields'])}"
|
||||
)
|
||||
if errors_dict["empty_fields"]:
|
||||
errors.append(
|
||||
f"Empty values in required fields: {', '.join(errors_dict['empty_fields'])}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": " | ".join(errors)}), 400
|
||||
)
|
||||
@@ -124,10 +134,7 @@ def get_hash(data):
|
||||
|
||||
|
||||
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
"""
|
||||
Limits chat history based on token count.
|
||||
Returns a list of messages that fit within the token limit.
|
||||
"""
|
||||
"""Limit chat history to fit within token limit."""
|
||||
from application.core.settings import settings
|
||||
|
||||
max_token_limit = (
|
||||
@@ -161,7 +168,7 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
|
||||
|
||||
def validate_function_name(function_name):
|
||||
"""Validates if a function name matches the allowed pattern."""
|
||||
"""Validate function name matches allowed pattern (alphanumeric, underscore, hyphen)."""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -9,7 +9,8 @@ import userService from './api/services/userService';
|
||||
import Add from './assets/add.svg';
|
||||
import DocsGPT3 from './assets/cute_docsgpt3.svg';
|
||||
import Discord from './assets/discord.svg';
|
||||
import Expand from './assets/expand.svg';
|
||||
import PanelLeftClose from './assets/panel-left-close.svg';
|
||||
import PanelLeftOpen from './assets/panel-left-open.svg';
|
||||
import Github from './assets/git_nav.svg';
|
||||
import Hamburger from './assets/hamburger.svg';
|
||||
import openNewChat from './assets/openNewChat.svg';
|
||||
@@ -302,18 +303,20 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
{
|
||||
<div className="absolute top-3 left-3 z-20 hidden transition-all duration-300 ease-in-out lg:block">
|
||||
<div className="flex items-center gap-3">
|
||||
<button
|
||||
onClick={() => {
|
||||
setNavOpen(!navOpen);
|
||||
}}
|
||||
className="transition-transform duration-200 hover:scale-110"
|
||||
>
|
||||
<img
|
||||
src={Expand}
|
||||
alt="Toggle navigation menu"
|
||||
className="m-auto transition-all duration-300 ease-in-out"
|
||||
/>
|
||||
</button>
|
||||
{!navOpen && (
|
||||
<button
|
||||
onClick={() => {
|
||||
setNavOpen(!navOpen);
|
||||
}}
|
||||
className="transition-transform duration-200 hover:scale-110"
|
||||
>
|
||||
<img
|
||||
src={PanelLeftOpen}
|
||||
alt="Open navigation menu"
|
||||
className="m-auto transition-all duration-300 ease-in-out"
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
{queries?.length > 0 && (
|
||||
<button
|
||||
onClick={() => {
|
||||
@@ -363,8 +366,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
}}
|
||||
>
|
||||
<img
|
||||
src={Expand}
|
||||
alt="Toggle navigation menu"
|
||||
src={navOpen ? PanelLeftClose : PanelLeftOpen}
|
||||
alt={navOpen ? 'Collapse sidebar' : 'Expand sidebar'}
|
||||
className="m-auto transition-all duration-300 ease-in-out hover:scale-110"
|
||||
/>
|
||||
</button>
|
||||
|
||||
@@ -109,18 +109,18 @@ export default function AgentPreview() {
|
||||
} else setLastQueryReturnedErr(false);
|
||||
}, [queries]);
|
||||
return (
|
||||
<div>
|
||||
<div className="dark:bg-raisin-black flex h-full flex-col items-center justify-between gap-2 overflow-y-hidden">
|
||||
<div className="h-[512px] w-full overflow-y-auto">
|
||||
<ConversationMessages
|
||||
handleQuestion={handleQuestion}
|
||||
handleQuestionSubmission={handleQuestionSubmission}
|
||||
queries={queries}
|
||||
status={status}
|
||||
showHeroOnEmpty={false}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex w-[95%] max-w-[1500px] flex-col items-center gap-4 pb-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
|
||||
<div className="relative h-full w-full">
|
||||
<div className="scrollbar-thin absolute inset-0 bottom-[180px] overflow-hidden px-4 pt-4 [&>div>div]:!w-full [&>div>div]:!max-w-none">
|
||||
<ConversationMessages
|
||||
handleQuestion={handleQuestion}
|
||||
handleQuestionSubmission={handleQuestionSubmission}
|
||||
queries={queries}
|
||||
status={status}
|
||||
showHeroOnEmpty={false}
|
||||
/>
|
||||
</div>
|
||||
<div className="absolute right-0 bottom-0 left-0 flex w-full flex-col gap-4 pb-2">
|
||||
<div className="w-full px-4">
|
||||
<MessageInput
|
||||
onSubmit={(text) => handleQuestionSubmission(text)}
|
||||
loading={status === 'loading'}
|
||||
@@ -128,11 +128,11 @@ export default function AgentPreview() {
|
||||
showToolButton={selectedAgent ? false : true}
|
||||
autoFocus={false}
|
||||
/>
|
||||
<p className="text-gray-4000 dark:text-sonic-silver w-full self-center bg-transparent pt-2 text-center text-xs md:inline">
|
||||
This is a preview of the agent. You can publish it to start using it
|
||||
in conversations.
|
||||
</p>
|
||||
</div>
|
||||
<p className="text-gray-4000 dark:text-sonic-silver w-full bg-transparent text-center text-xs md:inline">
|
||||
This is a preview of the agent. You can publish it to start using it
|
||||
in conversations.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -534,7 +534,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
setHasChanges(isChanged);
|
||||
}, [agent, dispatch, effectiveMode, imageFile, jsonSchemaText]);
|
||||
return (
|
||||
<div className="p-4 md:p-12">
|
||||
<div className="flex flex-col px-4 pt-4 pb-2 max-[1179px]:min-h-[100dvh] min-[1180px]:h-[100dvh] md:px-12 md:pt-12 md:pb-3">
|
||||
<div className="flex items-center gap-3 px-4">
|
||||
<button
|
||||
className="rounded-full border p-3 text-sm text-gray-400 dark:border-0 dark:bg-[#28292D] dark:text-gray-500 dark:hover:bg-[#2E2F34]"
|
||||
@@ -615,9 +615,9 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-5 flex w-full grid-cols-5 flex-col gap-10 min-[1180px]:grid min-[1180px]:gap-5">
|
||||
<div className="col-span-2 flex flex-col gap-5">
|
||||
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<div className="mt-3 flex w-full flex-1 grid-cols-5 flex-col gap-10 rounded-[30px] bg-[#F6F6F6] p-5 max-[1179px]:overflow-visible min-[1180px]:grid min-[1180px]:gap-5 min-[1180px]:overflow-hidden dark:bg-[#383838]">
|
||||
<div className="scrollbar-thin col-span-2 flex flex-col gap-5 max-[1179px]:overflow-visible min-[1180px]:max-h-full min-[1180px]:overflow-y-auto min-[1180px]:pr-3">
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<h2 className="text-lg font-semibold">Meta</h2>
|
||||
<input
|
||||
className="border-silver text-jet dark:bg-raisin-black dark:text-bright-gray dark:placeholder:text-silver mt-3 w-full rounded-3xl border bg-white px-5 py-3 text-sm outline-hidden placeholder:text-gray-400 dark:border-[#7E7E7E]"
|
||||
@@ -650,7 +650,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<h2 className="text-lg font-semibold">Source</h2>
|
||||
<div className="mt-3">
|
||||
<div className="flex flex-wrap items-center gap-1">
|
||||
@@ -744,7 +744,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<div className="flex flex-wrap items-end gap-1">
|
||||
<div className="min-w-20 grow basis-full sm:basis-0">
|
||||
<Prompts
|
||||
@@ -781,7 +781,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<h2 className="text-lg font-semibold">Tools</h2>
|
||||
<div className="mt-3 flex flex-wrap items-center gap-1">
|
||||
<button
|
||||
@@ -823,7 +823,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<h2 className="text-lg font-semibold">Agent type</h2>
|
||||
<div className="mt-3">
|
||||
<Dropdown
|
||||
@@ -848,7 +848,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
|
||||
<button
|
||||
onClick={() =>
|
||||
setIsAdvancedSectionExpanded(!isAdvancedSectionExpanded)
|
||||
@@ -1032,9 +1032,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="col-span-3 flex flex-col gap-3 rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<div className="col-span-3 flex flex-col gap-2 max-[1179px]:h-auto max-[1179px]:px-0 max-[1179px]:py-0 min-[1180px]:h-full min-[1180px]:py-2 dark:text-[#E0E0E0]">
|
||||
<h2 className="text-lg font-semibold">Preview</h2>
|
||||
<AgentPreviewArea />
|
||||
<div className="flex-1 max-[1179px]:overflow-visible min-[1180px]:min-h-0 min-[1180px]:overflow-hidden">
|
||||
<AgentPreviewArea />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<ConfirmationModal
|
||||
@@ -1071,9 +1073,9 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
function AgentPreviewArea() {
|
||||
const selectedAgent = useSelector(selectSelectedAgent);
|
||||
return (
|
||||
<div className="dark:bg-raisin-black h-full w-full rounded-[30px] border border-[#F6F6F6] bg-white max-[1180px]:h-192 dark:border-[#7E7E7E]">
|
||||
<div className="dark:bg-raisin-black w-full rounded-[30px] border border-[#F6F6F6] bg-white max-[1179px]:h-[600px] min-[1180px]:h-full dark:border-[#7E7E7E]">
|
||||
{selectedAgent?.status === 'published' ? (
|
||||
<div className="flex h-full w-full flex-col justify-end overflow-auto rounded-[30px]">
|
||||
<div className="flex h-full w-full flex-col overflow-hidden rounded-[30px]">
|
||||
<AgentPreview />
|
||||
</div>
|
||||
) : (
|
||||
|
||||
@@ -177,13 +177,15 @@ export default function SharedAgent() {
|
||||
/>
|
||||
</div>
|
||||
<div className="flex w-[95%] max-w-[1500px] flex-col items-center pb-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
|
||||
<MessageInput
|
||||
onSubmit={(text) => handleQuestionSubmission(text)}
|
||||
loading={status === 'loading'}
|
||||
showSourceButton={sharedAgent ? false : true}
|
||||
showToolButton={sharedAgent ? false : true}
|
||||
autoFocus={false}
|
||||
/>
|
||||
<div className="w-full px-2">
|
||||
<MessageInput
|
||||
onSubmit={(text) => handleQuestionSubmission(text)}
|
||||
loading={status === 'loading'}
|
||||
showSourceButton={sharedAgent ? false : true}
|
||||
showToolButton={sharedAgent ? false : true}
|
||||
autoFocus={false}
|
||||
/>
|
||||
</div>
|
||||
<p className="text-gray-4000 dark:text-sonic-silver hidden w-screen self-center bg-transparent py-2 text-center text-xs md:inline md:w-full">
|
||||
{t('tagline')}
|
||||
</p>
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
<svg width="27" height="26" viewBox="0 0 27 26" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M4.03371 5.27275L4.1915 20.9162C4.20021 21.7802 4.90766 22.4735 5.77162 22.4648L21.4151 22.307C22.2791 22.2983 22.9724 21.5909 22.9637 20.7269L22.8059 5.0834C22.7972 4.21944 22.0897 3.52612 21.2258 3.53483L5.58228 3.69262C4.71831 3.70134 4.02499 4.40878 4.03371 5.27275Z" stroke="#949494" stroke-width="2.08591" stroke-linejoin="round"/>
|
||||
<path d="M9.42289 22.428L9.23354 3.65585M17.6924 15.0436L15.5856 12.9788L17.6504 10.872M6.29419 22.4596L12.5516 22.3965M6.10484 3.68741L12.3622 3.62429" stroke="#949494" stroke-width="2.08591" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 692 B |
@@ -1,5 +1,5 @@
|
||||
<svg width="113" height="124" viewBox="0 0 113 124" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<circle cx="55.5" cy="71" r="53" fill="#F1F1F1" fill-opacity="0.5"/>
|
||||
<circle cx="55.5" cy="71" r="53" fill="#E8E3F3" fill-opacity="0.6"/>
|
||||
<rect x="-0.599797" y="0.654564" width="43.9445" height="61.5222" rx="4.39444" transform="matrix(-0.999048 0.0436194 0.0436194 0.999048 68.9873 43.3176)" fill="#EEEEEE" stroke="#999999" stroke-width="1.25556"/>
|
||||
<rect x="0.704349" y="-0.540466" width="46.4556" height="64.0333" rx="5.65" transform="matrix(-0.991445 -0.130526 -0.130526 0.991445 96.3673 40.893)" fill="#FAFAFA" stroke="#999999" stroke-width="1.25556"/>
|
||||
<path d="M94.3796 45.7849C94.7417 43.0349 92.8059 40.5122 90.0559 40.1501L55.2011 35.5614C52.4511 35.1994 49.9284 37.1352 49.5663 39.8851L48.3372 49.2212L93.1505 55.121L94.3796 45.7849Z" fill="#EEEEEE"/>
|
||||
|
||||
|
Before Width: | Height: | Size: 2.0 KiB After Width: | Height: | Size: 2.0 KiB |
1
frontend/src/assets/panel-left-close.svg
Normal file
1
frontend/src/assets/panel-left-close.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#949494" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-panel-left-close-icon lucide-panel-left-close"><rect width="18" height="18" x="3" y="3" rx="2"/><path d="M9 3v18"/><path d="m16 15-3-3 3-3"/></svg>
|
||||
|
After Width: | Height: | Size: 345 B |
1
frontend/src/assets/panel-left-open.svg
Normal file
1
frontend/src/assets/panel-left-open.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#949494" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-panel-left-open-icon lucide-panel-left-open"><rect width="18" height="18" x="3" y="3" rx="2"/><path d="M9 3v18"/><path d="m14 9 3 3-3 3"/></svg>
|
||||
|
After Width: | Height: | Size: 342 B |
@@ -1,16 +0,0 @@
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
|
||||
<!-- Uploaded to: SVG Repo, www.svgrepo.com, Transformed by: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 0 48 48" id="b" xmlns="http://www.w3.org/2000/svg" fill="#000000" stroke="#000000" stroke-width="3.312">
|
||||
|
||||
<g id="SVGRepo_bgCarrier" stroke-width="0"/>
|
||||
|
||||
<g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
|
||||
<g id="SVGRepo_iconCarrier">
|
||||
|
||||
<defs>
|
||||
|
||||
<style>.c{fill:none;stroke:#000000;stroke-linecap:round;stroke-linejoin:round;}</style>
|
||||
|
||||
</defs>
|
||||
|
Before Width: | Height: | Size: 1.2 KiB |
@@ -136,33 +136,34 @@ const Chunks: React.FC<ChunksProps> = ({
|
||||
|
||||
const pathParts = path ? path.split('/') : [];
|
||||
|
||||
const fetchChunks = () => {
|
||||
const fetchChunks = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
userService
|
||||
.getDocumentChunks(documentId, page, perPage, token, path, searchTerm)
|
||||
.then((response) => {
|
||||
if (!response.ok) {
|
||||
setLoading(false);
|
||||
setPaginatedChunks([]);
|
||||
throw new Error('Failed to fetch chunks data');
|
||||
}
|
||||
return response.json();
|
||||
})
|
||||
.then((data) => {
|
||||
setPage(data.page);
|
||||
setPerPage(data.per_page);
|
||||
setTotalChunks(data.total);
|
||||
setPaginatedChunks(data.chunks);
|
||||
setLoading(false);
|
||||
})
|
||||
.catch((error) => {
|
||||
setLoading(false);
|
||||
setPaginatedChunks([]);
|
||||
});
|
||||
} catch (e) {
|
||||
setLoading(false);
|
||||
const response = await userService.getDocumentChunks(
|
||||
documentId,
|
||||
page,
|
||||
perPage,
|
||||
token,
|
||||
path,
|
||||
searchTerm,
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to fetch chunks data');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
setPage(data.page);
|
||||
setPerPage(data.per_page);
|
||||
setTotalChunks(data.total);
|
||||
setPaginatedChunks(data.chunks);
|
||||
} catch (error) {
|
||||
setPaginatedChunks([]);
|
||||
console.error(error);
|
||||
} finally {
|
||||
// ✅ always runs, success or failure
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
const ConnectedStateSkeleton = () => (
|
||||
<div className="mb-4">
|
||||
<div className="flex w-full animate-pulse items-center justify-between rounded-[10px] bg-gray-200 px-4 py-2 dark:bg-gray-700">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="h-4 w-4 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
<div className="h-4 w-32 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
<div className="h-4 w-16 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
export default ConnectedStateSkeleton;
|
||||
@@ -150,7 +150,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
{isConnected ? (
|
||||
<div className="mb-4">
|
||||
<div className="flex w-full items-center justify-between rounded-[10px] bg-[#8FDD51] px-4 py-2 text-sm font-medium text-[#212121]">
|
||||
<div className="flex max-w-[500px] items-center gap-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<svg className="h-4 w-4" viewBox="0 0 24 24">
|
||||
<path
|
||||
fill="currentColor"
|
||||
|
||||
@@ -20,14 +20,9 @@ type CopyButtonProps = {
|
||||
const DEFAULT_ICON_SIZE = 'w-4 h-4';
|
||||
const DEFAULT_PADDING = 'p-2';
|
||||
const DEFAULT_COPIED_DURATION = 2000;
|
||||
const DEFAULT_BG_LIGHT = '#FFFFFF';
|
||||
const DEFAULT_BG_DARK = 'transparent';
|
||||
const DEFAULT_HOVER_BG_LIGHT = '#EEEEEE';
|
||||
const DEFAULT_HOVER_BG_DARK = '#464152';
|
||||
|
||||
export default function CopyButton({
|
||||
textToCopy,
|
||||
|
||||
iconSize = DEFAULT_ICON_SIZE,
|
||||
padding = DEFAULT_PADDING,
|
||||
showText = false,
|
||||
@@ -43,9 +38,8 @@ export default function CopyButton({
|
||||
const iconWrapperClasses = clsx(
|
||||
'flex items-center justify-center rounded-full transition-colors duration-150 ease-in-out',
|
||||
padding,
|
||||
`bg-[${DEFAULT_BG_LIGHT}] dark:bg-[${DEFAULT_BG_DARK}]`,
|
||||
{
|
||||
[`hover:bg-[${DEFAULT_HOVER_BG_LIGHT}] dark:hover:bg-[${DEFAULT_HOVER_BG_DARK}]`]:
|
||||
[`bg-[#FFFFFF}] dark:bg-transparent hover:bg-[#EEEEEE] dark:hover:bg-purple-taupe`]:
|
||||
!isCopied,
|
||||
'bg-green-100 dark:bg-green-900 hover:bg-green-100 dark:hover:bg-green-900':
|
||||
isCopied,
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
const FilesSectionSkeleton = () => (
|
||||
<div className="rounded-lg border border-[#EEE6FF78] dark:border-[#6A6A6A]">
|
||||
<div className="p-4">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div className="h-5 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div className="h-8 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
<div className="h-4 w-40 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
export default FilesSectionSkeleton;
|
||||
@@ -7,10 +7,7 @@ import {
|
||||
getSessionToken,
|
||||
setSessionToken,
|
||||
removeSessionToken,
|
||||
validateProviderSession,
|
||||
} from '../utils/providerUtils';
|
||||
import ConnectedStateSkeleton from './ConnectedStateSkeleton';
|
||||
import FilesSectionSkeleton from './FileSelectionSkeleton';
|
||||
|
||||
interface PickerFile {
|
||||
id: string;
|
||||
@@ -53,9 +50,20 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
|
||||
const validateSession = async (sessionToken: string) => {
|
||||
try {
|
||||
const validateResponse = await validateProviderSession(
|
||||
token,
|
||||
'google_drive',
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
const validateResponse = await fetch(
|
||||
`${apiHost}/api/connectors/validate-session`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: 'google_drive',
|
||||
session_token: sessionToken,
|
||||
}),
|
||||
},
|
||||
);
|
||||
|
||||
if (!validateResponse.ok) {
|
||||
@@ -226,6 +234,30 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
onSelectionChange([], []);
|
||||
};
|
||||
|
||||
const ConnectedStateSkeleton = () => (
|
||||
<div className="mb-4">
|
||||
<div className="flex w-full animate-pulse items-center justify-between rounded-[10px] bg-gray-200 px-4 py-2 dark:bg-gray-700">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="h-4 w-4 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
<div className="h-4 w-32 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
<div className="h-4 w-16 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
const FilesSectionSkeleton = () => (
|
||||
<div className="rounded-lg border border-[#EEE6FF78] dark:border-[#6A6A6A]">
|
||||
<div className="p-4">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div className="h-5 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div className="h-8 w-24 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
<div className="h-4 w-40 animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<div>
|
||||
{isValidating ? (
|
||||
|
||||
@@ -91,8 +91,10 @@ export default function MessageInput({
|
||||
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
const xhr = new XMLHttpRequest();
|
||||
const uniqueId = crypto.randomUUID();
|
||||
|
||||
const newAttachment = {
|
||||
id: uniqueId,
|
||||
fileName: file.name,
|
||||
progress: 0,
|
||||
status: 'uploading' as const,
|
||||
@@ -106,7 +108,7 @@ export default function MessageInput({
|
||||
const progress = Math.round((event.loaded / event.total) * 100);
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
taskId: newAttachment.taskId,
|
||||
id: uniqueId,
|
||||
updates: { progress },
|
||||
}),
|
||||
);
|
||||
@@ -119,7 +121,7 @@ export default function MessageInput({
|
||||
if (response.task_id) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
taskId: newAttachment.taskId,
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
status: 'processing',
|
||||
@@ -131,7 +133,7 @@ export default function MessageInput({
|
||||
} else {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
taskId: newAttachment.taskId,
|
||||
id: uniqueId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
@@ -141,7 +143,7 @@ export default function MessageInput({
|
||||
xhr.onerror = () => {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
taskId: newAttachment.taskId,
|
||||
id: uniqueId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
@@ -167,7 +169,7 @@ export default function MessageInput({
|
||||
if (data.status === 'SUCCESS') {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
taskId: attachment.taskId!,
|
||||
id: attachment.id,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
@@ -179,14 +181,14 @@ export default function MessageInput({
|
||||
} else if (data.status === 'FAILURE') {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
taskId: attachment.taskId!,
|
||||
id: attachment.id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
} else if (data.status === 'PROGRESS' && data.result?.current) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
taskId: attachment.taskId!,
|
||||
id: attachment.id,
|
||||
updates: { progress: data.result.current },
|
||||
}),
|
||||
);
|
||||
@@ -195,7 +197,7 @@ export default function MessageInput({
|
||||
.catch(() => {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
taskId: attachment.taskId!,
|
||||
id: attachment.id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
@@ -260,12 +262,12 @@ export default function MessageInput({
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="mx-2 flex w-full flex-col">
|
||||
<div className="flex w-full flex-col">
|
||||
<div className="border-dark-gray bg-lotion dark:border-grey relative flex w-full flex-col rounded-[23px] border dark:bg-transparent">
|
||||
<div className="flex flex-wrap gap-1.5 px-2 py-2 sm:gap-2 sm:px-3">
|
||||
{attachments.map((attachment, index) => (
|
||||
{attachments.map((attachment) => (
|
||||
<div
|
||||
key={index}
|
||||
key={attachment.id}
|
||||
className={`group dark:text-bright-gray relative flex items-center rounded-xl bg-[#EFF3F4] px-2 py-1 text-[12px] text-[#5D5D5D] sm:px-3 sm:py-1.5 sm:text-[14px] dark:bg-[#393B3D] ${
|
||||
attachment.status !== 'completed' ? 'opacity-70' : 'opacity-100'
|
||||
}`}
|
||||
@@ -327,11 +329,7 @@ export default function MessageInput({
|
||||
<button
|
||||
className="ml-1.5 flex items-center justify-center rounded-full p-1"
|
||||
onClick={() => {
|
||||
if (attachment.id) {
|
||||
dispatch(removeAttachment(attachment.id));
|
||||
} else if (attachment.taskId) {
|
||||
dispatch(removeAttachment(attachment.taskId));
|
||||
}
|
||||
dispatch(removeAttachment(attachment.id));
|
||||
}}
|
||||
aria-label={t('conversation.attachments.remove')}
|
||||
>
|
||||
|
||||
@@ -1,175 +0,0 @@
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ConnectorAuth from './ConnectorAuth';
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
import {
|
||||
getSessionToken,
|
||||
setSessionToken,
|
||||
removeSessionToken,
|
||||
validateProviderSession,
|
||||
} from '../utils/providerUtils';
|
||||
import ConnectedStateSkeleton from './ConnectedStateSkeleton';
|
||||
import FilesSectionSkeleton from './FileSelectionSkeleton';
|
||||
|
||||
interface SharePointPickerProps {
|
||||
token: string | null;
|
||||
}
|
||||
|
||||
const SharePointPicker: React.FC<SharePointPickerProps> = ({ token }) => {
|
||||
const { t } = useTranslation();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [userEmail, setUserEmail] = useState<string>('');
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
const [authError, setAuthError] = useState<string>('');
|
||||
const [accessToken, setAccessToken] = useState<string | null>(null);
|
||||
const [isValidating, setIsValidating] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const sessionToken = getSessionToken('share_point');
|
||||
if (sessionToken) {
|
||||
setIsValidating(true);
|
||||
setIsConnected(true); // Optimistically set as connected for skeleton
|
||||
validateSession(sessionToken);
|
||||
}
|
||||
}, [token]);
|
||||
|
||||
const validateSession = async (sessionToken: string) => {
|
||||
try {
|
||||
const validateResponse = await validateProviderSession(
|
||||
token,
|
||||
'share_point',
|
||||
);
|
||||
|
||||
if (!validateResponse.ok) {
|
||||
setIsConnected(false);
|
||||
setAuthError(
|
||||
t('modals.uploadDoc.connectors.sharePoint.sessionExpired'),
|
||||
);
|
||||
setIsValidating(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
const validateData = await validateResponse.json();
|
||||
if (validateData.success) {
|
||||
setUserEmail(
|
||||
validateData.user_email ||
|
||||
t('modals.uploadDoc.connectors.auth.connectedUser'),
|
||||
);
|
||||
setIsConnected(true);
|
||||
setAuthError('');
|
||||
setAccessToken(validateData.access_token || null);
|
||||
setIsValidating(false);
|
||||
|
||||
return true;
|
||||
} else {
|
||||
setIsConnected(false);
|
||||
setAuthError(
|
||||
validateData.error ||
|
||||
t('modals.uploadDoc.connectors.sharePoint.sessionExpiredGeneric'),
|
||||
);
|
||||
setIsValidating(false);
|
||||
return false;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error validating session:', error);
|
||||
setAuthError(t('modals.uploadDoc.connectors.sharePoint.validateFailed'));
|
||||
setIsConnected(false);
|
||||
setIsValidating(false);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
const handleDisconnect = async () => {
|
||||
const sessionToken = getSessionToken('share_point');
|
||||
if (sessionToken) {
|
||||
try {
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
await fetch(`${apiHost}/api/connectors/disconnect`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: 'share_point',
|
||||
session_token: sessionToken,
|
||||
}),
|
||||
});
|
||||
} catch (err) {
|
||||
console.error('Error disconnecting from SharePoint:', err);
|
||||
}
|
||||
}
|
||||
|
||||
removeSessionToken('share_point');
|
||||
setIsConnected(false);
|
||||
setAccessToken(null);
|
||||
setUserEmail('');
|
||||
setAuthError('');
|
||||
};
|
||||
|
||||
const handleOpenPicker = async () => {
|
||||
alert('Feature not supported yet.');
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
{isValidating ? (
|
||||
<>
|
||||
<ConnectedStateSkeleton />
|
||||
<FilesSectionSkeleton />
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<ConnectorAuth
|
||||
provider="share_point"
|
||||
label={t('modals.uploadDoc.connectors.sharePoint.connect')}
|
||||
onSuccess={(data) => {
|
||||
setUserEmail(
|
||||
data.user_email ||
|
||||
t('modals.uploadDoc.connectors.auth.connectedUser'),
|
||||
);
|
||||
setIsConnected(true);
|
||||
setAuthError('');
|
||||
|
||||
if (data.session_token) {
|
||||
setSessionToken('share_point', data.session_token);
|
||||
validateSession(data.session_token);
|
||||
}
|
||||
}}
|
||||
onError={(error) => {
|
||||
setAuthError(error);
|
||||
setIsConnected(false);
|
||||
}}
|
||||
isConnected={isConnected}
|
||||
userEmail={userEmail}
|
||||
onDisconnect={handleDisconnect}
|
||||
errorMessage={authError}
|
||||
/>
|
||||
|
||||
{isConnected && (
|
||||
<div className="rounded-lg border border-[#EEE6FF78] dark:border-[#6A6A6A]">
|
||||
<div className="p-4">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<h3 className="text-sm font-medium">
|
||||
{t('modals.uploadDoc.connectors.sharePoint.selectedFiles')}
|
||||
</h3>
|
||||
<button
|
||||
onClick={() => handleOpenPicker()}
|
||||
className="rounded-md bg-[#A076F6] px-3 py-1 text-sm text-white hover:bg-[#8A5FD4]"
|
||||
disabled={isLoading}
|
||||
>
|
||||
{isLoading
|
||||
? t('modals.uploadDoc.connectors.sharePoint.loading')
|
||||
: t('modals.uploadDoc.connectors.sharePoint.selectFiles')}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default SharePointPicker;
|
||||
@@ -33,7 +33,7 @@ export default function Sidebar({
|
||||
return (
|
||||
<div ref={sidebarRef} className="h-vh relative">
|
||||
<div
|
||||
className={`dark:bg-chinese-black fixed top-0 right-0 z-50 h-full w-72 transform bg-white shadow-xl transition-all duration-300 sm:w-96 ${
|
||||
className={`dark:bg-chinese-black fixed top-0 right-0 z-50 h-full w-64 transform bg-white shadow-xl transition-all duration-300 sm:w-80 ${
|
||||
isOpen ? 'translate-x-[10px]' : 'translate-x-full'
|
||||
} border-l border-[#9ca3af]/10`}
|
||||
>
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import React, { useRef, useEffect, useState, useLayoutEffect } from 'react';
|
||||
import { createPortal } from 'react-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { Doc } from '../models/misc';
|
||||
@@ -107,7 +108,7 @@ export default function SourcesPopup({
|
||||
onClose();
|
||||
};
|
||||
|
||||
return (
|
||||
const popupContent = (
|
||||
<div
|
||||
ref={popupRef}
|
||||
className="bg-lotion dark:bg-charleston-green-2 fixed z-50 flex flex-col rounded-xl shadow-[0px_9px_46px_8px_#0000001F,0px_24px_38px_3px_#00000024,0px_11px_15px_-7px_#00000033]"
|
||||
@@ -218,7 +219,7 @@ export default function SourcesPopup({
|
||||
</>
|
||||
) : (
|
||||
<div className="dark:text-bright-gray p-4 text-center text-gray-500 dark:text-[14px]">
|
||||
{t('noSourcesAvailable')}
|
||||
{t('conversation.sources.noSourcesAvailable')}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -245,4 +246,6 @@ export default function SourcesPopup({
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return createPortal(popupContent, document.body);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import React, { useEffect, useRef, useState, useLayoutEffect } from 'react';
|
||||
import { createPortal } from 'react-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
@@ -133,10 +134,10 @@ export default function ToolsPopup({
|
||||
tool.displayName.toLowerCase().includes(searchTerm.toLowerCase()),
|
||||
);
|
||||
|
||||
return (
|
||||
const popupContent = (
|
||||
<div
|
||||
ref={popupRef}
|
||||
className="border-light-silver bg-lotion dark:border-dim-gray dark:bg-charleston-green-2 fixed z-9999 rounded-lg border shadow-[0px_9px_46px_8px_#0000001F,0px_24px_38px_3px_#00000024,0px_11px_15px_-7px_#00000033]"
|
||||
className="border-light-silver bg-lotion dark:border-dim-gray dark:bg-charleston-green-2 fixed z-50 rounded-lg border shadow-[0px_9px_46px_8px_#0000001F,0px_24px_38px_3px_#00000024,0px_11px_15px_-7px_#00000033]"
|
||||
style={{
|
||||
top: popupPosition.showAbove ? popupPosition.top : undefined,
|
||||
bottom: popupPosition.showAbove
|
||||
@@ -242,4 +243,6 @@ export default function ToolsPopup({
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return createPortal(popupContent, document.body);
|
||||
}
|
||||
|
||||
@@ -44,7 +44,10 @@ export default function UploadToast() {
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="fixed right-4 bottom-4 z-50 flex max-w-md flex-col gap-2">
|
||||
<div
|
||||
className="fixed right-4 bottom-4 z-50 flex max-w-md flex-col gap-2"
|
||||
onMouseDown={(e) => e.stopPropagation()}
|
||||
>
|
||||
{uploadTasks
|
||||
.filter((task) => !task.dismissed)
|
||||
.map((task) => {
|
||||
|
||||
@@ -224,7 +224,7 @@ export default function Conversation() {
|
||||
<div className="bg-opacity-0 z-3 flex h-auto w-full max-w-[1300px] flex-col items-end self-center rounded-2xl py-1 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
|
||||
<div
|
||||
{...getRootProps()}
|
||||
className="flex w-full items-center rounded-[40px]"
|
||||
className="flex w-full items-center rounded-[40px] px-2"
|
||||
>
|
||||
<label htmlFor="file-upload" className="sr-only">
|
||||
{t('modals.uploadDoc.label')}
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
}
|
||||
|
||||
.list li:not(:first-child) {
|
||||
margin-top: 1em;
|
||||
margin-top: 0.5em;
|
||||
}
|
||||
|
||||
.list li > .list {
|
||||
margin-top: 1em;
|
||||
margin-top: 0.5em;
|
||||
}
|
||||
|
||||
@@ -86,10 +86,7 @@ const ConversationBubble = forwardRef<
|
||||
// const bubbleRef = useRef<HTMLDivElement | null>(null);
|
||||
const chunks = useSelector(selectChunks);
|
||||
const selectedDocs = useSelector(selectSelectedDocs);
|
||||
const [isLikeHovered, setIsLikeHovered] = useState(false);
|
||||
const [isEditClicked, setIsEditClicked] = useState(false);
|
||||
const [isDislikeHovered, setIsDislikeHovered] = useState(false);
|
||||
const [isQuestionHovered, setIsQuestionHovered] = useState(false);
|
||||
const [editInputBox, setEditInputBox] = useState<string>('');
|
||||
const messageRef = useRef<HTMLDivElement>(null);
|
||||
const [shouldShowToggle, setShouldShowToggle] = useState(false);
|
||||
@@ -115,11 +112,7 @@ const ConversationBubble = forwardRef<
|
||||
let bubble;
|
||||
if (type === 'QUESTION') {
|
||||
bubble = (
|
||||
<div
|
||||
onMouseEnter={() => setIsQuestionHovered(true)}
|
||||
onMouseLeave={() => setIsQuestionHovered(false)}
|
||||
className={className}
|
||||
>
|
||||
<div className={`group ${className}`}>
|
||||
<div className="flex flex-col items-end">
|
||||
{filesAttached && filesAttached.length > 0 && (
|
||||
<div className="mr-12 mb-4 flex flex-wrap justify-end gap-2">
|
||||
@@ -188,7 +181,7 @@ const ConversationBubble = forwardRef<
|
||||
setIsEditClicked(true);
|
||||
setEditInputBox(message ?? '');
|
||||
}}
|
||||
className={`hover:bg-light-silver mt-3 flex h-fit shrink-0 cursor-pointer items-center rounded-full p-2 pt-1.5 pl-1.5 dark:hover:bg-[#35363B] ${isQuestionHovered || isEditClicked ? 'visible' : 'invisible'}`}
|
||||
className={`hover:bg-light-silver mt-3 flex h-fit shrink-0 cursor-pointer items-center rounded-full p-2 pt-1.5 pl-1.5 dark:hover:bg-[#35363B] ${isEditClicked ? 'visible' : 'invisible group-hover:visible'}`}
|
||||
>
|
||||
<img src={Edit} alt="Edit" className="cursor-pointer" />
|
||||
</button>
|
||||
@@ -421,7 +414,7 @@ const ConversationBubble = forwardRef<
|
||||
<Fragment key={index}>
|
||||
{segment.type === 'text' ? (
|
||||
<ReactMarkdown
|
||||
className="fade-in leading-normal break-words whitespace-pre-wrap"
|
||||
className="fade-in flex flex-col gap-3 leading-normal break-words whitespace-pre-wrap"
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[rehypeKatex]}
|
||||
components={{
|
||||
@@ -568,13 +561,7 @@ const ConversationBubble = forwardRef<
|
||||
<>
|
||||
<div className="relative mr-2 flex items-center justify-center">
|
||||
<div>
|
||||
<div
|
||||
className={`flex items-center justify-center rounded-full p-2 ${
|
||||
isLikeHovered
|
||||
? 'dark:bg-purple-taupe bg-[#EEEEEE]'
|
||||
: 'bg-white-3000 dark:bg-transparent'
|
||||
}`}
|
||||
>
|
||||
<div className="bg-white-3000 dark:hover:bg-purple-taupe flex items-center justify-center rounded-full p-2 hover:bg-[#EEEEEE] dark:bg-transparent">
|
||||
<Like
|
||||
className={`${feedback === 'LIKE' ? 'fill-white-3000 stroke-purple-30 dark:fill-transparent' : 'stroke-gray-4000 fill-none'} cursor-pointer`}
|
||||
onClick={() => {
|
||||
@@ -584,8 +571,6 @@ const ConversationBubble = forwardRef<
|
||||
handleFeedback?.('LIKE');
|
||||
}
|
||||
}}
|
||||
onMouseEnter={() => setIsLikeHovered(true)}
|
||||
onMouseLeave={() => setIsLikeHovered(false)}
|
||||
></Like>
|
||||
</div>
|
||||
</div>
|
||||
@@ -593,13 +578,7 @@ const ConversationBubble = forwardRef<
|
||||
|
||||
<div className="relative mr-2 flex items-center justify-center">
|
||||
<div>
|
||||
<div
|
||||
className={`flex items-center justify-center rounded-full p-2 ${
|
||||
isDislikeHovered
|
||||
? 'dark:bg-purple-taupe bg-[#EEEEEE]'
|
||||
: 'bg-white-3000 dark:bg-transparent'
|
||||
}`}
|
||||
>
|
||||
<div className="bg-white-3000 dark:hover:bg-purple-taupe flex items-center justify-center rounded-full p-2 hover:bg-[#EEEEEE] dark:bg-transparent">
|
||||
<Dislike
|
||||
className={`${feedback === 'DISLIKE' ? 'fill-white-3000 stroke-red-2000 dark:fill-transparent' : 'stroke-gray-4000 fill-none'} cursor-pointer`}
|
||||
onClick={() => {
|
||||
@@ -609,8 +588,6 @@ const ConversationBubble = forwardRef<
|
||||
handleFeedback?.('DISLIKE');
|
||||
}
|
||||
}}
|
||||
onMouseEnter={() => setIsDislikeHovered(true)}
|
||||
onMouseLeave={() => setIsDislikeHovered(false)}
|
||||
></Dislike>
|
||||
</div>
|
||||
</div>
|
||||
@@ -658,7 +635,7 @@ function AllSources(sources: AllSourcesProps) {
|
||||
<p className="text-left text-xl">{`${sources.sources.length} ${t('conversation.sources.title')}`}</p>
|
||||
<div className="mx-1 mt-2 h-[0.8px] w-full rounded-full bg-[#C4C4C4]/40 lg:w-[95%]"></div>
|
||||
</div>
|
||||
<div className="mt-6 flex h-[90%] w-60 flex-col items-center gap-4 overflow-y-auto sm:w-80">
|
||||
<div className="scrollbar-thin mt-6 flex h-[90%] w-52 flex-col gap-4 overflow-y-auto pr-3 sm:w-64">
|
||||
{sources.sources.map((source, index) => {
|
||||
const isExternalSource = source.link && source.link !== 'local';
|
||||
return (
|
||||
|
||||
@@ -161,14 +161,16 @@ export const SharedConversation = () => {
|
||||
/>
|
||||
<div className="flex w-full max-w-[1200px] flex-col items-center gap-4 pb-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
|
||||
{apiKey ? (
|
||||
<MessageInput
|
||||
onSubmit={(text) => {
|
||||
handleQuestionSubmission(text);
|
||||
}}
|
||||
loading={status === 'loading'}
|
||||
showSourceButton={false}
|
||||
showToolButton={false}
|
||||
/>
|
||||
<div className="w-full px-2">
|
||||
<MessageInput
|
||||
onSubmit={(text) => {
|
||||
handleQuestionSubmission(text);
|
||||
}}
|
||||
loading={status === 'loading'}
|
||||
showSourceButton={false}
|
||||
showToolButton={false}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => navigate('/')}
|
||||
|
||||
@@ -56,7 +56,7 @@ export const fetchAnswer = createAsyncThunk<
|
||||
question,
|
||||
signal,
|
||||
state.preference.token,
|
||||
state.preference.selectedDocs!,
|
||||
state.preference.selectedDocs || [],
|
||||
currentConversationId,
|
||||
state.preference.prompt.id,
|
||||
state.preference.chunks,
|
||||
@@ -163,7 +163,7 @@ export const fetchAnswer = createAsyncThunk<
|
||||
question,
|
||||
signal,
|
||||
state.preference.token,
|
||||
state.preference.selectedDocs!,
|
||||
state.preference.selectedDocs || [],
|
||||
state.conversation.conversationId,
|
||||
state.preference.prompt.id,
|
||||
state.preference.chunks,
|
||||
|
||||
@@ -118,18 +118,34 @@ layer(base);
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
/* Light theme scrollbar */
|
||||
&::-webkit-scrollbar-thumb {
|
||||
background: rgba(156, 163, 175, 0.5);
|
||||
background: rgba(215, 215, 215, 1);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
&::-webkit-scrollbar-thumb:hover {
|
||||
background: rgba(156, 163, 175, 0.7);
|
||||
background: rgba(195, 195, 195, 1);
|
||||
}
|
||||
|
||||
/* For Firefox */
|
||||
/* Dark theme scrollbar */
|
||||
.dark &::-webkit-scrollbar-thumb {
|
||||
background: rgba(77, 78, 88, 1);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.dark &::-webkit-scrollbar-thumb:hover {
|
||||
background: rgba(97, 98, 108, 1);
|
||||
}
|
||||
|
||||
/* For Firefox - Light theme */
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: rgba(156, 163, 175, 0.5) transparent;
|
||||
scrollbar-color: rgba(215, 215, 215, 1) transparent;
|
||||
|
||||
/* For Firefox - Dark theme */
|
||||
.dark & {
|
||||
scrollbar-color: rgba(77, 78, 88, 1) transparent;
|
||||
}
|
||||
}
|
||||
|
||||
@utility table-default {
|
||||
|
||||
@@ -298,10 +298,6 @@
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Upload from Google Drive"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Upload from SharePoint"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
@@ -331,24 +327,6 @@
|
||||
"remove": "Remove",
|
||||
"folderAlt": "Folder",
|
||||
"fileAlt": "File"
|
||||
},
|
||||
"sharePoint": {
|
||||
"connect": "Connect to SharePoint",
|
||||
"sessionExpired": "Session expired. Please reconnect to SharePoint.",
|
||||
"sessionExpiredGeneric": "Session expired. Please reconnect your account.",
|
||||
"validateFailed": "Failed to validate session. Please reconnect.",
|
||||
"noSession": "No valid session found. Please reconnect to SharePoint.",
|
||||
"noAccessToken": "No access token available. Please reconnect to SharePoint.",
|
||||
"pickerFailed": "Failed to open file picker. Please try again.",
|
||||
"selectedFiles": "Selected Files",
|
||||
"selectFiles": "Select Files",
|
||||
"loading": "Loading...",
|
||||
"noFilesSelected": "No files or folders selected",
|
||||
"folders": "Folders",
|
||||
"files": "Files",
|
||||
"remove": "Remove",
|
||||
"folderAlt": "Folder",
|
||||
"fileAlt": "File"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -443,7 +421,8 @@
|
||||
"title": "Sources",
|
||||
"text": "Choose Your Sources",
|
||||
"link": "Source link",
|
||||
"view_more": "{{count}} more sources"
|
||||
"view_more": "{{count}} more sources",
|
||||
"noSourcesAvailable": "No sources available"
|
||||
},
|
||||
"attachments": {
|
||||
"attach": "Attach",
|
||||
|
||||
@@ -261,10 +261,6 @@
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Subir desde Google Drive"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Subir desde SharePoint"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
@@ -294,24 +290,6 @@
|
||||
"remove": "Eliminar",
|
||||
"folderAlt": "Carpeta",
|
||||
"fileAlt": "Archivo"
|
||||
},
|
||||
"sharePoint": {
|
||||
"connect": "Conectar a SharePoint",
|
||||
"sessionExpired": "Sesión expirada. Por favor, reconecte a SharePoint.",
|
||||
"sessionExpiredGeneric": "Sesión expirada. Por favor, reconecte su cuenta.",
|
||||
"validateFailed": "Error al validar la sesión. Por favor, reconecte.",
|
||||
"noSession": "No se encontró una sesión válida. Por favor, reconecte a SharePoint.",
|
||||
"noAccessToken": "No hay token de acceso disponible. Por favor, reconecte a SharePoint.",
|
||||
"pickerFailed": "Error al abrir el selector de archivos. Por favor, inténtelo de nuevo.",
|
||||
"selectedFiles": "Archivos Seleccionados",
|
||||
"selectFiles": "Seleccionar Archivos",
|
||||
"loading": "Cargando...",
|
||||
"noFilesSelected": "No hay archivos o carpetas seleccionados",
|
||||
"folders": "Carpetas",
|
||||
"files": "Archivos",
|
||||
"remove": "Eliminar",
|
||||
"folderAlt": "Carpeta",
|
||||
"fileAlt": "Archivo"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -406,7 +384,8 @@
|
||||
"title": "Fuentes",
|
||||
"link": "Enlace fuente",
|
||||
"view_more": "Ver {{count}} más fuentes",
|
||||
"text": "Elegir tus fuentes"
|
||||
"text": "Elegir tus fuentes",
|
||||
"noSourcesAvailable": "No hay fuentes disponibles"
|
||||
},
|
||||
"attachments": {
|
||||
"attach": "Adjuntar",
|
||||
|
||||
@@ -261,10 +261,6 @@
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Google Driveからアップロード"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "SharePointからアップロード"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
@@ -294,24 +290,6 @@
|
||||
"remove": "削除",
|
||||
"folderAlt": "フォルダ",
|
||||
"fileAlt": "ファイル"
|
||||
},
|
||||
"sharePoint": {
|
||||
"connect": "SharePointに接続",
|
||||
"sessionExpired": "セッションが期限切れです。SharePointに再接続してください。",
|
||||
"sessionExpiredGeneric": "セッションが期限切れです。アカウントに再接続してください。",
|
||||
"validateFailed": "セッションの検証に失敗しました。再接続してください。",
|
||||
"noSession": "有効なセッションが見つかりません。SharePointに再接続してください。",
|
||||
"noAccessToken": "アクセストークンが利用できません。SharePointに再接続してください。",
|
||||
"pickerFailed": "ファイルピッカーを開けませんでした。もう一度お試しください。",
|
||||
"selectedFiles": "選択されたファイル",
|
||||
"selectFiles": "ファイルを選択",
|
||||
"loading": "読み込み中...",
|
||||
"noFilesSelected": "ファイルまたはフォルダが選択されていません",
|
||||
"folders": "フォルダ",
|
||||
"files": "ファイル",
|
||||
"remove": "削除",
|
||||
"folderAlt": "フォルダ",
|
||||
"fileAlt": "ファイル"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -406,7 +384,8 @@
|
||||
"title": "ソース",
|
||||
"text": "ソーステキスト",
|
||||
"link": "ソースリンク",
|
||||
"view_more": "さらに{{count}}個のソース"
|
||||
"view_more": "さらに{{count}}個のソース",
|
||||
"noSourcesAvailable": "利用可能なソースがありません"
|
||||
},
|
||||
"attachments": {
|
||||
"attach": "添付",
|
||||
|
||||
@@ -261,10 +261,6 @@
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Загрузить из Google Drive"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Загрузить из SharePoint"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
@@ -294,24 +290,6 @@
|
||||
"remove": "Удалить",
|
||||
"folderAlt": "Папка",
|
||||
"fileAlt": "Файл"
|
||||
},
|
||||
"sharePoint": {
|
||||
"connect": "Подключиться к SharePoint",
|
||||
"sessionExpired": "Сеанс истек. Пожалуйста, переподключитесь к SharePoint.",
|
||||
"sessionExpiredGeneric": "Сеанс истек. Пожалуйста, переподключите свою учетную запись.",
|
||||
"validateFailed": "Не удалось проверить сеанс. Пожалуйста, переподключитесь.",
|
||||
"noSession": "Действительный сеанс не найден. Пожалуйста, переподключитесь к SharePoint.",
|
||||
"noAccessToken": "Токен доступа недоступен. Пожалуйста, переподключитесь к SharePoint.",
|
||||
"pickerFailed": "Не удалось открыть средство выбора файлов. Пожалуйста, попробуйте еще раз.",
|
||||
"selectedFiles": "Выбранные файлы",
|
||||
"selectFiles": "Выбрать файлы",
|
||||
"loading": "Загрузка...",
|
||||
"noFilesSelected": "Файлы или папки не выбраны",
|
||||
"folders": "Папки",
|
||||
"files": "Файлы",
|
||||
"remove": "Удалить",
|
||||
"folderAlt": "Папка",
|
||||
"fileAlt": "Файл"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -406,7 +384,8 @@
|
||||
"title": "Источники",
|
||||
"text": "Выберите ваши источники",
|
||||
"link": "Ссылка на источник",
|
||||
"view_more": "ещё {{count}} источников"
|
||||
"view_more": "ещё {{count}} источников",
|
||||
"noSourcesAvailable": "Нет доступных источников"
|
||||
},
|
||||
"attachments": {
|
||||
"attach": "Прикрепить",
|
||||
|
||||
@@ -261,10 +261,6 @@
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "從Google Drive上傳"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "從SharePoint上傳"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
@@ -294,24 +290,6 @@
|
||||
"remove": "移除",
|
||||
"folderAlt": "資料夾",
|
||||
"fileAlt": "檔案"
|
||||
},
|
||||
"sharePoint": {
|
||||
"connect": "連接到 SharePoint",
|
||||
"sessionExpired": "工作階段已過期。請重新連接到 SharePoint。",
|
||||
"sessionExpiredGeneric": "工作階段已過期。請重新連接您的帳戶。",
|
||||
"validateFailed": "驗證工作階段失敗。請重新連接。",
|
||||
"noSession": "未找到有效工作階段。請重新連接到 SharePoint。",
|
||||
"noAccessToken": "存取權杖不可用。請重新連接到 SharePoint。",
|
||||
"pickerFailed": "無法開啟檔案選擇器。請重試。",
|
||||
"selectedFiles": "已選擇的檔案",
|
||||
"selectFiles": "選擇檔案",
|
||||
"loading": "載入中...",
|
||||
"noFilesSelected": "未選擇檔案或資料夾",
|
||||
"folders": "資料夾",
|
||||
"files": "檔案",
|
||||
"remove": "移除",
|
||||
"folderAlt": "資料夾",
|
||||
"fileAlt": "檔案"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -406,7 +384,8 @@
|
||||
"title": "來源",
|
||||
"text": "來源文字",
|
||||
"link": "來源連結",
|
||||
"view_more": "查看更多 {{count}} 個來源"
|
||||
"view_more": "查看更多 {{count}} 個來源",
|
||||
"noSourcesAvailable": "沒有可用的來源"
|
||||
},
|
||||
"attachments": {
|
||||
"attach": "附件",
|
||||
|
||||
@@ -261,10 +261,6 @@
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "从Google Drive上传"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "从SharePoint上传"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
@@ -294,24 +290,6 @@
|
||||
"remove": "删除",
|
||||
"folderAlt": "文件夹",
|
||||
"fileAlt": "文件"
|
||||
},
|
||||
"sharePoint": {
|
||||
"connect": "连接到 SharePoint",
|
||||
"sessionExpired": "会话已过期。请重新连接到 SharePoint。",
|
||||
"sessionExpiredGeneric": "会话已过期。请重新连接您的账户。",
|
||||
"validateFailed": "验证会话失败。请重新连接。",
|
||||
"noSession": "未找到有效会话。请重新连接到 SharePoint。",
|
||||
"noAccessToken": "访问令牌不可用。请重新连接到 SharePoint。",
|
||||
"pickerFailed": "无法打开文件选择器。请重试。",
|
||||
"selectedFiles": "已选择的文件",
|
||||
"selectFiles": "选择文件",
|
||||
"loading": "加载中...",
|
||||
"noFilesSelected": "未选择文件或文件夹",
|
||||
"folders": "文件夹",
|
||||
"files": "文件",
|
||||
"remove": "删除",
|
||||
"folderAlt": "文件夹",
|
||||
"fileAlt": "文件"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -406,7 +384,8 @@
|
||||
"title": "来源",
|
||||
"text": "来源文本",
|
||||
"link": "来源链接",
|
||||
"view_more": "还有{{count}}个来源"
|
||||
"view_more": "还有{{count}}个来源",
|
||||
"noSourcesAvailable": "没有可用的来源"
|
||||
},
|
||||
"attachments": {
|
||||
"attach": "附件",
|
||||
|
||||
@@ -103,8 +103,8 @@ export const ShareConversationModal = ({
|
||||
};
|
||||
|
||||
return (
|
||||
<WrapperModal close={close}>
|
||||
<div className="flex max-h-[80vh] w-[600px] max-w-[80vw] flex-col gap-2 overflow-y-auto">
|
||||
<WrapperModal close={close} contentClassName="!overflow-visible">
|
||||
<div className="flex w-[600px] max-w-[80vw] flex-col gap-2">
|
||||
<h2 className="text-eerie-black dark:text-chinese-white text-xl font-medium">
|
||||
{t('modals.shareConv.label')}
|
||||
</h2>
|
||||
|
||||
@@ -32,7 +32,6 @@ import { FormField, IngestorConfig, IngestorType } from './types/ingestor';
|
||||
|
||||
import { FilePicker } from '../components/FilePicker';
|
||||
import GoogleDrivePicker from '../components/GoogleDrivePicker';
|
||||
import SharePointPicker from '../components/SharePointPicker';
|
||||
|
||||
import ChevronRight from '../assets/chevron-right.svg';
|
||||
|
||||
@@ -253,8 +252,6 @@ function Upload({
|
||||
token={token}
|
||||
/>
|
||||
);
|
||||
case 'share_point_picker':
|
||||
return <SharePointPicker key={field.name} token={token} />;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import UrlIcon from '../../assets/url.svg';
|
||||
import GithubIcon from '../../assets/github.svg';
|
||||
import RedditIcon from '../../assets/reddit.svg';
|
||||
import DriveIcon from '../../assets/drive.svg';
|
||||
import SharePoint from '../../assets/sharepoint.svg';
|
||||
|
||||
export type IngestorType =
|
||||
| 'crawler'
|
||||
@@ -12,8 +11,7 @@ export type IngestorType =
|
||||
| 'reddit'
|
||||
| 'url'
|
||||
| 'google_drive'
|
||||
| 'local_file'
|
||||
| 'share_point';
|
||||
| 'local_file';
|
||||
|
||||
export interface IngestorConfig {
|
||||
type: IngestorType | null;
|
||||
@@ -35,8 +33,7 @@ export type FieldType =
|
||||
| 'boolean'
|
||||
| 'local_file_picker'
|
||||
| 'remote_file_picker'
|
||||
| 'google_drive_picker'
|
||||
| 'share_point_picker';
|
||||
| 'google_drive_picker';
|
||||
|
||||
export interface FormField {
|
||||
name: string;
|
||||
@@ -150,24 +147,6 @@ export const IngestorFormSchemas: IngestorSchema[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
key: 'share_point',
|
||||
label: 'Share Point',
|
||||
icon: SharePoint,
|
||||
heading: 'Upload from Share Point',
|
||||
validate: () => {
|
||||
const sharePointClientId = import.meta.env.VITE_SHARE_POINT_CLIENT_ID;
|
||||
return !!sharePointClientId;
|
||||
},
|
||||
fields: [
|
||||
{
|
||||
name: 'files',
|
||||
label: 'Select Files from Share Point',
|
||||
type: 'share_point_picker',
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
export const IngestorDefaultConfigs: Record<
|
||||
@@ -196,14 +175,6 @@ export const IngestorDefaultConfigs: Record<
|
||||
},
|
||||
},
|
||||
local_file: { name: '', config: { files: [] } },
|
||||
share_point: {
|
||||
name: '',
|
||||
config: {
|
||||
file_ids: '',
|
||||
folder_ids: '',
|
||||
recursive: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export interface IngestorOption {
|
||||
|
||||
@@ -2,11 +2,11 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
import { RootState } from '../store';
|
||||
|
||||
export interface Attachment {
|
||||
id: string; // Unique identifier for the attachment (required for state management)
|
||||
fileName: string;
|
||||
progress: number;
|
||||
status: 'uploading' | 'processing' | 'completed' | 'failed';
|
||||
taskId: string;
|
||||
id?: string;
|
||||
taskId: string; // Server-assigned task ID (used for API calls)
|
||||
token_count?: number;
|
||||
}
|
||||
|
||||
@@ -47,12 +47,12 @@ export const uploadSlice = createSlice({
|
||||
updateAttachment: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
taskId: string;
|
||||
id: string;
|
||||
updates: Partial<Attachment>;
|
||||
}>,
|
||||
) => {
|
||||
const index = state.attachments.findIndex(
|
||||
(att) => att.taskId === action.payload.taskId,
|
||||
(att) => att.id === action.payload.id,
|
||||
);
|
||||
if (index !== -1) {
|
||||
state.attachments[index] = {
|
||||
@@ -63,7 +63,7 @@ export const uploadSlice = createSlice({
|
||||
},
|
||||
removeAttachment: (state, action: PayloadAction<string>) => {
|
||||
state.attachments = state.attachments.filter(
|
||||
(att) => att.taskId !== action.payload && att.id !== action.payload,
|
||||
(att) => att.id !== action.payload,
|
||||
);
|
||||
},
|
||||
clearAttachments: (state) => {
|
||||
|
||||
@@ -14,21 +14,3 @@ export const setSessionToken = (provider: string, token: string): void => {
|
||||
export const removeSessionToken = (provider: string): void => {
|
||||
localStorage.removeItem(`${provider}_session_token`);
|
||||
};
|
||||
|
||||
export const validateProviderSession = async (
|
||||
token: string | null,
|
||||
provider: string,
|
||||
) => {
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
return await fetch(`${apiHost}/api/connectors/validate-session`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: provider,
|
||||
session_token: getSessionToken(provider),
|
||||
}),
|
||||
});
|
||||
};
|
||||
|
||||
@@ -3,7 +3,6 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.core.settings import settings
|
||||
from tests.conftest import FakeMongoCollection
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -168,10 +167,13 @@ class TestBaseAgentTools:
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
|
||||
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
|
||||
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": True},
|
||||
}
|
||||
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
user_tools.insert_one(
|
||||
{"_id": "1", "user": "test_user", "name": "tool1", "status": True}
|
||||
)
|
||||
user_tools.insert_one(
|
||||
{"_id": "2", "user": "test_user", "name": "tool2", "status": True}
|
||||
)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
tools = agent._get_user_tools("test_user")
|
||||
@@ -187,10 +189,13 @@ class TestBaseAgentTools:
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
|
||||
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
|
||||
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": False},
|
||||
}
|
||||
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
user_tools.insert_one(
|
||||
{"_id": "1", "user": "test_user", "name": "tool1", "status": True}
|
||||
)
|
||||
user_tools.insert_one(
|
||||
{"_id": "2", "user": "test_user", "name": "tool2", "status": False}
|
||||
)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
tools = agent._get_user_tools("test_user")
|
||||
@@ -209,17 +214,16 @@ class TestBaseAgentTools:
|
||||
tool_id = str(ObjectId())
|
||||
tool_obj_id = ObjectId(tool_id)
|
||||
|
||||
fake_agent_collection = FakeMongoCollection()
|
||||
fake_agent_collection.docs["api_key_123"] = {
|
||||
"key": "api_key_123",
|
||||
"tools": [tool_id],
|
||||
}
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"key": "api_key_123",
|
||||
"tools": [tool_id],
|
||||
}
|
||||
)
|
||||
|
||||
fake_tools_collection = FakeMongoCollection()
|
||||
fake_tools_collection.docs[tool_id] = {"_id": tool_obj_id, "name": "api_tool"}
|
||||
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["agents"] = fake_agent_collection
|
||||
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] = fake_tools_collection
|
||||
tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tools_collection.insert_one({"_id": tool_obj_id, "name": "api_tool"})
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
tools = agent._get_tools("api_key_123")
|
||||
|
||||
0
tests/api/__init__.py
Normal file
0
tests/api/__init__.py
Normal file
0
tests/api/answer/__init__.py
Normal file
0
tests/api/answer/__init__.py
Normal file
0
tests/api/answer/routes/__init__.py
Normal file
0
tests/api/answer/routes/__init__.py
Normal file
552
tests/api/answer/routes/test_base.py
Normal file
552
tests/api/answer/routes/test_base.py
Normal file
@@ -0,0 +1,552 @@
|
||||
import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAnswerValidation:
|
||||
def test_validate_request_passes_with_required_fields(
|
||||
self, mock_mongo_db, flask_app
|
||||
):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
data = {"question": "What is Python?"}
|
||||
|
||||
result = resource.validate_request(data)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_validate_request_fails_without_question(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
data = {}
|
||||
|
||||
result = resource.validate_request(data)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 400
|
||||
assert "question" in result.json["message"].lower()
|
||||
|
||||
def test_validate_with_conversation_id_required(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
data = {"question": "Test"}
|
||||
|
||||
result = resource.validate_request(data, require_conversation_id=True)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 400
|
||||
assert "conversation_id" in result.json["message"].lower()
|
||||
|
||||
def test_validate_passes_with_all_required_fields(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
data = {"question": "Test", "conversation_id": str(ObjectId())}
|
||||
|
||||
result = resource.validate_request(data, require_conversation_id=True)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestUsageChecking:
|
||||
def test_returns_none_when_no_api_key(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_error_for_invalid_api_key(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "invalid_key_123"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 401
|
||||
assert result.json["success"] is False
|
||||
assert "invalid" in result.json["message"].lower()
|
||||
|
||||
def test_checks_token_limit_when_enabled(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": True,
|
||||
"token_limit": 1000,
|
||||
"limited_request_mode": False,
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_checks_request_limit_when_enabled(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": False,
|
||||
"limited_request_mode": True,
|
||||
"request_limit": 100,
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_uses_default_limits_when_not_specified(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": True,
|
||||
"limited_request_mode": True,
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_exceeds_token_limit(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
token_usage_collection = mock_mongo_db[settings.MONGO_DB_NAME][
|
||||
"token_usage"
|
||||
]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": True,
|
||||
"token_limit": 100,
|
||||
"limited_request_mode": False,
|
||||
}
|
||||
)
|
||||
|
||||
token_usage_collection.insert_one(
|
||||
{
|
||||
"_id": ObjectId(),
|
||||
"api_key": "test_key",
|
||||
"prompt_tokens": 60,
|
||||
"generated_tokens": 50,
|
||||
"timestamp": datetime.datetime.now(),
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 429
|
||||
assert result.json["success"] is False
|
||||
assert "usage limit" in result.json["message"].lower()
|
||||
|
||||
def test_exceeds_request_limit(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
token_usage_collection = mock_mongo_db[settings.MONGO_DB_NAME][
|
||||
"token_usage"
|
||||
]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": False,
|
||||
"limited_request_mode": True,
|
||||
"request_limit": 2,
|
||||
}
|
||||
)
|
||||
|
||||
now = datetime.datetime.now()
|
||||
for i in range(3):
|
||||
token_usage_collection.insert_one(
|
||||
{
|
||||
"_id": ObjectId(),
|
||||
"api_key": "test_key",
|
||||
"prompt_tokens": 10,
|
||||
"generated_tokens": 10,
|
||||
"timestamp": now,
|
||||
}
|
||||
)
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is not None
|
||||
assert result.status_code == 429
|
||||
assert result.json["success"] is False
|
||||
|
||||
def test_both_limits_disabled_returns_none(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_key",
|
||||
"limited_token_mode": False,
|
||||
"limited_request_mode": False,
|
||||
}
|
||||
)
|
||||
|
||||
resource = BaseAnswerResource()
|
||||
agent_config = {"user_api_key": "test_key"}
|
||||
|
||||
result = resource.check_usage(agent_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGPTModelRetrieval:
|
||||
def test_initializes_gpt_model(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
assert hasattr(resource, "gpt_model")
|
||||
assert resource.gpt_model is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConversationServiceIntegration:
|
||||
def test_initializes_conversation_service(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
assert hasattr(resource, "conversation_service")
|
||||
assert resource.conversation_service is not None
|
||||
|
||||
def test_has_access_to_user_logs_collection(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
assert hasattr(resource, "user_logs_collection")
|
||||
assert resource.user_logs_collection is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompleteStreamMethod:
|
||||
def test_streams_answer_chunks(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "Hello "},
|
||||
{"answer": "world!"},
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test question",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
|
||||
answer_chunks = [s for s in stream if '"type": "answer"' in s]
|
||||
assert len(answer_chunks) == 2
|
||||
assert '"answer": "Hello "' in answer_chunks[0]
|
||||
assert '"answer": "world!"' in answer_chunks[1]
|
||||
|
||||
def test_streams_sources(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "Test answer"},
|
||||
{"sources": [{"title": "doc1.txt", "text": "x" * 200}]},
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
|
||||
source_chunks = [s for s in stream if '"type": "source"' in s]
|
||||
assert len(source_chunks) == 1
|
||||
assert '"title": "doc1.txt"' in source_chunks[0]
|
||||
|
||||
def test_handles_error_during_streaming(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.side_effect = Exception("Test error")
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
stream = list(
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert any('"type": "error"' in s for s in stream)
|
||||
|
||||
def test_saves_conversation_when_enabled(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "Test answer"},
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
with patch.object(
|
||||
resource.conversation_service, "save_conversation"
|
||||
) as mock_save:
|
||||
mock_save.return_value = str(ObjectId())
|
||||
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="Test?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=True,
|
||||
)
|
||||
)
|
||||
|
||||
mock_save.assert_called_once()
|
||||
|
||||
def test_logs_to_user_logs_collection(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
user_logs = mock_mongo_db[settings.MONGO_DB_NAME]["user_logs"]
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{"answer": "Test answer"},
|
||||
]
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_params.return_value = {"retriever": "test"}
|
||||
|
||||
decoded_token = {"sub": "user123"}
|
||||
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="Test question?",
|
||||
agent=mock_agent,
|
||||
retriever=mock_retriever,
|
||||
conversation_id=None,
|
||||
user_api_key="test_key",
|
||||
decoded_token=decoded_token,
|
||||
should_save_conversation=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert user_logs.count_documents({}) == 1
|
||||
log_entry = user_logs.find_one({})
|
||||
assert log_entry["action"] == "stream_answer"
|
||||
assert log_entry["user"] == "user123"
|
||||
assert log_entry["api_key"] == "test_key"
|
||||
assert log_entry["question"] == "Test question?"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestProcessResponseStream:
|
||||
def test_processes_complete_stream(self, mock_mongo_db, flask_app):
|
||||
import json
|
||||
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
conv_id = str(ObjectId())
|
||||
stream = [
|
||||
f'data: {json.dumps({"type": "answer", "answer": "Hello "})}\n\n',
|
||||
f'data: {json.dumps({"type": "answer", "answer": "world"})}\n\n',
|
||||
f'data: {json.dumps({"type": "source", "source": [{"title": "doc1"}]})}\n\n',
|
||||
f'data: {json.dumps({"type": "id", "id": conv_id})}\n\n',
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert result[0] == conv_id
|
||||
assert result[1] == "Hello world"
|
||||
assert result[2] == [{"title": "doc1"}]
|
||||
assert result[5] is None
|
||||
|
||||
def test_handles_stream_error(self, mock_mongo_db, flask_app):
|
||||
import json
|
||||
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
stream = [
|
||||
f'data: {json.dumps({"type": "error", "error": "Test error"})}\n\n',
|
||||
]
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert len(result) == 5
|
||||
assert result[0] is None
|
||||
assert result[4] == "Test error"
|
||||
|
||||
def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
stream = [
|
||||
"data: invalid json\n\n",
|
||||
'data: {"type": "end"}\n\n',
|
||||
]
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestErrorStreamGenerate:
|
||||
def test_generates_error_stream(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
|
||||
error_stream = list(resource.error_stream_generate("Test error message"))
|
||||
|
||||
assert len(error_stream) == 1
|
||||
assert '"type": "error"' in error_stream[0]
|
||||
assert '"error": "Test error message"' in error_stream[0]
|
||||
0
tests/api/answer/services/__init__.py
Normal file
0
tests/api/answer/services/__init__.py
Normal file
242
tests/api/answer/services/test_conversation_service.py
Normal file
242
tests/api/answer/services/test_conversation_service.py
Normal file
@@ -0,0 +1,242 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConversationServiceGet:
|
||||
|
||||
def test_returns_none_when_no_conversation_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
result = service.get_conversation("", "user_123")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_no_user_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
result = service.get_conversation(str(ObjectId()), "")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_conversation_for_owner(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
conversation = {
|
||||
"_id": conv_id,
|
||||
"user": "user_123",
|
||||
"name": "Test Conv",
|
||||
"queries": [],
|
||||
}
|
||||
collection.insert_one(conversation)
|
||||
|
||||
result = service.get_conversation(str(conv_id), "user_123")
|
||||
|
||||
assert result is not None
|
||||
assert result["name"] == "Test Conv"
|
||||
assert result["_id"] == str(conv_id)
|
||||
|
||||
def test_returns_none_for_unauthorized_user(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one(
|
||||
{"_id": conv_id, "user": "owner_123", "name": "Private Conv"}
|
||||
)
|
||||
|
||||
result = service.get_conversation(str(conv_id), "hacker_456")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_converts_objectid_to_string(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one({"_id": conv_id, "user": "user_123", "name": "Test"})
|
||||
|
||||
result = service.get_conversation(str(conv_id), "user_123")
|
||||
|
||||
assert isinstance(result["_id"], str)
|
||||
assert result["_id"] == str(conv_id)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConversationServiceSave:
|
||||
|
||||
def test_raises_error_when_no_user_in_token(self, mock_mongo_db):
|
||||
"""Test validation: user ID required"""
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
mock_llm = Mock()
|
||||
|
||||
with pytest.raises(ValueError, match="User ID not found"):
|
||||
service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="Test?",
|
||||
response="Answer",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={}, # No 'sub' key
|
||||
)
|
||||
|
||||
def test_truncates_long_source_text(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from bson import ObjectId
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.gen.return_value = "Test Summary"
|
||||
|
||||
long_text = "x" * 2000
|
||||
sources = [{"text": long_text, "title": "Doc"}]
|
||||
|
||||
conv_id = service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="Question",
|
||||
response="Response",
|
||||
thought="",
|
||||
sources=sources,
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
saved_conv = collection.find_one({"_id": ObjectId(conv_id)})
|
||||
saved_source_text = saved_conv["queries"][0]["sources"][0]["text"]
|
||||
|
||||
assert len(saved_source_text) == 1000
|
||||
assert saved_source_text == "x" * 1000
|
||||
|
||||
def test_creates_new_conversation_with_summary(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from bson import ObjectId
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.gen.return_value = "Python Basics"
|
||||
|
||||
conv_id = service.save_conversation(
|
||||
conversation_id=None,
|
||||
question="What is Python?",
|
||||
response="Python is a programming language",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
assert conv_id is not None
|
||||
saved_conv = collection.find_one({"_id": ObjectId(conv_id)})
|
||||
assert saved_conv["name"] == "Python Basics"
|
||||
assert saved_conv["user"] == "user_123"
|
||||
assert len(saved_conv["queries"]) == 1
|
||||
assert saved_conv["queries"][0]["prompt"] == "What is Python?"
|
||||
|
||||
def test_appends_to_existing_conversation(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from bson import ObjectId
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
existing_conv_id = ObjectId()
|
||||
collection.insert_one(
|
||||
{
|
||||
"_id": existing_conv_id,
|
||||
"user": "user_123",
|
||||
"name": "Old Conv",
|
||||
"queries": [{"prompt": "Q1", "response": "A1"}],
|
||||
}
|
||||
)
|
||||
|
||||
mock_llm = Mock()
|
||||
|
||||
result = service.save_conversation(
|
||||
conversation_id=str(existing_conv_id),
|
||||
question="Q2",
|
||||
response="A2",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={"sub": "user_123"},
|
||||
)
|
||||
|
||||
assert result == str(existing_conv_id)
|
||||
|
||||
def test_prevents_unauthorized_conversation_update(self, mock_mongo_db):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
service = ConversationService()
|
||||
collection = mock_mongo_db[settings.MONGO_DB_NAME]["conversations"]
|
||||
|
||||
conv_id = ObjectId()
|
||||
collection.insert_one({"_id": conv_id, "user": "owner_123", "queries": []})
|
||||
|
||||
mock_llm = Mock()
|
||||
|
||||
with pytest.raises(ValueError, match="not found or unauthorized"):
|
||||
service.save_conversation(
|
||||
conversation_id=str(conv_id),
|
||||
question="Hack",
|
||||
response="Attempt",
|
||||
thought="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
llm=mock_llm,
|
||||
gpt_model="gpt-4",
|
||||
decoded_token={"sub": "hacker_456"},
|
||||
)
|
||||
252
tests/api/answer/services/test_stream_processor.py
Normal file
252
tests/api/answer/services/test_stream_processor.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetPromptFunction:
|
||||
|
||||
def test_loads_custom_prompt_from_database(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
from application.core.settings import settings
|
||||
|
||||
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
|
||||
prompt_id = ObjectId()
|
||||
|
||||
prompts_collection.insert_one(
|
||||
{
|
||||
"_id": prompt_id,
|
||||
"content": "Custom prompt from database",
|
||||
"user": "user_123",
|
||||
}
|
||||
)
|
||||
|
||||
result = get_prompt(str(prompt_id), prompts_collection)
|
||||
assert result == "Custom prompt from database"
|
||||
|
||||
def test_raises_error_for_invalid_prompt_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
from application.core.settings import settings
|
||||
|
||||
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid prompt ID"):
|
||||
get_prompt(str(ObjectId()), prompts_collection)
|
||||
|
||||
def test_raises_error_for_malformed_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
from application.core.settings import settings
|
||||
|
||||
prompts_collection = mock_mongo_db[settings.MONGO_DB_NAME]["prompts"]
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid prompt ID"):
|
||||
get_prompt("not_a_valid_id", prompts_collection)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorInitialization:
|
||||
|
||||
def test_initializes_with_decoded_token(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {
|
||||
"question": "What is Python?",
|
||||
"conversation_id": str(ObjectId()),
|
||||
}
|
||||
decoded_token = {"sub": "user_123", "email": "test@example.com"}
|
||||
|
||||
processor = StreamProcessor(request_data, decoded_token)
|
||||
|
||||
assert processor.data == request_data
|
||||
assert processor.decoded_token == decoded_token
|
||||
assert processor.initial_user_id == "user_123"
|
||||
assert processor.conversation_id == request_data["conversation_id"]
|
||||
|
||||
def test_initializes_without_token(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "Test question"}
|
||||
|
||||
processor = StreamProcessor(request_data, None)
|
||||
|
||||
assert processor.decoded_token is None
|
||||
assert processor.initial_user_id is None
|
||||
assert processor.data == request_data
|
||||
|
||||
def test_initializes_default_attributes(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
processor = StreamProcessor({"question": "Test"}, {"sub": "user_123"})
|
||||
|
||||
assert processor.source == {}
|
||||
assert processor.all_sources == []
|
||||
assert processor.attachments == []
|
||||
assert processor.history == []
|
||||
assert processor.agent_config == {}
|
||||
assert processor.retriever_config == {}
|
||||
assert processor.is_shared_usage is False
|
||||
assert processor.shared_token is None
|
||||
|
||||
def test_extracts_conversation_id_from_request(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
conv_id = str(ObjectId())
|
||||
request_data = {"question": "Test", "conversation_id": conv_id}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
assert processor.conversation_id == conv_id
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorHistoryLoading:
|
||||
|
||||
def test_loads_history_from_existing_conversation(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
conversations_collection = mock_mongo_db[settings.MONGO_DB_NAME][
|
||||
"conversations"
|
||||
]
|
||||
conv_id = ObjectId()
|
||||
|
||||
conversations_collection.insert_one(
|
||||
{
|
||||
"_id": conv_id,
|
||||
"user": "user_123",
|
||||
"name": "Test Conv",
|
||||
"queries": [
|
||||
{"prompt": "What is Python?", "response": "Python is a language"},
|
||||
{"prompt": "Tell me more", "response": "Python is versatile"},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"question": "How to install it?",
|
||||
"conversation_id": str(conv_id),
|
||||
}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
processor._load_conversation_history()
|
||||
|
||||
assert len(processor.history) == 2
|
||||
assert processor.history[0]["prompt"] == "What is Python?"
|
||||
assert processor.history[1]["response"] == "Python is versatile"
|
||||
|
||||
def test_raises_error_for_unauthorized_conversation(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
conversations_collection = mock_mongo_db[settings.MONGO_DB_NAME][
|
||||
"conversations"
|
||||
]
|
||||
conv_id = ObjectId()
|
||||
|
||||
conversations_collection.insert_one(
|
||||
{
|
||||
"_id": conv_id,
|
||||
"user": "owner_123",
|
||||
"name": "Private Conv",
|
||||
"queries": [],
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {"question": "Hack attempt", "conversation_id": str(conv_id)}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "hacker_456"})
|
||||
|
||||
with pytest.raises(ValueError, match="Conversation not found or unauthorized"):
|
||||
processor._load_conversation_history()
|
||||
|
||||
def test_uses_request_history_when_no_conversation_id(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {
|
||||
"question": "What is Python?",
|
||||
"history": [{"prompt": "Hello", "response": "Hi there!"}],
|
||||
}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
assert processor.conversation_id is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorAgentConfiguration:
|
||||
|
||||
def test_configures_agent_from_valid_api_key(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "test_api_key_123",
|
||||
"endpoint": "openai",
|
||||
"model": "gpt-4",
|
||||
"prompt_id": "default",
|
||||
"user": "user_123",
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {"question": "Test", "api_key": "test_api_key_123"}
|
||||
|
||||
processor = StreamProcessor(request_data, None)
|
||||
|
||||
try:
|
||||
processor._configure_agent()
|
||||
assert processor.agent_config is not None
|
||||
except Exception as e:
|
||||
assert "Invalid API Key" in str(e)
|
||||
|
||||
def test_uses_default_config_without_api_key(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "Test"}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
processor._configure_agent()
|
||||
|
||||
assert isinstance(processor.agent_config, dict)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorAttachments:
|
||||
|
||||
def test_processes_attachments_from_request(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
attachments_collection = mock_mongo_db[settings.MONGO_DB_NAME]["attachments"]
|
||||
att_id = ObjectId()
|
||||
|
||||
attachments_collection.insert_one(
|
||||
{
|
||||
"_id": att_id,
|
||||
"filename": "document.pdf",
|
||||
"content": "Document content",
|
||||
"user": "user_123",
|
||||
}
|
||||
)
|
||||
|
||||
request_data = {"question": "Analyze this", "attachments": [str(att_id)]}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
assert processor.data.get("attachments") == [str(att_id)]
|
||||
|
||||
def test_handles_empty_attachments(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
request_data = {"question": "Simple question"}
|
||||
|
||||
processor = StreamProcessor(request_data, {"sub": "user_123"})
|
||||
|
||||
assert processor.attachments == []
|
||||
assert (
|
||||
"attachments" not in processor.data
|
||||
or processor.data.get("attachments") is None
|
||||
)
|
||||
89
tests/api/conftest.py
Normal file
89
tests/api/conftest.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""API-specific test fixtures."""
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers():
|
||||
return {"Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_token(monkeypatch, decoded_token):
|
||||
def mock_decorator(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = decoded_token
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
monkeypatch.setattr("application.auth.api_key_required", lambda: mock_decorator)
|
||||
return decoded_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation():
|
||||
return {
|
||||
"_id": ObjectId(),
|
||||
"user": "test_user",
|
||||
"name": "Test Conversation",
|
||||
"queries": [
|
||||
{
|
||||
"prompt": "What is Python?",
|
||||
"response": "Python is a programming language",
|
||||
}
|
||||
],
|
||||
"date": "2025-01-01T00:00:00",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_prompt():
|
||||
return {
|
||||
"_id": ObjectId(),
|
||||
"user": "test_user",
|
||||
"name": "Helpful Assistant",
|
||||
"content": "You are a helpful assistant that provides clear and concise answers.",
|
||||
"type": "custom",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent():
|
||||
return {
|
||||
"_id": ObjectId(),
|
||||
"user": "test_user",
|
||||
"name": "Test Agent",
|
||||
"type": "classic",
|
||||
"endpoint": "openai",
|
||||
"model": "gpt-4",
|
||||
"prompt_id": "default",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_answer_request():
|
||||
return {
|
||||
"question": "What is Python?",
|
||||
"history": [],
|
||||
"conversation_id": None,
|
||||
"prompt_id": "default",
|
||||
"chunks": 2,
|
||||
"token_limit": 1000,
|
||||
"retriever": "classic_rag",
|
||||
"active_docs": "local/test/",
|
||||
"isNoneDoc": False,
|
||||
"save_conversation": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app():
|
||||
from flask import Flask
|
||||
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
311
tests/api/user/test_base.py
Normal file
311
tests/api/user/test_base.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import datetime
|
||||
import io
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTimeRangeGenerators:
|
||||
|
||||
def test_generate_minute_range(self):
|
||||
from application.api.user.base import generate_minute_range
|
||||
|
||||
start = datetime.datetime(2024, 1, 1, 10, 0, 0)
|
||||
end = datetime.datetime(2024, 1, 1, 10, 5, 0)
|
||||
|
||||
result = generate_minute_range(start, end)
|
||||
|
||||
assert len(result) == 6
|
||||
assert "2024-01-01 10:00:00" in result
|
||||
assert "2024-01-01 10:05:00" in result
|
||||
assert all(val == 0 for val in result.values())
|
||||
|
||||
def test_generate_hourly_range(self):
|
||||
from application.api.user.base import generate_hourly_range
|
||||
|
||||
start = datetime.datetime(2024, 1, 1, 10, 0, 0)
|
||||
end = datetime.datetime(2024, 1, 1, 15, 0, 0)
|
||||
|
||||
result = generate_hourly_range(start, end)
|
||||
|
||||
assert len(result) == 6
|
||||
assert "2024-01-01 10:00" in result
|
||||
assert "2024-01-01 15:00" in result
|
||||
assert all(val == 0 for val in result.values())
|
||||
|
||||
def test_generate_date_range(self):
|
||||
from application.api.user.base import generate_date_range
|
||||
|
||||
start = datetime.date(2024, 1, 1)
|
||||
end = datetime.date(2024, 1, 5)
|
||||
|
||||
result = generate_date_range(start, end)
|
||||
|
||||
assert len(result) == 5
|
||||
assert "2024-01-01" in result
|
||||
assert "2024-01-05" in result
|
||||
assert all(val == 0 for val in result.values())
|
||||
|
||||
def test_single_minute_range(self):
|
||||
from application.api.user.base import generate_minute_range
|
||||
|
||||
time = datetime.datetime(2024, 1, 1, 10, 30, 0)
|
||||
result = generate_minute_range(time, time)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "2024-01-01 10:30:00" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEnsureUserDoc:
|
||||
|
||||
def test_creates_new_user_with_defaults(self, mock_mongo_db):
|
||||
from application.api.user.base import ensure_user_doc
|
||||
|
||||
user_id = "test_user_123"
|
||||
|
||||
result = ensure_user_doc(user_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["user_id"] == user_id
|
||||
assert "agent_preferences" in result
|
||||
assert result["agent_preferences"]["pinned"] == []
|
||||
assert result["agent_preferences"]["shared_with_me"] == []
|
||||
|
||||
def test_returns_existing_user(self, mock_mongo_db):
|
||||
from application.api.user.base import ensure_user_doc
|
||||
from application.core.settings import settings
|
||||
|
||||
users_collection = mock_mongo_db[settings.MONGO_DB_NAME]["users"]
|
||||
user_id = "existing_user"
|
||||
|
||||
existing_doc = {
|
||||
"user_id": user_id,
|
||||
"agent_preferences": {"pinned": ["agent1"], "shared_with_me": ["agent2"]},
|
||||
}
|
||||
users_collection.insert_one(existing_doc)
|
||||
|
||||
result = ensure_user_doc(user_id)
|
||||
|
||||
assert result["user_id"] == user_id
|
||||
assert result["agent_preferences"]["pinned"] == ["agent1"]
|
||||
assert result["agent_preferences"]["shared_with_me"] == ["agent2"]
|
||||
|
||||
def test_adds_missing_preferences_fields(self, mock_mongo_db):
|
||||
from application.api.user.base import ensure_user_doc
|
||||
from application.core.settings import settings
|
||||
|
||||
users_collection = mock_mongo_db[settings.MONGO_DB_NAME]["users"]
|
||||
user_id = "incomplete_user"
|
||||
|
||||
users_collection.insert_one(
|
||||
{"user_id": user_id, "agent_preferences": {"pinned": ["agent1"]}}
|
||||
)
|
||||
|
||||
result = ensure_user_doc(user_id)
|
||||
|
||||
assert "shared_with_me" in result["agent_preferences"]
|
||||
assert result["agent_preferences"]["shared_with_me"] == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestResolveToolDetails:
|
||||
|
||||
def test_resolves_tool_ids_to_details(self, mock_mongo_db):
|
||||
from application.api.user.base import resolve_tool_details
|
||||
from application.core.settings import settings
|
||||
|
||||
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tool_id1 = ObjectId()
|
||||
tool_id2 = ObjectId()
|
||||
|
||||
user_tools.insert_one(
|
||||
{"_id": tool_id1, "name": "calculator", "displayName": "Calculator Tool"}
|
||||
)
|
||||
user_tools.insert_one(
|
||||
{"_id": tool_id2, "name": "weather", "displayName": "Weather API"}
|
||||
)
|
||||
|
||||
result = resolve_tool_details([str(tool_id1), str(tool_id2)])
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == str(tool_id1)
|
||||
assert result[0]["name"] == "calculator"
|
||||
assert result[0]["display_name"] == "Calculator Tool"
|
||||
assert result[1]["name"] == "weather"
|
||||
|
||||
def test_handles_missing_display_name(self, mock_mongo_db):
|
||||
from application.api.user.base import resolve_tool_details
|
||||
from application.core.settings import settings
|
||||
|
||||
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
|
||||
tool_id = ObjectId()
|
||||
|
||||
user_tools.insert_one({"_id": tool_id, "name": "test_tool"})
|
||||
|
||||
result = resolve_tool_details([str(tool_id)])
|
||||
|
||||
assert result[0]["display_name"] == "test_tool"
|
||||
|
||||
def test_empty_tool_ids_list(self, mock_mongo_db):
|
||||
from application.api.user.base import resolve_tool_details
|
||||
|
||||
result = resolve_tool_details([])
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetVectorStore:
|
||||
|
||||
@patch("application.api.user.base.VectorCreator.create_vectorstore")
|
||||
def test_creates_vector_store(self, mock_create, mock_mongo_db):
|
||||
from application.api.user.base import get_vector_store
|
||||
|
||||
mock_store = Mock()
|
||||
mock_create.return_value = mock_store
|
||||
source_id = "test_source_123"
|
||||
|
||||
result = get_vector_store(source_id)
|
||||
|
||||
assert result == mock_store
|
||||
mock_create.assert_called_once()
|
||||
args, kwargs = mock_create.call_args
|
||||
assert kwargs.get("source_id") == source_id
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandleImageUpload:
|
||||
|
||||
def test_returns_existing_url_when_no_file(self, flask_app):
|
||||
from application.api.user.base import handle_image_upload
|
||||
|
||||
with flask_app.test_request_context():
|
||||
mock_request = Mock()
|
||||
mock_request.files = {}
|
||||
mock_storage = Mock()
|
||||
existing_url = "existing/path/image.jpg"
|
||||
|
||||
url, error = handle_image_upload(
|
||||
mock_request, existing_url, "user123", mock_storage
|
||||
)
|
||||
|
||||
assert url == existing_url
|
||||
assert error is None
|
||||
|
||||
def test_uploads_new_image(self, flask_app):
|
||||
from application.api.user.base import handle_image_upload
|
||||
|
||||
with flask_app.test_request_context():
|
||||
mock_file = FileStorage(
|
||||
stream=io.BytesIO(b"fake image data"), filename="test_image.png"
|
||||
)
|
||||
mock_request = Mock()
|
||||
mock_request.files = {"image": mock_file}
|
||||
mock_storage = Mock()
|
||||
mock_storage.save_file.return_value = {"success": True}
|
||||
|
||||
url, error = handle_image_upload(
|
||||
mock_request, "old_url", "user123", mock_storage
|
||||
)
|
||||
|
||||
assert error is None
|
||||
assert url is not None
|
||||
assert "test_image.png" in url
|
||||
assert "user123" in url
|
||||
mock_storage.save_file.assert_called_once()
|
||||
|
||||
def test_ignores_empty_filename(self, flask_app):
|
||||
from application.api.user.base import handle_image_upload
|
||||
|
||||
with flask_app.test_request_context():
|
||||
mock_file = Mock()
|
||||
mock_file.filename = ""
|
||||
mock_request = Mock()
|
||||
mock_request.files = {"image": mock_file}
|
||||
mock_storage = Mock()
|
||||
existing_url = "existing.jpg"
|
||||
|
||||
url, error = handle_image_upload(
|
||||
mock_request, existing_url, "user123", mock_storage
|
||||
)
|
||||
|
||||
assert url == existing_url
|
||||
assert error is None
|
||||
mock_storage.save_file.assert_not_called()
|
||||
|
||||
def test_handles_upload_error(self, flask_app):
|
||||
from application.api.user.base import handle_image_upload
|
||||
|
||||
with flask_app.app_context():
|
||||
mock_file = FileStorage(stream=io.BytesIO(b"data"), filename="test.png")
|
||||
mock_request = Mock()
|
||||
mock_request.files = {"image": mock_file}
|
||||
mock_storage = Mock()
|
||||
mock_storage.save_file.side_effect = Exception("Storage error")
|
||||
|
||||
url, error = handle_image_upload(
|
||||
mock_request, "old.jpg", "user123", mock_storage
|
||||
)
|
||||
|
||||
assert url is None
|
||||
assert error is not None
|
||||
assert error.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRequireAgentDecorator:
|
||||
|
||||
def test_validates_webhook_token(self, mock_mongo_db, flask_app):
|
||||
from application.api.user.base import require_agent
|
||||
from application.core.settings import settings
|
||||
|
||||
with flask_app.app_context():
|
||||
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
|
||||
agent_id = ObjectId()
|
||||
webhook_token = "valid_webhook_token_123"
|
||||
|
||||
agents_collection.insert_one(
|
||||
{"_id": agent_id, "incoming_webhook_token": webhook_token}
|
||||
)
|
||||
|
||||
@require_agent
|
||||
def test_func(webhook_token=None, agent=None, agent_id_str=None):
|
||||
return {"agent_id": agent_id_str}
|
||||
|
||||
result = test_func(webhook_token=webhook_token)
|
||||
|
||||
assert result["agent_id"] == str(agent_id)
|
||||
|
||||
def test_returns_400_for_missing_token(self, mock_mongo_db, flask_app):
|
||||
from application.api.user.base import require_agent
|
||||
|
||||
with flask_app.app_context():
|
||||
|
||||
@require_agent
|
||||
def test_func(webhook_token=None, agent=None, agent_id_str=None):
|
||||
return {"success": True}
|
||||
|
||||
result = test_func()
|
||||
|
||||
assert result.status_code == 400
|
||||
assert result.json["success"] is False
|
||||
assert "missing" in result.json["message"].lower()
|
||||
|
||||
def test_returns_404_for_invalid_token(self, mock_mongo_db, flask_app):
|
||||
from application.api.user.base import require_agent
|
||||
|
||||
with flask_app.app_context():
|
||||
|
||||
@require_agent
|
||||
def test_func(webhook_token=None, agent=None, agent_id_str=None):
|
||||
return {"success": True}
|
||||
|
||||
result = test_func(webhook_token="invalid_token_999")
|
||||
|
||||
assert result.status_code == 404
|
||||
assert result.json["success"] is False
|
||||
assert "not found" in result.json["message"].lower()
|
||||
@@ -1,7 +1,15 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import mongomock
|
||||
|
||||
import pytest
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
def get_settings():
|
||||
"""Lazy load settings to avoid import-time errors."""
|
||||
from application.core.settings import settings
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -35,18 +43,51 @@ def mock_retriever():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mongo_db(monkeypatch):
|
||||
fake_collection = FakeMongoCollection()
|
||||
fake_db = {
|
||||
"agents": fake_collection,
|
||||
"user_tools": fake_collection,
|
||||
"memories": fake_collection,
|
||||
}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
"""Mock MongoDB using mongomock - industry standard MongoDB mocking library."""
|
||||
settings = get_settings()
|
||||
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_db = mock_client[settings.MONGO_DB_NAME]
|
||||
|
||||
def get_mock_client():
|
||||
return {settings.MONGO_DB_NAME: mock_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", get_mock_client)
|
||||
|
||||
monkeypatch.setattr("application.api.user.base.users_collection", mock_db["users"])
|
||||
monkeypatch.setattr(
|
||||
"application.core.mongo_db.MongoDB.get_client", lambda: fake_client
|
||||
"application.api.user.base.user_tools_collection", mock_db["user_tools"]
|
||||
)
|
||||
return fake_client
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.agents_collection", mock_db["agents"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.conversations_collection", mock_db["conversations"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.sources_collection", mock_db["sources"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.prompts_collection", mock_db["prompts"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.feedback_collection", mock_db["feedback"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.token_usage_collection", mock_db["token_usage"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.attachments_collection", mock_db["attachments"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.user_logs_collection", mock_db["user_logs"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.base.shared_conversations_collections",
|
||||
mock_db["shared_conversations"],
|
||||
)
|
||||
|
||||
return get_mock_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -87,53 +128,6 @@ def log_context():
|
||||
return context
|
||||
|
||||
|
||||
class FakeMongoCollection:
|
||||
def __init__(self):
|
||||
self.docs = {}
|
||||
|
||||
def find_one(self, query, projection=None):
|
||||
if "key" in query:
|
||||
return self.docs.get(query["key"])
|
||||
if "_id" in query:
|
||||
return self.docs.get(str(query["_id"]))
|
||||
if "user" in query:
|
||||
for doc in self.docs.values():
|
||||
if doc.get("user") == query["user"]:
|
||||
return doc
|
||||
return None
|
||||
|
||||
def find(self, query, projection=None):
|
||||
results = []
|
||||
if "_id" in query and "$in" in query["_id"]:
|
||||
for doc_id in query["_id"]["$in"]:
|
||||
doc = self.docs.get(str(doc_id))
|
||||
if doc:
|
||||
results.append(doc)
|
||||
elif "user" in query:
|
||||
for doc in self.docs.values():
|
||||
if doc.get("user") == query["user"]:
|
||||
if "status" in query:
|
||||
if doc.get("status") == query["status"]:
|
||||
results.append(doc)
|
||||
else:
|
||||
results.append(doc)
|
||||
return results
|
||||
|
||||
def insert_one(self, doc):
|
||||
doc_id = doc.get("_id", len(self.docs))
|
||||
self.docs[str(doc_id)] = doc
|
||||
return Mock(inserted_id=doc_id)
|
||||
|
||||
def update_one(self, query, update, upsert=False):
|
||||
return Mock(modified_count=1)
|
||||
|
||||
def delete_one(self, query):
|
||||
return Mock(deleted_count=1)
|
||||
|
||||
def delete_many(self, query):
|
||||
return Mock(deleted_count=0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_creator(mock_llm, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
from unittest.mock import Mock, patch
|
||||
from typing import Any, Dict, Generator
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||
|
||||
|
||||
class TestToolCall:
|
||||
"""Test ToolCall dataclass."""
|
||||
|
||||
def test_tool_call_creation(self):
|
||||
"""Test basic ToolCall creation."""
|
||||
tool_call = ToolCall(
|
||||
id="test_id",
|
||||
name="test_function",
|
||||
arguments={"arg1": "value1"},
|
||||
index=0
|
||||
id="test_id", name="test_function", arguments={"arg1": "value1"}, index=0
|
||||
)
|
||||
assert tool_call.id == "test_id"
|
||||
assert tool_call.name == "test_function"
|
||||
@@ -21,12 +15,11 @@ class TestToolCall:
|
||||
assert tool_call.index == 0
|
||||
|
||||
def test_tool_call_from_dict(self):
|
||||
"""Test ToolCall creation from dictionary."""
|
||||
data = {
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "New York"},
|
||||
"index": 1
|
||||
"index": 1,
|
||||
}
|
||||
tool_call = ToolCall.from_dict(data)
|
||||
assert tool_call.id == "call_123"
|
||||
@@ -35,7 +28,6 @@ class TestToolCall:
|
||||
assert tool_call.index == 1
|
||||
|
||||
def test_tool_call_from_dict_missing_fields(self):
|
||||
"""Test ToolCall creation with missing fields."""
|
||||
data = {"name": "test_func"}
|
||||
tool_call = ToolCall.from_dict(data)
|
||||
assert tool_call.id == ""
|
||||
@@ -45,16 +37,13 @@ class TestToolCall:
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
"""Test LLMResponse dataclass."""
|
||||
|
||||
def test_llm_response_creation(self):
|
||||
"""Test basic LLMResponse creation."""
|
||||
tool_calls = [ToolCall(id="1", name="func", arguments={})]
|
||||
response = LLMResponse(
|
||||
content="Hello",
|
||||
tool_calls=tool_calls,
|
||||
finish_reason="tool_calls",
|
||||
raw_response={"test": "data"}
|
||||
raw_response={"test": "data"},
|
||||
)
|
||||
assert response.content == "Hello"
|
||||
assert len(response.tool_calls) == 1
|
||||
@@ -62,55 +51,43 @@ class TestLLMResponse:
|
||||
assert response.raw_response == {"test": "data"}
|
||||
|
||||
def test_requires_tool_call_true(self):
|
||||
"""Test requires_tool_call property when tool calls are needed."""
|
||||
tool_calls = [ToolCall(id="1", name="func", arguments={})]
|
||||
response = LLMResponse(
|
||||
content="",
|
||||
tool_calls=tool_calls,
|
||||
finish_reason="tool_calls",
|
||||
raw_response={}
|
||||
raw_response={},
|
||||
)
|
||||
assert response.requires_tool_call is True
|
||||
|
||||
def test_requires_tool_call_false_no_tools(self):
|
||||
"""Test requires_tool_call property when no tool calls."""
|
||||
response = LLMResponse(
|
||||
content="Hello",
|
||||
tool_calls=[],
|
||||
finish_reason="stop",
|
||||
raw_response={}
|
||||
content="Hello", tool_calls=[], finish_reason="stop", raw_response={}
|
||||
)
|
||||
assert response.requires_tool_call is False
|
||||
|
||||
def test_requires_tool_call_false_wrong_finish_reason(self):
|
||||
"""Test requires_tool_call property with tools but wrong finish reason."""
|
||||
tool_calls = [ToolCall(id="1", name="func", arguments={})]
|
||||
response = LLMResponse(
|
||||
content="Hello",
|
||||
tool_calls=tool_calls,
|
||||
finish_reason="stop",
|
||||
raw_response={}
|
||||
raw_response={},
|
||||
)
|
||||
assert response.requires_tool_call is False
|
||||
|
||||
|
||||
class ConcreteHandler(LLMHandler):
|
||||
"""Concrete implementation for testing abstract base class."""
|
||||
|
||||
def parse_response(self, response: Any) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content=str(response),
|
||||
tool_calls=[],
|
||||
finish_reason="stop",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": str(result),
|
||||
"tool_call_id": tool_call.id
|
||||
}
|
||||
return {"role": "tool", "content": str(result), "tool_call_id": tool_call.id}
|
||||
|
||||
def _iterate_stream(self, response: Any) -> Generator:
|
||||
for chunk in response:
|
||||
@@ -118,114 +95,119 @@ class ConcreteHandler(LLMHandler):
|
||||
|
||||
|
||||
class TestLLMHandler:
|
||||
"""Test LLMHandler base class."""
|
||||
|
||||
def test_handler_initialization(self):
|
||||
"""Test handler initialization."""
|
||||
handler = ConcreteHandler()
|
||||
assert handler.llm_calls == []
|
||||
assert handler.tool_calls == []
|
||||
|
||||
def test_prepare_messages_no_attachments(self):
|
||||
"""Test prepare_messages with no attachments."""
|
||||
handler = ConcreteHandler()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
|
||||
mock_agent = Mock()
|
||||
result = handler.prepare_messages(mock_agent, messages, None)
|
||||
assert result == messages
|
||||
|
||||
def test_prepare_messages_with_supported_attachments(self):
|
||||
"""Test prepare_messages with supported attachments."""
|
||||
handler = ConcreteHandler()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
attachments = [{"mime_type": "image/png", "path": "/test.png"}]
|
||||
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
|
||||
mock_agent.llm.prepare_messages_with_attachments.return_value = messages
|
||||
|
||||
|
||||
result = handler.prepare_messages(mock_agent, messages, attachments)
|
||||
mock_agent.llm.prepare_messages_with_attachments.assert_called_once_with(
|
||||
messages, attachments
|
||||
)
|
||||
assert result == messages
|
||||
|
||||
@patch('application.llm.handlers.base.logger')
|
||||
@patch("application.llm.handlers.base.logger")
|
||||
def test_prepare_messages_with_unsupported_attachments(self, mock_logger):
|
||||
"""Test prepare_messages with unsupported attachments."""
|
||||
handler = ConcreteHandler()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
attachments = [{"mime_type": "text/plain", "path": "/test.txt"}]
|
||||
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
|
||||
|
||||
with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append:
|
||||
|
||||
with patch.object(
|
||||
handler, "_append_unsupported_attachments", return_value=messages
|
||||
) as mock_append:
|
||||
result = handler.prepare_messages(mock_agent, messages, attachments)
|
||||
mock_append.assert_called_once_with(messages, attachments)
|
||||
assert result == messages
|
||||
|
||||
def test_prepare_messages_mixed_attachments(self):
|
||||
"""Test prepare_messages with both supported and unsupported attachments."""
|
||||
handler = ConcreteHandler()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
attachments = [
|
||||
{"mime_type": "image/png", "path": "/test.png"},
|
||||
{"mime_type": "text/plain", "path": "/test.txt"}
|
||||
{"mime_type": "text/plain", "path": "/test.txt"},
|
||||
]
|
||||
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
|
||||
mock_agent.llm.prepare_messages_with_attachments.return_value = messages
|
||||
|
||||
with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append:
|
||||
|
||||
with patch.object(
|
||||
handler, "_append_unsupported_attachments", return_value=messages
|
||||
) as mock_append:
|
||||
result = handler.prepare_messages(mock_agent, messages, attachments)
|
||||
|
||||
# Should call both methods
|
||||
|
||||
mock_agent.llm.prepare_messages_with_attachments.assert_called_once()
|
||||
mock_append.assert_called_once()
|
||||
assert result == messages
|
||||
|
||||
def test_process_message_flow_non_streaming(self):
|
||||
"""Test process_message_flow for non-streaming."""
|
||||
handler = ConcreteHandler()
|
||||
mock_agent = Mock()
|
||||
initial_response = "test response"
|
||||
tools_dict = {}
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare:
|
||||
with patch.object(handler, 'handle_non_streaming', return_value="final") as mock_handle:
|
||||
|
||||
with patch.object(
|
||||
handler, "prepare_messages", return_value=messages
|
||||
) as mock_prepare:
|
||||
with patch.object(
|
||||
handler, "handle_non_streaming", return_value="final"
|
||||
) as mock_handle:
|
||||
result = handler.process_message_flow(
|
||||
mock_agent, initial_response, tools_dict, messages, stream=False
|
||||
)
|
||||
|
||||
|
||||
mock_prepare.assert_called_once_with(mock_agent, messages, None)
|
||||
mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages)
|
||||
mock_handle.assert_called_once_with(
|
||||
mock_agent, initial_response, tools_dict, messages
|
||||
)
|
||||
assert result == "final"
|
||||
|
||||
def test_process_message_flow_streaming(self):
|
||||
"""Test process_message_flow for streaming."""
|
||||
handler = ConcreteHandler()
|
||||
mock_agent = Mock()
|
||||
initial_response = "test response"
|
||||
tools_dict = {}
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
|
||||
def mock_generator():
|
||||
yield "chunk1"
|
||||
yield "chunk2"
|
||||
|
||||
with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare:
|
||||
with patch.object(handler, 'handle_streaming', return_value=mock_generator()) as mock_handle:
|
||||
|
||||
with patch.object(
|
||||
handler, "prepare_messages", return_value=messages
|
||||
) as mock_prepare:
|
||||
with patch.object(
|
||||
handler, "handle_streaming", return_value=mock_generator()
|
||||
) as mock_handle:
|
||||
result = handler.process_message_flow(
|
||||
mock_agent, initial_response, tools_dict, messages, stream=True
|
||||
)
|
||||
|
||||
|
||||
mock_prepare.assert_called_once_with(mock_agent, messages, None)
|
||||
mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages)
|
||||
|
||||
# Verify it's a generator
|
||||
mock_handle.assert_called_once_with(
|
||||
mock_agent, initial_response, tools_dict, messages
|
||||
)
|
||||
|
||||
chunks = list(result)
|
||||
assert chunks == ["chunk1", "chunk2"]
|
||||
@@ -1,3 +1,4 @@
|
||||
pytest>=8.0.0
|
||||
pytest-cov>=4.1.0
|
||||
coverage>=7.4.0
|
||||
mongomock>=4.3.0
|
||||
|
||||
@@ -16,12 +16,25 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
|
||||
class DummyClient:
|
||||
def __init__(self, api_key):
|
||||
created["api_key"] = api_key
|
||||
self.generate_calls = []
|
||||
self.convert_calls = []
|
||||
|
||||
def generate(self, *, text, model, voice):
|
||||
self.generate_calls.append({"text": text, "model": model, "voice": voice})
|
||||
yield b"chunk-one"
|
||||
yield b"chunk-two"
|
||||
class TextToSpeech:
|
||||
def __init__(self, outer):
|
||||
self._outer = outer
|
||||
|
||||
def convert(self, *, voice_id, model_id, text, output_format):
|
||||
self._outer.convert_calls.append(
|
||||
{
|
||||
"voice_id": voice_id,
|
||||
"model_id": model_id,
|
||||
"text": text,
|
||||
"output_format": output_format,
|
||||
}
|
||||
)
|
||||
yield b"chunk-one"
|
||||
yield b"chunk-two"
|
||||
|
||||
self.text_to_speech = TextToSpeech(self)
|
||||
|
||||
client_module = ModuleType("elevenlabs.client")
|
||||
client_module.ElevenLabs = DummyClient
|
||||
@@ -35,8 +48,13 @@ def test_elevenlabs_text_to_speech_monkeypatched_client(monkeypatch):
|
||||
audio_base64, lang = tts.text_to_speech("Speak")
|
||||
|
||||
assert created["api_key"] == "api-key"
|
||||
assert tts.client.generate_calls == [
|
||||
{"text": "Speak", "model": "eleven_multilingual_v2", "voice": "Brian"}
|
||||
assert tts.client.convert_calls == [
|
||||
{
|
||||
"voice_id": "nPczCjzI2devNBz1zQrb",
|
||||
"model_id": "eleven_multilingual_v2",
|
||||
"text": "Speak",
|
||||
"output_format": "mp3_44100_128",
|
||||
}
|
||||
]
|
||||
assert lang == "en"
|
||||
assert base64.b64decode(audio_base64.encode()) == b"chunk-onechunk-two"
|
||||
|
||||
61
tests/tts/test_tts_creator.py
Normal file
61
tests/tts/test_tts_creator.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from application.tts.tts_creator import TTSCreator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tts_creator():
|
||||
return TTSCreator()
|
||||
|
||||
|
||||
def test_create_google_tts(tts_creator):
|
||||
# Patch the provider registry so the factory calls our mock class
|
||||
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
|
||||
mock_google_tts = TTSCreator.tts_providers["google_tts"]
|
||||
instance = MagicMock()
|
||||
mock_google_tts.return_value = instance
|
||||
|
||||
result = tts_creator.create_tts("google_tts", "arg1", key="value")
|
||||
|
||||
mock_google_tts.assert_called_once_with("arg1", key="value")
|
||||
assert result == instance
|
||||
|
||||
|
||||
def test_create_elevenlabs_tts(tts_creator):
|
||||
# Patch the provider registry so the factory calls our mock class
|
||||
with patch.dict(TTSCreator.tts_providers, {"elevenlabs": MagicMock()}):
|
||||
mock_elevenlabs_tts = TTSCreator.tts_providers["elevenlabs"]
|
||||
instance = MagicMock()
|
||||
mock_elevenlabs_tts.return_value = instance
|
||||
|
||||
result = tts_creator.create_tts("elevenlabs", "voice", lang="en")
|
||||
|
||||
mock_elevenlabs_tts.assert_called_once_with("voice", lang="en")
|
||||
assert result == instance
|
||||
|
||||
|
||||
def test_invalid_tts_type(tts_creator):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
tts_creator.create_tts("unknown_tts")
|
||||
assert "No tts class found" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_tts_type_case_insensitivity(tts_creator):
|
||||
# Patch the provider registry to ensure case-insensitive lookup hits our mock
|
||||
with patch.dict(TTSCreator.tts_providers, {"google_tts": MagicMock()}):
|
||||
mock_google_tts = TTSCreator.tts_providers["google_tts"]
|
||||
instance = MagicMock()
|
||||
mock_google_tts.return_value = instance
|
||||
|
||||
result = tts_creator.create_tts("GoOgLe_TtS")
|
||||
|
||||
mock_google_tts.assert_called_once_with()
|
||||
assert result == instance
|
||||
|
||||
|
||||
def test_tts_providers_integrity(tts_creator):
|
||||
providers = tts_creator.tts_providers
|
||||
assert "google_tts" in providers
|
||||
assert "elevenlabs" in providers
|
||||
assert callable(providers["google_tts"])
|
||||
assert callable(providers["elevenlabs"])
|
||||
Reference in New Issue
Block a user