mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: answer routes re-structure for better maintainability and reuse
This commit is contained in:
@@ -0,0 +1,7 @@
|
|||||||
|
from flask_restx import Api
|
||||||
|
|
||||||
|
api = Api(
|
||||||
|
version="1.0",
|
||||||
|
title="DocsGPT API",
|
||||||
|
description="API for DocsGPT",
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
from flask import Blueprint
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
from application.api.answer.routes.answer import AnswerResource
|
||||||
|
from application.api.answer.routes.base import answer_ns
|
||||||
|
from application.api.answer.routes.stream import StreamResource
|
||||||
|
|
||||||
|
|
||||||
|
answer = Blueprint("answer", __name__)
|
||||||
|
|
||||||
|
api.add_namespace(answer_ns)
|
||||||
|
|
||||||
|
|
||||||
|
def init_answer_routes():
|
||||||
|
api.add_resource(StreamResource, "/stream")
|
||||||
|
api.add_resource(AnswerResource, "/api/answer")
|
||||||
|
|
||||||
|
|
||||||
|
init_answer_routes()
|
||||||
|
|||||||
@@ -1,914 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from bson.dbref import DBRef
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import Blueprint, make_response, request, Response
|
|
||||||
from flask_restx import fields, Namespace, Resource
|
|
||||||
|
|
||||||
from application.agents.agent_creator import AgentCreator
|
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
|
||||||
from application.error import bad_request
|
|
||||||
from application.extensions import api
|
|
||||||
from application.llm.llm_creator import LLMCreator
|
|
||||||
from application.retriever.retriever_creator import RetrieverCreator
|
|
||||||
from application.utils import check_required_fields, limit_chat_history
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
conversations_collection = db["conversations"]
|
|
||||||
sources_collection = db["sources"]
|
|
||||||
prompts_collection = db["prompts"]
|
|
||||||
agents_collection = db["agents"]
|
|
||||||
user_logs_collection = db["user_logs"]
|
|
||||||
attachments_collection = db["attachments"]
|
|
||||||
|
|
||||||
answer = Blueprint("answer", __name__)
|
|
||||||
answer_ns = Namespace("answer", description="Answer related operations", path="/")
|
|
||||||
api.add_namespace(answer_ns)
|
|
||||||
|
|
||||||
gpt_model = ""
|
|
||||||
# to have some kind of default behaviour
|
|
||||||
if settings.LLM_PROVIDER == "openai":
|
|
||||||
gpt_model = "gpt-4o-mini"
|
|
||||||
elif settings.LLM_PROVIDER == "anthropic":
|
|
||||||
gpt_model = "claude-2"
|
|
||||||
elif settings.LLM_PROVIDER == "groq":
|
|
||||||
gpt_model = "llama3-8b-8192"
|
|
||||||
elif settings.LLM_PROVIDER == "novita":
|
|
||||||
gpt_model = "deepseek/deepseek-r1"
|
|
||||||
|
|
||||||
if settings.LLM_NAME: # in case there is particular model name configured
|
|
||||||
gpt_model = settings.LLM_NAME
|
|
||||||
|
|
||||||
# load the prompts
|
|
||||||
current_dir = os.path.dirname(
|
|
||||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
)
|
|
||||||
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
|
|
||||||
chat_combine_template = f.read()
|
|
||||||
|
|
||||||
with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f:
|
|
||||||
chat_reduce_template = f.read()
|
|
||||||
|
|
||||||
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
|
|
||||||
chat_combine_creative = f.read()
|
|
||||||
|
|
||||||
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
|
|
||||||
chat_combine_strict = f.read()
|
|
||||||
|
|
||||||
api_key_set = settings.API_KEY is not None
|
|
||||||
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
|
|
||||||
|
|
||||||
|
|
||||||
async def async_generate(chain, question, chat_history):
|
|
||||||
result = await chain.arun({"question": question, "chat_history": chat_history})
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def run_async_chain(chain, question, chat_history):
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
result = {}
|
|
||||||
try:
|
|
||||||
answer = loop.run_until_complete(async_generate(chain, question, chat_history))
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
result["answer"] = answer
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_key(agent_id, user_id):
|
|
||||||
if not agent_id:
|
|
||||||
return None, False, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
|
|
||||||
if agent is None:
|
|
||||||
raise Exception("Agent not found", 404)
|
|
||||||
|
|
||||||
is_owner = agent.get("user") == user_id
|
|
||||||
|
|
||||||
if is_owner:
|
|
||||||
agents_collection.update_one(
|
|
||||||
{"_id": ObjectId(agent_id)},
|
|
||||||
{"$set": {"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)}},
|
|
||||||
)
|
|
||||||
return str(agent["key"]), False, None
|
|
||||||
|
|
||||||
is_shared_with_user = agent.get(
|
|
||||||
"shared_publicly", False
|
|
||||||
) or user_id in agent.get("shared_with", [])
|
|
||||||
|
|
||||||
if is_shared_with_user:
|
|
||||||
return str(agent["key"]), True, agent.get("shared_token")
|
|
||||||
|
|
||||||
raise Exception("Unauthorized access to the agent", 403)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def get_data_from_api_key(api_key):
|
|
||||||
data = agents_collection.find_one({"key": api_key})
|
|
||||||
if not data:
|
|
||||||
raise Exception("Invalid API Key, please generate a new key", 401)
|
|
||||||
|
|
||||||
source = data.get("source")
|
|
||||||
if isinstance(source, DBRef):
|
|
||||||
source_doc = db.dereference(source)
|
|
||||||
data["source"] = str(source_doc["_id"])
|
|
||||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
|
||||||
else:
|
|
||||||
data["source"] = {}
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def get_retriever(source_id: str):
|
|
||||||
doc = sources_collection.find_one({"_id": ObjectId(source_id)})
|
|
||||||
if doc is None:
|
|
||||||
raise Exception("Source document does not exist", 404)
|
|
||||||
retriever_name = None if "retriever" not in doc else doc["retriever"]
|
|
||||||
return retriever_name
|
|
||||||
|
|
||||||
|
|
||||||
def is_azure_configured():
|
|
||||||
return (
|
|
||||||
settings.OPENAI_API_BASE
|
|
||||||
and settings.OPENAI_API_VERSION
|
|
||||||
and settings.AZURE_DEPLOYMENT_NAME
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def save_conversation(
|
|
||||||
conversation_id,
|
|
||||||
question,
|
|
||||||
response,
|
|
||||||
thought,
|
|
||||||
source_log_docs,
|
|
||||||
tool_calls,
|
|
||||||
llm,
|
|
||||||
decoded_token,
|
|
||||||
index=None,
|
|
||||||
api_key=None,
|
|
||||||
agent_id=None,
|
|
||||||
is_shared_usage=False,
|
|
||||||
shared_token=None,
|
|
||||||
attachment_ids=None,
|
|
||||||
):
|
|
||||||
current_time = datetime.datetime.now(datetime.timezone.utc)
|
|
||||||
if conversation_id is not None and index is not None:
|
|
||||||
conversations_collection.update_one(
|
|
||||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
f"queries.{index}.prompt": question,
|
|
||||||
f"queries.{index}.response": response,
|
|
||||||
f"queries.{index}.thought": thought,
|
|
||||||
f"queries.{index}.sources": source_log_docs,
|
|
||||||
f"queries.{index}.tool_calls": tool_calls,
|
|
||||||
f"queries.{index}.timestamp": current_time,
|
|
||||||
f"queries.{index}.attachments": attachment_ids,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
##remove following queries from the array
|
|
||||||
conversations_collection.update_one(
|
|
||||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
|
||||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
|
||||||
)
|
|
||||||
elif conversation_id is not None and conversation_id != "None":
|
|
||||||
conversations_collection.update_one(
|
|
||||||
{"_id": ObjectId(conversation_id)},
|
|
||||||
{
|
|
||||||
"$push": {
|
|
||||||
"queries": {
|
|
||||||
"prompt": question,
|
|
||||||
"response": response,
|
|
||||||
"thought": thought,
|
|
||||||
"sources": source_log_docs,
|
|
||||||
"tool_calls": tool_calls,
|
|
||||||
"timestamp": current_time,
|
|
||||||
"attachments": attachment_ids,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# create new conversation
|
|
||||||
# generate summary
|
|
||||||
messages_summary = [
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Summarise following conversation in no more than 3 "
|
|
||||||
"words, respond ONLY with the summary, use the same "
|
|
||||||
"language as the user query",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Summarise following conversation in no more than 3 words, "
|
|
||||||
"respond ONLY with the summary, use the same language as the "
|
|
||||||
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
|
|
||||||
conversation_data = {
|
|
||||||
"user": decoded_token.get("sub"),
|
|
||||||
"date": datetime.datetime.utcnow(),
|
|
||||||
"name": completion,
|
|
||||||
"queries": [
|
|
||||||
{
|
|
||||||
"prompt": question,
|
|
||||||
"response": response,
|
|
||||||
"thought": thought,
|
|
||||||
"sources": source_log_docs,
|
|
||||||
"tool_calls": tool_calls,
|
|
||||||
"timestamp": current_time,
|
|
||||||
"attachments": attachment_ids,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
if api_key:
|
|
||||||
if agent_id:
|
|
||||||
conversation_data["agent_id"] = agent_id
|
|
||||||
if is_shared_usage:
|
|
||||||
conversation_data["is_shared_usage"] = is_shared_usage
|
|
||||||
conversation_data["shared_token"] = shared_token
|
|
||||||
api_key_doc = agents_collection.find_one({"key": api_key})
|
|
||||||
if api_key_doc:
|
|
||||||
conversation_data["api_key"] = api_key_doc["key"]
|
|
||||||
conversation_id = conversations_collection.insert_one(
|
|
||||||
conversation_data
|
|
||||||
).inserted_id
|
|
||||||
return conversation_id
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt(prompt_id):
|
|
||||||
if prompt_id == "default":
|
|
||||||
prompt = chat_combine_template
|
|
||||||
elif prompt_id == "creative":
|
|
||||||
prompt = chat_combine_creative
|
|
||||||
elif prompt_id == "strict":
|
|
||||||
prompt = chat_combine_strict
|
|
||||||
else:
|
|
||||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def complete_stream(
|
|
||||||
question,
|
|
||||||
agent,
|
|
||||||
retriever,
|
|
||||||
conversation_id,
|
|
||||||
user_api_key,
|
|
||||||
decoded_token,
|
|
||||||
isNoneDoc=False,
|
|
||||||
index=None,
|
|
||||||
should_save_conversation=True,
|
|
||||||
attachment_ids=None,
|
|
||||||
agent_id=None,
|
|
||||||
is_shared_usage=False,
|
|
||||||
shared_token=None,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
|
||||||
|
|
||||||
answer = agent.gen(query=question, retriever=retriever)
|
|
||||||
|
|
||||||
for line in answer:
|
|
||||||
if "answer" in line:
|
|
||||||
response_full += str(line["answer"])
|
|
||||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
elif "sources" in line:
|
|
||||||
truncated_sources = []
|
|
||||||
source_log_docs = line["sources"]
|
|
||||||
for source in line["sources"]:
|
|
||||||
truncated_source = source.copy()
|
|
||||||
if "text" in truncated_source:
|
|
||||||
truncated_source["text"] = (
|
|
||||||
truncated_source["text"][:100].strip() + "..."
|
|
||||||
)
|
|
||||||
truncated_sources.append(truncated_source)
|
|
||||||
if len(truncated_sources) > 0:
|
|
||||||
data = json.dumps({"type": "source", "source": truncated_sources})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
elif "tool_calls" in line:
|
|
||||||
tool_calls = line["tool_calls"]
|
|
||||||
elif "thought" in line:
|
|
||||||
thought += line["thought"]
|
|
||||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
elif "type" in line:
|
|
||||||
data = json.dumps(line)
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
|
|
||||||
if isNoneDoc:
|
|
||||||
for doc in source_log_docs:
|
|
||||||
doc["source"] = "None"
|
|
||||||
|
|
||||||
llm = LLMCreator.create_llm(
|
|
||||||
settings.LLM_PROVIDER,
|
|
||||||
api_key=settings.API_KEY,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_save_conversation:
|
|
||||||
conversation_id = save_conversation(
|
|
||||||
conversation_id,
|
|
||||||
question,
|
|
||||||
response_full,
|
|
||||||
thought,
|
|
||||||
source_log_docs,
|
|
||||||
tool_calls,
|
|
||||||
llm,
|
|
||||||
decoded_token,
|
|
||||||
index,
|
|
||||||
api_key=user_api_key,
|
|
||||||
attachment_ids=attachment_ids,
|
|
||||||
agent_id=agent_id,
|
|
||||||
is_shared_usage=is_shared_usage,
|
|
||||||
shared_token=shared_token,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
conversation_id = None
|
|
||||||
|
|
||||||
# send data.type = "end" to indicate that the stream has ended as json
|
|
||||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
|
|
||||||
retriever_params = retriever.get_params()
|
|
||||||
user_logs_collection.insert_one(
|
|
||||||
{
|
|
||||||
"action": "stream_answer",
|
|
||||||
"level": "info",
|
|
||||||
"user": decoded_token.get("sub"),
|
|
||||||
"api_key": user_api_key,
|
|
||||||
"question": question,
|
|
||||||
"response": response_full,
|
|
||||||
"sources": source_log_docs,
|
|
||||||
"retriever_params": retriever_params,
|
|
||||||
"attachments": attachment_ids,
|
|
||||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
data = json.dumps({"type": "end"})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
|
||||||
data = json.dumps(
|
|
||||||
{
|
|
||||||
"type": "error",
|
|
||||||
"error": "Please try again later. We apologize for any inconvenience.",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
@answer_ns.route("/stream")
|
|
||||||
class Stream(Resource):
|
|
||||||
stream_model = api.model(
|
|
||||||
"StreamModel",
|
|
||||||
{
|
|
||||||
"question": fields.String(
|
|
||||||
required=True, description="Question to be asked"
|
|
||||||
),
|
|
||||||
"history": fields.List(
|
|
||||||
fields.String, required=False, description="Chat history"
|
|
||||||
),
|
|
||||||
"conversation_id": fields.String(
|
|
||||||
required=False, description="Conversation ID"
|
|
||||||
),
|
|
||||||
"prompt_id": fields.String(
|
|
||||||
required=False, default="default", description="Prompt ID"
|
|
||||||
),
|
|
||||||
"chunks": fields.Integer(
|
|
||||||
required=False, default=2, description="Number of chunks"
|
|
||||||
),
|
|
||||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
|
||||||
"retriever": fields.String(required=False, description="Retriever type"),
|
|
||||||
"api_key": fields.String(required=False, description="API key"),
|
|
||||||
"active_docs": fields.String(
|
|
||||||
required=False, description="Active documents"
|
|
||||||
),
|
|
||||||
"isNoneDoc": fields.Boolean(
|
|
||||||
required=False, description="Flag indicating if no document is used"
|
|
||||||
),
|
|
||||||
"index": fields.Integer(
|
|
||||||
required=False, description="Index of the query to update"
|
|
||||||
),
|
|
||||||
"save_conversation": fields.Boolean(
|
|
||||||
required=False,
|
|
||||||
default=True,
|
|
||||||
description="Whether to save the conversation",
|
|
||||||
),
|
|
||||||
"attachments": fields.List(
|
|
||||||
fields.String, required=False, description="List of attachment IDs"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@api.expect(stream_model)
|
|
||||||
@api.doc(description="Stream a response based on the question and retriever")
|
|
||||||
def post(self):
|
|
||||||
data = request.get_json()
|
|
||||||
required_fields = ["question"]
|
|
||||||
if "index" in data:
|
|
||||||
required_fields = ["question", "conversation_id"]
|
|
||||||
missing_fields = check_required_fields(data, required_fields)
|
|
||||||
if missing_fields:
|
|
||||||
return missing_fields
|
|
||||||
|
|
||||||
save_conv = data.get("save_conversation", True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
question = data["question"]
|
|
||||||
history = limit_chat_history(
|
|
||||||
json.loads(data.get("history", "[]")), gpt_model=gpt_model
|
|
||||||
)
|
|
||||||
conversation_id = data.get("conversation_id")
|
|
||||||
prompt_id = data.get("prompt_id", "default")
|
|
||||||
attachment_ids = data.get("attachments", [])
|
|
||||||
|
|
||||||
index = data.get("index", None)
|
|
||||||
chunks = int(data.get("chunks", 2))
|
|
||||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
|
||||||
retriever_name = data.get("retriever", "classic")
|
|
||||||
agent_id = data.get("agent_id", None)
|
|
||||||
agent_type = settings.AGENT_NAME
|
|
||||||
decoded_token = getattr(request, "decoded_token", None)
|
|
||||||
user_sub = decoded_token.get("sub") if decoded_token else None
|
|
||||||
agent_key, is_shared_usage, shared_token = get_agent_key(agent_id, user_sub)
|
|
||||||
|
|
||||||
if agent_key:
|
|
||||||
data.update({"api_key": agent_key})
|
|
||||||
else:
|
|
||||||
agent_id = None
|
|
||||||
|
|
||||||
if "api_key" in data:
|
|
||||||
data_key = get_data_from_api_key(data["api_key"])
|
|
||||||
chunks = int(data_key.get("chunks", 2))
|
|
||||||
prompt_id = data_key.get("prompt_id", "default")
|
|
||||||
source = {"active_docs": data_key.get("source")}
|
|
||||||
retriever_name = data_key.get("retriever", retriever_name)
|
|
||||||
user_api_key = data["api_key"]
|
|
||||||
agent_type = data_key.get("agent_type", agent_type)
|
|
||||||
if is_shared_usage:
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
else:
|
|
||||||
decoded_token = {"sub": data_key.get("user")}
|
|
||||||
is_shared_usage = False
|
|
||||||
|
|
||||||
elif "active_docs" in data:
|
|
||||||
source = {"active_docs": data["active_docs"]}
|
|
||||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
|
||||||
user_api_key = None
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
|
|
||||||
else:
|
|
||||||
source = {}
|
|
||||||
user_api_key = None
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
|
|
||||||
if not decoded_token:
|
|
||||||
return make_response({"error": "Unauthorized"}, 401)
|
|
||||||
|
|
||||||
attachments = get_attachments_content(
|
|
||||||
attachment_ids, decoded_token.get("sub")
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"/stream - request_data: {data}, source: {source}, attachments: {len(attachments)}",
|
|
||||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = get_prompt(prompt_id)
|
|
||||||
if "isNoneDoc" in data and data["isNoneDoc"] is True:
|
|
||||||
chunks = 0
|
|
||||||
|
|
||||||
agent = AgentCreator.create_agent(
|
|
||||||
agent_type,
|
|
||||||
endpoint="stream",
|
|
||||||
llm_name=settings.LLM_PROVIDER,
|
|
||||||
gpt_model=gpt_model,
|
|
||||||
api_key=settings.API_KEY,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
prompt=prompt,
|
|
||||||
chat_history=history,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
attachments=attachments,
|
|
||||||
)
|
|
||||||
|
|
||||||
retriever = RetrieverCreator.create_retriever(
|
|
||||||
retriever_name,
|
|
||||||
source=source,
|
|
||||||
chat_history=history,
|
|
||||||
prompt=prompt,
|
|
||||||
chunks=chunks,
|
|
||||||
token_limit=token_limit,
|
|
||||||
gpt_model=gpt_model,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Response(
|
|
||||||
complete_stream(
|
|
||||||
question=question,
|
|
||||||
agent=agent,
|
|
||||||
retriever=retriever,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
isNoneDoc=data.get("isNoneDoc"),
|
|
||||||
index=index,
|
|
||||||
should_save_conversation=save_conv,
|
|
||||||
attachment_ids=attachment_ids,
|
|
||||||
agent_id=agent_id,
|
|
||||||
is_shared_usage=is_shared_usage,
|
|
||||||
shared_token=shared_token,
|
|
||||||
),
|
|
||||||
mimetype="text/event-stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
message = "Malformed request body"
|
|
||||||
logger.error(f"/stream - error: {message}")
|
|
||||||
return Response(
|
|
||||||
error_stream_generate(message),
|
|
||||||
status=400,
|
|
||||||
mimetype="text/event-stream",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
|
||||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
|
||||||
)
|
|
||||||
status_code = 400
|
|
||||||
return Response(
|
|
||||||
error_stream_generate("Unknown error occurred"),
|
|
||||||
status=status_code,
|
|
||||||
mimetype="text/event-stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def error_stream_generate(err_response):
|
|
||||||
data = json.dumps({"type": "error", "error": err_response})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
@answer_ns.route("/api/answer")
|
|
||||||
class Answer(Resource):
|
|
||||||
answer_model = api.model(
|
|
||||||
"AnswerModel",
|
|
||||||
{
|
|
||||||
"question": fields.String(
|
|
||||||
required=True, description="The question to answer"
|
|
||||||
),
|
|
||||||
"history": fields.List(
|
|
||||||
fields.String, required=False, description="Conversation history"
|
|
||||||
),
|
|
||||||
"conversation_id": fields.String(
|
|
||||||
required=False, description="Conversation ID"
|
|
||||||
),
|
|
||||||
"prompt_id": fields.String(
|
|
||||||
required=False, default="default", description="Prompt ID"
|
|
||||||
),
|
|
||||||
"chunks": fields.Integer(
|
|
||||||
required=False, default=2, description="Number of chunks"
|
|
||||||
),
|
|
||||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
|
||||||
"retriever": fields.String(required=False, description="Retriever type"),
|
|
||||||
"api_key": fields.String(required=False, description="API key"),
|
|
||||||
"active_docs": fields.String(
|
|
||||||
required=False, description="Active documents"
|
|
||||||
),
|
|
||||||
"isNoneDoc": fields.Boolean(
|
|
||||||
required=False, description="Flag indicating if no document is used"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@api.expect(answer_model)
|
|
||||||
@api.doc(description="Provide an answer based on the question and retriever")
|
|
||||||
def post(self):
|
|
||||||
data = request.get_json()
|
|
||||||
required_fields = ["question"]
|
|
||||||
missing_fields = check_required_fields(data, required_fields)
|
|
||||||
if missing_fields:
|
|
||||||
return missing_fields
|
|
||||||
|
|
||||||
try:
|
|
||||||
question = data["question"]
|
|
||||||
history = limit_chat_history(
|
|
||||||
json.loads(data.get("history", "[]")), gpt_model=gpt_model
|
|
||||||
)
|
|
||||||
conversation_id = data.get("conversation_id")
|
|
||||||
prompt_id = data.get("prompt_id", "default")
|
|
||||||
chunks = int(data.get("chunks", 2))
|
|
||||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
|
||||||
retriever_name = data.get("retriever", "classic")
|
|
||||||
agent_type = settings.AGENT_NAME
|
|
||||||
|
|
||||||
if "api_key" in data:
|
|
||||||
data_key = get_data_from_api_key(data["api_key"])
|
|
||||||
chunks = int(data_key.get("chunks", 2))
|
|
||||||
prompt_id = data_key.get("prompt_id", "default")
|
|
||||||
source = {"active_docs": data_key.get("source")}
|
|
||||||
retriever_name = data_key.get("retriever", retriever_name)
|
|
||||||
user_api_key = data["api_key"]
|
|
||||||
agent_type = data_key.get("agent_type", agent_type)
|
|
||||||
decoded_token = {"sub": data_key.get("user")}
|
|
||||||
|
|
||||||
elif "active_docs" in data:
|
|
||||||
source = {"active_docs": data["active_docs"]}
|
|
||||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
|
||||||
user_api_key = None
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
|
|
||||||
else:
|
|
||||||
source = {}
|
|
||||||
user_api_key = None
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
|
|
||||||
if not decoded_token:
|
|
||||||
return make_response({"error": "Unauthorized"}, 401)
|
|
||||||
|
|
||||||
prompt = get_prompt(prompt_id)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"/api/answer - request_data: {data}, source: {source}",
|
|
||||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
|
||||||
)
|
|
||||||
|
|
||||||
agent = AgentCreator.create_agent(
|
|
||||||
agent_type,
|
|
||||||
endpoint="api/answer",
|
|
||||||
llm_name=settings.LLM_PROVIDER,
|
|
||||||
gpt_model=gpt_model,
|
|
||||||
api_key=settings.API_KEY,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
prompt=prompt,
|
|
||||||
chat_history=history,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
retriever = RetrieverCreator.create_retriever(
|
|
||||||
retriever_name,
|
|
||||||
source=source,
|
|
||||||
chat_history=history,
|
|
||||||
prompt=prompt,
|
|
||||||
chunks=chunks,
|
|
||||||
token_limit=token_limit,
|
|
||||||
gpt_model=gpt_model,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
response_full = ""
|
|
||||||
source_log_docs = []
|
|
||||||
tool_calls = []
|
|
||||||
stream_ended = False
|
|
||||||
thought = ""
|
|
||||||
|
|
||||||
for line in complete_stream(
|
|
||||||
question=question,
|
|
||||||
agent=agent,
|
|
||||||
retriever=retriever,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
isNoneDoc=data.get("isNoneDoc"),
|
|
||||||
index=None,
|
|
||||||
should_save_conversation=False,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
event_data = line.replace("data: ", "").strip()
|
|
||||||
event = json.loads(event_data)
|
|
||||||
|
|
||||||
if event["type"] == "answer":
|
|
||||||
response_full += event["answer"]
|
|
||||||
elif event["type"] == "source":
|
|
||||||
source_log_docs = event["source"]
|
|
||||||
elif event["type"] == "tool_calls":
|
|
||||||
tool_calls = event["tool_calls"]
|
|
||||||
elif event["type"] == "thought":
|
|
||||||
thought = event["thought"]
|
|
||||||
elif event["type"] == "error":
|
|
||||||
logger.error(f"Error from stream: {event['error']}")
|
|
||||||
return bad_request(500, event["error"])
|
|
||||||
elif event["type"] == "end":
|
|
||||||
stream_ended = True
|
|
||||||
|
|
||||||
except (json.JSONDecodeError, KeyError) as e:
|
|
||||||
logger.warning(f"Error parsing stream event: {e}, line: {line}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not stream_ended:
|
|
||||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
|
||||||
return bad_request(500, "Stream ended unexpectedly.")
|
|
||||||
|
|
||||||
if data.get("isNoneDoc"):
|
|
||||||
for doc in source_log_docs:
|
|
||||||
doc["source"] = "None"
|
|
||||||
|
|
||||||
llm = LLMCreator.create_llm(
|
|
||||||
settings.LLM_PROVIDER,
|
|
||||||
api_key=settings.API_KEY,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {"answer": response_full, "sources": source_log_docs}
|
|
||||||
result["conversation_id"] = str(
|
|
||||||
save_conversation(
|
|
||||||
conversation_id,
|
|
||||||
question,
|
|
||||||
response_full,
|
|
||||||
thought,
|
|
||||||
source_log_docs,
|
|
||||||
tool_calls,
|
|
||||||
llm,
|
|
||||||
decoded_token,
|
|
||||||
api_key=user_api_key,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
retriever_params = retriever.get_params()
|
|
||||||
user_logs_collection.insert_one(
|
|
||||||
{
|
|
||||||
"action": "api_answer",
|
|
||||||
"level": "info",
|
|
||||||
"user": decoded_token.get("sub"),
|
|
||||||
"api_key": user_api_key,
|
|
||||||
"question": question,
|
|
||||||
"response": response_full,
|
|
||||||
"sources": source_log_docs,
|
|
||||||
"retriever_params": retriever_params,
|
|
||||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
|
||||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
|
||||||
)
|
|
||||||
return bad_request(500, str(e))
|
|
||||||
|
|
||||||
return make_response(result, 200)
|
|
||||||
|
|
||||||
|
|
||||||
@answer_ns.route("/api/search")
|
|
||||||
class Search(Resource):
|
|
||||||
search_model = api.model(
|
|
||||||
"SearchModel",
|
|
||||||
{
|
|
||||||
"question": fields.String(
|
|
||||||
required=True, description="The question to search"
|
|
||||||
),
|
|
||||||
"chunks": fields.Integer(
|
|
||||||
required=False, default=2, description="Number of chunks"
|
|
||||||
),
|
|
||||||
"api_key": fields.String(
|
|
||||||
required=False, description="API key for authentication"
|
|
||||||
),
|
|
||||||
"active_docs": fields.String(
|
|
||||||
required=False, description="Active documents for retrieval"
|
|
||||||
),
|
|
||||||
"retriever": fields.String(required=False, description="Retriever type"),
|
|
||||||
"token_limit": fields.Integer(
|
|
||||||
required=False, description="Limit for tokens"
|
|
||||||
),
|
|
||||||
"isNoneDoc": fields.Boolean(
|
|
||||||
required=False, description="Flag indicating if no document is used"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@api.expect(search_model)
|
|
||||||
@api.doc(
|
|
||||||
description="Search for relevant documents based on the question and retriever"
|
|
||||||
)
|
|
||||||
def post(self):
|
|
||||||
data = request.get_json()
|
|
||||||
required_fields = ["question"]
|
|
||||||
missing_fields = check_required_fields(data, required_fields)
|
|
||||||
if missing_fields:
|
|
||||||
return missing_fields
|
|
||||||
|
|
||||||
try:
|
|
||||||
question = data["question"]
|
|
||||||
chunks = int(data.get("chunks", 2))
|
|
||||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
|
||||||
retriever_name = data.get("retriever", "classic")
|
|
||||||
|
|
||||||
if "api_key" in data:
|
|
||||||
data_key = get_data_from_api_key(data["api_key"])
|
|
||||||
chunks = int(data_key.get("chunks", 2))
|
|
||||||
source = {"active_docs": data_key.get("source")}
|
|
||||||
user_api_key = data["api_key"]
|
|
||||||
decoded_token = {"sub": data_key.get("user")}
|
|
||||||
|
|
||||||
elif "active_docs" in data:
|
|
||||||
source = {"active_docs": data["active_docs"]}
|
|
||||||
user_api_key = None
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
|
|
||||||
else:
|
|
||||||
source = {}
|
|
||||||
user_api_key = None
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
|
|
||||||
if not decoded_token:
|
|
||||||
return make_response({"error": "Unauthorized"}, 401)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"/api/answer - request_data: {data}, source: {source}",
|
|
||||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
|
||||||
)
|
|
||||||
|
|
||||||
retriever = RetrieverCreator.create_retriever(
|
|
||||||
retriever_name,
|
|
||||||
source=source,
|
|
||||||
chat_history=[],
|
|
||||||
prompt="default",
|
|
||||||
chunks=chunks,
|
|
||||||
token_limit=token_limit,
|
|
||||||
gpt_model=gpt_model,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
docs = retriever.search(question)
|
|
||||||
retriever_params = retriever.get_params()
|
|
||||||
|
|
||||||
user_logs_collection.insert_one(
|
|
||||||
{
|
|
||||||
"action": "api_search",
|
|
||||||
"level": "info",
|
|
||||||
"user": decoded_token.get("sub"),
|
|
||||||
"api_key": user_api_key,
|
|
||||||
"question": question,
|
|
||||||
"sources": docs,
|
|
||||||
"retriever_params": retriever_params,
|
|
||||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if data.get("isNoneDoc"):
|
|
||||||
for doc in docs:
|
|
||||||
doc["source"] = "None"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}",
|
|
||||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
|
||||||
)
|
|
||||||
return bad_request(500, str(e))
|
|
||||||
|
|
||||||
return make_response(docs, 200)
|
|
||||||
|
|
||||||
|
|
||||||
def get_attachments_content(attachment_ids, user):
|
|
||||||
"""
|
|
||||||
Retrieve content from attachment documents based on their IDs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attachment_ids (list): List of attachment document IDs
|
|
||||||
user (str): User identifier to verify ownership
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of dictionaries containing attachment content and metadata
|
|
||||||
"""
|
|
||||||
if not attachment_ids:
|
|
||||||
return []
|
|
||||||
|
|
||||||
attachments = []
|
|
||||||
for attachment_id in attachment_ids:
|
|
||||||
try:
|
|
||||||
attachment_doc = attachments_collection.find_one(
|
|
||||||
{"_id": ObjectId(attachment_id), "user": user}
|
|
||||||
)
|
|
||||||
|
|
||||||
if attachment_doc:
|
|
||||||
attachments.append(attachment_doc)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return attachments
|
|
||||||
0
application/api/answer/routes/__init__.py
Normal file
0
application/api/answer/routes/__init__.py
Normal file
103
application/api/answer/routes/answer.py
Normal file
103
application/api/answer/routes/answer.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from flask import make_response, request
|
||||||
|
from flask_restx import fields, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
|
||||||
|
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
|
||||||
|
|
||||||
|
from application.api.answer.services.stream_processor import StreamProcessor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@answer_ns.route("/api/answer")
|
||||||
|
class AnswerResource(Resource, BaseAnswerResource):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
Resource.__init__(self, *args, **kwargs)
|
||||||
|
BaseAnswerResource.__init__(self)
|
||||||
|
|
||||||
|
answer_model = answer_ns.model(
|
||||||
|
"AnswerModel",
|
||||||
|
{
|
||||||
|
"question": fields.String(
|
||||||
|
required=True, description="Question to be asked"
|
||||||
|
),
|
||||||
|
"history": fields.List(
|
||||||
|
fields.String,
|
||||||
|
required=False,
|
||||||
|
description="Conversation history (only for new conversations)",
|
||||||
|
),
|
||||||
|
"conversation_id": fields.String(
|
||||||
|
required=False,
|
||||||
|
description="Existing conversation ID (loads history)",
|
||||||
|
),
|
||||||
|
"prompt_id": fields.String(
|
||||||
|
required=False, default="default", description="Prompt ID"
|
||||||
|
),
|
||||||
|
"chunks": fields.Integer(
|
||||||
|
required=False, default=2, description="Number of chunks"
|
||||||
|
),
|
||||||
|
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||||
|
"retriever": fields.String(required=False, description="Retriever type"),
|
||||||
|
"api_key": fields.String(required=False, description="API key"),
|
||||||
|
"active_docs": fields.String(
|
||||||
|
required=False, description="Active documents"
|
||||||
|
),
|
||||||
|
"isNoneDoc": fields.Boolean(
|
||||||
|
required=False, description="Flag indicating if no document is used"
|
||||||
|
),
|
||||||
|
"save_conversation": fields.Boolean(
|
||||||
|
required=False,
|
||||||
|
default=True,
|
||||||
|
description="Whether to save the conversation",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(answer_model)
|
||||||
|
@api.doc(description="Provide a response based on the question and retriever")
|
||||||
|
def post(self):
|
||||||
|
data = request.get_json()
|
||||||
|
if error := self.validate_request(data):
|
||||||
|
return error
|
||||||
|
processor = StreamProcessor(data, None)
|
||||||
|
try:
|
||||||
|
processor.initialize()
|
||||||
|
if not processor.decoded_token:
|
||||||
|
return make_response({"error": "Unauthorized"}, 401)
|
||||||
|
agent = processor.create_agent()
|
||||||
|
retriever = processor.create_retriever()
|
||||||
|
|
||||||
|
stream = self.complete_stream(
|
||||||
|
question=data["question"],
|
||||||
|
agent=agent,
|
||||||
|
retriever=retriever,
|
||||||
|
conversation_id=processor.conversation_id,
|
||||||
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
isNoneDoc=data.get("isNoneDoc"),
|
||||||
|
index=None,
|
||||||
|
should_save_conversation=data.get("save_conversation", True),
|
||||||
|
)
|
||||||
|
conversation_id, response, sources, tool_calls, thought, error = (
|
||||||
|
self.process_response_stream(stream)
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
return make_response({"error": error}, 400)
|
||||||
|
result = {
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"answer": response,
|
||||||
|
"sources": sources,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"thought": thought,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||||
|
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
|
return make_response({"error": str(e)}, 500)
|
||||||
|
return make_response(result, 200)
|
||||||
226
application/api/answer/routes/base.py
Normal file
226
application/api/answer/routes/base.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Generator, List, Optional
|
||||||
|
|
||||||
|
from flask import Response
|
||||||
|
from flask_restx import Namespace
|
||||||
|
|
||||||
|
from application.api.answer.services.conversation_service import ConversationService
|
||||||
|
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.llm.llm_creator import LLMCreator
|
||||||
|
from application.utils import check_required_fields, get_gpt_model
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
answer_ns = Namespace("answer", description="Answer related operations", path="/")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAnswerResource:
|
||||||
|
"""Shared base class for answer endpoints"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
self.user_logs_collection = db["user_logs"]
|
||||||
|
self.gpt_model = get_gpt_model()
|
||||||
|
self.conversation_service = ConversationService()
|
||||||
|
|
||||||
|
def validate_request(
|
||||||
|
self, data: Dict[str, Any], require_conversation_id: bool = False
|
||||||
|
) -> Optional[Response]:
|
||||||
|
"""Common request validation"""
|
||||||
|
required_fields = ["question"]
|
||||||
|
if require_conversation_id:
|
||||||
|
required_fields.append("conversation_id")
|
||||||
|
if missing_fields := check_required_fields(data, required_fields):
|
||||||
|
return missing_fields
|
||||||
|
return None
|
||||||
|
|
||||||
|
def complete_stream(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
agent: Any,
|
||||||
|
retriever: Any,
|
||||||
|
conversation_id: Optional[str],
|
||||||
|
user_api_key: Optional[str],
|
||||||
|
decoded_token: Dict[str, Any],
|
||||||
|
isNoneDoc: bool = False,
|
||||||
|
index: Optional[int] = None,
|
||||||
|
should_save_conversation: bool = True,
|
||||||
|
attachment_ids: Optional[List[str]] = None,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
is_shared_usage: bool = False,
|
||||||
|
shared_token: Optional[str] = None,
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
"""
|
||||||
|
Generator function that streams the complete conversation response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question: The user's question
|
||||||
|
agent: The agent instance
|
||||||
|
retriever: The retriever instance
|
||||||
|
conversation_id: Existing conversation ID
|
||||||
|
user_api_key: User's API key if any
|
||||||
|
decoded_token: Decoded JWT token
|
||||||
|
isNoneDoc: Flag for document-less responses
|
||||||
|
index: Index of message to update
|
||||||
|
should_save_conversation: Whether to persist the conversation
|
||||||
|
attachment_ids: List of attachment IDs
|
||||||
|
agent_id: ID of agent used
|
||||||
|
is_shared_usage: Flag for shared agent usage
|
||||||
|
shared_token: Token for shared agent
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Server-sent event strings
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||||
|
|
||||||
|
for line in agent.gen(query=question, retriever=retriever):
|
||||||
|
if "answer" in line:
|
||||||
|
response_full += str(line["answer"])
|
||||||
|
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
elif "sources" in line:
|
||||||
|
truncated_sources = []
|
||||||
|
source_log_docs = line["sources"]
|
||||||
|
for source in line["sources"]:
|
||||||
|
truncated_source = source.copy()
|
||||||
|
if "text" in truncated_source:
|
||||||
|
truncated_source["text"] = (
|
||||||
|
truncated_source["text"][:100].strip() + "..."
|
||||||
|
)
|
||||||
|
truncated_sources.append(truncated_source)
|
||||||
|
if truncated_sources:
|
||||||
|
data = json.dumps(
|
||||||
|
{"type": "source", "source": truncated_sources}
|
||||||
|
)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
elif "tool_calls" in line:
|
||||||
|
tool_calls = line["tool_calls"]
|
||||||
|
elif "thought" in line:
|
||||||
|
thought += line["thought"]
|
||||||
|
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
elif "type" in line:
|
||||||
|
data = json.dumps(line)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
if isNoneDoc:
|
||||||
|
for doc in source_log_docs:
|
||||||
|
doc["source"] = "None"
|
||||||
|
llm = LLMCreator.create_llm(
|
||||||
|
settings.LLM_PROVIDER,
|
||||||
|
api_key=settings.API_KEY,
|
||||||
|
user_api_key=user_api_key,
|
||||||
|
decoded_token=decoded_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_save_conversation:
|
||||||
|
conversation_id = self.conversation_service.save_conversation(
|
||||||
|
conversation_id,
|
||||||
|
question,
|
||||||
|
response_full,
|
||||||
|
thought,
|
||||||
|
source_log_docs,
|
||||||
|
tool_calls,
|
||||||
|
llm,
|
||||||
|
self.gpt_model,
|
||||||
|
decoded_token,
|
||||||
|
index=index,
|
||||||
|
api_key=user_api_key,
|
||||||
|
agent_id=agent_id,
|
||||||
|
is_shared_usage=is_shared_usage,
|
||||||
|
shared_token=shared_token,
|
||||||
|
attachment_ids=attachment_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
conversation_id = None
|
||||||
|
# Send conversation ID
|
||||||
|
|
||||||
|
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# Log the interaction
|
||||||
|
|
||||||
|
retriever_params = retriever.get_params()
|
||||||
|
self.user_logs_collection.insert_one(
|
||||||
|
{
|
||||||
|
"action": "stream_answer",
|
||||||
|
"level": "info",
|
||||||
|
"user": decoded_token.get("sub"),
|
||||||
|
"api_key": user_api_key,
|
||||||
|
"question": question,
|
||||||
|
"response": response_full,
|
||||||
|
"sources": source_log_docs,
|
||||||
|
"retriever_params": retriever_params,
|
||||||
|
"attachments": attachment_ids,
|
||||||
|
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# End of stream
|
||||||
|
|
||||||
|
data = json.dumps({"type": "end"})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||||
|
data = json.dumps(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"error": "Please try again later. We apologize for any inconvenience.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
def process_response_stream(self, stream):
|
||||||
|
"""Process the stream response for non-streaming endpoint"""
|
||||||
|
conversation_id = ""
|
||||||
|
response_full = ""
|
||||||
|
source_log_docs = []
|
||||||
|
tool_calls = []
|
||||||
|
thought = ""
|
||||||
|
stream_ended = False
|
||||||
|
|
||||||
|
for line in stream:
|
||||||
|
try:
|
||||||
|
event_data = line.replace("data: ", "").strip()
|
||||||
|
event = json.loads(event_data)
|
||||||
|
|
||||||
|
if event["type"] == "id":
|
||||||
|
conversation_id = event["id"]
|
||||||
|
elif event["type"] == "answer":
|
||||||
|
response_full += event["answer"]
|
||||||
|
elif event["type"] == "source":
|
||||||
|
source_log_docs = event["source"]
|
||||||
|
elif event["type"] == "tool_calls":
|
||||||
|
tool_calls = event["tool_calls"]
|
||||||
|
elif event["type"] == "thought":
|
||||||
|
thought = event["thought"]
|
||||||
|
elif event["type"] == "error":
|
||||||
|
logger.error(f"Error from stream: {event['error']}")
|
||||||
|
return None, None, None, None, event["error"]
|
||||||
|
elif event["type"] == "end":
|
||||||
|
stream_ended = True
|
||||||
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
|
logger.warning(f"Error parsing stream event: {e}, line: {line}")
|
||||||
|
continue
|
||||||
|
if not stream_ended:
|
||||||
|
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||||
|
return None, None, None, None, "Stream ended unexpectedly"
|
||||||
|
return (
|
||||||
|
conversation_id,
|
||||||
|
response_full,
|
||||||
|
source_log_docs,
|
||||||
|
tool_calls,
|
||||||
|
thought,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def error_stream_generate(self, err_response):
|
||||||
|
data = json.dumps({"type": "error", "error": err_response})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
116
application/api/answer/routes/stream.py
Normal file
116
application/api/answer/routes/stream.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from flask import make_response, request, Response
|
||||||
|
from flask_restx import fields, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
|
||||||
|
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
|
||||||
|
|
||||||
|
from application.api.answer.services.stream_processor import StreamProcessor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@answer_ns.route("/stream")
|
||||||
|
class StreamResource(Resource, BaseAnswerResource):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
Resource.__init__(self, *args, **kwargs)
|
||||||
|
BaseAnswerResource.__init__(self)
|
||||||
|
|
||||||
|
stream_model = answer_ns.model(
|
||||||
|
"StreamModel",
|
||||||
|
{
|
||||||
|
"question": fields.String(
|
||||||
|
required=True, description="Question to be asked"
|
||||||
|
),
|
||||||
|
"history": fields.List(
|
||||||
|
fields.String,
|
||||||
|
required=False,
|
||||||
|
description="Conversation history (only for new conversations)",
|
||||||
|
),
|
||||||
|
"conversation_id": fields.String(
|
||||||
|
required=False,
|
||||||
|
description="Existing conversation ID (loads history)",
|
||||||
|
),
|
||||||
|
"prompt_id": fields.String(
|
||||||
|
required=False, default="default", description="Prompt ID"
|
||||||
|
),
|
||||||
|
"chunks": fields.Integer(
|
||||||
|
required=False, default=2, description="Number of chunks"
|
||||||
|
),
|
||||||
|
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||||
|
"retriever": fields.String(required=False, description="Retriever type"),
|
||||||
|
"api_key": fields.String(required=False, description="API key"),
|
||||||
|
"active_docs": fields.String(
|
||||||
|
required=False, description="Active documents"
|
||||||
|
),
|
||||||
|
"isNoneDoc": fields.Boolean(
|
||||||
|
required=False, description="Flag indicating if no document is used"
|
||||||
|
),
|
||||||
|
"index": fields.Integer(
|
||||||
|
required=False, description="Index of the query to update"
|
||||||
|
),
|
||||||
|
"save_conversation": fields.Boolean(
|
||||||
|
required=False,
|
||||||
|
default=True,
|
||||||
|
description="Whether to save the conversation",
|
||||||
|
),
|
||||||
|
"attachments": fields.List(
|
||||||
|
fields.String, required=False, description="List of attachment IDs"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(stream_model)
|
||||||
|
@api.doc(description="Stream a response based on the question and retriever")
|
||||||
|
def post(self):
|
||||||
|
data = request.get_json()
|
||||||
|
if error := self.validate_request(data, "index" in data):
|
||||||
|
return error
|
||||||
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response({"error": "Unauthorized"}, 401)
|
||||||
|
processor = StreamProcessor(data, decoded_token)
|
||||||
|
try:
|
||||||
|
processor.initialize()
|
||||||
|
agent = processor.create_agent()
|
||||||
|
retriever = processor.create_retriever()
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
self.complete_stream(
|
||||||
|
question=data["question"],
|
||||||
|
agent=agent,
|
||||||
|
retriever=retriever,
|
||||||
|
conversation_id=processor.conversation_id,
|
||||||
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
isNoneDoc=data.get("isNoneDoc"),
|
||||||
|
index=data.get("index"),
|
||||||
|
should_save_conversation=data.get("save_conversation", True),
|
||||||
|
attachment_ids=data.get("attachments", []),
|
||||||
|
agent_id=data.get("agent_id"),
|
||||||
|
is_shared_usage=processor.is_shared_usage,
|
||||||
|
shared_token=processor.shared_token,
|
||||||
|
),
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
message = "Malformed request body"
|
||||||
|
logger.error(f"/stream - error: {message}")
|
||||||
|
return Response(
|
||||||
|
self.error_stream_generate(message),
|
||||||
|
status=400,
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||||
|
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
|
return Response(
|
||||||
|
self.error_stream_generate("Unknown error occurred"),
|
||||||
|
status=400,
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
0
application/api/answer/services/__init__.py
Normal file
0
application/api/answer/services/__init__.py
Normal file
175
application/api/answer/services/conversation_service.py
Normal file
175
application/api/answer/services/conversation_service.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationService:
|
||||||
|
def __init__(self):
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
self.conversations_collection = db["conversations"]
|
||||||
|
self.agents_collection = db["agents"]
|
||||||
|
|
||||||
|
def get_conversation(
|
||||||
|
self, conversation_id: str, user_id: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Retrieve a conversation with proper access control"""
|
||||||
|
if not conversation_id or not user_id:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
conversation = self.conversations_collection.find_one(
|
||||||
|
{
|
||||||
|
"_id": ObjectId(conversation_id),
|
||||||
|
"$or": [{"user": user_id}, {"shared_with": user_id}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
logger.warning(
|
||||||
|
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
conversation["_id"] = str(conversation["_id"])
|
||||||
|
return conversation
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def save_conversation(
|
||||||
|
self,
|
||||||
|
conversation_id: Optional[str],
|
||||||
|
question: str,
|
||||||
|
response: str,
|
||||||
|
thought: str,
|
||||||
|
sources: List[Dict[str, Any]],
|
||||||
|
tool_calls: List[Dict[str, Any]],
|
||||||
|
llm: Any,
|
||||||
|
gpt_model: str,
|
||||||
|
decoded_token: Dict[str, Any],
|
||||||
|
index: Optional[int] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
is_shared_usage: bool = False,
|
||||||
|
shared_token: Optional[str] = None,
|
||||||
|
attachment_ids: Optional[List[str]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Save or update a conversation in the database"""
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
|
if not user_id:
|
||||||
|
raise ValueError("User ID not found in token")
|
||||||
|
current_time = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
if conversation_id is not None and index is not None:
|
||||||
|
# Update existing conversation with new query
|
||||||
|
|
||||||
|
result = self.conversations_collection.update_one(
|
||||||
|
{
|
||||||
|
"_id": ObjectId(conversation_id),
|
||||||
|
"user": user_id,
|
||||||
|
f"queries.{index}": {"$exists": True},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
f"queries.{index}.prompt": question,
|
||||||
|
f"queries.{index}.response": response,
|
||||||
|
f"queries.{index}.thought": thought,
|
||||||
|
f"queries.{index}.sources": sources,
|
||||||
|
f"queries.{index}.tool_calls": tool_calls,
|
||||||
|
f"queries.{index}.timestamp": current_time,
|
||||||
|
f"queries.{index}.attachments": attachment_ids,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.matched_count == 0:
|
||||||
|
raise ValueError("Conversation not found or unauthorized")
|
||||||
|
self.conversations_collection.update_one(
|
||||||
|
{
|
||||||
|
"_id": ObjectId(conversation_id),
|
||||||
|
"user": user_id,
|
||||||
|
f"queries.{index}": {"$exists": True},
|
||||||
|
},
|
||||||
|
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||||
|
)
|
||||||
|
return conversation_id
|
||||||
|
elif conversation_id:
|
||||||
|
# Append new message to existing conversation
|
||||||
|
|
||||||
|
result = self.conversations_collection.update_one(
|
||||||
|
{"_id": ObjectId(conversation_id), "user": user_id},
|
||||||
|
{
|
||||||
|
"$push": {
|
||||||
|
"queries": {
|
||||||
|
"prompt": question,
|
||||||
|
"response": response,
|
||||||
|
"thought": thought,
|
||||||
|
"sources": sources,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"timestamp": current_time,
|
||||||
|
"attachments": attachment_ids,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.matched_count == 0:
|
||||||
|
raise ValueError("Conversation not found or unauthorized")
|
||||||
|
return conversation_id
|
||||||
|
else:
|
||||||
|
# Create new conversation
|
||||||
|
|
||||||
|
messages_summary = [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Summarise following conversation in no more than 3 "
|
||||||
|
"words, respond ONLY with the summary, use the same "
|
||||||
|
"language as the user query",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Summarise following conversation in no more than 3 words, "
|
||||||
|
"respond ONLY with the summary, use the same language as the "
|
||||||
|
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
completion = llm.gen(
|
||||||
|
model=gpt_model, messages=messages_summary, max_tokens=30
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_data = {
|
||||||
|
"user": user_id,
|
||||||
|
"date": current_time,
|
||||||
|
"name": completion,
|
||||||
|
"queries": [
|
||||||
|
{
|
||||||
|
"prompt": question,
|
||||||
|
"response": response,
|
||||||
|
"thought": thought,
|
||||||
|
"sources": sources,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"timestamp": current_time,
|
||||||
|
"attachments": attachment_ids,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
if agent_id:
|
||||||
|
conversation_data["agent_id"] = agent_id
|
||||||
|
if is_shared_usage:
|
||||||
|
conversation_data["is_shared_usage"] = is_shared_usage
|
||||||
|
conversation_data["shared_token"] = shared_token
|
||||||
|
agent = self.agents_collection.find_one({"key": api_key})
|
||||||
|
if agent:
|
||||||
|
conversation_data["api_key"] = agent["key"]
|
||||||
|
result = self.conversations_collection.insert_one(conversation_data)
|
||||||
|
return str(result.inserted_id)
|
||||||
260
application/api/answer/services/stream_processor.py
Normal file
260
application/api/answer/services/stream_processor.py
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from bson.dbref import DBRef
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
|
||||||
|
from application.agents.agent_creator import AgentCreator
|
||||||
|
from application.api.answer.services.conversation_service import ConversationService
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.retriever.retriever_creator import RetrieverCreator
|
||||||
|
from application.utils import get_gpt_model, limit_chat_history
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
|
||||||
|
"""
|
||||||
|
Get a prompt by preset name or MongoDB ID
|
||||||
|
"""
|
||||||
|
current_dir = Path(__file__).resolve().parents[3]
|
||||||
|
prompts_dir = current_dir / "prompts"
|
||||||
|
|
||||||
|
preset_mapping = {
|
||||||
|
"default": "chat_combine_default.txt",
|
||||||
|
"creative": "chat_combine_creative.txt",
|
||||||
|
"strict": "chat_combine_strict.txt",
|
||||||
|
"reduce": "chat_reduce_prompt.txt",
|
||||||
|
}
|
||||||
|
|
||||||
|
if prompt_id in preset_mapping:
|
||||||
|
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
|
||||||
|
try:
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
return f.read()
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise FileNotFoundError(f"Prompt file not found: {file_path}")
|
||||||
|
try:
|
||||||
|
if not prompts_collection:
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
prompts_collection = db["prompts"]
|
||||||
|
prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
|
||||||
|
if not prompt_doc:
|
||||||
|
raise ValueError(f"Prompt with ID {prompt_id} not found")
|
||||||
|
return prompt_doc["content"]
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
|
||||||
|
|
||||||
|
|
||||||
|
class StreamProcessor:
|
||||||
|
def __init__(
|
||||||
|
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
|
||||||
|
):
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
self.db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
self.agents_collection = self.db["agents"]
|
||||||
|
self.attachments_collection = self.db["attachments"]
|
||||||
|
self.prompts_collection = self.db["prompts"]
|
||||||
|
|
||||||
|
self.data = request_data
|
||||||
|
self.decoded_token = decoded_token
|
||||||
|
self.initial_user_id = (
|
||||||
|
self.decoded_token.get("sub") if self.decoded_token is not None else None
|
||||||
|
)
|
||||||
|
self.conversation_id = self.data.get("conversation_id")
|
||||||
|
self.source = (
|
||||||
|
{"active_docs": self.data["active_docs"]}
|
||||||
|
if "active_docs" in self.data
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
self.attachments = []
|
||||||
|
self.history = []
|
||||||
|
self.agent_config = {}
|
||||||
|
self.retriever_config = {}
|
||||||
|
self.is_shared_usage = False
|
||||||
|
self.shared_token = None
|
||||||
|
self.gpt_model = get_gpt_model()
|
||||||
|
self.conversation_service = ConversationService()
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""Initialize all required components for processing"""
|
||||||
|
self._configure_agent()
|
||||||
|
self._configure_retriever()
|
||||||
|
self._load_conversation_history()
|
||||||
|
self._process_attachments()
|
||||||
|
|
||||||
|
def _load_conversation_history(self):
|
||||||
|
"""Load conversation history either from DB or request"""
|
||||||
|
if self.conversation_id and self.initial_user_id:
|
||||||
|
conversation = self.conversation_service.get_conversation(
|
||||||
|
self.conversation_id, self.initial_user_id
|
||||||
|
)
|
||||||
|
if not conversation:
|
||||||
|
raise ValueError("Conversation not found or unauthorized")
|
||||||
|
self.history = [
|
||||||
|
{"prompt": query["prompt"], "response": query["response"]}
|
||||||
|
for query in conversation.get("queries", [])
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.history = limit_chat_history(
|
||||||
|
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_attachments(self):
|
||||||
|
"""Process any attachments in the request"""
|
||||||
|
attachment_ids = self.data.get("attachments", [])
|
||||||
|
self.attachments = self._get_attachments_content(
|
||||||
|
attachment_ids, self.initial_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_attachments_content(self, attachment_ids, user_id):
|
||||||
|
"""
|
||||||
|
Retrieve content from attachment documents based on their IDs.
|
||||||
|
"""
|
||||||
|
if not attachment_ids:
|
||||||
|
return []
|
||||||
|
attachments = []
|
||||||
|
for attachment_id in attachment_ids:
|
||||||
|
try:
|
||||||
|
attachment_doc = self.attachments_collection.find_one(
|
||||||
|
{"_id": ObjectId(attachment_id), "user": user_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if attachment_doc:
|
||||||
|
attachments.append(attachment_doc)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
return attachments
|
||||||
|
|
||||||
|
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||||
|
"""Get API key for agent with access control"""
|
||||||
|
if not agent_id:
|
||||||
|
return None, False, None
|
||||||
|
try:
|
||||||
|
agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)})
|
||||||
|
if agent is None:
|
||||||
|
raise Exception("Agent not found")
|
||||||
|
is_owner = agent.get("user") == user_id
|
||||||
|
is_shared_with_user = agent.get(
|
||||||
|
"shared_publicly", False
|
||||||
|
) or user_id in agent.get("shared_with", [])
|
||||||
|
|
||||||
|
if not (is_owner or is_shared_with_user):
|
||||||
|
raise Exception("Unauthorized access to the agent")
|
||||||
|
if is_owner:
|
||||||
|
self.agents_collection.update_one(
|
||||||
|
{"_id": ObjectId(agent_id)},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return str(agent["key"]), not is_owner, agent.get("shared_token")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
|
||||||
|
data = self.agents_collection.find_one({"key": api_key})
|
||||||
|
if not data:
|
||||||
|
raise Exception("Invalid API Key, please generate a new key", 401)
|
||||||
|
source = data.get("source")
|
||||||
|
if isinstance(source, DBRef):
|
||||||
|
source_doc = self.db.dereference(source)
|
||||||
|
data["source"] = str(source_doc["_id"])
|
||||||
|
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||||
|
else:
|
||||||
|
data["source"] = {}
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _configure_agent(self):
|
||||||
|
"""Configure the agent based on request data"""
|
||||||
|
agent_id = self.data.get("agent_id")
|
||||||
|
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
|
||||||
|
agent_id, self.initial_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = self.data.get("api_key")
|
||||||
|
if api_key:
|
||||||
|
data_key = self._get_data_from_api_key(api_key)
|
||||||
|
self.agent_config.update(
|
||||||
|
{
|
||||||
|
"prompt_id": data_key.get("prompt_id", "default"),
|
||||||
|
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||||
|
"user_api_key": api_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.initial_user_id = data_key.get("user")
|
||||||
|
self.decoded_token = {"sub": data_key.get("user")}
|
||||||
|
elif self.agent_key:
|
||||||
|
data_key = self._get_data_from_api_key(self.agent_key)
|
||||||
|
self.agent_config.update(
|
||||||
|
{
|
||||||
|
"prompt_id": data_key.get("prompt_id", "default"),
|
||||||
|
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||||
|
"user_api_key": self.agent_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.decoded_token = (
|
||||||
|
self.decoded_token
|
||||||
|
if self.is_shared_usage
|
||||||
|
else {"sub": data_key.get("user")}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.agent_config.update(
|
||||||
|
{
|
||||||
|
"prompt_id": self.data.get("prompt_id", "default"),
|
||||||
|
"agent_type": settings.AGENT_NAME,
|
||||||
|
"user_api_key": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _configure_retriever(self):
|
||||||
|
"""Configure the retriever based on request data"""
|
||||||
|
self.retriever_config = {
|
||||||
|
"retriever_name": self.data.get("retriever", "classic"),
|
||||||
|
"chunks": int(self.data.get("chunks", 2)),
|
||||||
|
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
|
||||||
|
}
|
||||||
|
|
||||||
|
if "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||||
|
self.retriever_config["chunks"] = 0
|
||||||
|
|
||||||
|
def create_agent(self):
|
||||||
|
"""Create and return the configured agent"""
|
||||||
|
return AgentCreator.create_agent(
|
||||||
|
self.agent_config["agent_type"],
|
||||||
|
endpoint="stream",
|
||||||
|
llm_name=settings.LLM_PROVIDER,
|
||||||
|
gpt_model=self.gpt_model,
|
||||||
|
api_key=settings.API_KEY,
|
||||||
|
user_api_key=self.agent_config["user_api_key"],
|
||||||
|
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||||
|
chat_history=self.history,
|
||||||
|
decoded_token=self.decoded_token,
|
||||||
|
attachments=self.attachments,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_retriever(self):
|
||||||
|
"""Create and return the configured retriever"""
|
||||||
|
return RetrieverCreator.create_retriever(
|
||||||
|
self.retriever_config["retriever_name"],
|
||||||
|
source=self.source,
|
||||||
|
chat_history=self.history,
|
||||||
|
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||||
|
chunks=self.retriever_config["chunks"],
|
||||||
|
token_limit=self.retriever_config["token_limit"],
|
||||||
|
gpt_model=self.gpt_model,
|
||||||
|
user_api_key=self.agent_config["user_api_key"],
|
||||||
|
decoded_token=self.decoded_token,
|
||||||
|
)
|
||||||
@@ -34,7 +34,7 @@ from application.api.user.tasks import (
|
|||||||
)
|
)
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.extensions import api
|
from application.api import api
|
||||||
from application.storage.storage_creator import StorageCreator
|
from application.storage.storage_creator import StorageCreator
|
||||||
from application.tts.google_tts import GoogleTTS
|
from application.tts.google_tts import GoogleTTS
|
||||||
from application.utils import (
|
from application.utils import (
|
||||||
@@ -3538,14 +3538,18 @@ class StoreAttachment(Resource):
|
|||||||
"AttachmentModel",
|
"AttachmentModel",
|
||||||
{
|
{
|
||||||
"file": fields.Raw(required=True, description="File to upload"),
|
"file": fields.Raw(required=True, description="File to upload"),
|
||||||
|
"api_key": fields.String(
|
||||||
|
required=False, description="API key (optional)"
|
||||||
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@api.doc(description="Stores a single attachment without vectorization or training")
|
@api.doc(
|
||||||
|
description="Stores a single attachment without vectorization or training. Supports user or API key authentication."
|
||||||
|
)
|
||||||
def post(self):
|
def post(self):
|
||||||
decoded_token = request.decoded_token
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
if not decoded_token:
|
api_key = request.form.get("api_key") or request.args.get("api_key")
|
||||||
return make_response(jsonify({"success": False}), 401)
|
|
||||||
file = request.files.get("file")
|
file = request.files.get("file")
|
||||||
|
|
||||||
if not file or file.filename == "":
|
if not file or file.filename == "":
|
||||||
@@ -3553,7 +3557,21 @@ class StoreAttachment(Resource):
|
|||||||
jsonify({"status": "error", "message": "Missing file"}),
|
jsonify({"status": "error", "message": "Missing file"}),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
user = safe_filename(decoded_token.get("sub"))
|
|
||||||
|
user = None
|
||||||
|
if decoded_token:
|
||||||
|
user = safe_filename(decoded_token.get("sub"))
|
||||||
|
elif api_key:
|
||||||
|
agent = agents_collection.find_one({"key": api_key})
|
||||||
|
if not agent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid API key"}), 401
|
||||||
|
)
|
||||||
|
user = safe_filename(agent.get("user"))
|
||||||
|
else:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Authentication required"}), 401
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
attachment_id = ObjectId()
|
attachment_id = ObjectId()
|
||||||
|
|||||||
@@ -12,19 +12,18 @@ from application.core.logging_config import setup_logging
|
|||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
from application.api.answer.routes import answer # noqa: E402
|
from application.api import api # noqa: E402
|
||||||
|
from application.api.answer import answer # noqa: E402
|
||||||
from application.api.internal.routes import internal # noqa: E402
|
from application.api.internal.routes import internal # noqa: E402
|
||||||
from application.api.user.routes import user # noqa: E402
|
from application.api.user.routes import user # noqa: E402
|
||||||
from application.celery_init import celery # noqa: E402
|
from application.celery_init import celery # noqa: E402
|
||||||
from application.core.settings import settings # noqa: E402
|
from application.core.settings import settings # noqa: E402
|
||||||
from application.extensions import api # noqa: E402
|
|
||||||
|
|
||||||
|
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
pathlib.PosixPath = pathlib.WindowsPath
|
pathlib.PosixPath = pathlib.WindowsPath
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
@@ -52,7 +51,6 @@ if settings.AUTH_TYPE in ("simple_jwt", "session_jwt") and not settings.JWT_SECR
|
|||||||
settings.JWT_SECRET_KEY = new_key
|
settings.JWT_SECRET_KEY = new_key
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to setup JWT_SECRET_KEY: {e}")
|
raise RuntimeError(f"Failed to setup JWT_SECRET_KEY: {e}")
|
||||||
|
|
||||||
SIMPLE_JWT_TOKEN = None
|
SIMPLE_JWT_TOKEN = None
|
||||||
if settings.AUTH_TYPE == "simple_jwt":
|
if settings.AUTH_TYPE == "simple_jwt":
|
||||||
payload = {"sub": "local"}
|
payload = {"sub": "local"}
|
||||||
@@ -92,7 +90,6 @@ def generate_token():
|
|||||||
def authenticate_request():
|
def authenticate_request():
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
return "", 200
|
return "", 200
|
||||||
|
|
||||||
decoded_token = handle_auth(request)
|
decoded_token = handle_auth(request)
|
||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
request.decoded_token = None
|
request.decoded_token = None
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ current_dir = os.path.dirname(
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
AUTH_TYPE: Optional[str] = None
|
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
|
||||||
LLM_PROVIDER: str = "docsgpt"
|
LLM_PROVIDER: str = "docsgpt"
|
||||||
LLM_NAME: Optional[str] = (
|
LLM_NAME: Optional[str] = (
|
||||||
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
|
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
from flask_restx import Api
|
|
||||||
|
|
||||||
api = Api(
|
|
||||||
version="1.0",
|
|
||||||
title="DocsGPT API",
|
|
||||||
description="API for DocsGPT",
|
|
||||||
)
|
|
||||||
0
application/storage/__init__.py
Normal file
0
application/storage/__init__.py
Normal file
@@ -6,6 +6,7 @@ import uuid
|
|||||||
import tiktoken
|
import tiktoken
|
||||||
from flask import jsonify, make_response
|
from flask import jsonify, make_response
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
|
||||||
@@ -19,6 +20,17 @@ def get_encoding():
|
|||||||
return _encoding
|
return _encoding
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpt_model() -> str:
|
||||||
|
"""Get the appropriate GPT model based on provider"""
|
||||||
|
model_map = {
|
||||||
|
"openai": "gpt-4o-mini",
|
||||||
|
"anthropic": "claude-2",
|
||||||
|
"groq": "llama3-8b-8192",
|
||||||
|
"novita": "deepseek/deepseek-r1",
|
||||||
|
}
|
||||||
|
return settings.LLM_NAME or model_map.get(settings.LLM_PROVIDER, "")
|
||||||
|
|
||||||
|
|
||||||
def safe_filename(filename):
|
def safe_filename(filename):
|
||||||
"""
|
"""
|
||||||
Creates a safe filename that preserves the original extension.
|
Creates a safe filename that preserves the original extension.
|
||||||
@@ -32,15 +44,14 @@ def safe_filename(filename):
|
|||||||
"""
|
"""
|
||||||
if not filename:
|
if not filename:
|
||||||
return str(uuid.uuid4())
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
_, extension = os.path.splitext(filename)
|
_, extension = os.path.splitext(filename)
|
||||||
|
|
||||||
safe_name = secure_filename(filename)
|
safe_name = secure_filename(filename)
|
||||||
|
|
||||||
# If secure_filename returns just the extension or an empty string
|
# If secure_filename returns just the extension or an empty string
|
||||||
|
|
||||||
if not safe_name or safe_name == extension.lstrip("."):
|
if not safe_name or safe_name == extension.lstrip("."):
|
||||||
return f"{str(uuid.uuid4())}{extension}"
|
return f"{str(uuid.uuid4())}{extension}"
|
||||||
|
|
||||||
return safe_name
|
return safe_name
|
||||||
|
|
||||||
|
|
||||||
@@ -68,7 +79,6 @@ def count_tokens_docs(docs):
|
|||||||
docs_content = ""
|
docs_content = ""
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
docs_content += doc.page_content
|
docs_content += doc.page_content
|
||||||
|
|
||||||
tokens = num_tokens_from_string(docs_content)
|
tokens = num_tokens_from_string(docs_content)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
@@ -97,13 +107,11 @@ def validate_required_fields(data, required_fields):
|
|||||||
missing_fields.append(field)
|
missing_fields.append(field)
|
||||||
elif not data[field]:
|
elif not data[field]:
|
||||||
empty_fields.append(field)
|
empty_fields.append(field)
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
errors.append(f"Missing required fields: {', '.join(missing_fields)}")
|
errors.append(f"Missing required fields: {', '.join(missing_fields)}")
|
||||||
if empty_fields:
|
if empty_fields:
|
||||||
errors.append(f"Empty values in required fields: {', '.join(empty_fields)}")
|
errors.append(f"Empty values in required fields: {', '.join(empty_fields)}")
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": " | ".join(errors)}), 400
|
jsonify({"success": False, "message": " | ".join(errors)}), 400
|
||||||
@@ -132,7 +140,6 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
|||||||
|
|
||||||
if not history:
|
if not history:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
trimmed_history = []
|
trimmed_history = []
|
||||||
tokens_current_history = 0
|
tokens_current_history = 0
|
||||||
|
|
||||||
@@ -141,18 +148,15 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
|||||||
if "prompt" in message and "response" in message:
|
if "prompt" in message and "response" in message:
|
||||||
tokens_batch += num_tokens_from_string(message["prompt"])
|
tokens_batch += num_tokens_from_string(message["prompt"])
|
||||||
tokens_batch += num_tokens_from_string(message["response"])
|
tokens_batch += num_tokens_from_string(message["response"])
|
||||||
|
|
||||||
if "tool_calls" in message:
|
if "tool_calls" in message:
|
||||||
for tool_call in message["tool_calls"]:
|
for tool_call in message["tool_calls"]:
|
||||||
tool_call_string = f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}"
|
tool_call_string = f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}"
|
||||||
tokens_batch += num_tokens_from_string(tool_call_string)
|
tokens_batch += num_tokens_from_string(tool_call_string)
|
||||||
|
|
||||||
if tokens_current_history + tokens_batch < max_token_limit:
|
if tokens_current_history + tokens_batch < max_token_limit:
|
||||||
tokens_current_history += tokens_batch
|
tokens_current_history += tokens_batch
|
||||||
trimmed_history.insert(0, message)
|
trimmed_history.insert(0, message)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
return trimmed_history
|
return trimmed_history
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from bson.dbref import DBRef
|
|||||||
from bson.objectid import ObjectId
|
from bson.objectid import ObjectId
|
||||||
|
|
||||||
from application.agents.agent_creator import AgentCreator
|
from application.agents.agent_creator import AgentCreator
|
||||||
from application.api.answer.routes import get_prompt
|
from application.api.answer.services.stream_processor import get_prompt
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
@@ -35,17 +35,22 @@ db = mongo[settings.MONGO_DB_NAME]
|
|||||||
sources_collection = db["sources"]
|
sources_collection = db["sources"]
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
|
|
||||||
MIN_TOKENS = 150
|
MIN_TOKENS = 150
|
||||||
MAX_TOKENS = 1250
|
MAX_TOKENS = 1250
|
||||||
RECURSION_DEPTH = 2
|
RECURSION_DEPTH = 2
|
||||||
|
|
||||||
|
|
||||||
# Define a function to extract metadata from a given filename.
|
# Define a function to extract metadata from a given filename.
|
||||||
|
|
||||||
|
|
||||||
def metadata_from_filename(title):
|
def metadata_from_filename(title):
|
||||||
return {"title": title}
|
return {"title": title}
|
||||||
|
|
||||||
|
|
||||||
# Define a function to generate a random string of a given length.
|
# Define a function to generate a random string of a given length.
|
||||||
|
|
||||||
|
|
||||||
def generate_random_string(length):
|
def generate_random_string(length):
|
||||||
return "".join([string.ascii_letters[i % 52] for i in range(length)])
|
return "".join([string.ascii_letters[i % 52] for i in range(length)])
|
||||||
|
|
||||||
@@ -68,7 +73,6 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
|
|||||||
if current_depth > max_depth:
|
if current_depth > max_depth:
|
||||||
logging.warning(f"Reached maximum recursion depth of {max_depth}")
|
logging.warning(f"Reached maximum recursion depth of {max_depth}")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||||
zip_ref.extractall(extract_to)
|
zip_ref.extractall(extract_to)
|
||||||
@@ -76,12 +80,13 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error extracting zip file {zip_path}: {e}", exc_info=True)
|
logging.error(f"Error extracting zip file {zip_path}: {e}", exc_info=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check for nested zip files and extract them
|
# Check for nested zip files and extract them
|
||||||
|
|
||||||
for root, dirs, files in os.walk(extract_to):
|
for root, dirs, files in os.walk(extract_to):
|
||||||
for file in files:
|
for file in files:
|
||||||
if file.endswith(".zip"):
|
if file.endswith(".zip"):
|
||||||
# If a nested zip file is found, extract it recursively
|
# If a nested zip file is found, extract it recursively
|
||||||
|
|
||||||
file_path = os.path.join(root, file)
|
file_path = os.path.join(root, file)
|
||||||
extract_zip_recursive(file_path, root, current_depth + 1, max_depth)
|
extract_zip_recursive(file_path, root, current_depth + 1, max_depth)
|
||||||
|
|
||||||
@@ -139,7 +144,7 @@ def run_agent_logic(agent_config, input_data):
|
|||||||
user_api_key = agent_config["key"]
|
user_api_key = agent_config["key"]
|
||||||
agent_type = agent_config.get("agent_type", "classic")
|
agent_type = agent_config.get("agent_type", "classic")
|
||||||
decoded_token = {"sub": agent_config.get("user")}
|
decoded_token = {"sub": agent_config.get("user")}
|
||||||
prompt = get_prompt(prompt_id)
|
prompt = get_prompt(prompt_id, db["prompts"])
|
||||||
agent = AgentCreator.create_agent(
|
agent = AgentCreator.create_agent(
|
||||||
agent_type,
|
agent_type,
|
||||||
endpoint="webhook",
|
endpoint="webhook",
|
||||||
@@ -178,7 +183,6 @@ def run_agent_logic(agent_config, input_data):
|
|||||||
tool_calls.extend(line["tool_calls"])
|
tool_calls.extend(line["tool_calls"])
|
||||||
elif "thought" in line:
|
elif "thought" in line:
|
||||||
thought += line["thought"]
|
thought += line["thought"]
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"answer": response_full,
|
"answer": response_full,
|
||||||
"sources": source_log_docs,
|
"sources": source_log_docs,
|
||||||
@@ -193,8 +197,18 @@ def run_agent_logic(agent_config, input_data):
|
|||||||
|
|
||||||
|
|
||||||
# Define the main function for ingesting and processing documents.
|
# Define the main function for ingesting and processing documents.
|
||||||
|
|
||||||
|
|
||||||
def ingest_worker(
|
def ingest_worker(
|
||||||
self, directory, formats, job_name, filename, user, dir_name=None, user_dir=None, retriever="classic"
|
self,
|
||||||
|
directory,
|
||||||
|
formats,
|
||||||
|
job_name,
|
||||||
|
filename,
|
||||||
|
user,
|
||||||
|
dir_name=None,
|
||||||
|
user_dir=None,
|
||||||
|
retriever="classic",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Ingest and process documents.
|
Ingest and process documents.
|
||||||
@@ -227,29 +241,29 @@ def ingest_worker(
|
|||||||
logging.info(f"Ingest file: {full_path}", extra={"user": user, "job": job_name})
|
logging.info(f"Ingest file: {full_path}", extra={"user": user, "job": job_name})
|
||||||
|
|
||||||
# Create temporary working directory
|
# Create temporary working directory
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
try:
|
try:
|
||||||
os.makedirs(temp_dir, exist_ok=True)
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
|
||||||
# Download file from storage to temp directory
|
# Download file from storage to temp directory
|
||||||
|
|
||||||
temp_file_path = os.path.join(temp_dir, filename)
|
temp_file_path = os.path.join(temp_dir, filename)
|
||||||
file_data = storage.get_file(source_file_path)
|
file_data = storage.get_file(source_file_path)
|
||||||
|
|
||||||
with open(temp_file_path, "wb") as f:
|
with open(temp_file_path, "wb") as f:
|
||||||
f.write(file_data.read())
|
f.write(file_data.read())
|
||||||
|
|
||||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||||
|
|
||||||
# Handle zip files
|
# Handle zip files
|
||||||
|
|
||||||
if filename.endswith(".zip"):
|
if filename.endswith(".zip"):
|
||||||
logging.info(f"Extracting zip file: {filename}")
|
logging.info(f"Extracting zip file: {filename}")
|
||||||
extract_zip_recursive(
|
extract_zip_recursive(
|
||||||
temp_file_path, temp_dir, current_depth=0, max_depth=RECURSION_DEPTH
|
temp_file_path, temp_dir, current_depth=0, max_depth=RECURSION_DEPTH
|
||||||
)
|
)
|
||||||
|
|
||||||
if sample:
|
if sample:
|
||||||
logging.info(f"Sample mode enabled. Using {limit} documents.")
|
logging.info(f"Sample mode enabled. Using {limit} documents.")
|
||||||
|
|
||||||
reader = SimpleDirectoryReader(
|
reader = SimpleDirectoryReader(
|
||||||
input_dir=temp_dir,
|
input_dir=temp_dir,
|
||||||
input_files=input_files,
|
input_files=input_files,
|
||||||
@@ -296,11 +310,9 @@ def ingest_worker(
|
|||||||
}
|
}
|
||||||
|
|
||||||
upload_index(vector_store_path, file_data)
|
upload_index(vector_store_path, file_data)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in ingest_worker: {e}", exc_info=True)
|
logging.error(f"Error in ingest_worker: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"directory": directory,
|
"directory": directory,
|
||||||
"formats": formats,
|
"formats": formats,
|
||||||
@@ -326,7 +338,6 @@ def remote_worker(
|
|||||||
full_path = os.path.join(directory, user, name_job)
|
full_path = os.path.join(directory, user, name_job)
|
||||||
if not os.path.exists(full_path):
|
if not os.path.exists(full_path):
|
||||||
os.makedirs(full_path)
|
os.makedirs(full_path)
|
||||||
|
|
||||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||||
try:
|
try:
|
||||||
logging.info("Initializing remote loader with type: %s", loader)
|
logging.info("Initializing remote loader with type: %s", loader)
|
||||||
@@ -353,7 +364,6 @@ def remote_worker(
|
|||||||
raise ValueError("doc_id must be provided for sync operation.")
|
raise ValueError("doc_id must be provided for sync operation.")
|
||||||
id = ObjectId(doc_id)
|
id = ObjectId(doc_id)
|
||||||
embed_and_store_documents(docs, full_path, id, self)
|
embed_and_store_documents(docs, full_path, id, self)
|
||||||
|
|
||||||
self.update_state(state="PROGRESS", meta={"current": 100})
|
self.update_state(state="PROGRESS", meta={"current": 100})
|
||||||
|
|
||||||
file_data = {
|
file_data = {
|
||||||
@@ -367,15 +377,12 @@ def remote_worker(
|
|||||||
"sync_frequency": sync_frequency,
|
"sync_frequency": sync_frequency,
|
||||||
}
|
}
|
||||||
upload_index(full_path, file_data)
|
upload_index(full_path, file_data)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Error in remote_worker task: %s", str(e), exc_info=True)
|
logging.error("Error in remote_worker task: %s", str(e), exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(full_path):
|
if os.path.exists(full_path):
|
||||||
shutil.rmtree(full_path)
|
shutil.rmtree(full_path)
|
||||||
|
|
||||||
logging.info("remote_worker task completed successfully")
|
logging.info("remote_worker task completed successfully")
|
||||||
return {"urls": source_data, "name_job": name_job, "user": user, "limited": False}
|
return {"urls": source_data, "name_job": name_job, "user": user, "limited": False}
|
||||||
|
|
||||||
@@ -428,7 +435,6 @@ def sync_worker(self, frequency):
|
|||||||
sync_counts[
|
sync_counts[
|
||||||
"sync_success" if resp["status"] == "success" else "sync_failure"
|
"sync_success" if resp["status"] == "success" else "sync_failure"
|
||||||
] += 1
|
] += 1
|
||||||
|
|
||||||
return {
|
return {
|
||||||
key: sync_counts[key]
|
key: sync_counts[key]
|
||||||
for key in ["total_sync_count", "sync_success", "sync_failure"]
|
for key in ["total_sync_count", "sync_success", "sync_failure"]
|
||||||
@@ -503,7 +509,6 @@ def attachment_worker(self, file_info, user):
|
|||||||
"mime_type": mime_type,
|
"mime_type": mime_type,
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
logging.error(
|
||||||
f"Error processing file {filename}: {e}",
|
f"Error processing file {filename}: {e}",
|
||||||
@@ -539,7 +544,6 @@ def agent_webhook_worker(self, agent_id, payload):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error processing agent webhook: {e}", exc_info=True)
|
logging.error(f"Error processing agent webhook: {e}", exc_info=True)
|
||||||
return {"status": "error", "error": str(e)}
|
return {"status": "error", "error": str(e)}
|
||||||
|
|
||||||
self.update_state(state="PROGRESS", meta={"current": 50})
|
self.update_state(state="PROGRESS", meta={"current": 50})
|
||||||
try:
|
try:
|
||||||
result = run_agent_logic(agent_config, input_data)
|
result = run_agent_logic(agent_config, input_data)
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ export function handleFetchAnswer(
|
|||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
token: string | null,
|
token: string | null,
|
||||||
selectedDocs: Doc | null,
|
selectedDocs: Doc | null,
|
||||||
history: Array<any> = [],
|
|
||||||
conversationId: string | null,
|
conversationId: string | null,
|
||||||
promptId: string | null,
|
promptId: string | null,
|
||||||
chunks: string,
|
chunks: string,
|
||||||
@@ -37,16 +36,8 @@ export function handleFetchAnswer(
|
|||||||
title: any;
|
title: any;
|
||||||
}
|
}
|
||||||
> {
|
> {
|
||||||
history = history.map((item) => {
|
|
||||||
return {
|
|
||||||
prompt: item.prompt,
|
|
||||||
response: item.response,
|
|
||||||
tool_calls: item.tool_calls,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
const payload: RetrievalPayload = {
|
const payload: RetrievalPayload = {
|
||||||
question: question,
|
question: question,
|
||||||
history: JSON.stringify(history),
|
|
||||||
conversation_id: conversationId,
|
conversation_id: conversationId,
|
||||||
prompt_id: promptId,
|
prompt_id: promptId,
|
||||||
chunks: chunks,
|
chunks: chunks,
|
||||||
@@ -94,7 +85,6 @@ export function handleFetchAnswerSteaming(
|
|||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
token: string | null,
|
token: string | null,
|
||||||
selectedDocs: Doc | null,
|
selectedDocs: Doc | null,
|
||||||
history: Array<any> = [],
|
|
||||||
conversationId: string | null,
|
conversationId: string | null,
|
||||||
promptId: string | null,
|
promptId: string | null,
|
||||||
chunks: string,
|
chunks: string,
|
||||||
@@ -105,17 +95,8 @@ export function handleFetchAnswerSteaming(
|
|||||||
attachments?: string[],
|
attachments?: string[],
|
||||||
save_conversation = true,
|
save_conversation = true,
|
||||||
): Promise<Answer> {
|
): Promise<Answer> {
|
||||||
history = history.map((item) => {
|
|
||||||
return {
|
|
||||||
prompt: item.prompt,
|
|
||||||
response: item.response,
|
|
||||||
thought: item.thought,
|
|
||||||
tool_calls: item.tool_calls,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
const payload: RetrievalPayload = {
|
const payload: RetrievalPayload = {
|
||||||
question: question,
|
question: question,
|
||||||
history: JSON.stringify(history),
|
|
||||||
conversation_id: conversationId,
|
conversation_id: conversationId,
|
||||||
prompt_id: promptId,
|
prompt_id: promptId,
|
||||||
chunks: chunks,
|
chunks: chunks,
|
||||||
@@ -192,20 +173,11 @@ export function handleSearch(
|
|||||||
token: string | null,
|
token: string | null,
|
||||||
selectedDocs: Doc | null,
|
selectedDocs: Doc | null,
|
||||||
conversation_id: string | null,
|
conversation_id: string | null,
|
||||||
history: Array<any> = [],
|
|
||||||
chunks: string,
|
chunks: string,
|
||||||
token_limit: number,
|
token_limit: number,
|
||||||
) {
|
) {
|
||||||
history = history.map((item) => {
|
|
||||||
return {
|
|
||||||
prompt: item.prompt,
|
|
||||||
response: item.response,
|
|
||||||
tool_calls: item.tool_calls,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
const payload: RetrievalPayload = {
|
const payload: RetrievalPayload = {
|
||||||
question: question,
|
question: question,
|
||||||
history: JSON.stringify(history),
|
|
||||||
conversation_id: conversation_id,
|
conversation_id: conversation_id,
|
||||||
chunks: chunks,
|
chunks: chunks,
|
||||||
token_limit: token_limit,
|
token_limit: token_limit,
|
||||||
|
|||||||
@@ -52,7 +52,6 @@ export interface RetrievalPayload {
|
|||||||
question: string;
|
question: string;
|
||||||
active_docs?: string;
|
active_docs?: string;
|
||||||
retriever?: string;
|
retriever?: string;
|
||||||
history: string;
|
|
||||||
conversation_id: string | null;
|
conversation_id: string | null;
|
||||||
prompt_id?: string | null;
|
prompt_id?: string | null;
|
||||||
chunks: string;
|
chunks: string;
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ export const fetchAnswer = createAsyncThunk<
|
|||||||
signal,
|
signal,
|
||||||
state.preference.token,
|
state.preference.token,
|
||||||
state.preference.selectedDocs!,
|
state.preference.selectedDocs!,
|
||||||
state.conversation.queries,
|
|
||||||
currentConversationId,
|
currentConversationId,
|
||||||
state.preference.prompt.id,
|
state.preference.prompt.id,
|
||||||
state.preference.chunks,
|
state.preference.chunks,
|
||||||
@@ -153,7 +152,6 @@ export const fetchAnswer = createAsyncThunk<
|
|||||||
signal,
|
signal,
|
||||||
state.preference.token,
|
state.preference.token,
|
||||||
state.preference.selectedDocs!,
|
state.preference.selectedDocs!,
|
||||||
state.conversation.queries,
|
|
||||||
state.conversation.conversationId,
|
state.conversation.conversationId,
|
||||||
state.preference.prompt.id,
|
state.preference.prompt.id,
|
||||||
state.preference.chunks,
|
state.preference.chunks,
|
||||||
|
|||||||
Reference in New Issue
Block a user