From f9dbaa9407cfd5800d0990f8543eef6adc321453 Mon Sep 17 00:00:00 2001 From: ManishMadan2882 Date: Wed, 7 Aug 2024 03:41:31 +0530 Subject: [PATCH] migrate: link source to vector collection --- application/api/answer/routes.py | 256 +++++++++++++++---------------- application/api/user/routes.py | 9 +- 2 files changed, 134 insertions(+), 131 deletions(-) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 7eed8434..f076285d 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -9,6 +9,7 @@ import traceback from pymongo import MongoClient from bson.objectid import ObjectId +from bson.dbref import DBRef from application.core.settings import settings from application.llm.llm_creator import LLMCreator @@ -36,9 +37,7 @@ if settings.MODEL_NAME: # in case there is particular model name configured gpt_model = settings.MODEL_NAME # load the prompts -current_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) +current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: chat_combine_template = f.read() @@ -74,13 +73,29 @@ def run_async_chain(chain, question, chat_history): def get_data_from_api_key(api_key): data = api_key_collection.find_one({"key": api_key}) - # # Raise custom exception if the API key is not found if data is None: raise Exception("Invalid API Key, please generate new key", 401) + + if isinstance(data["source"], DBRef): + source_id = db.dereference(data["source"])["_id"] + data["source"] = get_source(source_id) + return data +def get_source(active_doc): + if ObjectId.is_valid(active_doc): + doc = vectors_collection.find_one({"_id": ObjectId(active_doc)}) + if doc is None: + raise Exception("Source document does not exist", 404) + print("res", doc) + source = {"active_docs": "/".join(doc["location"].split("/")[-2:])} + else: + source = {"active_docs": active_doc} + return source + + def get_vectorstore(data): if "active_docs" in data: if data["active_docs"].split("/")[0] == "default": @@ -98,11 +113,7 @@ def get_vectorstore(data): def is_azure_configured(): - return ( - settings.OPENAI_API_BASE - and settings.OPENAI_API_VERSION - and settings.AZURE_DEPLOYMENT_NAME - ) + return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME def save_conversation(conversation_id, question, response, source_log_docs, llm): @@ -128,11 +139,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm) "role": "assistant", "content": "Summarise following conversation in no more than 3 " "words, respond ONLY with the summary, use the same " - "language as the system \n\nUser: " - +question - +"\n\n" - +"AI: " - +response, + "language as the system \n\nUser: " + question + "\n\n" + "AI: " + response, }, { "role": "user", @@ -173,7 +180,6 @@ def get_prompt(prompt_id): def complete_stream(question, retriever, conversation_id, user_api_key): - try: response_full = "" source_log_docs = [] @@ -186,126 +192,128 @@ def complete_stream(question, retriever, conversation_id, user_api_key): elif "source" in line: source_log_docs.append(line["source"]) - llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key - ) - if(user_api_key is None): - conversation_id = save_conversation( - conversation_id, question, response_full, source_log_docs, llm - ) + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key) + if user_api_key is None: + conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm) # send data.type = "end" to indicate that the stream has ended as json data = json.dumps({"type": "id", "id": str(conversation_id)}) yield f"data: {data}\n\n" - + data = json.dumps({"type": "end"}) yield f"data: {data}\n\n" except Exception as e: print("\033[91merr", str(e), file=sys.stderr) - data = json.dumps({"type": "error","error":"Please try again later. We apologize for any inconvenience.", - "error_exception": str(e)}) + data = json.dumps( + { + "type": "error", + "error": "Please try again later. We apologize for any inconvenience.", + "error_exception": str(e), + } + ) yield f"data: {data}\n\n" - return + return + @answer.route("/stream", methods=["POST"]) def stream(): - try: - data = request.get_json() - # get parameter from url question - question = data["question"] - if "history" not in data: - history = [] - else: - history = data["history"] - history = json.loads(history) - if "conversation_id" not in data: - conversation_id = None - else: - conversation_id = data["conversation_id"] - if "prompt_id" in data: - prompt_id = data["prompt_id"] - else: - prompt_id = "default" - if "selectedDocs" in data and data["selectedDocs"] is None: - chunks = 0 - elif "chunks" in data: - chunks = int(data["chunks"]) - else: - chunks = 2 - if "token_limit" in data: - token_limit = data["token_limit"] - else: - token_limit = settings.DEFAULT_MAX_HISTORY + try: + data = request.get_json() + # get parameter from url question + question = data["question"] + if "history" not in data: + history = [] + else: + history = data["history"] + history = json.loads(history) + if "conversation_id" not in data: + conversation_id = None + else: + conversation_id = data["conversation_id"] + if "prompt_id" in data: + prompt_id = data["prompt_id"] + else: + prompt_id = "default" + if "selectedDocs" in data and data["selectedDocs"] is None: + chunks = 0 + elif "chunks" in data: + chunks = int(data["chunks"]) + else: + chunks = 2 + if "token_limit" in data: + token_limit = data["token_limit"] + else: + token_limit = settings.DEFAULT_MAX_HISTORY - # check if active_docs or api_key is set + # check if active_docs or api_key is set - if "api_key" in data: - data_key = get_data_from_api_key(data["api_key"]) - chunks = int(data_key["chunks"]) - prompt_id = data_key["prompt_id"] - source = {"active_docs": data_key["source"]} - user_api_key = data["api_key"] - elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} - user_api_key = None - else: - source = {} - user_api_key = None + if "api_key" in data: + data_key = get_data_from_api_key(data["api_key"]) + chunks = int(data_key["chunks"]) + prompt_id = data_key["prompt_id"] + source = data_key["source"] + user_api_key = data["api_key"] + elif "active_docs" in data: + source = get_source(data["active_docs"]) + user_api_key = None + else: + source = {} + user_api_key = None - if ( - source["active_docs"].split("/")[0] == "default" - or source["active_docs"].split("/")[0] == "local" - ): - retriever_name = "classic" - else: - retriever_name = source["active_docs"] + if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + retriever_name = "classic" + else: + retriever_name = source["active_docs"] - prompt = get_prompt(prompt_id) + prompt = get_prompt(prompt_id) - retriever = RetrieverCreator.create_retriever( - retriever_name, - question=question, - source=source, - chat_history=history, - prompt=prompt, - chunks=chunks, - token_limit=token_limit, - gpt_model=gpt_model, - user_api_key=user_api_key, - ) - - return Response( - complete_stream( + retriever = RetrieverCreator.create_retriever( + retriever_name, question=question, - retriever=retriever, - conversation_id=conversation_id, + source=source, + chat_history=history, + prompt=prompt, + chunks=chunks, + token_limit=token_limit, + gpt_model=gpt_model, user_api_key=user_api_key, - ), - mimetype="text/event-stream", - ) - - except ValueError: - message = "Malformed request body" - print("\033[91merr", str(message), file=sys.stderr) - return Response( - error_stream_generate(message), - status=400, - mimetype="text/event-stream", - ) - except Exception as e: + ) + + return Response( + complete_stream( + question=question, + retriever=retriever, + conversation_id=conversation_id, + user_api_key=user_api_key, + ), + mimetype="text/event-stream", + ) + + except ValueError as err: + message = "Malformed request body" + print("\033[91merr", str(err), file=sys.stderr) + return Response( + error_stream_generate(message), + status=400, + mimetype="text/event-stream", + ) + except Exception as e: print("\033[91merr", str(e), file=sys.stderr) message = e.args[0] status_code = 400 # # Custom exceptions with two arguments, index 1 as status code - if(len(e.args) >= 2): + if len(e.args) >= 2: status_code = e.args[1] return Response( - error_stream_generate(message), - status=status_code, - mimetype="text/event-stream", - ) + error_stream_generate(message), + status=status_code, + mimetype="text/event-stream", + ) + + def error_stream_generate(err_response): - data = json.dumps({"type": "error", "error":err_response}) - yield f"data: {data}\n\n" + data = json.dumps({"type": "error", "error": err_response}) + yield f"data: {data}\n\n" + @answer.route("/api/answer", methods=["POST"]) def api_answer(): @@ -340,16 +348,13 @@ def api_answer(): data_key = get_data_from_api_key(data["api_key"]) chunks = int(data_key["chunks"]) prompt_id = data_key["prompt_id"] - source = {"active_docs": data_key["source"]} + source = data_key["source"] user_api_key = data["api_key"] else: - source = data + source = get_source(data["active_docs"]) user_api_key = None - if ( - source["active_docs"].split("/")[0] == "default" - or source["active_docs"].split("/")[0] == "local" - ): + if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": retriever_name = "classic" else: retriever_name = source["active_docs"] @@ -375,13 +380,11 @@ def api_answer(): elif "answer" in line: response_full += line["answer"] - llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key - ) + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key) result = {"answer": response_full, "sources": source_log_docs} - result["conversation_id"] = save_conversation( - conversation_id, question, response_full, source_log_docs, llm + result["conversation_id"] = str( + save_conversation(conversation_id, question, response_full, source_log_docs, llm) ) return result @@ -404,19 +407,16 @@ def api_search(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) chunks = int(data_key["chunks"]) - source = {"active_docs": data_key["source"]} - user_api_key = data["api_key"] + source = data_key["source"] + user_api_key = data_key["api_key"] elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} + source = get_source(data["active_docs"]) user_api_key = None else: source = {} user_api_key = None - if ( - source["active_docs"].split("/")[0] == "default" - or source["active_docs"].split("/")[0] == "local" - ): + if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": retriever_name = "classic" else: retriever_name = source["active_docs"] diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 91b90d6a..06bab591 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -264,6 +264,7 @@ def combined_json(): for index in vectors_collection.find({"user": user}).sort("date", -1): data.append( { + "id":str(index["_id"]), "name": index["name"], "language": index["language"], "version": "", @@ -453,7 +454,7 @@ def get_api_keys(): "id": str(key["_id"]), "name": key["name"], "key": key["key"][:4] + "..." + key["key"][-4:], - "source": key["source"], + "source": str(key["source"]), "prompt_id": key["prompt_id"], "chunks": key["chunks"], } @@ -470,6 +471,8 @@ def create_api_key(): chunks = data["chunks"] key = str(uuid.uuid4()) user = "local" + if(ObjectId.is_valid(data["source"])): + source = DBRef("vectors",ObjectId(data["source"])) resp = api_key_collection.insert_one( { "name": name, @@ -524,7 +527,7 @@ def share_conversation(): { "prompt_id": prompt_id, "chunks": chunks, - "source": source, + "source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source, "user": user, } ) @@ -574,7 +577,7 @@ def share_conversation(): { "name": name, "key": api_uuid, - "source": source, + "source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source, "user": user, "prompt_id": prompt_id, "chunks": chunks,