mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
refactor: answer routes to comply with OpenAPI spec using flask-restx
This commit is contained in:
@@ -1,20 +1,24 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from flask import Blueprint, request, Response, current_app
|
||||
import json
|
||||
import datetime
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from pymongo import MongoClient
|
||||
from bson.objectid import ObjectId
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import Blueprint, current_app, make_response, request, Response
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from pymongo import MongoClient
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.error import bad_request
|
||||
from application.extensions import api
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.error import bad_request
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,7 +29,10 @@ sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
api_key_collection = db["api_keys"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
|
||||
answer = Blueprint("answer", __name__)
|
||||
answer_ns = Namespace("answer", description="Answer related operations", path="/")
|
||||
api.add_namespace(answer_ns)
|
||||
|
||||
gpt_model = ""
|
||||
# to have some kind of default behaviour
|
||||
@@ -186,10 +193,10 @@ def complete_stream(
|
||||
answer = retriever.gen()
|
||||
sources = retriever.search()
|
||||
for source in sources:
|
||||
if("text" in source):
|
||||
source["text"] = source["text"][:100].strip()+"..."
|
||||
if(len(sources) > 0):
|
||||
data = json.dumps({"type":"source","source":sources})
|
||||
if "text" in source:
|
||||
source["text"] = source["text"][:100].strip() + "..."
|
||||
if len(sources) > 0:
|
||||
data = json.dumps({"type": "source", "source": sources})
|
||||
yield f"data: {data}\n\n"
|
||||
for line in answer:
|
||||
if "answer" in line:
|
||||
@@ -243,109 +250,133 @@ def complete_stream(
|
||||
return
|
||||
|
||||
|
||||
@answer.route("/stream", methods=["POST"])
|
||||
def stream():
|
||||
try:
|
||||
data = request.get_json()
|
||||
question = data["question"]
|
||||
if "history" not in data:
|
||||
history = []
|
||||
else:
|
||||
history = data["history"]
|
||||
history = json.loads(history)
|
||||
if "conversation_id" not in data:
|
||||
conversation_id = None
|
||||
else:
|
||||
conversation_id = data["conversation_id"]
|
||||
if "prompt_id" in data:
|
||||
prompt_id = data["prompt_id"]
|
||||
else:
|
||||
prompt_id = "default"
|
||||
if "selectedDocs" in data and data["selectedDocs"] is None:
|
||||
chunks = 0
|
||||
elif "chunks" in data:
|
||||
chunks = int(data["chunks"])
|
||||
else:
|
||||
chunks = 2
|
||||
if "token_limit" in data:
|
||||
token_limit = data["token_limit"]
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
|
||||
## retriever can be "brave_search, duckduck_search or classic"
|
||||
retriever_name = data["retriever"] if "retriever" in data else "classic"
|
||||
|
||||
# check if active_docs or api_key is set
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key["chunks"])
|
||||
prompt_id = data_key["prompt_id"]
|
||||
source = {"active_docs": data_key["source"]}
|
||||
retriever_name = data_key["retriever"] or retriever_name
|
||||
user_api_key = data["api_key"]
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
current_app.logger.info(
|
||||
f"/stream - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
return Response(
|
||||
complete_stream(
|
||||
question=question,
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
@answer_ns.route("/stream")
|
||||
class Stream(Resource):
|
||||
stream_model = api.model(
|
||||
"StreamModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Question to be asked"
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
"history": fields.List(
|
||||
fields.String, required=False, description="Chat history"
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False, description="Conversation ID"
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"selectedDocs": fields.String(
|
||||
required=False, description="Selected documents"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
message = "Malformed request body"
|
||||
print("\033[91merr", str(message), file=sys.stderr)
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
message = e.args[0]
|
||||
status_code = 400
|
||||
# # Custom exceptions with two arguments, index 1 as status code
|
||||
if len(e.args) >= 2:
|
||||
status_code = e.args[1]
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=status_code,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
@api.expect(stream_model)
|
||||
@api.doc(description="Stream a response based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = data.get("history", [])
|
||||
history = json.loads(history)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
if "selectedDocs" in data and data["selectedDocs"] is None:
|
||||
chunks = 0
|
||||
else:
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
prompt_id = data_key.get("prompt_id", "default")
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
current_app.logger.info(
|
||||
f"/stream - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
return Response(
|
||||
complete_stream(
|
||||
question=question,
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
message = "Malformed request body"
|
||||
print("\033[91merr", str(message), file=sys.stderr)
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
message = e.args[0]
|
||||
status_code = 400
|
||||
# Custom exceptions with two arguments, index 1 as status code
|
||||
if len(e.args) >= 2:
|
||||
status_code = e.args[1]
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=status_code,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
def error_stream_generate(err_response):
|
||||
@@ -353,180 +384,235 @@ def error_stream_generate(err_response):
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
|
||||
@answer.route("/api/answer", methods=["POST"])
|
||||
def api_answer():
|
||||
data = request.get_json()
|
||||
question = data["question"]
|
||||
if "history" not in data:
|
||||
history = []
|
||||
else:
|
||||
history = data["history"]
|
||||
if "conversation_id" not in data:
|
||||
conversation_id = None
|
||||
else:
|
||||
conversation_id = data["conversation_id"]
|
||||
print("-" * 5)
|
||||
if "prompt_id" in data:
|
||||
prompt_id = data["prompt_id"]
|
||||
else:
|
||||
prompt_id = "default"
|
||||
if "chunks" in data:
|
||||
chunks = int(data["chunks"])
|
||||
else:
|
||||
chunks = 2
|
||||
if "token_limit" in data:
|
||||
token_limit = data["token_limit"]
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
|
||||
## retriever can be brave_search, duckduck_search or classic
|
||||
retriever_name = data["retriever"] if "retriever" in data else "classic"
|
||||
|
||||
# use try and except to check for exception
|
||||
try:
|
||||
# check if the vectorstore is set
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key["chunks"])
|
||||
prompt_id = data_key["prompt_id"]
|
||||
source = {"active_docs": data_key["source"]}
|
||||
retriever_name = data_key["retriever"] or retriever_name
|
||||
user_api_key = data["api_key"]
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
current_app.logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
source_log_docs = []
|
||||
response_full = ""
|
||||
for line in retriever.gen():
|
||||
if "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
elif "answer" in line:
|
||||
response_full += line["answer"]
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
)
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = str(
|
||||
save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
)
|
||||
)
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "api_answer",
|
||||
"level": "info",
|
||||
"user": "local",
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return bad_request(500, str(e))
|
||||
|
||||
|
||||
@answer.route("/api/search", methods=["POST"])
|
||||
def api_search():
|
||||
data = request.get_json()
|
||||
question = data["question"]
|
||||
if "chunks" in data:
|
||||
chunks = int(data["chunks"])
|
||||
else:
|
||||
chunks = 2
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key["chunks"])
|
||||
source = {"active_docs":data_key["source"]}
|
||||
user_api_key = data["api_key"]
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
user_api_key = None
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
if "retriever" in data:
|
||||
retriever_name = data["retriever"]
|
||||
else:
|
||||
retriever_name = "classic"
|
||||
if "token_limit" in data:
|
||||
token_limit = data["token_limit"]
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
|
||||
current_app.logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=[],
|
||||
prompt="default",
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
docs = retriever.search()
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
@answer_ns.route("/api/answer")
|
||||
class Answer(Resource):
|
||||
answer_model = api.model(
|
||||
"AnswerModel",
|
||||
{
|
||||
"action": "api_search",
|
||||
"level": "info",
|
||||
"user": "local",
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"sources": docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
"question": fields.String(
|
||||
required=True, description="The question to answer"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String, required=False, description="Conversation history"
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False, description="Conversation ID"
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in docs:
|
||||
doc["source"] = "None"
|
||||
@api.expect(answer_model)
|
||||
@api.doc(description="Provide an answer based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
return docs
|
||||
try:
|
||||
question = data["question"]
|
||||
history = data.get("history", [])
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
prompt_id = data_key.get("prompt_id", "default")
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
current_app.logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
source_log_docs = []
|
||||
response_full = ""
|
||||
for line in retriever.gen():
|
||||
if "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
elif "answer" in line:
|
||||
response_full += line["answer"]
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
)
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = str(
|
||||
save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
)
|
||||
)
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "api_answer",
|
||||
"level": "info",
|
||||
"user": "local",
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return bad_request(500, str(e))
|
||||
|
||||
return make_response(result, 200)
|
||||
|
||||
|
||||
@answer_ns.route("/api/search")
|
||||
class Search(Resource):
|
||||
search_model = api.model(
|
||||
"SearchModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="The question to search"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"api_key": fields.String(
|
||||
required=False, description="API key for authentication"
|
||||
),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents for retrieval"
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"token_limit": fields.Integer(
|
||||
required=False, description="Limit for tokens"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(search_model)
|
||||
@api.doc(
|
||||
description="Search for relevant documents based on the question and retriever"
|
||||
)
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
user_api_key = data["api_key"]
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
user_api_key = None
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
current_app.logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=[],
|
||||
prompt="default",
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
docs = retriever.search()
|
||||
retriever_params = retriever.get_params()
|
||||
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "api_search",
|
||||
"level": "info",
|
||||
"user": "local",
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"sources": docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return bad_request(500, str(e))
|
||||
|
||||
return make_response(docs, 200)
|
||||
|
||||
@@ -7,13 +7,15 @@ from bson.binary import Binary, UuidRepresentation
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import Blueprint, jsonify, make_response, request
|
||||
from flask_restx import Api, fields, Namespace, Resource
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from pymongo import MongoClient
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.api.user.tasks import ingest, ingest_remote
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.extensions import api
|
||||
from application.utils import check_required_fields
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
mongo = MongoClient(settings.MONGO_URI)
|
||||
@@ -28,14 +30,8 @@ shared_conversations_collections = db["shared_conversations"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
|
||||
user = Blueprint("user", __name__)
|
||||
api = Api(
|
||||
user,
|
||||
version="1.0",
|
||||
title="DocsGPT API",
|
||||
description="API for DocsGPT",
|
||||
default="user",
|
||||
default_label="User operations",
|
||||
)
|
||||
user_ns = Namespace("user", description="User related operations", path="/")
|
||||
api.add_namespace(user_ns)
|
||||
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -63,22 +59,7 @@ def generate_date_range(start_date, end_date):
|
||||
}
|
||||
|
||||
|
||||
def check_required_fields(data, required_fields):
|
||||
missing_fields = [field for field in required_fields if field not in data]
|
||||
if missing_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Missing fields: {', '.join(missing_fields)}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@api.route("/api/delete_conversation")
|
||||
@user_ns.route("/api/delete_conversation")
|
||||
class DeleteConversation(Resource):
|
||||
@api.doc(
|
||||
description="Deletes a conversation by ID",
|
||||
@@ -98,7 +79,7 @@ class DeleteConversation(Resource):
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@api.route("/api/delete_all_conversations")
|
||||
@user_ns.route("/api/delete_all_conversations")
|
||||
class DeleteAllConversations(Resource):
|
||||
@api.doc(
|
||||
description="Deletes all conversations for a specific user",
|
||||
@@ -112,7 +93,7 @@ class DeleteAllConversations(Resource):
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@api.route("/api/get_conversations")
|
||||
@user_ns.route("/api/get_conversations")
|
||||
class GetConversations(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a list of the latest 30 conversations",
|
||||
@@ -129,7 +110,7 @@ class GetConversations(Resource):
|
||||
return make_response(jsonify(list_conversations), 200)
|
||||
|
||||
|
||||
@api.route("/api/get_single_conversation")
|
||||
@user_ns.route("/api/get_single_conversation")
|
||||
class GetSingleConversation(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a single conversation by ID",
|
||||
@@ -153,7 +134,7 @@ class GetSingleConversation(Resource):
|
||||
return make_response(jsonify(conversation["queries"]), 200)
|
||||
|
||||
|
||||
@api.route("/api/update_conversation_name")
|
||||
@user_ns.route("/api/update_conversation_name")
|
||||
class UpdateConversationName(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
@@ -186,7 +167,7 @@ class UpdateConversationName(Resource):
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@api.route("/api/feedback")
|
||||
@user_ns.route("/api/feedback")
|
||||
class SubmitFeedback(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
@@ -229,7 +210,7 @@ class SubmitFeedback(Resource):
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@api.route("/api/delete_by_ids")
|
||||
@user_ns.route("/api/delete_by_ids")
|
||||
class DeleteByIds(Resource):
|
||||
@api.doc(
|
||||
description="Deletes documents from the vector store by IDs",
|
||||
@@ -252,7 +233,7 @@ class DeleteByIds(Resource):
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
|
||||
@api.route("/api/delete_old")
|
||||
@user_ns.route("/api/delete_old")
|
||||
class DeleteOldIndexes(Resource):
|
||||
@api.doc(
|
||||
description="Deletes old indexes",
|
||||
@@ -289,7 +270,7 @@ class DeleteOldIndexes(Resource):
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@api.route("/api/upload")
|
||||
@user_ns.route("/api/upload")
|
||||
class UploadFile(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
@@ -370,7 +351,7 @@ class UploadFile(Resource):
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@api.route("/api/remote")
|
||||
@user_ns.route("/api/remote")
|
||||
class UploadRemote(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
@@ -408,7 +389,7 @@ class UploadRemote(Resource):
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@api.route("/api/task_status")
|
||||
@user_ns.route("/api/task_status")
|
||||
class TaskStatus(Resource):
|
||||
task_status_model = api.model(
|
||||
"TaskStatusModel",
|
||||
@@ -435,7 +416,7 @@ class TaskStatus(Resource):
|
||||
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
|
||||
|
||||
|
||||
@api.route("/api/combine")
|
||||
@user_ns.route("/api/combine")
|
||||
class CombinedJson(Resource):
|
||||
@api.doc(description="Provide JSON file with combined available indexes")
|
||||
def get(self):
|
||||
@@ -496,7 +477,7 @@ class CombinedJson(Resource):
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
|
||||
@api.route("/api/docs_check")
|
||||
@user_ns.route("/api/docs_check")
|
||||
class CheckDocs(Resource):
|
||||
check_docs_model = api.model(
|
||||
"CheckDocsModel",
|
||||
@@ -522,7 +503,7 @@ class CheckDocs(Resource):
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
|
||||
|
||||
@api.route("/api/create_prompt")
|
||||
@user_ns.route("/api/create_prompt")
|
||||
class CreatePrompt(Resource):
|
||||
create_prompt_model = api.model(
|
||||
"CreatePromptModel",
|
||||
@@ -560,7 +541,7 @@ class CreatePrompt(Resource):
|
||||
return make_response(jsonify({"id": new_id}), 200)
|
||||
|
||||
|
||||
@api.route("/api/get_prompts")
|
||||
@user_ns.route("/api/get_prompts")
|
||||
class GetPrompts(Resource):
|
||||
@api.doc(description="Get all prompts for the user")
|
||||
def get(self):
|
||||
@@ -587,7 +568,7 @@ class GetPrompts(Resource):
|
||||
return make_response(jsonify(list_prompts), 200)
|
||||
|
||||
|
||||
@api.route("/api/get_single_prompt")
|
||||
@user_ns.route("/api/get_single_prompt")
|
||||
class GetSinglePrompt(Resource):
|
||||
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
|
||||
def get(self):
|
||||
@@ -628,7 +609,7 @@ class GetSinglePrompt(Resource):
|
||||
return make_response(jsonify({"content": prompt["content"]}), 200)
|
||||
|
||||
|
||||
@api.route("/api/delete_prompt")
|
||||
@user_ns.route("/api/delete_prompt")
|
||||
class DeletePrompt(Resource):
|
||||
delete_prompt_model = api.model(
|
||||
"DeletePromptModel",
|
||||
@@ -652,7 +633,7 @@ class DeletePrompt(Resource):
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@api.route("/api/update_prompt")
|
||||
@user_ns.route("/api/update_prompt")
|
||||
class UpdatePrompt(Resource):
|
||||
update_prompt_model = api.model(
|
||||
"UpdatePromptModel",
|
||||
@@ -685,7 +666,7 @@ class UpdatePrompt(Resource):
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@api.route("/api/get_api_keys")
|
||||
@user_ns.route("/api/get_api_keys")
|
||||
class GetApiKeys(Resource):
|
||||
@api.doc(description="Retrieve API keys for the user")
|
||||
def get(self):
|
||||
@@ -719,7 +700,7 @@ class GetApiKeys(Resource):
|
||||
return make_response(jsonify(list_keys), 200)
|
||||
|
||||
|
||||
@api.route("/api/create_api_key")
|
||||
@user_ns.route("/api/create_api_key")
|
||||
class CreateApiKey(Resource):
|
||||
create_api_key_model = api.model(
|
||||
"CreateApiKeyModel",
|
||||
@@ -764,7 +745,7 @@ class CreateApiKey(Resource):
|
||||
return make_response(jsonify({"id": new_id, "key": key}), 201)
|
||||
|
||||
|
||||
@api.route("/api/delete_api_key")
|
||||
@user_ns.route("/api/delete_api_key")
|
||||
class DeleteApiKey(Resource):
|
||||
delete_api_key_model = api.model(
|
||||
"DeleteApiKeyModel",
|
||||
@@ -790,7 +771,7 @@ class DeleteApiKey(Resource):
|
||||
return {"success": True}, 200
|
||||
|
||||
|
||||
@api.route("/api/share")
|
||||
@user_ns.route("/api/share")
|
||||
class ShareConversation(Resource):
|
||||
share_conversation_model = api.model(
|
||||
"ShareConversationModel",
|
||||
@@ -988,7 +969,7 @@ class ShareConversation(Resource):
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
|
||||
@api.route("/api/shared_conversation/<string:identifier>")
|
||||
@user_ns.route("/api/shared_conversation/<string:identifier>")
|
||||
class GetPubliclySharedConversations(Resource):
|
||||
@api.doc(description="Get publicly shared conversations by identifier")
|
||||
def get(self, identifier: str):
|
||||
@@ -1043,7 +1024,7 @@ class GetPubliclySharedConversations(Resource):
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
|
||||
@api.route("/api/get_message_analytics")
|
||||
@user_ns.route("/api/get_message_analytics")
|
||||
class GetMessageAnalytics(Resource):
|
||||
get_message_analytics_model = api.model(
|
||||
"GetMessageAnalyticsModel",
|
||||
@@ -1181,7 +1162,7 @@ class GetMessageAnalytics(Resource):
|
||||
)
|
||||
|
||||
|
||||
@api.route("/api/get_token_analytics")
|
||||
@user_ns.route("/api/get_token_analytics")
|
||||
class GetTokenAnalytics(Resource):
|
||||
get_token_analytics_model = api.model(
|
||||
"GetTokenAnalyticsModel",
|
||||
@@ -1332,7 +1313,7 @@ class GetTokenAnalytics(Resource):
|
||||
)
|
||||
|
||||
|
||||
@api.route("/api/get_feedback_analytics")
|
||||
@user_ns.route("/api/get_feedback_analytics")
|
||||
class GetFeedbackAnalytics(Resource):
|
||||
get_feedback_analytics_model = api.model(
|
||||
"GetFeedbackAnalyticsModel",
|
||||
@@ -1550,7 +1531,7 @@ class GetFeedbackAnalytics(Resource):
|
||||
)
|
||||
|
||||
|
||||
@api.route("/api/get_user_logs")
|
||||
@user_ns.route("/api/get_user_logs")
|
||||
class GetUserLogs(Resource):
|
||||
get_user_logs_model = api.model(
|
||||
"GetUserLogsModel",
|
||||
@@ -1629,7 +1610,7 @@ class GetUserLogs(Resource):
|
||||
)
|
||||
|
||||
|
||||
@api.route("/api/manage_sync")
|
||||
@user_ns.route("/api/manage_sync")
|
||||
class ManageSync(Resource):
|
||||
manage_sync_model = api.model(
|
||||
"ManageSyncModel",
|
||||
|
||||
Reference in New Issue
Block a user