migrate: link source to vector collection

This commit is contained in:
ManishMadan2882
2024-08-07 03:41:31 +05:30
parent 57b9b369b7
commit f9dbaa9407
2 changed files with 134 additions and 131 deletions

View File

@@ -9,6 +9,7 @@ import traceback
from pymongo import MongoClient
from bson.objectid import ObjectId
from bson.dbref import DBRef
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
@@ -36,9 +37,7 @@ if settings.MODEL_NAME: # in case there is particular model name configured
gpt_model = settings.MODEL_NAME
# load the prompts
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
chat_combine_template = f.read()
@@ -74,13 +73,29 @@ def run_async_chain(chain, question, chat_history):
def get_data_from_api_key(api_key):
data = api_key_collection.find_one({"key": api_key})
# # Raise custom exception if the API key is not found
if data is None:
raise Exception("Invalid API Key, please generate new key", 401)
if isinstance(data["source"], DBRef):
source_id = db.dereference(data["source"])["_id"]
data["source"] = get_source(source_id)
return data
def get_source(active_doc):
if ObjectId.is_valid(active_doc):
doc = vectors_collection.find_one({"_id": ObjectId(active_doc)})
if doc is None:
raise Exception("Source document does not exist", 404)
print("res", doc)
source = {"active_docs": "/".join(doc["location"].split("/")[-2:])}
else:
source = {"active_docs": active_doc}
return source
def get_vectorstore(data):
if "active_docs" in data:
if data["active_docs"].split("/")[0] == "default":
@@ -98,11 +113,7 @@ def get_vectorstore(data):
def is_azure_configured():
return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
def save_conversation(conversation_id, question, response, source_log_docs, llm):
@@ -128,11 +139,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: "
+question
+"\n\n"
+"AI: "
+response,
"language as the system \n\nUser: " + question + "\n\n" + "AI: " + response,
},
{
"role": "user",
@@ -173,7 +180,6 @@ def get_prompt(prompt_id):
def complete_stream(question, retriever, conversation_id, user_api_key):
try:
response_full = ""
source_log_docs = []
@@ -186,126 +192,128 @@ def complete_stream(question, retriever, conversation_id, user_api_key):
elif "source" in line:
source_log_docs.append(line["source"])
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
if(user_api_key is None):
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key)
if user_api_key is None:
conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except Exception as e:
print("\033[91merr", str(e), file=sys.stderr)
data = json.dumps({"type": "error","error":"Please try again later. We apologize for any inconvenience.",
"error_exception": str(e)})
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
"error_exception": str(e),
}
)
yield f"data: {data}\n\n"
return
return
@answer.route("/stream", methods=["POST"])
def stream():
try:
data = request.get_json()
# get parameter from url question
question = data["question"]
if "history" not in data:
history = []
else:
history = data["history"]
history = json.loads(history)
if "conversation_id" not in data:
conversation_id = None
else:
conversation_id = data["conversation_id"]
if "prompt_id" in data:
prompt_id = data["prompt_id"]
else:
prompt_id = "default"
if "selectedDocs" in data and data["selectedDocs"] is None:
chunks = 0
elif "chunks" in data:
chunks = int(data["chunks"])
else:
chunks = 2
if "token_limit" in data:
token_limit = data["token_limit"]
else:
token_limit = settings.DEFAULT_MAX_HISTORY
try:
data = request.get_json()
# get parameter from url question
question = data["question"]
if "history" not in data:
history = []
else:
history = data["history"]
history = json.loads(history)
if "conversation_id" not in data:
conversation_id = None
else:
conversation_id = data["conversation_id"]
if "prompt_id" in data:
prompt_id = data["prompt_id"]
else:
prompt_id = "default"
if "selectedDocs" in data and data["selectedDocs"] is None:
chunks = 0
elif "chunks" in data:
chunks = int(data["chunks"])
else:
chunks = 2
if "token_limit" in data:
token_limit = data["token_limit"]
else:
token_limit = settings.DEFAULT_MAX_HISTORY
# check if active_docs or api_key is set
# check if active_docs or api_key is set
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key["chunks"])
prompt_id = data_key["prompt_id"]
source = {"active_docs": data_key["source"]}
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
user_api_key = None
else:
source = {}
user_api_key = None
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key["chunks"])
prompt_id = data_key["prompt_id"]
source = data_key["source"]
user_api_key = data["api_key"]
elif "active_docs" in data:
source = get_source(data["active_docs"])
user_api_key = None
else:
source = {}
user_api_key = None
if (
source["active_docs"].split("/")[0] == "default"
or source["active_docs"].split("/")[0] == "local"
):
retriever_name = "classic"
else:
retriever_name = source["active_docs"]
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
else:
retriever_name = source["active_docs"]
prompt = get_prompt(prompt_id)
prompt = get_prompt(prompt_id)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
return Response(
complete_stream(
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
retriever=retriever,
conversation_id=conversation_id,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
),
mimetype="text/event-stream",
)
except ValueError:
message = "Malformed request body"
print("\033[91merr", str(message), file=sys.stderr)
return Response(
error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
)
return Response(
complete_stream(
question=question,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
),
mimetype="text/event-stream",
)
except ValueError as err:
message = "Malformed request body"
print("\033[91merr", str(err), file=sys.stderr)
return Response(
error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
print("\033[91merr", str(e), file=sys.stderr)
message = e.args[0]
status_code = 400
# # Custom exceptions with two arguments, index 1 as status code
if(len(e.args) >= 2):
if len(e.args) >= 2:
status_code = e.args[1]
return Response(
error_stream_generate(message),
status=status_code,
mimetype="text/event-stream",
)
error_stream_generate(message),
status=status_code,
mimetype="text/event-stream",
)
def error_stream_generate(err_response):
data = json.dumps({"type": "error", "error":err_response})
yield f"data: {data}\n\n"
data = json.dumps({"type": "error", "error": err_response})
yield f"data: {data}\n\n"
@answer.route("/api/answer", methods=["POST"])
def api_answer():
@@ -340,16 +348,13 @@ def api_answer():
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key["chunks"])
prompt_id = data_key["prompt_id"]
source = {"active_docs": data_key["source"]}
source = data_key["source"]
user_api_key = data["api_key"]
else:
source = data
source = get_source(data["active_docs"])
user_api_key = None
if (
source["active_docs"].split("/")[0] == "default"
or source["active_docs"].split("/")[0] == "local"
):
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
else:
retriever_name = source["active_docs"]
@@ -375,13 +380,11 @@ def api_answer():
elif "answer" in line:
response_full += line["answer"]
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key)
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
result["conversation_id"] = str(
save_conversation(conversation_id, question, response_full, source_log_docs, llm)
)
return result
@@ -404,19 +407,16 @@ def api_search():
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key["chunks"])
source = {"active_docs": data_key["source"]}
user_api_key = data["api_key"]
source = data_key["source"]
user_api_key = data_key["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
source = get_source(data["active_docs"])
user_api_key = None
else:
source = {}
user_api_key = None
if (
source["active_docs"].split("/")[0] == "default"
or source["active_docs"].split("/")[0] == "local"
):
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
else:
retriever_name = source["active_docs"]

View File

@@ -264,6 +264,7 @@ def combined_json():
for index in vectors_collection.find({"user": user}).sort("date", -1):
data.append(
{
"id":str(index["_id"]),
"name": index["name"],
"language": index["language"],
"version": "",
@@ -453,7 +454,7 @@ def get_api_keys():
"id": str(key["_id"]),
"name": key["name"],
"key": key["key"][:4] + "..." + key["key"][-4:],
"source": key["source"],
"source": str(key["source"]),
"prompt_id": key["prompt_id"],
"chunks": key["chunks"],
}
@@ -470,6 +471,8 @@ def create_api_key():
chunks = data["chunks"]
key = str(uuid.uuid4())
user = "local"
if(ObjectId.is_valid(data["source"])):
source = DBRef("vectors",ObjectId(data["source"]))
resp = api_key_collection.insert_one(
{
"name": name,
@@ -524,7 +527,7 @@ def share_conversation():
{
"prompt_id": prompt_id,
"chunks": chunks,
"source": source,
"source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source,
"user": user,
}
)
@@ -574,7 +577,7 @@ def share_conversation():
{
"name": name,
"key": api_uuid,
"source": source,
"source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source,
"user": user,
"prompt_id": prompt_id,
"chunks": chunks,