(feat:attach) extract contents in endpoint layer

This commit is contained in:
ManishMadan2882
2025-04-02 15:21:33 +05:30
parent f235a94986
commit c629460acb

View File

@@ -29,6 +29,7 @@ sources_collection = db["sources"]
prompts_collection = db["prompts"]
api_key_collection = db["api_keys"]
user_logs_collection = db["user_logs"]
attachments_collection = db["attachments"]
answer = Blueprint("answer", __name__)
answer_ns = Namespace("answer", description="Answer related operations", path="/")
@@ -127,20 +128,24 @@ def save_conversation(
decoded_token,
index=None,
api_key=None,
attachment_ids=None,
):
current_time = datetime.datetime.now(datetime.timezone.utc)
if conversation_id is not None and index is not None:
update_data = {
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.sources": source_log_docs,
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
}
if attachment_ids:
update_data[f"queries.{index}.attachments"] = attachment_ids
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$set": {
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.sources": source_log_docs,
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
}
},
{"$set": update_data},
)
##remove following queries from the array
conversations_collection.update_one(
@@ -148,19 +153,20 @@ def save_conversation(
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
elif conversation_id is not None and conversation_id != "None":
query_data = {
"prompt": question,
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
"timestamp": current_time,
}
if attachment_ids:
query_data["attachments"] = attachment_ids
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {
"queries": {
"prompt": question,
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
"timestamp": current_time,
}
}
},
{"$push": {"queries": query_data}},
)
else:
@@ -228,11 +234,17 @@ def complete_stream(
isNoneDoc=False,
index=None,
should_save_conversation=True,
attachments=None,
):
try:
response_full = ""
source_log_docs = []
tool_calls = []
attachment_ids = []
if attachments:
attachment_ids = [attachment["id"] for attachment in attachments]
logger.info(f"Processing request with {len(attachments)} attachments: {attachment_ids}")
answer = agent.gen(query=question, retriever=retriever)
@@ -281,6 +293,7 @@ def complete_stream(
decoded_token,
index,
api_key=user_api_key,
attachment_ids=attachment_ids,
)
else:
conversation_id = None
@@ -300,6 +313,7 @@ def complete_stream(
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
@@ -348,10 +362,13 @@ class Stream(Resource):
required=False, description="Flag indicating if no document is used"
),
"index": fields.Integer(
required=False, description="The position where query is to be updated"
required=False, description="Index of the query to update"
),
"save_conversation": fields.Boolean(
required=False, default=True, description="Flag to save conversation"
required=False, default=True, description="Whether to save the conversation"
),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
),
},
)
@@ -376,6 +393,7 @@ class Stream(Resource):
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
attachment_ids = data.get("attachments", [])
index = data.get("index", None)
chunks = int(data.get("chunks", 2))
@@ -404,9 +422,11 @@ class Stream(Resource):
if not decoded_token:
return make_response({"error": "Unauthorized"}, 401)
attachments = get_attachments_content(attachment_ids, decoded_token.get("sub"))
logger.info(
f"/stream - request_data: {data}, source: {source}",
f"/stream - request_data: {data}, source: {source}, attachments: {len(attachments)}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
@@ -424,6 +444,7 @@ class Stream(Resource):
prompt=prompt,
chat_history=history,
decoded_token=decoded_token,
attachments=attachments,
)
retriever = RetrieverCreator.create_retriever(
@@ -784,3 +805,38 @@ class Search(Resource):
return bad_request(500, str(e))
return make_response(docs, 200)
def get_attachments_content(attachment_ids, user):
"""
Retrieve content from attachment documents based on their IDs.
Args:
attachment_ids (list): List of attachment document IDs
user (str): User identifier to verify ownership
Returns:
list: List of dictionaries containing attachment content and metadata
"""
if not attachment_ids:
return []
attachments = []
for attachment_id in attachment_ids:
try:
attachment_doc = attachments_collection.find_one({
"_id": ObjectId(attachment_id),
"user": user
})
if attachment_doc:
attachments.append({
"id": str(attachment_doc["_id"]),
"content": attachment_doc["content"],
"token_count": attachment_doc.get("token_count", 0),
"path": attachment_doc.get("path", "")
})
except Exception as e:
logger.error(f"Error retrieving attachment {attachment_id}: {e}")
return attachments