mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
migrate: link source to vector collection
This commit is contained in:
@@ -9,6 +9,7 @@ import traceback
|
||||
|
||||
from pymongo import MongoClient
|
||||
from bson.objectid import ObjectId
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
@@ -36,9 +37,7 @@ if settings.MODEL_NAME: # in case there is particular model name configured
|
||||
gpt_model = settings.MODEL_NAME
|
||||
|
||||
# load the prompts
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
|
||||
chat_combine_template = f.read()
|
||||
|
||||
@@ -74,13 +73,29 @@ def run_async_chain(chain, question, chat_history):
|
||||
|
||||
def get_data_from_api_key(api_key):
|
||||
data = api_key_collection.find_one({"key": api_key})
|
||||
|
||||
# # Raise custom exception if the API key is not found
|
||||
if data is None:
|
||||
raise Exception("Invalid API Key, please generate new key", 401)
|
||||
|
||||
if isinstance(data["source"], DBRef):
|
||||
source_id = db.dereference(data["source"])["_id"]
|
||||
data["source"] = get_source(source_id)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_source(active_doc):
|
||||
if ObjectId.is_valid(active_doc):
|
||||
doc = vectors_collection.find_one({"_id": ObjectId(active_doc)})
|
||||
if doc is None:
|
||||
raise Exception("Source document does not exist", 404)
|
||||
print("res", doc)
|
||||
source = {"active_docs": "/".join(doc["location"].split("/")[-2:])}
|
||||
else:
|
||||
source = {"active_docs": active_doc}
|
||||
return source
|
||||
|
||||
|
||||
def get_vectorstore(data):
|
||||
if "active_docs" in data:
|
||||
if data["active_docs"].split("/")[0] == "default":
|
||||
@@ -98,11 +113,7 @@ def get_vectorstore(data):
|
||||
|
||||
|
||||
def is_azure_configured():
|
||||
return (
|
||||
settings.OPENAI_API_BASE
|
||||
and settings.OPENAI_API_VERSION
|
||||
and settings.AZURE_DEPLOYMENT_NAME
|
||||
)
|
||||
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
||||
|
||||
|
||||
def save_conversation(conversation_id, question, response, source_log_docs, llm):
|
||||
@@ -128,11 +139,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
|
||||
"role": "assistant",
|
||||
"content": "Summarise following conversation in no more than 3 "
|
||||
"words, respond ONLY with the summary, use the same "
|
||||
"language as the system \n\nUser: "
|
||||
+question
|
||||
+"\n\n"
|
||||
+"AI: "
|
||||
+response,
|
||||
"language as the system \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
@@ -173,7 +180,6 @@ def get_prompt(prompt_id):
|
||||
|
||||
|
||||
def complete_stream(question, retriever, conversation_id, user_api_key):
|
||||
|
||||
try:
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
@@ -186,126 +192,128 @@ def complete_stream(question, retriever, conversation_id, user_api_key):
|
||||
elif "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
)
|
||||
if(user_api_key is None):
|
||||
conversation_id = save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
)
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key)
|
||||
if user_api_key is None:
|
||||
conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except Exception as e:
|
||||
print("\033[91merr", str(e), file=sys.stderr)
|
||||
data = json.dumps({"type": "error","error":"Please try again later. We apologize for any inconvenience.",
|
||||
"error_exception": str(e)})
|
||||
data = json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Please try again later. We apologize for any inconvenience.",
|
||||
"error_exception": str(e),
|
||||
}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
return
|
||||
|
||||
|
||||
@answer.route("/stream", methods=["POST"])
|
||||
def stream():
|
||||
try:
|
||||
data = request.get_json()
|
||||
# get parameter from url question
|
||||
question = data["question"]
|
||||
if "history" not in data:
|
||||
history = []
|
||||
else:
|
||||
history = data["history"]
|
||||
history = json.loads(history)
|
||||
if "conversation_id" not in data:
|
||||
conversation_id = None
|
||||
else:
|
||||
conversation_id = data["conversation_id"]
|
||||
if "prompt_id" in data:
|
||||
prompt_id = data["prompt_id"]
|
||||
else:
|
||||
prompt_id = "default"
|
||||
if "selectedDocs" in data and data["selectedDocs"] is None:
|
||||
chunks = 0
|
||||
elif "chunks" in data:
|
||||
chunks = int(data["chunks"])
|
||||
else:
|
||||
chunks = 2
|
||||
if "token_limit" in data:
|
||||
token_limit = data["token_limit"]
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
try:
|
||||
data = request.get_json()
|
||||
# get parameter from url question
|
||||
question = data["question"]
|
||||
if "history" not in data:
|
||||
history = []
|
||||
else:
|
||||
history = data["history"]
|
||||
history = json.loads(history)
|
||||
if "conversation_id" not in data:
|
||||
conversation_id = None
|
||||
else:
|
||||
conversation_id = data["conversation_id"]
|
||||
if "prompt_id" in data:
|
||||
prompt_id = data["prompt_id"]
|
||||
else:
|
||||
prompt_id = "default"
|
||||
if "selectedDocs" in data and data["selectedDocs"] is None:
|
||||
chunks = 0
|
||||
elif "chunks" in data:
|
||||
chunks = int(data["chunks"])
|
||||
else:
|
||||
chunks = 2
|
||||
if "token_limit" in data:
|
||||
token_limit = data["token_limit"]
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
|
||||
# check if active_docs or api_key is set
|
||||
# check if active_docs or api_key is set
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key["chunks"])
|
||||
prompt_id = data_key["prompt_id"]
|
||||
source = {"active_docs": data_key["source"]}
|
||||
user_api_key = data["api_key"]
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
user_api_key = None
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key["chunks"])
|
||||
prompt_id = data_key["prompt_id"]
|
||||
source = data_key["source"]
|
||||
user_api_key = data["api_key"]
|
||||
elif "active_docs" in data:
|
||||
source = get_source(data["active_docs"])
|
||||
user_api_key = None
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
if (
|
||||
source["active_docs"].split("/")[0] == "default"
|
||||
or source["active_docs"].split("/")[0] == "local"
|
||||
):
|
||||
retriever_name = "classic"
|
||||
else:
|
||||
retriever_name = source["active_docs"]
|
||||
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
|
||||
retriever_name = "classic"
|
||||
else:
|
||||
retriever_name = source["active_docs"]
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
return Response(
|
||||
complete_stream(
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
message = "Malformed request body"
|
||||
print("\033[91merr", str(message), file=sys.stderr)
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
)
|
||||
|
||||
return Response(
|
||||
complete_stream(
|
||||
question=question,
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
except ValueError as err:
|
||||
message = "Malformed request body"
|
||||
print("\033[91merr", str(err), file=sys.stderr)
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
print("\033[91merr", str(e), file=sys.stderr)
|
||||
message = e.args[0]
|
||||
status_code = 400
|
||||
# # Custom exceptions with two arguments, index 1 as status code
|
||||
if(len(e.args) >= 2):
|
||||
if len(e.args) >= 2:
|
||||
status_code = e.args[1]
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=status_code,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
error_stream_generate(message),
|
||||
status=status_code,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
def error_stream_generate(err_response):
|
||||
data = json.dumps({"type": "error", "error":err_response})
|
||||
yield f"data: {data}\n\n"
|
||||
data = json.dumps({"type": "error", "error": err_response})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
|
||||
@answer.route("/api/answer", methods=["POST"])
|
||||
def api_answer():
|
||||
@@ -340,16 +348,13 @@ def api_answer():
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key["chunks"])
|
||||
prompt_id = data_key["prompt_id"]
|
||||
source = {"active_docs": data_key["source"]}
|
||||
source = data_key["source"]
|
||||
user_api_key = data["api_key"]
|
||||
else:
|
||||
source = data
|
||||
source = get_source(data["active_docs"])
|
||||
user_api_key = None
|
||||
|
||||
if (
|
||||
source["active_docs"].split("/")[0] == "default"
|
||||
or source["active_docs"].split("/")[0] == "local"
|
||||
):
|
||||
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
|
||||
retriever_name = "classic"
|
||||
else:
|
||||
retriever_name = source["active_docs"]
|
||||
@@ -375,13 +380,11 @@ def api_answer():
|
||||
elif "answer" in line:
|
||||
response_full += line["answer"]
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
)
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key)
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
result["conversation_id"] = str(
|
||||
save_conversation(conversation_id, question, response_full, source_log_docs, llm)
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -404,19 +407,16 @@ def api_search():
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key["chunks"])
|
||||
source = {"active_docs": data_key["source"]}
|
||||
user_api_key = data["api_key"]
|
||||
source = data_key["source"]
|
||||
user_api_key = data_key["api_key"]
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
source = get_source(data["active_docs"])
|
||||
user_api_key = None
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
if (
|
||||
source["active_docs"].split("/")[0] == "default"
|
||||
or source["active_docs"].split("/")[0] == "local"
|
||||
):
|
||||
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
|
||||
retriever_name = "classic"
|
||||
else:
|
||||
retriever_name = source["active_docs"]
|
||||
|
||||
@@ -264,6 +264,7 @@ def combined_json():
|
||||
for index in vectors_collection.find({"user": user}).sort("date", -1):
|
||||
data.append(
|
||||
{
|
||||
"id":str(index["_id"]),
|
||||
"name": index["name"],
|
||||
"language": index["language"],
|
||||
"version": "",
|
||||
@@ -453,7 +454,7 @@ def get_api_keys():
|
||||
"id": str(key["_id"]),
|
||||
"name": key["name"],
|
||||
"key": key["key"][:4] + "..." + key["key"][-4:],
|
||||
"source": key["source"],
|
||||
"source": str(key["source"]),
|
||||
"prompt_id": key["prompt_id"],
|
||||
"chunks": key["chunks"],
|
||||
}
|
||||
@@ -470,6 +471,8 @@ def create_api_key():
|
||||
chunks = data["chunks"]
|
||||
key = str(uuid.uuid4())
|
||||
user = "local"
|
||||
if(ObjectId.is_valid(data["source"])):
|
||||
source = DBRef("vectors",ObjectId(data["source"]))
|
||||
resp = api_key_collection.insert_one(
|
||||
{
|
||||
"name": name,
|
||||
@@ -524,7 +527,7 @@ def share_conversation():
|
||||
{
|
||||
"prompt_id": prompt_id,
|
||||
"chunks": chunks,
|
||||
"source": source,
|
||||
"source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
@@ -574,7 +577,7 @@ def share_conversation():
|
||||
{
|
||||
"name": name,
|
||||
"key": api_uuid,
|
||||
"source": source,
|
||||
"source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source,
|
||||
"user": user,
|
||||
"prompt_id": prompt_id,
|
||||
"chunks": chunks,
|
||||
|
||||
Reference in New Issue
Block a user