Merge pull request #1733 from ManishMadan2882/main

Attachments: Enhancements , strategy specific to certain LLMs
This commit is contained in:
Alex
2025-04-15 01:55:46 +03:00
committed by GitHub
16 changed files with 652 additions and 198 deletions

View File

@@ -44,10 +44,13 @@ class ClassicAgent(BaseAgent):
): ):
yield {"answer": resp.message.content} yield {"answer": resp.message.content}
else: else:
completion = self.llm.gen_stream( # completion = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=self.tools # model=self.gpt_model, messages=messages, tools=self.tools
) # )
for line in completion: # log type of resp
logger.info(f"Response type: {type(resp)}")
logger.info(f"Response: {resp}")
for line in resp:
if isinstance(line, str): if isinstance(line, str):
yield {"answer": line} yield {"answer": line}

View File

@@ -33,15 +33,53 @@ class LLMHandler(ABC):
logger.info(f"Preparing messages with {len(attachments)} attachments") logger.info(f"Preparing messages with {len(attachments)} attachments")
# Check if the LLM has its own custom attachment handling implementation supported_types = agent.llm.get_supported_attachment_types()
if hasattr(agent.llm, "prepare_messages_with_attachments") and agent.llm.__class__.__name__ != "BaseLLM":
logger.info(f"Using {agent.llm.__class__.__name__}'s own prepare_messages_with_attachments method")
return agent.llm.prepare_messages_with_attachments(messages, attachments)
# Otherwise, append attachment content to the system prompt supported_attachments = []
unsupported_attachments = []
for attachment in attachments:
mime_type = attachment.get('mime_type')
if not mime_type:
import mimetypes
file_path = attachment.get('path')
if file_path:
mime_type = mimetypes.guess_type(file_path)[0] or 'application/octet-stream'
else:
unsupported_attachments.append(attachment)
continue
if mime_type in supported_types:
supported_attachments.append(attachment)
else:
unsupported_attachments.append(attachment)
# Process supported attachments with the LLM's custom method
prepared_messages = messages
if supported_attachments:
logger.info(f"Processing {len(supported_attachments)} supported attachments with {agent.llm.__class__.__name__}'s method")
prepared_messages = agent.llm.prepare_messages_with_attachments(messages, supported_attachments)
# Process unsupported attachments with the default method
if unsupported_attachments:
logger.info(f"Processing {len(unsupported_attachments)} unsupported attachments with default method")
prepared_messages = self._append_attachment_content_to_system(prepared_messages, unsupported_attachments)
return prepared_messages
def _append_attachment_content_to_system(self, messages, attachments):
"""
Default method to append attachment content to the system prompt.
Args:
messages (list): List of message dictionaries.
attachments (list): List of attachment dictionaries with content.
Returns:
list: Messages with attachment context added to the system prompt.
"""
prepared_messages = messages.copy() prepared_messages = messages.copy()
# Build attachment content string
attachment_texts = [] attachment_texts = []
for attachment in attachments: for attachment in attachments:
logger.info(f"Adding attachment {attachment.get('id')} to context") logger.info(f"Adding attachment {attachment.get('id')} to context")
@@ -122,12 +160,13 @@ class OpenAILLMHandler(LLMHandler):
return resp return resp
else: else:
text_buffer = ""
while True: while True:
tool_calls = {} tool_calls = {}
for chunk in resp: for chunk in resp:
if isinstance(chunk, str) and len(chunk) > 0: if isinstance(chunk, str) and len(chunk) > 0:
return yield chunk
continue
elif hasattr(chunk, "delta"): elif hasattr(chunk, "delta"):
chunk_delta = chunk.delta chunk_delta = chunk.delta
@@ -206,12 +245,17 @@ class OpenAILLMHandler(LLMHandler):
} }
) )
tool_calls = {} tool_calls = {}
if hasattr(chunk_delta, "content") and chunk_delta.content:
# Add to buffer or yield immediately based on your preference
text_buffer += chunk_delta.content
yield text_buffer
text_buffer = ""
if ( if (
hasattr(chunk, "finish_reason") hasattr(chunk, "finish_reason")
and chunk.finish_reason == "stop" and chunk.finish_reason == "stop"
): ):
return return resp
elif isinstance(chunk, str) and len(chunk) == 0: elif isinstance(chunk, str) and len(chunk) == 0:
continue continue
@@ -298,6 +342,9 @@ class GoogleLLMHandler(LLMHandler):
"content": [function_response_part.to_json_dict()], "content": [function_response_part.to_json_dict()],
} }
) )
else:
tool_call_found = False
yield result
if not tool_call_found: if not tool_call_found:
return response return response

View File

@@ -835,12 +835,7 @@ def get_attachments_content(attachment_ids, user):
}) })
if attachment_doc: if attachment_doc:
attachments.append({ attachments.append(attachment_doc)
"id": str(attachment_doc["_id"]),
"content": attachment_doc["content"],
"token_count": attachment_doc.get("token_count", 0),
"path": attachment_doc.get("path", "")
})
except Exception as e: except Exception as e:
logger.error(f"Error retrieving attachment {attachment_id}: {e}") logger.error(f"Error retrieving attachment {attachment_id}: {e}")

View File

@@ -2506,23 +2506,26 @@ class StoreAttachment(Resource):
user = secure_filename(decoded_token.get("sub")) user = secure_filename(decoded_token.get("sub"))
try: try:
attachment_id = ObjectId()
original_filename = secure_filename(file.filename) original_filename = secure_filename(file.filename)
folder_name = original_filename
save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, "attachments",folder_name) save_dir = os.path.join(
current_dir,
settings.UPLOAD_FOLDER,
user,
"attachments",
str(attachment_id)
)
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
# Create directory structure: user/attachments/filename/
file_path = os.path.join(save_dir, original_filename) file_path = os.path.join(save_dir, original_filename)
# Handle filename conflicts
if os.path.exists(file_path):
name_parts = os.path.splitext(original_filename)
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
new_filename = f"{name_parts[0]}_{timestamp}{name_parts[1]}"
file_path = os.path.join(save_dir, new_filename)
original_filename = new_filename
file.save(file_path) file.save(file_path)
file_info = {"folder": folder_name, "filename": original_filename} file_info = {
"filename": original_filename,
"attachment_id": str(attachment_id)
}
current_app.logger.info(f"Saved file: {file_path}") current_app.logger.info(f"Saved file: {file_path}")
# Start async task to process single file # Start async task to process single file

View File

@@ -55,3 +55,12 @@ class BaseLLM(ABC):
def _supports_tools(self): def _supports_tools(self):
raise NotImplementedError("Subclass must implement _supports_tools method") raise NotImplementedError("Subclass must implement _supports_tools method")
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by this LLM for file uploads.
Returns:
list: List of supported MIME types
"""
return [] # Default: no attachments supported

View File

@@ -1,5 +1,9 @@
from google import genai from google import genai
from google.genai import types from google.genai import types
import os
import logging
import mimetypes
import json
from application.llm.base import BaseLLM from application.llm.base import BaseLLM
@@ -9,6 +13,138 @@ class GoogleLLM(BaseLLM):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.api_key = api_key self.api_key = api_key
self.user_api_key = user_api_key self.user_api_key = user_api_key
self.client = genai.Client(api_key=self.api_key)
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by Google Gemini for file uploads.
Returns:
list: List of supported MIME types
"""
return [
'application/pdf',
'image/png',
'image/jpeg',
'image/jpg',
'image/webp',
'image/gif'
]
def prepare_messages_with_attachments(self, messages, attachments=None):
"""
Process attachments using Google AI's file API for more efficient handling.
Args:
messages (list): List of message dictionaries.
attachments (list): List of attachment dictionaries with content and metadata.
Returns:
list: Messages formatted with file references for Google AI API.
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the user message to attach files to the last one
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
{"type": "text", "text": text_content}
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
files = []
for attachment in attachments:
mime_type = attachment.get('mime_type')
if not mime_type:
file_path = attachment.get('path')
if file_path:
mime_type = mimetypes.guess_type(file_path)[0] or 'application/octet-stream'
if mime_type in self.get_supported_attachment_types():
try:
file_uri = self._upload_file_to_google(attachment)
logging.info(f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}")
files.append({"file_uri": file_uri, "mime_type": mime_type})
except Exception as e:
logging.error(f"GoogleLLM: Error uploading file: {e}")
if 'content' in attachment:
prepared_messages[user_message_index]["content"].append({
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]"
})
if files:
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
prepared_messages[user_message_index]["content"].append({
"files": files
})
return prepared_messages
def _upload_file_to_google(self, attachment):
"""
Upload a file to Google AI and return the file URI.
Args:
attachment (dict): Attachment dictionary with path and metadata.
Returns:
str: Google AI file URI for the uploaded file.
"""
if 'google_file_uri' in attachment:
return attachment['google_file_uri']
file_path = attachment.get('path')
if not file_path:
raise ValueError("No file path provided in attachment")
if not os.path.isabs(file_path):
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
file_path = os.path.join(current_dir, "application", file_path)
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
mime_type = attachment.get('mime_type')
if not mime_type:
mime_type = mimetypes.guess_type(file_path)[0] or 'application/octet-stream'
try:
response = self.client.files.upload(file=file_path)
file_uri = response.uri
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
attachments_collection = db["attachments"]
if '_id' in attachment:
attachments_collection.update_one(
{"_id": attachment['_id']},
{"$set": {"google_file_uri": file_uri}}
)
return file_uri
except Exception as e:
logging.error(f"Error uploading file to Google AI: {e}")
raise
def _clean_messages_google(self, messages): def _clean_messages_google(self, messages):
cleaned_messages = [] cleaned_messages = []
@@ -26,7 +162,7 @@ class GoogleLLM(BaseLLM):
elif isinstance(content, list): elif isinstance(content, list):
for item in content: for item in content:
if "text" in item: if "text" in item:
parts.append(types.Part.from_text(item["text"])) parts.append(types.Part.from_text(text=item["text"]))
elif "function_call" in item: elif "function_call" in item:
parts.append( parts.append(
types.Part.from_function_call( types.Part.from_function_call(
@@ -41,6 +177,14 @@ class GoogleLLM(BaseLLM):
response=item["function_response"]["response"], response=item["function_response"]["response"],
) )
) )
elif "files" in item:
for file_data in item["files"]:
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"]
)
)
else: else:
raise ValueError( raise ValueError(
f"Unexpected content dictionary format:{item}" f"Unexpected content dictionary format:{item}"
@@ -146,11 +290,25 @@ class GoogleLLM(BaseLLM):
cleaned_tools = self._clean_tools_format(tools) cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools config.tools = cleaned_tools
# Check if we have both tools and file attachments
has_attachments = False
for message in messages:
for part in message.parts:
if hasattr(part, 'file_data') and part.file_data is not None:
has_attachments = True
break
if has_attachments:
break
logging.info(f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}")
response = client.models.generate_content_stream( response = client.models.generate_content_stream(
model=model, model=model,
contents=messages, contents=messages,
config=config, config=config,
) )
for chunk in response: for chunk in response:
if hasattr(chunk, "candidates") and chunk.candidates: if hasattr(chunk, "candidates") and chunk.candidates:
for candidate in chunk.candidates: for candidate in chunk.candidates:

View File

@@ -1,4 +1,8 @@
import json import json
import base64
import os
import mimetypes
import logging
from application.core.settings import settings from application.core.settings import settings
from application.llm.base import BaseLLM from application.llm.base import BaseLLM
@@ -65,6 +69,15 @@ class OpenAILLM(BaseLLM):
), ),
} }
) )
elif isinstance(item, dict):
content_parts = []
if "text" in item:
content_parts.append({"type": "text", "text": item["text"]})
elif "type" in item and item["type"] == "text" and "text" in item:
content_parts.append(item)
elif "type" in item and item["type"] == "file" and "file" in item:
content_parts.append(item)
cleaned_messages.append({"role": role, "content": content_parts})
else: else:
raise ValueError( raise ValueError(
f"Unexpected content dictionary format: {item}" f"Unexpected content dictionary format: {item}"
@@ -133,6 +146,183 @@ class OpenAILLM(BaseLLM):
def _supports_tools(self): def _supports_tools(self):
return True return True
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by OpenAI for file uploads.
Returns:
list: List of supported MIME types
"""
return [
'application/pdf',
'image/png',
'image/jpeg',
'image/jpg',
'image/webp',
'image/gif'
]
def prepare_messages_with_attachments(self, messages, attachments=None):
"""
Process attachments using OpenAI's file API for more efficient handling.
Args:
messages (list): List of message dictionaries.
attachments (list): List of attachment dictionaries with content and metadata.
Returns:
list: Messages formatted with file references for OpenAI API.
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the user message to attach file_id to the last one
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
{"type": "text", "text": text_content}
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
for attachment in attachments:
mime_type = attachment.get('mime_type')
if not mime_type:
file_path = attachment.get('path')
if file_path:
mime_type = mimetypes.guess_type(file_path)[0] or 'application/octet-stream'
if mime_type and mime_type.startswith('image/'):
try:
base64_image = self._get_base64_image(attachment)
prepared_messages[user_message_index]["content"].append({
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_image}"
}
})
except Exception as e:
logging.error(f"Error processing image attachment: {e}")
if 'content' in attachment:
prepared_messages[user_message_index]["content"].append({
"type": "text",
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]"
})
# Handle PDFs using the file API
elif mime_type == 'application/pdf':
try:
file_id = self._upload_file_to_openai(attachment)
prepared_messages[user_message_index]["content"].append({
"type": "file",
"file": {"file_id": file_id}
})
except Exception as e:
logging.error(f"Error uploading PDF to OpenAI: {e}")
if 'content' in attachment:
prepared_messages[user_message_index]["content"].append({
"type": "text",
"text": f"File content:\n\n{attachment['content']}"
})
return prepared_messages
def _get_base64_image(self, attachment):
"""
Convert an image file to base64 encoding.
Args:
attachment (dict): Attachment dictionary with path and metadata.
Returns:
str: Base64-encoded image data.
"""
file_path = attachment.get('path')
if not file_path:
raise ValueError("No file path provided in attachment")
if not os.path.isabs(file_path):
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
file_path = os.path.join(current_dir, "application", file_path)
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def _upload_file_to_openai(self, attachment): ##pdfs
"""
Upload a file to OpenAI and return the file_id.
Args:
attachment (dict): Attachment dictionary with path and metadata.
Expected keys:
- path: Path to the file
- id: Optional MongoDB ID for caching
Returns:
str: OpenAI file_id for the uploaded file.
"""
import os
import logging
if 'openai_file_id' in attachment:
return attachment['openai_file_id']
file_path = attachment.get('path')
if not file_path:
raise ValueError("No file path provided in attachment")
if not os.path.isabs(file_path):
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
file_path = os.path.join(current_dir,"application", file_path)
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
with open(file_path, 'rb') as file:
response = self.client.files.create(
file=file,
purpose="assistants"
)
file_id = response.id
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
attachments_collection = db["attachments"]
if '_id' in attachment:
attachments_collection.update_one(
{"_id": attachment['_id']},
{"$set": {"openai_file_id": file_id}}
)
return file_id
except Exception as e:
logging.error(f"Error uploading file to OpenAI: {e}")
raise
class AzureOpenAILLM(OpenAILLM): class AzureOpenAILLM(OpenAILLM):

View File

@@ -328,34 +328,30 @@ def attachment_worker(self, directory, file_info, user):
""" """
import datetime import datetime
import os import os
import mimetypes
from application.utils import num_tokens_from_string from application.utils import num_tokens_from_string
mongo = MongoDB.get_client() mongo = MongoDB.get_client()
db = mongo["docsgpt"] db = mongo["docsgpt"]
attachments_collection = db["attachments"] attachments_collection = db["attachments"]
job_name = file_info["folder"] filename = file_info["filename"]
logging.info(f"Processing attachment: {job_name}", extra={"user": user, "job": job_name}) attachment_id = file_info["attachment_id"]
logging.info(f"Processing attachment: {attachment_id}/{filename}", extra={"user": user})
self.update_state(state="PROGRESS", meta={"current": 10}) self.update_state(state="PROGRESS", meta={"current": 10})
folder_name = file_info["folder"]
filename = file_info["filename"]
file_path = os.path.join(directory, filename) file_path = os.path.join(directory, filename)
logging.info(f"Processing file: {file_path}", extra={"user": user, "job": job_name})
if not os.path.exists(file_path): if not os.path.exists(file_path):
logging.warning(f"File not found: {file_path}", extra={"user": user, "job": job_name}) logging.warning(f"File not found: {file_path}", extra={"user": user})
return {"error": "File not found"} raise FileNotFoundError(f"File not found: {file_path}")
try: try:
reader = SimpleDirectoryReader( reader = SimpleDirectoryReader(
input_files=[file_path] input_files=[file_path]
) )
documents = reader.load_data() documents = reader.load_data()
self.update_state(state="PROGRESS", meta={"current": 50}) self.update_state(state="PROGRESS", meta={"current": 50})
@@ -364,33 +360,37 @@ def attachment_worker(self, directory, file_info, user):
content = documents[0].text content = documents[0].text
token_count = num_tokens_from_string(content) token_count = num_tokens_from_string(content)
file_path_relative = f"{user}/attachments/{folder_name}/{filename}" file_path_relative = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{attachment_id}/{filename}"
attachment_id = attachments_collection.insert_one({ mime_type = mimetypes.guess_type(file_path)[0] or 'application/octet-stream'
doc_id = ObjectId(attachment_id)
attachments_collection.insert_one({
"_id": doc_id,
"user": user, "user": user,
"path": file_path_relative, "path": file_path_relative,
"content": content, "content": content,
"token_count": token_count, "token_count": token_count,
"mime_type": mime_type,
"date": datetime.datetime.now(), "date": datetime.datetime.now(),
}).inserted_id })
logging.info(f"Stored attachment with ID: {attachment_id}", logging.info(f"Stored attachment with ID: {attachment_id}",
extra={"user": user, "job": job_name}) extra={"user": user})
self.update_state(state="PROGRESS", meta={"current": 100}) self.update_state(state="PROGRESS", meta={"current": 100})
return { return {
"attachment_id": str(attachment_id),
"filename": filename, "filename": filename,
"folder": folder_name,
"path": file_path_relative, "path": file_path_relative,
"token_count": token_count "token_count": token_count,
"attachment_id": attachment_id,
"mime_type": mime_type
} }
else: else:
logging.warning("No content was extracted from the file", logging.warning("No content was extracted from the file",
extra={"user": user, "job": job_name}) extra={"user": user})
return {"error": "No content was extracted from the file"} raise ValueError("No content was extracted from the file")
except Exception as e: except Exception as e:
logging.error(f"Error processing file {filename}: {e}", logging.error(f"Error processing file {filename}: {e}", extra={"user": user}, exc_info=True)
extra={"user": user, "job": job_name}, exc_info=True) raise
return {"error": f"Error processing file: {str(e)}"}

View File

@@ -4,18 +4,27 @@ import { useDarkTheme } from '../hooks';
import { useSelector, useDispatch } from 'react-redux'; import { useSelector, useDispatch } from 'react-redux';
import userService from '../api/services/userService'; import userService from '../api/services/userService';
import endpoints from '../api/endpoints'; import endpoints from '../api/endpoints';
import { getOS, isTouchDevice } from '../utils/browserUtils';
import PaperPlane from '../assets/paper_plane.svg'; import PaperPlane from '../assets/paper_plane.svg';
import SourceIcon from '../assets/source.svg'; import SourceIcon from '../assets/source.svg';
import ToolIcon from '../assets/tool.svg'; import ToolIcon from '../assets/tool.svg';
import SpinnerDark from '../assets/spinner-dark.svg'; import SpinnerDark from '../assets/spinner-dark.svg';
import Spinner from '../assets/spinner.svg'; import Spinner from '../assets/spinner.svg';
import ExitIcon from '../assets/exit.svg';
import AlertIcon from '../assets/alert.svg';
import SourcesPopup from './SourcesPopup'; import SourcesPopup from './SourcesPopup';
import ToolsPopup from './ToolsPopup'; import ToolsPopup from './ToolsPopup';
import { selectSelectedDocs, selectToken } from '../preferences/preferenceSlice'; import { selectSelectedDocs, selectToken } from '../preferences/preferenceSlice';
import { ActiveState } from '../models/misc'; import { ActiveState } from '../models/misc';
import Upload from '../upload/Upload'; import Upload from '../upload/Upload';
import ClipIcon from '../assets/clip.svg'; import ClipIcon from '../assets/clip.svg';
import { setAttachments } from '../conversation/conversationSlice'; import {
addAttachment,
updateAttachment,
removeAttachment,
selectAttachments
} from '../conversation/conversationSlice';
interface MessageInputProps { interface MessageInputProps {
value: string; value: string;
@@ -47,13 +56,33 @@ export default function MessageInput({
const [isSourcesPopupOpen, setIsSourcesPopupOpen] = useState(false); const [isSourcesPopupOpen, setIsSourcesPopupOpen] = useState(false);
const [isToolsPopupOpen, setIsToolsPopupOpen] = useState(false); const [isToolsPopupOpen, setIsToolsPopupOpen] = useState(false);
const [uploadModalState, setUploadModalState] = useState<ActiveState>('INACTIVE'); const [uploadModalState, setUploadModalState] = useState<ActiveState>('INACTIVE');
const [uploads, setUploads] = useState<UploadState[]>([]);
const selectedDocs = useSelector(selectSelectedDocs); const selectedDocs = useSelector(selectSelectedDocs);
const token = useSelector(selectToken); const token = useSelector(selectToken);
const attachments = useSelector(selectAttachments);
const dispatch = useDispatch(); const dispatch = useDispatch();
const browserOS = getOS();
const isTouch = isTouchDevice();
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (
((browserOS === 'win' || browserOS === 'linux') && event.ctrlKey && event.key === 'k') ||
(browserOS === 'mac' && event.metaKey && event.key === 'k')
) {
event.preventDefault();
setIsSourcesPopupOpen(!isSourcesPopupOpen);
}
};
document.addEventListener('keydown', handleKeyDown);
return () => {
document.removeEventListener('keydown', handleKeyDown);
};
}, [browserOS]);
const handleFileAttachment = (e: React.ChangeEvent<HTMLInputElement>) => { const handleFileAttachment = (e: React.ChangeEvent<HTMLInputElement>) => {
if (!e.target.files || e.target.files.length === 0) return; if (!e.target.files || e.target.files.length === 0) return;
@@ -64,56 +93,51 @@ export default function MessageInput({
const apiHost = import.meta.env.VITE_API_HOST; const apiHost = import.meta.env.VITE_API_HOST;
const xhr = new XMLHttpRequest(); const xhr = new XMLHttpRequest();
const uploadState: UploadState = { const newAttachment = {
taskId: '',
fileName: file.name, fileName: file.name,
progress: 0, progress: 0,
status: 'uploading' status: 'uploading' as const,
taskId: '',
}; };
setUploads(prev => [...prev, uploadState]); dispatch(addAttachment(newAttachment));
const uploadIndex = uploads.length;
xhr.upload.addEventListener('progress', (event) => { xhr.upload.addEventListener('progress', (event) => {
if (event.lengthComputable) { if (event.lengthComputable) {
const progress = Math.round((event.loaded / event.total) * 100); const progress = Math.round((event.loaded / event.total) * 100);
setUploads(prev => prev.map((upload, index) => dispatch(updateAttachment({
index === uploadIndex taskId: newAttachment.taskId,
? { ...upload, progress } updates: { progress }
: upload }));
));
} }
}); });
xhr.onload = () => { xhr.onload = () => {
if (xhr.status === 200) { if (xhr.status === 200) {
const response = JSON.parse(xhr.responseText); const response = JSON.parse(xhr.responseText);
console.log('File uploaded successfully:', response);
if (response.task_id) { if (response.task_id) {
setUploads(prev => prev.map((upload, index) => dispatch(updateAttachment({
index === uploadIndex taskId: newAttachment.taskId,
? { ...upload, taskId: response.task_id, status: 'processing' } updates: {
: upload taskId: response.task_id,
)); status: 'processing',
progress: 10
}
}));
} }
} else { } else {
setUploads(prev => prev.map((upload, index) => dispatch(updateAttachment({
index === uploadIndex taskId: newAttachment.taskId,
? { ...upload, status: 'failed' } updates: { status: 'failed' }
: upload }));
));
console.error('Error uploading file:', xhr.responseText);
} }
}; };
xhr.onerror = () => { xhr.onerror = () => {
setUploads(prev => prev.map((upload, index) => dispatch(updateAttachment({
index === uploadIndex taskId: newAttachment.taskId,
? { ...upload, status: 'failed' } updates: { status: 'failed' }
: upload }));
));
console.error('Network error during file upload');
}; };
xhr.open('POST', `${apiHost}${endpoints.USER.STORE_ATTACHMENT}`); xhr.open('POST', `${apiHost}${endpoints.USER.STORE_ATTACHMENT}`);
@@ -123,64 +147,55 @@ export default function MessageInput({
}; };
useEffect(() => { useEffect(() => {
let timeoutIds: number[] = [];
const checkTaskStatus = () => { const checkTaskStatus = () => {
const processingUploads = uploads.filter(upload => const processingAttachments = attachments.filter(att =>
upload.status === 'processing' && upload.taskId att.status === 'processing' && att.taskId
); );
processingUploads.forEach(upload => { processingAttachments.forEach(attachment => {
userService userService
.getTaskStatus(upload.taskId, null) .getTaskStatus(attachment.taskId!, null)
.then((data) => data.json()) .then((data) => data.json())
.then((data) => { .then((data) => {
console.log('Task status:', data);
setUploads(prev => prev.map(u => {
if (u.taskId !== upload.taskId) return u;
if (data.status === 'SUCCESS') { if (data.status === 'SUCCESS') {
return { dispatch(updateAttachment({
...u, taskId: attachment.taskId!,
updates: {
status: 'completed', status: 'completed',
progress: 100, progress: 100,
attachment_id: data.result?.attachment_id, id: data.result?.attachment_id,
token_count: data.result?.token_count token_count: data.result?.token_count
};
} else if (data.status === 'FAILURE') {
return { ...u, status: 'failed' };
} else if (data.status === 'PROGRESS' && data.result?.current) {
return { ...u, progress: data.result.current };
} }
return u;
})); }));
} else if (data.status === 'FAILURE') {
if (data.status !== 'SUCCESS' && data.status !== 'FAILURE') { dispatch(updateAttachment({
const timeoutId = window.setTimeout(() => checkTaskStatus(), 2000); taskId: attachment.taskId!,
timeoutIds.push(timeoutId); updates: { status: 'failed' }
}));
} else if (data.status === 'PROGRESS' && data.result?.current) {
dispatch(updateAttachment({
taskId: attachment.taskId!,
updates: { progress: data.result.current }
}));
} }
}) })
.catch((error) => { .catch(() => {
console.error('Error checking task status:', error); dispatch(updateAttachment({
setUploads(prev => prev.map(u => taskId: attachment.taskId!,
u.taskId === upload.taskId updates: { status: 'failed' }
? { ...u, status: 'failed' } }));
: u
));
}); });
}); });
}; };
if (uploads.some(upload => upload.status === 'processing')) { const interval = setInterval(() => {
const timeoutId = window.setTimeout(checkTaskStatus, 2000); if (attachments.some(att => att.status === 'processing')) {
timeoutIds.push(timeoutId); checkTaskStatus();
} }
}, 2000);
return () => { return () => clearInterval(interval);
timeoutIds.forEach(id => clearTimeout(id)); }, [attachments, dispatch]);
};
}, [uploads]);
const handleInput = () => { const handleInput = () => {
if (inputRef.current) { if (inputRef.current) {
@@ -215,39 +230,53 @@ export default function MessageInput({
const handleSubmit = () => { const handleSubmit = () => {
const completedAttachments = uploads
.filter(upload => upload.status === 'completed' && upload.attachment_id)
.map(upload => ({
fileName: upload.fileName,
id: upload.attachment_id as string
}));
dispatch(setAttachments(completedAttachments));
onSubmit(); onSubmit();
}; };
return ( return (
<div className="flex flex-col w-full mx-2"> <div className="flex flex-col w-full mx-2">
<div className="flex flex-col w-full rounded-[23px] border dark:border-grey border-dark-gray bg-lotion dark:bg-transparent relative"> <div className="flex flex-col w-full rounded-[23px] border dark:border-grey border-dark-gray bg-lotion dark:bg-transparent relative">
<div className="flex flex-wrap gap-1.5 sm:gap-2 px-4 sm:px-6 pt-3 pb-0"> <div className="flex flex-wrap gap-1.5 sm:gap-2 px-4 sm:px-6 pt-3 pb-0">
{uploads.map((upload, index) => ( {attachments.map((attachment, index) => (
<div <div
key={index} key={index}
className="flex items-center px-2 sm:px-3 py-1 sm:py-1.5 rounded-[32px] border border-[#AAAAAA] dark:border-purple-taupe bg-white dark:bg-[#1F2028] text-[12px] sm:text-[14px] text-[#5D5D5D] dark:text-bright-gray" className={`flex items-center px-2 sm:px-3 py-1 sm:py-1.5 rounded-[32px] border border-[#AAAAAA] dark:border-purple-taupe bg-white dark:bg-[#1F2028] text-[12px] sm:text-[14px] text-[#5D5D5D] dark:text-bright-gray group relative ${
attachment.status !== 'completed' ? 'opacity-70' : 'opacity-100'
}`}
title={attachment.fileName}
> >
<span className="font-medium truncate max-w-[120px] sm:max-w-[150px]">{upload.fileName}</span> <span className="font-medium truncate max-w-[120px] sm:max-w-[150px]">{attachment.fileName}</span>
{upload.status === 'completed' && ( {attachment.status === 'completed' && (
<span className="ml-2 text-green-500"></span> <button
className="absolute right-2 top-1/2 -translate-y-1/2 opacity-0 group-hover:opacity-100 focus:opacity-100 transition-opacity bg-white dark:bg-[#1F2028] rounded-full p-1 hover:bg-white/95 dark:hover:bg-[#1F2028]/95"
onClick={() => {
if (attachment.id) {
dispatch(removeAttachment(attachment.id));
}
}}
aria-label="Remove attachment"
>
<img
src={ExitIcon}
alt="Remove"
className="w-2.5 h-2.5 filter dark:invert"
/>
</button>
)} )}
{upload.status === 'failed' && ( {attachment.status === 'failed' && (
<span className="ml-2 text-red-500"></span> <img
src={AlertIcon}
alt="Upload failed"
className="ml-2 w-3.5 h-3.5"
title="Upload failed"
/>
)} )}
{(upload.status === 'uploading' || upload.status === 'processing') && ( {(attachment.status === 'uploading' || attachment.status === 'processing') && (
<div className="ml-2 w-4 h-4 relative"> <div className="ml-2 w-4 h-4 relative">
<svg className="w-4 h-4" viewBox="0 0 24 24"> <svg className="w-4 h-4" viewBox="0 0 24 24">
{/* Background circle */}
<circle <circle
className="text-gray-200 dark:text-gray-700" className="text-gray-200 dark:text-gray-700"
cx="12" cx="12"
@@ -266,7 +295,7 @@ export default function MessageInput({
strokeWidth="4" strokeWidth="4"
fill="none" fill="none"
strokeDasharray="62.83" strokeDasharray="62.83"
strokeDashoffset={62.83 - (upload.progress / 100) * 62.83} strokeDashoffset={62.83 * (1 - attachment.progress / 100)}
transform="rotate(-90 12 12)" transform="rotate(-90 12 12)"
/> />
</svg> </svg>
@@ -298,15 +327,21 @@ export default function MessageInput({
<div className="flex-grow flex flex-wrap gap-1 sm:gap-2"> <div className="flex-grow flex flex-wrap gap-1 sm:gap-2">
<button <button
ref={sourceButtonRef} ref={sourceButtonRef}
className="flex items-center px-2 xs:px-3 py-1 xs:py-1.5 rounded-[32px] border border-[#AAAAAA] dark:border-purple-taupe hover:bg-gray-100 dark:hover:bg-[#2C2E3C] transition-colors max-w-[130px] xs:max-w-[150px]" className="flex items-center px-2 xs:px-3 py-1 xs:py-1.5 rounded-[32px] border border-[#AAAAAA] dark:border-purple-taupe hover:bg-gray-100 dark:hover:bg-[#2C2E3C] transition-colors max-w-[130px] sm:max-w-[150px]"
onClick={() => setIsSourcesPopupOpen(!isSourcesPopupOpen)} onClick={() => setIsSourcesPopupOpen(!isSourcesPopupOpen)}
title={selectedDocs ? selectedDocs.name : t('conversation.sources.title')}
> >
<img src={SourceIcon} alt="Sources" className="w-3.5 sm:w-4 h-3.5 sm:h-4 mr-1 sm:mr-1.5 flex-shrink-0" /> <img src={SourceIcon} alt="Sources" className="w-3.5 h-3.5 sm:h-4 mr-1 sm:mr-1.5 flex-shrink-0" />
<span className="text-[10px] xs:text-[12px] sm:text-[14px] text-[#5D5D5D] dark:text-bright-gray font-medium truncate overflow-hidden"> <span className="text-[10px] xs:text-[12px] sm:text-[14px] text-[#5D5D5D] dark:text-bright-gray font-medium truncate overflow-hidden">
{selectedDocs {selectedDocs
? selectedDocs.name ? selectedDocs.name
: t('conversation.sources.title')} : t('conversation.sources.title')}
</span> </span>
{!isTouch && (
<span className="hidden sm:inline-block ml-1 text-[10px] text-gray-500 dark:text-gray-400">
{browserOS === 'mac' ? '(⌘K)' : '(ctrl+K)'}
</span>
)}
</button> </button>
<button <button

View File

@@ -207,7 +207,7 @@ export default function SourcesPopup({
<div className="px-4 md:px-6 py-4 opacity-75 hover:opacity-100 transition-opacity duration-200 flex-shrink-0"> <div className="px-4 md:px-6 py-4 opacity-75 hover:opacity-100 transition-opacity duration-200 flex-shrink-0">
<a <a
href="/settings/documents" href="/settings/documents"
className="text-violets-are-blue text-base font-medium flex items-center gap-2" className="text-violets-are-blue text-base font-medium inline-flex items-center gap-2"
onClick={onClose} onClick={onClose}
> >
Go to Documents Go to Documents

View File

@@ -217,10 +217,10 @@ export default function ToolsPopup({
</div> </div>
)} )}
<div className="p-4 flex-shrink-0"> <div className="p-4 flex-shrink-0 opacity-75 hover:opacity-100 transition-opacity duration-200">
<a <a
href="/settings/tools" href="/settings/tools"
className="text-base text-purple-30 font-medium hover:text-violets-are-blue flex items-center" className="text-base text-purple-30 font-medium inline-flex items-center"
> >
{t('settings.tools.manageTools')} {t('settings.tools.manageTools')}
<img <img

View File

@@ -57,7 +57,6 @@ const ConversationBubble = forwardRef<
updated?: boolean, updated?: boolean,
index?: number, index?: number,
) => void; ) => void;
attachments?: { fileName: string; id: string }[];
} }
>(function ConversationBubble( >(function ConversationBubble(
{ {
@@ -72,7 +71,6 @@ const ConversationBubble = forwardRef<
retryBtn, retryBtn,
questionNumber, questionNumber,
handleUpdatedQuestionSubmission, handleUpdatedQuestionSubmission,
attachments,
}, },
ref, ref,
) { ) {
@@ -99,36 +97,6 @@ const ConversationBubble = forwardRef<
handleUpdatedQuestionSubmission?.(editInputBox, true, questionNumber); handleUpdatedQuestionSubmission?.(editInputBox, true, questionNumber);
}; };
let bubble; let bubble;
const renderAttachments = () => {
if (!attachments || attachments.length === 0) return null;
return (
<div className="mt-2 flex flex-wrap gap-2">
{attachments.map((attachment, index) => (
<div
key={index}
className="flex items-center rounded-md bg-gray-100 px-2 py-1 text-sm dark:bg-gray-700"
>
<svg
className="mr-1 h-4 w-4"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M15.172 7l-6.586 6.586a2 2 0 102.828 2.828l6.414-6.586a4 4 0 00-5.656-5.656l-6.415 6.585a6 6 0 108.486 8.486L20.5 13"
/>
</svg>
<span>{attachment.fileName}</span>
</div>
))}
</div>
);
};
if (type === 'QUESTION') { if (type === 'QUESTION') {
bubble = ( bubble = (
<div <div
@@ -157,7 +125,6 @@ const ConversationBubble = forwardRef<
> >
{message} {message}
</div> </div>
{renderAttachments()}
</div> </div>
<button <button
onClick={() => { onClick={() => {

View File

@@ -161,7 +161,6 @@ export default function ConversationMessages({
{queries.length > 0 ? ( {queries.length > 0 ? (
queries.map((query, index) => ( queries.map((query, index) => (
<Fragment key={index}> <Fragment key={index}>
<ConversationBubble <ConversationBubble
className={'first:mt-5'} className={'first:mt-5'}
key={`${index}QUESTION`} key={`${index}QUESTION`}
@@ -170,7 +169,6 @@ export default function ConversationMessages({
handleUpdatedQuestionSubmission={handleQuestionSubmission} handleUpdatedQuestionSubmission={handleQuestionSubmission}
questionNumber={index} questionNumber={index}
sources={query.sources} sources={query.sources}
attachments={query.attachments}
/> />
{prepResponseView(query, index)} {prepResponseView(query, index)}
</Fragment> </Fragment>

View File

@@ -9,11 +9,21 @@ export interface Message {
type: MESSAGE_TYPE; type: MESSAGE_TYPE;
} }
export interface Attachment {
id?: string;
fileName: string;
status: 'uploading' | 'processing' | 'completed' | 'failed';
progress: number;
taskId?: string;
token_count?: number;
}
export interface ConversationState { export interface ConversationState {
queries: Query[]; queries: Query[];
status: Status; status: Status;
conversationId: string | null; conversationId: string | null;
attachments?: { fileName: string; id: string }[]; attachments: Attachment[];
} }
export interface Answer { export interface Answer {

View File

@@ -7,7 +7,7 @@ import {
handleFetchAnswer, handleFetchAnswer,
handleFetchAnswerSteaming, handleFetchAnswerSteaming,
} from './conversationHandlers'; } from './conversationHandlers';
import { Answer, ConversationState, Query, Status } from './conversationModels'; import { Answer, Query, Status, ConversationState, Attachment } from './conversationModels';
const initialState: ConversationState = { const initialState: ConversationState = {
queries: [], queries: [],
@@ -38,7 +38,9 @@ export const fetchAnswer = createAsyncThunk<
let isSourceUpdated = false; let isSourceUpdated = false;
const state = getState() as RootState; const state = getState() as RootState;
const attachments = state.conversation.attachments?.map(a => a.id) || []; const attachmentIds = state.conversation.attachments
.filter(a => a.id && a.status === 'completed')
.map(a => a.id) as string[];
if (state.preference) { if (state.preference) {
if (API_STREAMING) { if (API_STREAMING) {
@@ -122,7 +124,7 @@ export const fetchAnswer = createAsyncThunk<
} }
}, },
indx, indx,
attachments attachmentIds
); );
} else { } else {
const answer = await handleFetchAnswer( const answer = await handleFetchAnswer(
@@ -135,7 +137,7 @@ export const fetchAnswer = createAsyncThunk<
state.preference.prompt.id, state.preference.prompt.id,
state.preference.chunks, state.preference.chunks,
state.preference.token_limit, state.preference.token_limit,
attachments attachmentIds
); );
if (answer) { if (answer) {
let sourcesPrepped = []; let sourcesPrepped = [];
@@ -286,9 +288,29 @@ export const conversationSlice = createSlice({
const { index, message } = action.payload; const { index, message } = action.payload;
state.queries[index].error = message; state.queries[index].error = message;
}, },
setAttachments: (state, action: PayloadAction<{ fileName: string; id: string }[]>) => { setAttachments: (state, action: PayloadAction<Attachment[]>) => {
state.attachments = action.payload; state.attachments = action.payload;
}, },
addAttachment: (state, action: PayloadAction<Attachment>) => {
state.attachments.push(action.payload);
},
updateAttachment: (state, action: PayloadAction<{
taskId: string;
updates: Partial<Attachment>;
}>) => {
const index = state.attachments.findIndex(att => att.taskId === action.payload.taskId);
if (index !== -1) {
state.attachments[index] = {
...state.attachments[index],
...action.payload.updates
};
}
},
removeAttachment: (state, action: PayloadAction<string>) => {
state.attachments = state.attachments.filter(att =>
att.taskId !== action.payload && att.id !== action.payload
);
},
}, },
extraReducers(builder) { extraReducers(builder) {
builder builder
@@ -312,6 +334,10 @@ export const selectQueries = (state: RootState) => state.conversation.queries;
export const selectStatus = (state: RootState) => state.conversation.status; export const selectStatus = (state: RootState) => state.conversation.status;
export const selectAttachments = (state: RootState) => state.conversation.attachments;
export const selectCompletedAttachments = (state: RootState) =>
state.conversation.attachments.filter(att => att.status === 'completed');
export const { export const {
addQuery, addQuery,
updateQuery, updateQuery,
@@ -323,5 +349,8 @@ export const {
updateToolCalls, updateToolCalls,
setConversation, setConversation,
setAttachments, setAttachments,
addAttachment,
updateAttachment,
removeAttachment,
} = conversationSlice.actions; } = conversationSlice.actions;
export default conversationSlice.reducer; export default conversationSlice.reducer;

View File

@@ -0,0 +1,10 @@
export function getOS() {
const userAgent = window.navigator.userAgent;
if (userAgent.indexOf('Mac') !== -1) return 'mac';
if (userAgent.indexOf('Win') !== -1) return 'win';
return 'linux';
}
export function isTouchDevice() {
return 'ontouchstart' in window || navigator.maxTouchPoints > 0;
}