From c629460acb9cdec630dee3c03f71c4e0b0650206 Mon Sep 17 00:00:00 2001 From: ManishMadan2882 Date: Wed, 2 Apr 2025 15:21:33 +0530 Subject: [PATCH] (feat:attach) extract contents in endpoint layer --- application/api/answer/routes.py | 102 ++++++++++++++++++++++++------- 1 file changed, 79 insertions(+), 23 deletions(-) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 34081784..cf90f92b 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -29,6 +29,7 @@ sources_collection = db["sources"] prompts_collection = db["prompts"] api_key_collection = db["api_keys"] user_logs_collection = db["user_logs"] +attachments_collection = db["attachments"] answer = Blueprint("answer", __name__) answer_ns = Namespace("answer", description="Answer related operations", path="/") @@ -127,20 +128,24 @@ def save_conversation( decoded_token, index=None, api_key=None, + attachment_ids=None, ): current_time = datetime.datetime.now(datetime.timezone.utc) if conversation_id is not None and index is not None: + update_data = { + f"queries.{index}.prompt": question, + f"queries.{index}.response": response, + f"queries.{index}.sources": source_log_docs, + f"queries.{index}.tool_calls": tool_calls, + f"queries.{index}.timestamp": current_time, + } + + if attachment_ids: + update_data[f"queries.{index}.attachments"] = attachment_ids + conversations_collection.update_one( {"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}}, - { - "$set": { - f"queries.{index}.prompt": question, - f"queries.{index}.response": response, - f"queries.{index}.sources": source_log_docs, - f"queries.{index}.tool_calls": tool_calls, - f"queries.{index}.timestamp": current_time, - } - }, + {"$set": update_data}, ) ##remove following queries from the array conversations_collection.update_one( @@ -148,19 +153,20 @@ def save_conversation( {"$push": {"queries": {"$each": [], "$slice": index + 1}}}, ) elif conversation_id is not None and conversation_id != "None": + query_data = { + "prompt": question, + "response": response, + "sources": source_log_docs, + "tool_calls": tool_calls, + "timestamp": current_time, + } + + if attachment_ids: + query_data["attachments"] = attachment_ids + conversations_collection.update_one( {"_id": ObjectId(conversation_id)}, - { - "$push": { - "queries": { - "prompt": question, - "response": response, - "sources": source_log_docs, - "tool_calls": tool_calls, - "timestamp": current_time, - } - } - }, + {"$push": {"queries": query_data}}, ) else: @@ -228,11 +234,17 @@ def complete_stream( isNoneDoc=False, index=None, should_save_conversation=True, + attachments=None, ): try: response_full = "" source_log_docs = [] tool_calls = [] + attachment_ids = [] + + if attachments: + attachment_ids = [attachment["id"] for attachment in attachments] + logger.info(f"Processing request with {len(attachments)} attachments: {attachment_ids}") answer = agent.gen(query=question, retriever=retriever) @@ -281,6 +293,7 @@ def complete_stream( decoded_token, index, api_key=user_api_key, + attachment_ids=attachment_ids, ) else: conversation_id = None @@ -300,6 +313,7 @@ def complete_stream( "response": response_full, "sources": source_log_docs, "retriever_params": retriever_params, + "attachments": attachment_ids, "timestamp": datetime.datetime.now(datetime.timezone.utc), } ) @@ -348,10 +362,13 @@ class Stream(Resource): required=False, description="Flag indicating if no document is used" ), "index": fields.Integer( - required=False, description="The position where query is to be updated" + required=False, description="Index of the query to update" ), "save_conversation": fields.Boolean( - required=False, default=True, description="Flag to save conversation" + required=False, default=True, description="Whether to save the conversation" + ), + "attachments": fields.List( + fields.String, required=False, description="List of attachment IDs" ), }, ) @@ -376,6 +393,7 @@ class Stream(Resource): ) conversation_id = data.get("conversation_id") prompt_id = data.get("prompt_id", "default") + attachment_ids = data.get("attachments", []) index = data.get("index", None) chunks = int(data.get("chunks", 2)) @@ -404,9 +422,11 @@ class Stream(Resource): if not decoded_token: return make_response({"error": "Unauthorized"}, 401) + + attachments = get_attachments_content(attachment_ids, decoded_token.get("sub")) logger.info( - f"/stream - request_data: {data}, source: {source}", + f"/stream - request_data: {data}, source: {source}, attachments: {len(attachments)}", extra={"data": json.dumps({"request_data": data, "source": source})}, ) @@ -424,6 +444,7 @@ class Stream(Resource): prompt=prompt, chat_history=history, decoded_token=decoded_token, + attachments=attachments, ) retriever = RetrieverCreator.create_retriever( @@ -784,3 +805,38 @@ class Search(Resource): return bad_request(500, str(e)) return make_response(docs, 200) + + +def get_attachments_content(attachment_ids, user): + """ + Retrieve content from attachment documents based on their IDs. + + Args: + attachment_ids (list): List of attachment document IDs + user (str): User identifier to verify ownership + + Returns: + list: List of dictionaries containing attachment content and metadata + """ + if not attachment_ids: + return [] + + attachments = [] + for attachment_id in attachment_ids: + try: + attachment_doc = attachments_collection.find_one({ + "_id": ObjectId(attachment_id), + "user": user + }) + + if attachment_doc: + attachments.append({ + "id": str(attachment_doc["_id"]), + "content": attachment_doc["content"], + "token_count": attachment_doc.get("token_count", 0), + "path": attachment_doc.get("path", "") + }) + except Exception as e: + logger.error(f"Error retrieving attachment {attachment_id}: {e}") + + return attachments