diff --git a/application/api/__init__.py b/application/api/__init__.py index e69de29b..b6f52893 100644 --- a/application/api/__init__.py +++ b/application/api/__init__.py @@ -0,0 +1,7 @@ +from flask_restx import Api + +api = Api( + version="1.0", + title="DocsGPT API", + description="API for DocsGPT", +) diff --git a/application/api/answer/__init__.py b/application/api/answer/__init__.py index e69de29b..861c922d 100644 --- a/application/api/answer/__init__.py +++ b/application/api/answer/__init__.py @@ -0,0 +1,19 @@ +from flask import Blueprint + +from application.api import api +from application.api.answer.routes.answer import AnswerResource +from application.api.answer.routes.base import answer_ns +from application.api.answer.routes.stream import StreamResource + + +answer = Blueprint("answer", __name__) + +api.add_namespace(answer_ns) + + +def init_answer_routes(): + api.add_resource(StreamResource, "/stream") + api.add_resource(AnswerResource, "/api/answer") + + +init_answer_routes() diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py deleted file mode 100644 index c9d5bf53..00000000 --- a/application/api/answer/routes.py +++ /dev/null @@ -1,914 +0,0 @@ -import asyncio -import datetime -import json -import logging -import os -import traceback - -from bson.dbref import DBRef -from bson.objectid import ObjectId -from flask import Blueprint, make_response, request, Response -from flask_restx import fields, Namespace, Resource - -from application.agents.agent_creator import AgentCreator - -from application.core.mongo_db import MongoDB -from application.core.settings import settings -from application.error import bad_request -from application.extensions import api -from application.llm.llm_creator import LLMCreator -from application.retriever.retriever_creator import RetrieverCreator -from application.utils import check_required_fields, limit_chat_history - -logger = logging.getLogger(__name__) - -mongo = MongoDB.get_client() -db = mongo[settings.MONGO_DB_NAME] -conversations_collection = db["conversations"] -sources_collection = db["sources"] -prompts_collection = db["prompts"] -agents_collection = db["agents"] -user_logs_collection = db["user_logs"] -attachments_collection = db["attachments"] - -answer = Blueprint("answer", __name__) -answer_ns = Namespace("answer", description="Answer related operations", path="/") -api.add_namespace(answer_ns) - -gpt_model = "" -# to have some kind of default behaviour -if settings.LLM_PROVIDER == "openai": - gpt_model = "gpt-4o-mini" -elif settings.LLM_PROVIDER == "anthropic": - gpt_model = "claude-2" -elif settings.LLM_PROVIDER == "groq": - gpt_model = "llama3-8b-8192" -elif settings.LLM_PROVIDER == "novita": - gpt_model = "deepseek/deepseek-r1" - -if settings.LLM_NAME: # in case there is particular model name configured - gpt_model = settings.LLM_NAME - -# load the prompts -current_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) -with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: - chat_combine_template = f.read() - -with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f: - chat_reduce_template = f.read() - -with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f: - chat_combine_creative = f.read() - -with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f: - chat_combine_strict = f.read() - -api_key_set = settings.API_KEY is not None -embeddings_key_set = settings.EMBEDDINGS_KEY is not None - - -async def async_generate(chain, question, chat_history): - result = await chain.arun({"question": question, "chat_history": chat_history}) - return result - - -def run_async_chain(chain, question, chat_history): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - result = {} - try: - answer = loop.run_until_complete(async_generate(chain, question, chat_history)) - finally: - loop.close() - result["answer"] = answer - return result - - -def get_agent_key(agent_id, user_id): - if not agent_id: - return None, False, None - - try: - agent = agents_collection.find_one({"_id": ObjectId(agent_id)}) - if agent is None: - raise Exception("Agent not found", 404) - - is_owner = agent.get("user") == user_id - - if is_owner: - agents_collection.update_one( - {"_id": ObjectId(agent_id)}, - {"$set": {"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)}}, - ) - return str(agent["key"]), False, None - - is_shared_with_user = agent.get( - "shared_publicly", False - ) or user_id in agent.get("shared_with", []) - - if is_shared_with_user: - return str(agent["key"]), True, agent.get("shared_token") - - raise Exception("Unauthorized access to the agent", 403) - - except Exception as e: - logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True) - raise - - -def get_data_from_api_key(api_key): - data = agents_collection.find_one({"key": api_key}) - if not data: - raise Exception("Invalid API Key, please generate a new key", 401) - - source = data.get("source") - if isinstance(source, DBRef): - source_doc = db.dereference(source) - data["source"] = str(source_doc["_id"]) - data["retriever"] = source_doc.get("retriever", data.get("retriever")) - else: - data["source"] = {} - - return data - - -def get_retriever(source_id: str): - doc = sources_collection.find_one({"_id": ObjectId(source_id)}) - if doc is None: - raise Exception("Source document does not exist", 404) - retriever_name = None if "retriever" not in doc else doc["retriever"] - return retriever_name - - -def is_azure_configured(): - return ( - settings.OPENAI_API_BASE - and settings.OPENAI_API_VERSION - and settings.AZURE_DEPLOYMENT_NAME - ) - - -def save_conversation( - conversation_id, - question, - response, - thought, - source_log_docs, - tool_calls, - llm, - decoded_token, - index=None, - api_key=None, - agent_id=None, - is_shared_usage=False, - shared_token=None, - attachment_ids=None, -): - current_time = datetime.datetime.now(datetime.timezone.utc) - if conversation_id is not None and index is not None: - conversations_collection.update_one( - {"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}}, - { - "$set": { - f"queries.{index}.prompt": question, - f"queries.{index}.response": response, - f"queries.{index}.thought": thought, - f"queries.{index}.sources": source_log_docs, - f"queries.{index}.tool_calls": tool_calls, - f"queries.{index}.timestamp": current_time, - f"queries.{index}.attachments": attachment_ids, - } - }, - ) - ##remove following queries from the array - conversations_collection.update_one( - {"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}}, - {"$push": {"queries": {"$each": [], "$slice": index + 1}}}, - ) - elif conversation_id is not None and conversation_id != "None": - conversations_collection.update_one( - {"_id": ObjectId(conversation_id)}, - { - "$push": { - "queries": { - "prompt": question, - "response": response, - "thought": thought, - "sources": source_log_docs, - "tool_calls": tool_calls, - "timestamp": current_time, - "attachments": attachment_ids, - } - } - }, - ) - - else: - # create new conversation - # generate summary - messages_summary = [ - { - "role": "assistant", - "content": "Summarise following conversation in no more than 3 " - "words, respond ONLY with the summary, use the same " - "language as the user query", - }, - { - "role": "user", - "content": "Summarise following conversation in no more than 3 words, " - "respond ONLY with the summary, use the same language as the " - "user query \n\nUser: " + question + "\n\n" + "AI: " + response, - }, - ] - - completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30) - conversation_data = { - "user": decoded_token.get("sub"), - "date": datetime.datetime.utcnow(), - "name": completion, - "queries": [ - { - "prompt": question, - "response": response, - "thought": thought, - "sources": source_log_docs, - "tool_calls": tool_calls, - "timestamp": current_time, - "attachments": attachment_ids, - } - ], - } - if api_key: - if agent_id: - conversation_data["agent_id"] = agent_id - if is_shared_usage: - conversation_data["is_shared_usage"] = is_shared_usage - conversation_data["shared_token"] = shared_token - api_key_doc = agents_collection.find_one({"key": api_key}) - if api_key_doc: - conversation_data["api_key"] = api_key_doc["key"] - conversation_id = conversations_collection.insert_one( - conversation_data - ).inserted_id - return conversation_id - - -def get_prompt(prompt_id): - if prompt_id == "default": - prompt = chat_combine_template - elif prompt_id == "creative": - prompt = chat_combine_creative - elif prompt_id == "strict": - prompt = chat_combine_strict - else: - prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] - return prompt - - -def complete_stream( - question, - agent, - retriever, - conversation_id, - user_api_key, - decoded_token, - isNoneDoc=False, - index=None, - should_save_conversation=True, - attachment_ids=None, - agent_id=None, - is_shared_usage=False, - shared_token=None, -): - try: - response_full, thought, source_log_docs, tool_calls = "", "", [], [] - - answer = agent.gen(query=question, retriever=retriever) - - for line in answer: - if "answer" in line: - response_full += str(line["answer"]) - data = json.dumps({"type": "answer", "answer": line["answer"]}) - yield f"data: {data}\n\n" - elif "sources" in line: - truncated_sources = [] - source_log_docs = line["sources"] - for source in line["sources"]: - truncated_source = source.copy() - if "text" in truncated_source: - truncated_source["text"] = ( - truncated_source["text"][:100].strip() + "..." - ) - truncated_sources.append(truncated_source) - if len(truncated_sources) > 0: - data = json.dumps({"type": "source", "source": truncated_sources}) - yield f"data: {data}\n\n" - elif "tool_calls" in line: - tool_calls = line["tool_calls"] - elif "thought" in line: - thought += line["thought"] - data = json.dumps({"type": "thought", "thought": line["thought"]}) - yield f"data: {data}\n\n" - elif "type" in line: - data = json.dumps(line) - yield f"data: {data}\n\n" - - if isNoneDoc: - for doc in source_log_docs: - doc["source"] = "None" - - llm = LLMCreator.create_llm( - settings.LLM_PROVIDER, - api_key=settings.API_KEY, - user_api_key=user_api_key, - decoded_token=decoded_token, - ) - - if should_save_conversation: - conversation_id = save_conversation( - conversation_id, - question, - response_full, - thought, - source_log_docs, - tool_calls, - llm, - decoded_token, - index, - api_key=user_api_key, - attachment_ids=attachment_ids, - agent_id=agent_id, - is_shared_usage=is_shared_usage, - shared_token=shared_token, - ) - else: - conversation_id = None - - # send data.type = "end" to indicate that the stream has ended as json - data = json.dumps({"type": "id", "id": str(conversation_id)}) - yield f"data: {data}\n\n" - - retriever_params = retriever.get_params() - user_logs_collection.insert_one( - { - "action": "stream_answer", - "level": "info", - "user": decoded_token.get("sub"), - "api_key": user_api_key, - "question": question, - "response": response_full, - "sources": source_log_docs, - "retriever_params": retriever_params, - "attachments": attachment_ids, - "timestamp": datetime.datetime.now(datetime.timezone.utc), - } - ) - data = json.dumps({"type": "end"}) - yield f"data: {data}\n\n" - except Exception as e: - logger.error(f"Error in stream: {str(e)}", exc_info=True) - data = json.dumps( - { - "type": "error", - "error": "Please try again later. We apologize for any inconvenience.", - } - ) - yield f"data: {data}\n\n" - return - - -@answer_ns.route("/stream") -class Stream(Resource): - stream_model = api.model( - "StreamModel", - { - "question": fields.String( - required=True, description="Question to be asked" - ), - "history": fields.List( - fields.String, required=False, description="Chat history" - ), - "conversation_id": fields.String( - required=False, description="Conversation ID" - ), - "prompt_id": fields.String( - required=False, default="default", description="Prompt ID" - ), - "chunks": fields.Integer( - required=False, default=2, description="Number of chunks" - ), - "token_limit": fields.Integer(required=False, description="Token limit"), - "retriever": fields.String(required=False, description="Retriever type"), - "api_key": fields.String(required=False, description="API key"), - "active_docs": fields.String( - required=False, description="Active documents" - ), - "isNoneDoc": fields.Boolean( - required=False, description="Flag indicating if no document is used" - ), - "index": fields.Integer( - required=False, description="Index of the query to update" - ), - "save_conversation": fields.Boolean( - required=False, - default=True, - description="Whether to save the conversation", - ), - "attachments": fields.List( - fields.String, required=False, description="List of attachment IDs" - ), - }, - ) - - @api.expect(stream_model) - @api.doc(description="Stream a response based on the question and retriever") - def post(self): - data = request.get_json() - required_fields = ["question"] - if "index" in data: - required_fields = ["question", "conversation_id"] - missing_fields = check_required_fields(data, required_fields) - if missing_fields: - return missing_fields - - save_conv = data.get("save_conversation", True) - - try: - question = data["question"] - history = limit_chat_history( - json.loads(data.get("history", "[]")), gpt_model=gpt_model - ) - conversation_id = data.get("conversation_id") - prompt_id = data.get("prompt_id", "default") - attachment_ids = data.get("attachments", []) - - index = data.get("index", None) - chunks = int(data.get("chunks", 2)) - token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY) - retriever_name = data.get("retriever", "classic") - agent_id = data.get("agent_id", None) - agent_type = settings.AGENT_NAME - decoded_token = getattr(request, "decoded_token", None) - user_sub = decoded_token.get("sub") if decoded_token else None - agent_key, is_shared_usage, shared_token = get_agent_key(agent_id, user_sub) - - if agent_key: - data.update({"api_key": agent_key}) - else: - agent_id = None - - if "api_key" in data: - data_key = get_data_from_api_key(data["api_key"]) - chunks = int(data_key.get("chunks", 2)) - prompt_id = data_key.get("prompt_id", "default") - source = {"active_docs": data_key.get("source")} - retriever_name = data_key.get("retriever", retriever_name) - user_api_key = data["api_key"] - agent_type = data_key.get("agent_type", agent_type) - if is_shared_usage: - decoded_token = request.decoded_token - else: - decoded_token = {"sub": data_key.get("user")} - is_shared_usage = False - - elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} - retriever_name = get_retriever(data["active_docs"]) or retriever_name - user_api_key = None - decoded_token = request.decoded_token - - else: - source = {} - user_api_key = None - decoded_token = request.decoded_token - - if not decoded_token: - return make_response({"error": "Unauthorized"}, 401) - - attachments = get_attachments_content( - attachment_ids, decoded_token.get("sub") - ) - - logger.info( - f"/stream - request_data: {data}, source: {source}, attachments: {len(attachments)}", - extra={"data": json.dumps({"request_data": data, "source": source})}, - ) - - prompt = get_prompt(prompt_id) - if "isNoneDoc" in data and data["isNoneDoc"] is True: - chunks = 0 - - agent = AgentCreator.create_agent( - agent_type, - endpoint="stream", - llm_name=settings.LLM_PROVIDER, - gpt_model=gpt_model, - api_key=settings.API_KEY, - user_api_key=user_api_key, - prompt=prompt, - chat_history=history, - decoded_token=decoded_token, - attachments=attachments, - ) - - retriever = RetrieverCreator.create_retriever( - retriever_name, - source=source, - chat_history=history, - prompt=prompt, - chunks=chunks, - token_limit=token_limit, - gpt_model=gpt_model, - user_api_key=user_api_key, - decoded_token=decoded_token, - ) - - return Response( - complete_stream( - question=question, - agent=agent, - retriever=retriever, - conversation_id=conversation_id, - user_api_key=user_api_key, - decoded_token=decoded_token, - isNoneDoc=data.get("isNoneDoc"), - index=index, - should_save_conversation=save_conv, - attachment_ids=attachment_ids, - agent_id=agent_id, - is_shared_usage=is_shared_usage, - shared_token=shared_token, - ), - mimetype="text/event-stream", - ) - - except ValueError: - message = "Malformed request body" - logger.error(f"/stream - error: {message}") - return Response( - error_stream_generate(message), - status=400, - mimetype="text/event-stream", - ) - except Exception as e: - logger.error( - f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}", - extra={"error": str(e), "traceback": traceback.format_exc()}, - ) - status_code = 400 - return Response( - error_stream_generate("Unknown error occurred"), - status=status_code, - mimetype="text/event-stream", - ) - - -def error_stream_generate(err_response): - data = json.dumps({"type": "error", "error": err_response}) - yield f"data: {data}\n\n" - - -@answer_ns.route("/api/answer") -class Answer(Resource): - answer_model = api.model( - "AnswerModel", - { - "question": fields.String( - required=True, description="The question to answer" - ), - "history": fields.List( - fields.String, required=False, description="Conversation history" - ), - "conversation_id": fields.String( - required=False, description="Conversation ID" - ), - "prompt_id": fields.String( - required=False, default="default", description="Prompt ID" - ), - "chunks": fields.Integer( - required=False, default=2, description="Number of chunks" - ), - "token_limit": fields.Integer(required=False, description="Token limit"), - "retriever": fields.String(required=False, description="Retriever type"), - "api_key": fields.String(required=False, description="API key"), - "active_docs": fields.String( - required=False, description="Active documents" - ), - "isNoneDoc": fields.Boolean( - required=False, description="Flag indicating if no document is used" - ), - }, - ) - - @api.expect(answer_model) - @api.doc(description="Provide an answer based on the question and retriever") - def post(self): - data = request.get_json() - required_fields = ["question"] - missing_fields = check_required_fields(data, required_fields) - if missing_fields: - return missing_fields - - try: - question = data["question"] - history = limit_chat_history( - json.loads(data.get("history", "[]")), gpt_model=gpt_model - ) - conversation_id = data.get("conversation_id") - prompt_id = data.get("prompt_id", "default") - chunks = int(data.get("chunks", 2)) - token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY) - retriever_name = data.get("retriever", "classic") - agent_type = settings.AGENT_NAME - - if "api_key" in data: - data_key = get_data_from_api_key(data["api_key"]) - chunks = int(data_key.get("chunks", 2)) - prompt_id = data_key.get("prompt_id", "default") - source = {"active_docs": data_key.get("source")} - retriever_name = data_key.get("retriever", retriever_name) - user_api_key = data["api_key"] - agent_type = data_key.get("agent_type", agent_type) - decoded_token = {"sub": data_key.get("user")} - - elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} - retriever_name = get_retriever(data["active_docs"]) or retriever_name - user_api_key = None - decoded_token = request.decoded_token - - else: - source = {} - user_api_key = None - decoded_token = request.decoded_token - - if not decoded_token: - return make_response({"error": "Unauthorized"}, 401) - - prompt = get_prompt(prompt_id) - - logger.info( - f"/api/answer - request_data: {data}, source: {source}", - extra={"data": json.dumps({"request_data": data, "source": source})}, - ) - - agent = AgentCreator.create_agent( - agent_type, - endpoint="api/answer", - llm_name=settings.LLM_PROVIDER, - gpt_model=gpt_model, - api_key=settings.API_KEY, - user_api_key=user_api_key, - prompt=prompt, - chat_history=history, - decoded_token=decoded_token, - ) - - retriever = RetrieverCreator.create_retriever( - retriever_name, - source=source, - chat_history=history, - prompt=prompt, - chunks=chunks, - token_limit=token_limit, - gpt_model=gpt_model, - user_api_key=user_api_key, - decoded_token=decoded_token, - ) - - response_full = "" - source_log_docs = [] - tool_calls = [] - stream_ended = False - thought = "" - - for line in complete_stream( - question=question, - agent=agent, - retriever=retriever, - conversation_id=conversation_id, - user_api_key=user_api_key, - decoded_token=decoded_token, - isNoneDoc=data.get("isNoneDoc"), - index=None, - should_save_conversation=False, - ): - try: - event_data = line.replace("data: ", "").strip() - event = json.loads(event_data) - - if event["type"] == "answer": - response_full += event["answer"] - elif event["type"] == "source": - source_log_docs = event["source"] - elif event["type"] == "tool_calls": - tool_calls = event["tool_calls"] - elif event["type"] == "thought": - thought = event["thought"] - elif event["type"] == "error": - logger.error(f"Error from stream: {event['error']}") - return bad_request(500, event["error"]) - elif event["type"] == "end": - stream_ended = True - - except (json.JSONDecodeError, KeyError) as e: - logger.warning(f"Error parsing stream event: {e}, line: {line}") - continue - - if not stream_ended: - logger.error("Stream ended unexpectedly without an 'end' event.") - return bad_request(500, "Stream ended unexpectedly.") - - if data.get("isNoneDoc"): - for doc in source_log_docs: - doc["source"] = "None" - - llm = LLMCreator.create_llm( - settings.LLM_PROVIDER, - api_key=settings.API_KEY, - user_api_key=user_api_key, - decoded_token=decoded_token, - ) - - result = {"answer": response_full, "sources": source_log_docs} - result["conversation_id"] = str( - save_conversation( - conversation_id, - question, - response_full, - thought, - source_log_docs, - tool_calls, - llm, - decoded_token, - api_key=user_api_key, - ) - ) - - retriever_params = retriever.get_params() - user_logs_collection.insert_one( - { - "action": "api_answer", - "level": "info", - "user": decoded_token.get("sub"), - "api_key": user_api_key, - "question": question, - "response": response_full, - "sources": source_log_docs, - "retriever_params": retriever_params, - "timestamp": datetime.datetime.now(datetime.timezone.utc), - } - ) - - except Exception as e: - logger.error( - f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}", - extra={"error": str(e), "traceback": traceback.format_exc()}, - ) - return bad_request(500, str(e)) - - return make_response(result, 200) - - -@answer_ns.route("/api/search") -class Search(Resource): - search_model = api.model( - "SearchModel", - { - "question": fields.String( - required=True, description="The question to search" - ), - "chunks": fields.Integer( - required=False, default=2, description="Number of chunks" - ), - "api_key": fields.String( - required=False, description="API key for authentication" - ), - "active_docs": fields.String( - required=False, description="Active documents for retrieval" - ), - "retriever": fields.String(required=False, description="Retriever type"), - "token_limit": fields.Integer( - required=False, description="Limit for tokens" - ), - "isNoneDoc": fields.Boolean( - required=False, description="Flag indicating if no document is used" - ), - }, - ) - - @api.expect(search_model) - @api.doc( - description="Search for relevant documents based on the question and retriever" - ) - def post(self): - data = request.get_json() - required_fields = ["question"] - missing_fields = check_required_fields(data, required_fields) - if missing_fields: - return missing_fields - - try: - question = data["question"] - chunks = int(data.get("chunks", 2)) - token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY) - retriever_name = data.get("retriever", "classic") - - if "api_key" in data: - data_key = get_data_from_api_key(data["api_key"]) - chunks = int(data_key.get("chunks", 2)) - source = {"active_docs": data_key.get("source")} - user_api_key = data["api_key"] - decoded_token = {"sub": data_key.get("user")} - - elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} - user_api_key = None - decoded_token = request.decoded_token - - else: - source = {} - user_api_key = None - decoded_token = request.decoded_token - - if not decoded_token: - return make_response({"error": "Unauthorized"}, 401) - - logger.info( - f"/api/answer - request_data: {data}, source: {source}", - extra={"data": json.dumps({"request_data": data, "source": source})}, - ) - - retriever = RetrieverCreator.create_retriever( - retriever_name, - source=source, - chat_history=[], - prompt="default", - chunks=chunks, - token_limit=token_limit, - gpt_model=gpt_model, - user_api_key=user_api_key, - decoded_token=decoded_token, - ) - - docs = retriever.search(question) - retriever_params = retriever.get_params() - - user_logs_collection.insert_one( - { - "action": "api_search", - "level": "info", - "user": decoded_token.get("sub"), - "api_key": user_api_key, - "question": question, - "sources": docs, - "retriever_params": retriever_params, - "timestamp": datetime.datetime.now(datetime.timezone.utc), - } - ) - - if data.get("isNoneDoc"): - for doc in docs: - doc["source"] = "None" - - except Exception as e: - logger.error( - f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}", - extra={"error": str(e), "traceback": traceback.format_exc()}, - ) - return bad_request(500, str(e)) - - return make_response(docs, 200) - - -def get_attachments_content(attachment_ids, user): - """ - Retrieve content from attachment documents based on their IDs. - - Args: - attachment_ids (list): List of attachment document IDs - user (str): User identifier to verify ownership - - Returns: - list: List of dictionaries containing attachment content and metadata - """ - if not attachment_ids: - return [] - - attachments = [] - for attachment_id in attachment_ids: - try: - attachment_doc = attachments_collection.find_one( - {"_id": ObjectId(attachment_id), "user": user} - ) - - if attachment_doc: - attachments.append(attachment_doc) - except Exception as e: - logger.error( - f"Error retrieving attachment {attachment_id}: {e}", exc_info=True - ) - - return attachments diff --git a/application/api/answer/routes/__init__.py b/application/api/answer/routes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/application/api/answer/routes/answer.py b/application/api/answer/routes/answer.py new file mode 100644 index 00000000..1ca23b7d --- /dev/null +++ b/application/api/answer/routes/answer.py @@ -0,0 +1,103 @@ +import logging +import traceback + +from flask import make_response, request +from flask_restx import fields, Resource + +from application.api import api + +from application.api.answer.routes.base import answer_ns, BaseAnswerResource + +from application.api.answer.services.stream_processor import StreamProcessor + +logger = logging.getLogger(__name__) + + +@answer_ns.route("/api/answer") +class AnswerResource(Resource, BaseAnswerResource): + def __init__(self, *args, **kwargs): + Resource.__init__(self, *args, **kwargs) + BaseAnswerResource.__init__(self) + + answer_model = answer_ns.model( + "AnswerModel", + { + "question": fields.String( + required=True, description="Question to be asked" + ), + "history": fields.List( + fields.String, + required=False, + description="Conversation history (only for new conversations)", + ), + "conversation_id": fields.String( + required=False, + description="Existing conversation ID (loads history)", + ), + "prompt_id": fields.String( + required=False, default="default", description="Prompt ID" + ), + "chunks": fields.Integer( + required=False, default=2, description="Number of chunks" + ), + "token_limit": fields.Integer(required=False, description="Token limit"), + "retriever": fields.String(required=False, description="Retriever type"), + "api_key": fields.String(required=False, description="API key"), + "active_docs": fields.String( + required=False, description="Active documents" + ), + "isNoneDoc": fields.Boolean( + required=False, description="Flag indicating if no document is used" + ), + "save_conversation": fields.Boolean( + required=False, + default=True, + description="Whether to save the conversation", + ), + }, + ) + + @api.expect(answer_model) + @api.doc(description="Provide a response based on the question and retriever") + def post(self): + data = request.get_json() + if error := self.validate_request(data): + return error + processor = StreamProcessor(data, None) + try: + processor.initialize() + if not processor.decoded_token: + return make_response({"error": "Unauthorized"}, 401) + agent = processor.create_agent() + retriever = processor.create_retriever() + + stream = self.complete_stream( + question=data["question"], + agent=agent, + retriever=retriever, + conversation_id=processor.conversation_id, + user_api_key=processor.agent_config.get("user_api_key"), + decoded_token=processor.decoded_token, + isNoneDoc=data.get("isNoneDoc"), + index=None, + should_save_conversation=data.get("save_conversation", True), + ) + conversation_id, response, sources, tool_calls, thought, error = ( + self.process_response_stream(stream) + ) + if error: + return make_response({"error": error}, 400) + result = { + "conversation_id": conversation_id, + "answer": response, + "sources": sources, + "tool_calls": tool_calls, + "thought": thought, + } + except Exception as e: + logger.error( + f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}", + extra={"error": str(e), "traceback": traceback.format_exc()}, + ) + return make_response({"error": str(e)}, 500) + return make_response(result, 200) diff --git a/application/api/answer/routes/base.py b/application/api/answer/routes/base.py new file mode 100644 index 00000000..682da1f0 --- /dev/null +++ b/application/api/answer/routes/base.py @@ -0,0 +1,226 @@ +import datetime +import json +import logging +from typing import Any, Dict, Generator, List, Optional + +from flask import Response +from flask_restx import Namespace + +from application.api.answer.services.conversation_service import ConversationService + +from application.core.mongo_db import MongoDB +from application.core.settings import settings +from application.llm.llm_creator import LLMCreator +from application.utils import check_required_fields, get_gpt_model + +logger = logging.getLogger(__name__) + + +answer_ns = Namespace("answer", description="Answer related operations", path="/") + + +class BaseAnswerResource: + """Shared base class for answer endpoints""" + + def __init__(self): + mongo = MongoDB.get_client() + db = mongo[settings.MONGO_DB_NAME] + self.user_logs_collection = db["user_logs"] + self.gpt_model = get_gpt_model() + self.conversation_service = ConversationService() + + def validate_request( + self, data: Dict[str, Any], require_conversation_id: bool = False + ) -> Optional[Response]: + """Common request validation""" + required_fields = ["question"] + if require_conversation_id: + required_fields.append("conversation_id") + if missing_fields := check_required_fields(data, required_fields): + return missing_fields + return None + + def complete_stream( + self, + question: str, + agent: Any, + retriever: Any, + conversation_id: Optional[str], + user_api_key: Optional[str], + decoded_token: Dict[str, Any], + isNoneDoc: bool = False, + index: Optional[int] = None, + should_save_conversation: bool = True, + attachment_ids: Optional[List[str]] = None, + agent_id: Optional[str] = None, + is_shared_usage: bool = False, + shared_token: Optional[str] = None, + ) -> Generator[str, None, None]: + """ + Generator function that streams the complete conversation response. + + Args: + question: The user's question + agent: The agent instance + retriever: The retriever instance + conversation_id: Existing conversation ID + user_api_key: User's API key if any + decoded_token: Decoded JWT token + isNoneDoc: Flag for document-less responses + index: Index of message to update + should_save_conversation: Whether to persist the conversation + attachment_ids: List of attachment IDs + agent_id: ID of agent used + is_shared_usage: Flag for shared agent usage + shared_token: Token for shared agent + + Yields: + Server-sent event strings + """ + try: + response_full, thought, source_log_docs, tool_calls = "", "", [], [] + + for line in agent.gen(query=question, retriever=retriever): + if "answer" in line: + response_full += str(line["answer"]) + data = json.dumps({"type": "answer", "answer": line["answer"]}) + yield f"data: {data}\n\n" + elif "sources" in line: + truncated_sources = [] + source_log_docs = line["sources"] + for source in line["sources"]: + truncated_source = source.copy() + if "text" in truncated_source: + truncated_source["text"] = ( + truncated_source["text"][:100].strip() + "..." + ) + truncated_sources.append(truncated_source) + if truncated_sources: + data = json.dumps( + {"type": "source", "source": truncated_sources} + ) + yield f"data: {data}\n\n" + elif "tool_calls" in line: + tool_calls = line["tool_calls"] + elif "thought" in line: + thought += line["thought"] + data = json.dumps({"type": "thought", "thought": line["thought"]}) + yield f"data: {data}\n\n" + elif "type" in line: + data = json.dumps(line) + yield f"data: {data}\n\n" + if isNoneDoc: + for doc in source_log_docs: + doc["source"] = "None" + llm = LLMCreator.create_llm( + settings.LLM_PROVIDER, + api_key=settings.API_KEY, + user_api_key=user_api_key, + decoded_token=decoded_token, + ) + + if should_save_conversation: + conversation_id = self.conversation_service.save_conversation( + conversation_id, + question, + response_full, + thought, + source_log_docs, + tool_calls, + llm, + self.gpt_model, + decoded_token, + index=index, + api_key=user_api_key, + agent_id=agent_id, + is_shared_usage=is_shared_usage, + shared_token=shared_token, + attachment_ids=attachment_ids, + ) + else: + conversation_id = None + # Send conversation ID + + data = json.dumps({"type": "id", "id": str(conversation_id)}) + yield f"data: {data}\n\n" + + # Log the interaction + + retriever_params = retriever.get_params() + self.user_logs_collection.insert_one( + { + "action": "stream_answer", + "level": "info", + "user": decoded_token.get("sub"), + "api_key": user_api_key, + "question": question, + "response": response_full, + "sources": source_log_docs, + "retriever_params": retriever_params, + "attachments": attachment_ids, + "timestamp": datetime.datetime.now(datetime.timezone.utc), + } + ) + + # End of stream + + data = json.dumps({"type": "end"}) + yield f"data: {data}\n\n" + except Exception as e: + logger.error(f"Error in stream: {str(e)}", exc_info=True) + data = json.dumps( + { + "type": "error", + "error": "Please try again later. We apologize for any inconvenience.", + } + ) + yield f"data: {data}\n\n" + return + + def process_response_stream(self, stream): + """Process the stream response for non-streaming endpoint""" + conversation_id = "" + response_full = "" + source_log_docs = [] + tool_calls = [] + thought = "" + stream_ended = False + + for line in stream: + try: + event_data = line.replace("data: ", "").strip() + event = json.loads(event_data) + + if event["type"] == "id": + conversation_id = event["id"] + elif event["type"] == "answer": + response_full += event["answer"] + elif event["type"] == "source": + source_log_docs = event["source"] + elif event["type"] == "tool_calls": + tool_calls = event["tool_calls"] + elif event["type"] == "thought": + thought = event["thought"] + elif event["type"] == "error": + logger.error(f"Error from stream: {event['error']}") + return None, None, None, None, event["error"] + elif event["type"] == "end": + stream_ended = True + except (json.JSONDecodeError, KeyError) as e: + logger.warning(f"Error parsing stream event: {e}, line: {line}") + continue + if not stream_ended: + logger.error("Stream ended unexpectedly without an 'end' event.") + return None, None, None, None, "Stream ended unexpectedly" + return ( + conversation_id, + response_full, + source_log_docs, + tool_calls, + thought, + None, + ) + + def error_stream_generate(self, err_response): + data = json.dumps({"type": "error", "error": err_response}) + yield f"data: {data}\n\n" diff --git a/application/api/answer/routes/stream.py b/application/api/answer/routes/stream.py new file mode 100644 index 00000000..eb0ba6eb --- /dev/null +++ b/application/api/answer/routes/stream.py @@ -0,0 +1,116 @@ +import logging +import traceback + +from flask import make_response, request, Response +from flask_restx import fields, Resource + +from application.api import api + +from application.api.answer.routes.base import answer_ns, BaseAnswerResource + +from application.api.answer.services.stream_processor import StreamProcessor + +logger = logging.getLogger(__name__) + + +@answer_ns.route("/stream") +class StreamResource(Resource, BaseAnswerResource): + def __init__(self, *args, **kwargs): + Resource.__init__(self, *args, **kwargs) + BaseAnswerResource.__init__(self) + + stream_model = answer_ns.model( + "StreamModel", + { + "question": fields.String( + required=True, description="Question to be asked" + ), + "history": fields.List( + fields.String, + required=False, + description="Conversation history (only for new conversations)", + ), + "conversation_id": fields.String( + required=False, + description="Existing conversation ID (loads history)", + ), + "prompt_id": fields.String( + required=False, default="default", description="Prompt ID" + ), + "chunks": fields.Integer( + required=False, default=2, description="Number of chunks" + ), + "token_limit": fields.Integer(required=False, description="Token limit"), + "retriever": fields.String(required=False, description="Retriever type"), + "api_key": fields.String(required=False, description="API key"), + "active_docs": fields.String( + required=False, description="Active documents" + ), + "isNoneDoc": fields.Boolean( + required=False, description="Flag indicating if no document is used" + ), + "index": fields.Integer( + required=False, description="Index of the query to update" + ), + "save_conversation": fields.Boolean( + required=False, + default=True, + description="Whether to save the conversation", + ), + "attachments": fields.List( + fields.String, required=False, description="List of attachment IDs" + ), + }, + ) + + @api.expect(stream_model) + @api.doc(description="Stream a response based on the question and retriever") + def post(self): + data = request.get_json() + if error := self.validate_request(data, "index" in data): + return error + decoded_token = getattr(request, "decoded_token", None) + if not decoded_token: + return make_response({"error": "Unauthorized"}, 401) + processor = StreamProcessor(data, decoded_token) + try: + processor.initialize() + agent = processor.create_agent() + retriever = processor.create_retriever() + + return Response( + self.complete_stream( + question=data["question"], + agent=agent, + retriever=retriever, + conversation_id=processor.conversation_id, + user_api_key=processor.agent_config.get("user_api_key"), + decoded_token=processor.decoded_token, + isNoneDoc=data.get("isNoneDoc"), + index=data.get("index"), + should_save_conversation=data.get("save_conversation", True), + attachment_ids=data.get("attachments", []), + agent_id=data.get("agent_id"), + is_shared_usage=processor.is_shared_usage, + shared_token=processor.shared_token, + ), + mimetype="text/event-stream", + ) + except ValueError: + message = "Malformed request body" + logger.error(f"/stream - error: {message}") + return Response( + self.error_stream_generate(message), + status=400, + mimetype="text/event-stream", + ) + except Exception as e: + logger.error( + f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}", + extra={"error": str(e), "traceback": traceback.format_exc()}, + ) + return Response( + self.error_stream_generate("Unknown error occurred"), + status=400, + mimetype="text/event-stream", + ) diff --git a/application/api/answer/services/__init__.py b/application/api/answer/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/application/api/answer/services/conversation_service.py b/application/api/answer/services/conversation_service.py new file mode 100644 index 00000000..e35fcc40 --- /dev/null +++ b/application/api/answer/services/conversation_service.py @@ -0,0 +1,175 @@ +import logging +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from application.core.mongo_db import MongoDB + +from application.core.settings import settings +from bson import ObjectId + + +logger = logging.getLogger(__name__) + + +class ConversationService: + def __init__(self): + mongo = MongoDB.get_client() + db = mongo[settings.MONGO_DB_NAME] + self.conversations_collection = db["conversations"] + self.agents_collection = db["agents"] + + def get_conversation( + self, conversation_id: str, user_id: str + ) -> Optional[Dict[str, Any]]: + """Retrieve a conversation with proper access control""" + if not conversation_id or not user_id: + return None + try: + conversation = self.conversations_collection.find_one( + { + "_id": ObjectId(conversation_id), + "$or": [{"user": user_id}, {"shared_with": user_id}], + } + ) + + if not conversation: + logger.warning( + f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}" + ) + return None + conversation["_id"] = str(conversation["_id"]) + return conversation + except Exception as e: + logger.error(f"Error fetching conversation: {str(e)}", exc_info=True) + return None + + def save_conversation( + self, + conversation_id: Optional[str], + question: str, + response: str, + thought: str, + sources: List[Dict[str, Any]], + tool_calls: List[Dict[str, Any]], + llm: Any, + gpt_model: str, + decoded_token: Dict[str, Any], + index: Optional[int] = None, + api_key: Optional[str] = None, + agent_id: Optional[str] = None, + is_shared_usage: bool = False, + shared_token: Optional[str] = None, + attachment_ids: Optional[List[str]] = None, + ) -> str: + """Save or update a conversation in the database""" + user_id = decoded_token.get("sub") + if not user_id: + raise ValueError("User ID not found in token") + current_time = datetime.now(timezone.utc) + + if conversation_id is not None and index is not None: + # Update existing conversation with new query + + result = self.conversations_collection.update_one( + { + "_id": ObjectId(conversation_id), + "user": user_id, + f"queries.{index}": {"$exists": True}, + }, + { + "$set": { + f"queries.{index}.prompt": question, + f"queries.{index}.response": response, + f"queries.{index}.thought": thought, + f"queries.{index}.sources": sources, + f"queries.{index}.tool_calls": tool_calls, + f"queries.{index}.timestamp": current_time, + f"queries.{index}.attachments": attachment_ids, + } + }, + ) + + if result.matched_count == 0: + raise ValueError("Conversation not found or unauthorized") + self.conversations_collection.update_one( + { + "_id": ObjectId(conversation_id), + "user": user_id, + f"queries.{index}": {"$exists": True}, + }, + {"$push": {"queries": {"$each": [], "$slice": index + 1}}}, + ) + return conversation_id + elif conversation_id: + # Append new message to existing conversation + + result = self.conversations_collection.update_one( + {"_id": ObjectId(conversation_id), "user": user_id}, + { + "$push": { + "queries": { + "prompt": question, + "response": response, + "thought": thought, + "sources": sources, + "tool_calls": tool_calls, + "timestamp": current_time, + "attachments": attachment_ids, + } + } + }, + ) + + if result.matched_count == 0: + raise ValueError("Conversation not found or unauthorized") + return conversation_id + else: + # Create new conversation + + messages_summary = [ + { + "role": "assistant", + "content": "Summarise following conversation in no more than 3 " + "words, respond ONLY with the summary, use the same " + "language as the user query", + }, + { + "role": "user", + "content": "Summarise following conversation in no more than 3 words, " + "respond ONLY with the summary, use the same language as the " + "user query \n\nUser: " + question + "\n\n" + "AI: " + response, + }, + ] + + completion = llm.gen( + model=gpt_model, messages=messages_summary, max_tokens=30 + ) + + conversation_data = { + "user": user_id, + "date": current_time, + "name": completion, + "queries": [ + { + "prompt": question, + "response": response, + "thought": thought, + "sources": sources, + "tool_calls": tool_calls, + "timestamp": current_time, + "attachments": attachment_ids, + } + ], + } + + if api_key: + if agent_id: + conversation_data["agent_id"] = agent_id + if is_shared_usage: + conversation_data["is_shared_usage"] = is_shared_usage + conversation_data["shared_token"] = shared_token + agent = self.agents_collection.find_one({"key": api_key}) + if agent: + conversation_data["api_key"] = agent["key"] + result = self.conversations_collection.insert_one(conversation_data) + return str(result.inserted_id) diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py new file mode 100644 index 00000000..8c017a96 --- /dev/null +++ b/application/api/answer/services/stream_processor.py @@ -0,0 +1,260 @@ +import datetime +import json +import logging +import os +from pathlib import Path +from typing import Any, Dict, Optional + +from bson.dbref import DBRef + +from bson.objectid import ObjectId + +from application.agents.agent_creator import AgentCreator +from application.api.answer.services.conversation_service import ConversationService +from application.core.mongo_db import MongoDB +from application.core.settings import settings +from application.retriever.retriever_creator import RetrieverCreator +from application.utils import get_gpt_model, limit_chat_history + +logger = logging.getLogger(__name__) + + +def get_prompt(prompt_id: str, prompts_collection=None) -> str: + """ + Get a prompt by preset name or MongoDB ID + """ + current_dir = Path(__file__).resolve().parents[3] + prompts_dir = current_dir / "prompts" + + preset_mapping = { + "default": "chat_combine_default.txt", + "creative": "chat_combine_creative.txt", + "strict": "chat_combine_strict.txt", + "reduce": "chat_reduce_prompt.txt", + } + + if prompt_id in preset_mapping: + file_path = os.path.join(prompts_dir, preset_mapping[prompt_id]) + try: + with open(file_path, "r") as f: + return f.read() + except FileNotFoundError: + raise FileNotFoundError(f"Prompt file not found: {file_path}") + try: + if not prompts_collection: + mongo = MongoDB.get_client() + db = mongo[settings.MONGO_DB_NAME] + prompts_collection = db["prompts"] + prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)}) + if not prompt_doc: + raise ValueError(f"Prompt with ID {prompt_id} not found") + return prompt_doc["content"] + except Exception as e: + raise ValueError(f"Invalid prompt ID: {prompt_id}") from e + + +class StreamProcessor: + def __init__( + self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]] + ): + mongo = MongoDB.get_client() + self.db = mongo[settings.MONGO_DB_NAME] + self.agents_collection = self.db["agents"] + self.attachments_collection = self.db["attachments"] + self.prompts_collection = self.db["prompts"] + + self.data = request_data + self.decoded_token = decoded_token + self.initial_user_id = ( + self.decoded_token.get("sub") if self.decoded_token is not None else None + ) + self.conversation_id = self.data.get("conversation_id") + self.source = ( + {"active_docs": self.data["active_docs"]} + if "active_docs" in self.data + else {} + ) + self.attachments = [] + self.history = [] + self.agent_config = {} + self.retriever_config = {} + self.is_shared_usage = False + self.shared_token = None + self.gpt_model = get_gpt_model() + self.conversation_service = ConversationService() + + def initialize(self): + """Initialize all required components for processing""" + self._configure_agent() + self._configure_retriever() + self._load_conversation_history() + self._process_attachments() + + def _load_conversation_history(self): + """Load conversation history either from DB or request""" + if self.conversation_id and self.initial_user_id: + conversation = self.conversation_service.get_conversation( + self.conversation_id, self.initial_user_id + ) + if not conversation: + raise ValueError("Conversation not found or unauthorized") + self.history = [ + {"prompt": query["prompt"], "response": query["response"]} + for query in conversation.get("queries", []) + ] + else: + self.history = limit_chat_history( + json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model + ) + + def _process_attachments(self): + """Process any attachments in the request""" + attachment_ids = self.data.get("attachments", []) + self.attachments = self._get_attachments_content( + attachment_ids, self.initial_user_id + ) + + def _get_attachments_content(self, attachment_ids, user_id): + """ + Retrieve content from attachment documents based on their IDs. + """ + if not attachment_ids: + return [] + attachments = [] + for attachment_id in attachment_ids: + try: + attachment_doc = self.attachments_collection.find_one( + {"_id": ObjectId(attachment_id), "user": user_id} + ) + + if attachment_doc: + attachments.append(attachment_doc) + except Exception as e: + logger.error( + f"Error retrieving attachment {attachment_id}: {e}", exc_info=True + ) + return attachments + + def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple: + """Get API key for agent with access control""" + if not agent_id: + return None, False, None + try: + agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)}) + if agent is None: + raise Exception("Agent not found") + is_owner = agent.get("user") == user_id + is_shared_with_user = agent.get( + "shared_publicly", False + ) or user_id in agent.get("shared_with", []) + + if not (is_owner or is_shared_with_user): + raise Exception("Unauthorized access to the agent") + if is_owner: + self.agents_collection.update_one( + {"_id": ObjectId(agent_id)}, + { + "$set": { + "lastUsedAt": datetime.datetime.now(datetime.timezone.utc) + } + }, + ) + return str(agent["key"]), not is_owner, agent.get("shared_token") + except Exception as e: + logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True) + raise + + def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]: + data = self.agents_collection.find_one({"key": api_key}) + if not data: + raise Exception("Invalid API Key, please generate a new key", 401) + source = data.get("source") + if isinstance(source, DBRef): + source_doc = self.db.dereference(source) + data["source"] = str(source_doc["_id"]) + data["retriever"] = source_doc.get("retriever", data.get("retriever")) + else: + data["source"] = {} + return data + + def _configure_agent(self): + """Configure the agent based on request data""" + agent_id = self.data.get("agent_id") + self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key( + agent_id, self.initial_user_id + ) + + api_key = self.data.get("api_key") + if api_key: + data_key = self._get_data_from_api_key(api_key) + self.agent_config.update( + { + "prompt_id": data_key.get("prompt_id", "default"), + "agent_type": data_key.get("agent_type", settings.AGENT_NAME), + "user_api_key": api_key, + } + ) + self.initial_user_id = data_key.get("user") + self.decoded_token = {"sub": data_key.get("user")} + elif self.agent_key: + data_key = self._get_data_from_api_key(self.agent_key) + self.agent_config.update( + { + "prompt_id": data_key.get("prompt_id", "default"), + "agent_type": data_key.get("agent_type", settings.AGENT_NAME), + "user_api_key": self.agent_key, + } + ) + self.decoded_token = ( + self.decoded_token + if self.is_shared_usage + else {"sub": data_key.get("user")} + ) + else: + self.agent_config.update( + { + "prompt_id": self.data.get("prompt_id", "default"), + "agent_type": settings.AGENT_NAME, + "user_api_key": None, + } + ) + + def _configure_retriever(self): + """Configure the retriever based on request data""" + self.retriever_config = { + "retriever_name": self.data.get("retriever", "classic"), + "chunks": int(self.data.get("chunks", 2)), + "token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY), + } + + if "isNoneDoc" in self.data and self.data["isNoneDoc"]: + self.retriever_config["chunks"] = 0 + + def create_agent(self): + """Create and return the configured agent""" + return AgentCreator.create_agent( + self.agent_config["agent_type"], + endpoint="stream", + llm_name=settings.LLM_PROVIDER, + gpt_model=self.gpt_model, + api_key=settings.API_KEY, + user_api_key=self.agent_config["user_api_key"], + prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection), + chat_history=self.history, + decoded_token=self.decoded_token, + attachments=self.attachments, + ) + + def create_retriever(self): + """Create and return the configured retriever""" + return RetrieverCreator.create_retriever( + self.retriever_config["retriever_name"], + source=self.source, + chat_history=self.history, + prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection), + chunks=self.retriever_config["chunks"], + token_limit=self.retriever_config["token_limit"], + gpt_model=self.gpt_model, + user_api_key=self.agent_config["user_api_key"], + decoded_token=self.decoded_token, + ) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 0d35b8ca..ca2c50eb 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -34,7 +34,7 @@ from application.api.user.tasks import ( ) from application.core.mongo_db import MongoDB from application.core.settings import settings -from application.extensions import api +from application.api import api from application.storage.storage_creator import StorageCreator from application.tts.google_tts import GoogleTTS from application.utils import ( @@ -3538,14 +3538,18 @@ class StoreAttachment(Resource): "AttachmentModel", { "file": fields.Raw(required=True, description="File to upload"), + "api_key": fields.String( + required=False, description="API key (optional)" + ), }, ) ) - @api.doc(description="Stores a single attachment without vectorization or training") + @api.doc( + description="Stores a single attachment without vectorization or training. Supports user or API key authentication." + ) def post(self): - decoded_token = request.decoded_token - if not decoded_token: - return make_response(jsonify({"success": False}), 401) + decoded_token = getattr(request, "decoded_token", None) + api_key = request.form.get("api_key") or request.args.get("api_key") file = request.files.get("file") if not file or file.filename == "": @@ -3553,7 +3557,21 @@ class StoreAttachment(Resource): jsonify({"status": "error", "message": "Missing file"}), 400, ) - user = safe_filename(decoded_token.get("sub")) + + user = None + if decoded_token: + user = safe_filename(decoded_token.get("sub")) + elif api_key: + agent = agents_collection.find_one({"key": api_key}) + if not agent: + return make_response( + jsonify({"success": False, "message": "Invalid API key"}), 401 + ) + user = safe_filename(agent.get("user")) + else: + return make_response( + jsonify({"success": False, "message": "Authentication required"}), 401 + ) try: attachment_id = ObjectId() diff --git a/application/app.py b/application/app.py index 7ca0ac2b..4159a2bb 100644 --- a/application/app.py +++ b/application/app.py @@ -12,19 +12,18 @@ from application.core.logging_config import setup_logging setup_logging() -from application.api.answer.routes import answer # noqa: E402 +from application.api import api # noqa: E402 +from application.api.answer import answer # noqa: E402 from application.api.internal.routes import internal # noqa: E402 from application.api.user.routes import user # noqa: E402 from application.celery_init import celery # noqa: E402 from application.core.settings import settings # noqa: E402 -from application.extensions import api # noqa: E402 if platform.system() == "Windows": import pathlib pathlib.PosixPath = pathlib.WindowsPath - dotenv.load_dotenv() app = Flask(__name__) @@ -52,7 +51,6 @@ if settings.AUTH_TYPE in ("simple_jwt", "session_jwt") and not settings.JWT_SECR settings.JWT_SECRET_KEY = new_key except Exception as e: raise RuntimeError(f"Failed to setup JWT_SECRET_KEY: {e}") - SIMPLE_JWT_TOKEN = None if settings.AUTH_TYPE == "simple_jwt": payload = {"sub": "local"} @@ -92,7 +90,6 @@ def generate_token(): def authenticate_request(): if request.method == "OPTIONS": return "", 200 - decoded_token = handle_auth(request) if not decoded_token: request.decoded_token = None diff --git a/application/core/settings.py b/application/core/settings.py index 35e1bb75..7ecc9aeb 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -10,7 +10,7 @@ current_dir = os.path.dirname( class Settings(BaseSettings): - AUTH_TYPE: Optional[str] = None + AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None LLM_PROVIDER: str = "docsgpt" LLM_NAME: Optional[str] = ( None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo diff --git a/application/extensions.py b/application/extensions.py deleted file mode 100644 index b6f52893..00000000 --- a/application/extensions.py +++ /dev/null @@ -1,7 +0,0 @@ -from flask_restx import Api - -api = Api( - version="1.0", - title="DocsGPT API", - description="API for DocsGPT", -) diff --git a/application/storage/__init__.py b/application/storage/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/application/utils.py b/application/utils.py index 883eb926..d4f0a362 100644 --- a/application/utils.py +++ b/application/utils.py @@ -6,6 +6,7 @@ import uuid import tiktoken from flask import jsonify, make_response from werkzeug.utils import secure_filename + from application.core.settings import settings @@ -19,6 +20,17 @@ def get_encoding(): return _encoding +def get_gpt_model() -> str: + """Get the appropriate GPT model based on provider""" + model_map = { + "openai": "gpt-4o-mini", + "anthropic": "claude-2", + "groq": "llama3-8b-8192", + "novita": "deepseek/deepseek-r1", + } + return settings.LLM_NAME or model_map.get(settings.LLM_PROVIDER, "") + + def safe_filename(filename): """ Creates a safe filename that preserves the original extension. @@ -32,15 +44,14 @@ def safe_filename(filename): """ if not filename: return str(uuid.uuid4()) - _, extension = os.path.splitext(filename) safe_name = secure_filename(filename) # If secure_filename returns just the extension or an empty string + if not safe_name or safe_name == extension.lstrip("."): return f"{str(uuid.uuid4())}{extension}" - return safe_name @@ -68,7 +79,6 @@ def count_tokens_docs(docs): docs_content = "" for doc in docs: docs_content += doc.page_content - tokens = num_tokens_from_string(docs_content) return tokens @@ -97,13 +107,11 @@ def validate_required_fields(data, required_fields): missing_fields.append(field) elif not data[field]: empty_fields.append(field) - errors = [] if missing_fields: errors.append(f"Missing required fields: {', '.join(missing_fields)}") if empty_fields: errors.append(f"Empty values in required fields: {', '.join(empty_fields)}") - if errors: return make_response( jsonify({"success": False, "message": " | ".join(errors)}), 400 @@ -132,7 +140,6 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): if not history: return [] - trimmed_history = [] tokens_current_history = 0 @@ -141,18 +148,15 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): if "prompt" in message and "response" in message: tokens_batch += num_tokens_from_string(message["prompt"]) tokens_batch += num_tokens_from_string(message["response"]) - if "tool_calls" in message: for tool_call in message["tool_calls"]: tool_call_string = f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}" tokens_batch += num_tokens_from_string(tool_call_string) - if tokens_current_history + tokens_batch < max_token_limit: tokens_current_history += tokens_batch trimmed_history.insert(0, message) else: break - return trimmed_history diff --git a/application/worker.py b/application/worker.py index c6178931..23f96bf5 100755 --- a/application/worker.py +++ b/application/worker.py @@ -16,7 +16,7 @@ from bson.dbref import DBRef from bson.objectid import ObjectId from application.agents.agent_creator import AgentCreator -from application.api.answer.routes import get_prompt +from application.api.answer.services.stream_processor import get_prompt from application.core.mongo_db import MongoDB from application.core.settings import settings @@ -35,17 +35,22 @@ db = mongo[settings.MONGO_DB_NAME] sources_collection = db["sources"] # Constants + MIN_TOKENS = 150 MAX_TOKENS = 1250 RECURSION_DEPTH = 2 # Define a function to extract metadata from a given filename. + + def metadata_from_filename(title): return {"title": title} # Define a function to generate a random string of a given length. + + def generate_random_string(length): return "".join([string.ascii_letters[i % 52] for i in range(length)]) @@ -68,7 +73,6 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5): if current_depth > max_depth: logging.warning(f"Reached maximum recursion depth of {max_depth}") return - try: with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_to) @@ -76,12 +80,13 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5): except Exception as e: logging.error(f"Error extracting zip file {zip_path}: {e}", exc_info=True) return - # Check for nested zip files and extract them + for root, dirs, files in os.walk(extract_to): for file in files: if file.endswith(".zip"): # If a nested zip file is found, extract it recursively + file_path = os.path.join(root, file) extract_zip_recursive(file_path, root, current_depth + 1, max_depth) @@ -139,7 +144,7 @@ def run_agent_logic(agent_config, input_data): user_api_key = agent_config["key"] agent_type = agent_config.get("agent_type", "classic") decoded_token = {"sub": agent_config.get("user")} - prompt = get_prompt(prompt_id) + prompt = get_prompt(prompt_id, db["prompts"]) agent = AgentCreator.create_agent( agent_type, endpoint="webhook", @@ -178,7 +183,6 @@ def run_agent_logic(agent_config, input_data): tool_calls.extend(line["tool_calls"]) elif "thought" in line: thought += line["thought"] - result = { "answer": response_full, "sources": source_log_docs, @@ -193,8 +197,18 @@ def run_agent_logic(agent_config, input_data): # Define the main function for ingesting and processing documents. + + def ingest_worker( - self, directory, formats, job_name, filename, user, dir_name=None, user_dir=None, retriever="classic" + self, + directory, + formats, + job_name, + filename, + user, + dir_name=None, + user_dir=None, + retriever="classic", ): """ Ingest and process documents. @@ -218,7 +232,7 @@ def ingest_worker( limit = None exclude = True sample = False - + storage = StorageCreator.get_storage() full_path = os.path.join(directory, user_dir, dir_name) @@ -227,29 +241,29 @@ def ingest_worker( logging.info(f"Ingest file: {full_path}", extra={"user": user, "job": job_name}) # Create temporary working directory + with tempfile.TemporaryDirectory() as temp_dir: try: os.makedirs(temp_dir, exist_ok=True) # Download file from storage to temp directory + temp_file_path = os.path.join(temp_dir, filename) file_data = storage.get_file(source_file_path) with open(temp_file_path, "wb") as f: f.write(file_data.read()) - self.update_state(state="PROGRESS", meta={"current": 1}) # Handle zip files + if filename.endswith(".zip"): logging.info(f"Extracting zip file: {filename}") extract_zip_recursive( temp_file_path, temp_dir, current_depth=0, max_depth=RECURSION_DEPTH ) - if sample: logging.info(f"Sample mode enabled. Using {limit} documents.") - reader = SimpleDirectoryReader( input_dir=temp_dir, input_files=input_files, @@ -296,11 +310,9 @@ def ingest_worker( } upload_index(vector_store_path, file_data) - except Exception as e: logging.error(f"Error in ingest_worker: {e}", exc_info=True) raise - return { "directory": directory, "formats": formats, @@ -326,7 +338,6 @@ def remote_worker( full_path = os.path.join(directory, user, name_job) if not os.path.exists(full_path): os.makedirs(full_path) - self.update_state(state="PROGRESS", meta={"current": 1}) try: logging.info("Initializing remote loader with type: %s", loader) @@ -353,7 +364,6 @@ def remote_worker( raise ValueError("doc_id must be provided for sync operation.") id = ObjectId(doc_id) embed_and_store_documents(docs, full_path, id, self) - self.update_state(state="PROGRESS", meta={"current": 100}) file_data = { @@ -367,15 +377,12 @@ def remote_worker( "sync_frequency": sync_frequency, } upload_index(full_path, file_data) - except Exception as e: logging.error("Error in remote_worker task: %s", str(e), exc_info=True) raise - finally: if os.path.exists(full_path): shutil.rmtree(full_path) - logging.info("remote_worker task completed successfully") return {"urls": source_data, "name_job": name_job, "user": user, "limited": False} @@ -428,7 +435,6 @@ def sync_worker(self, frequency): sync_counts[ "sync_success" if resp["status"] == "success" else "sync_failure" ] += 1 - return { key: sync_counts[key] for key in ["total_sync_count", "sync_success", "sync_failure"] @@ -503,7 +509,6 @@ def attachment_worker(self, file_info, user): "mime_type": mime_type, "metadata": metadata, } - except Exception as e: logging.error( f"Error processing file {filename}: {e}", @@ -539,7 +544,6 @@ def agent_webhook_worker(self, agent_id, payload): except Exception as e: logging.error(f"Error processing agent webhook: {e}", exc_info=True) return {"status": "error", "error": str(e)} - self.update_state(state="PROGRESS", meta={"current": 50}) try: result = run_agent_logic(agent_config, input_data) diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts index 71f539e5..fb6e1b59 100644 --- a/frontend/src/conversation/conversationHandlers.ts +++ b/frontend/src/conversation/conversationHandlers.ts @@ -8,7 +8,6 @@ export function handleFetchAnswer( signal: AbortSignal, token: string | null, selectedDocs: Doc | null, - history: Array = [], conversationId: string | null, promptId: string | null, chunks: string, @@ -37,16 +36,8 @@ export function handleFetchAnswer( title: any; } > { - history = history.map((item) => { - return { - prompt: item.prompt, - response: item.response, - tool_calls: item.tool_calls, - }; - }); const payload: RetrievalPayload = { question: question, - history: JSON.stringify(history), conversation_id: conversationId, prompt_id: promptId, chunks: chunks, @@ -94,7 +85,6 @@ export function handleFetchAnswerSteaming( signal: AbortSignal, token: string | null, selectedDocs: Doc | null, - history: Array = [], conversationId: string | null, promptId: string | null, chunks: string, @@ -105,17 +95,8 @@ export function handleFetchAnswerSteaming( attachments?: string[], save_conversation = true, ): Promise { - history = history.map((item) => { - return { - prompt: item.prompt, - response: item.response, - thought: item.thought, - tool_calls: item.tool_calls, - }; - }); const payload: RetrievalPayload = { question: question, - history: JSON.stringify(history), conversation_id: conversationId, prompt_id: promptId, chunks: chunks, @@ -192,20 +173,11 @@ export function handleSearch( token: string | null, selectedDocs: Doc | null, conversation_id: string | null, - history: Array = [], chunks: string, token_limit: number, ) { - history = history.map((item) => { - return { - prompt: item.prompt, - response: item.response, - tool_calls: item.tool_calls, - }; - }); const payload: RetrievalPayload = { question: question, - history: JSON.stringify(history), conversation_id: conversation_id, chunks: chunks, token_limit: token_limit, diff --git a/frontend/src/conversation/conversationModels.ts b/frontend/src/conversation/conversationModels.ts index b16dd6c1..f95c48a6 100644 --- a/frontend/src/conversation/conversationModels.ts +++ b/frontend/src/conversation/conversationModels.ts @@ -52,7 +52,6 @@ export interface RetrievalPayload { question: string; active_docs?: string; retriever?: string; - history: string; conversation_id: string | null; prompt_id?: string | null; chunks: string; diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index b3eb92f9..38aeec7a 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -57,7 +57,6 @@ export const fetchAnswer = createAsyncThunk< signal, state.preference.token, state.preference.selectedDocs!, - state.conversation.queries, currentConversationId, state.preference.prompt.id, state.preference.chunks, @@ -153,7 +152,6 @@ export const fetchAnswer = createAsyncThunk< signal, state.preference.token, state.preference.selectedDocs!, - state.conversation.queries, state.conversation.conversationId, state.preference.prompt.id, state.preference.chunks,