From 3c6fd365fbc66d94f31c77db784932ff4ec7eb57 Mon Sep 17 00:00:00 2001 From: ManishMadan2882 Date: Fri, 9 Aug 2024 18:27:54 +0530 Subject: [PATCH 01/18] store only local docs as location --- application/api/answer/routes.py | 62 ++++++++-------- application/api/user/routes.py | 118 ++++++++++--------------------- 2 files changed, 72 insertions(+), 108 deletions(-) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index f076285d..85cc3afd 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -77,23 +77,23 @@ def get_data_from_api_key(api_key): if data is None: raise Exception("Invalid API Key, please generate new key", 401) - if isinstance(data["source"], DBRef): - source_id = db.dereference(data["source"])["_id"] - data["source"] = get_source(source_id) + if "retriever" not in data: + data["retriever"] = "classic" + if "source" in data and isinstance(data["source"], DBRef): + source_doc = db.dereference(data["source"]) + data["source"] = str(source_doc._id) + if "retriever" in source_doc: + data["retriever"] = source_doc["retriever"] return data -def get_source(active_doc): - if ObjectId.is_valid(active_doc): - doc = vectors_collection.find_one({"_id": ObjectId(active_doc)}) - if doc is None: - raise Exception("Source document does not exist", 404) - print("res", doc) - source = {"active_docs": "/".join(doc["location"].split("/")[-2:])} - else: - source = {"active_docs": active_doc} - return source +def get_retriever(source_id: str): + doc = vectors_collection.find_one({"_id": ObjectId(source_id)}) + if doc is None: + raise Exception("Source document does not exist", 404) + retriever_name = "classic" if "retriever" not in doc else doc["retriever"] + return retriever_name def get_vectorstore(data): @@ -244,25 +244,31 @@ def stream(): else: token_limit = settings.DEFAULT_MAX_HISTORY - # check if active_docs or api_key is set + ## retriever can be "brave_search, duckduck_search or classic" + retriever_name = data["retriever"] if "retriever" in data else "classic" + # check if active_docs or api_key is set if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) chunks = int(data_key["chunks"]) prompt_id = data_key["prompt_id"] - source = data_key["source"] + source = {"active_docs": data_key["source"]} + retriever_name = data_key["retriever"] user_api_key = data["api_key"] + elif "active_docs" in data: - source = get_source(data["active_docs"]) + source = {"active_docs" : data["active_docs"]} + retriever_name = get_retriever(data["active_docs"]) user_api_key = None + else: source = {} user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + """ if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": retriever_name = "classic" else: - retriever_name = source["active_docs"] + retriever_name = source["active_docs"] """ prompt = get_prompt(prompt_id) @@ -341,6 +347,9 @@ def api_answer(): else: token_limit = settings.DEFAULT_MAX_HISTORY + ## retriever can be brave_search, duckduck_search or classic + retriever_name = data["retriever"] if "retriever" in data else "classic" + # use try and except to check for exception try: # check if the vectorstore is set @@ -350,15 +359,10 @@ def api_answer(): prompt_id = data_key["prompt_id"] source = data_key["source"] user_api_key = data["api_key"] - else: - source = get_source(data["active_docs"]) + elif "active_docs" in data: + source = data["active_docs"] user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": - retriever_name = "classic" - else: - retriever_name = source["active_docs"] - prompt = get_prompt(prompt_id) retriever = RetrieverCreator.create_retriever( @@ -410,16 +414,16 @@ def api_search(): source = data_key["source"] user_api_key = data_key["api_key"] elif "active_docs" in data: - source = get_source(data["active_docs"]) + source = data["active_docs"] user_api_key = None else: source = {} user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": - retriever_name = "classic" + if "retriever" in data: + retriever_name = data["retriever"] else: - retriever_name = source["active_docs"] + retriever_name = "classic" if "token_limit" in data: token_limit = data["token_limit"] else: diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 06bab591..aab30469 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -25,9 +25,7 @@ shared_conversations_collections = db["shared_conversations"] user = Blueprint("user", __name__) -current_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) +current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @user.route("/api/delete_conversation", methods=["POST"]) @@ -57,9 +55,7 @@ def get_conversations(): conversations = conversations_collection.find().sort("date", -1).limit(30) list_conversations = [] for conversation in conversations: - list_conversations.append( - {"id": str(conversation["_id"]), "name": conversation["name"]} - ) + list_conversations.append({"id": str(conversation["_id"]), "name": conversation["name"]}) # list_conversations = [{"id": "default", "name": "default"}, {"id": "jeff", "name": "jeff"}] @@ -138,9 +134,7 @@ def delete_old(): except FileNotFoundError: pass else: - vetorstore = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, path=os.path.join(current_dir, path_clean) - ) + vetorstore = VectorCreator.create_vectorstore(settings.VECTOR_STORE, path=os.path.join(current_dir, path_clean)) vetorstore.delete_index() return {"status": "ok"} @@ -175,9 +169,7 @@ def upload_file(): file.save(os.path.join(temp_dir, filename)) # Use shutil.make_archive to zip the temp directory - zip_path = shutil.make_archive( - base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir - ) + zip_path = shutil.make_archive(base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir) final_filename = os.path.basename(zip_path) # Clean up the temporary directory after zipping @@ -219,9 +211,7 @@ def upload_remote(): source_data = request.form["data"] if source_data: - task = ingest_remote.delay( - source_data=source_data, job_name=job_name, user=user, loader=source - ) + task = ingest_remote.delay(source_data=source_data, job_name=job_name, user=user, loader=source) task_id = task.id return {"status": "ok", "task_id": task_id} else: @@ -264,7 +254,7 @@ def combined_json(): for index in vectors_collection.find({"user": user}).sort("date", -1): data.append( { - "id":str(index["_id"]), + "id": str(index["_id"]), "name": index["name"], "language": index["language"], "version": "", @@ -278,9 +268,7 @@ def combined_json(): } ) if settings.VECTOR_STORE == "faiss": - data_remote = requests.get( - "https://d3dg1063dc54p9.cloudfront.net/combined.json" - ).json() + data_remote = requests.get("https://d3dg1063dc54p9.cloudfront.net/combined.json").json() for index in data_remote: index["location"] = "remote" data.append(index) @@ -383,9 +371,7 @@ def get_prompts(): list_prompts.append({"id": "creative", "name": "creative", "type": "public"}) list_prompts.append({"id": "strict", "name": "strict", "type": "public"}) for prompt in prompts: - list_prompts.append( - {"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"} - ) + list_prompts.append({"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"}) return jsonify(list_prompts) @@ -394,21 +380,15 @@ def get_prompts(): def get_single_prompt(): prompt_id = request.args.get("id") if prompt_id == "default": - with open( - os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r" - ) as f: + with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: chat_combine_template = f.read() return jsonify({"content": chat_combine_template}) elif prompt_id == "creative": - with open( - os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r" - ) as f: + with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f: chat_reduce_creative = f.read() return jsonify({"content": chat_reduce_creative}) elif prompt_id == "strict": - with open( - os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r" - ) as f: + with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f: chat_reduce_strict = f.read() return jsonify({"content": chat_reduce_strict}) @@ -437,9 +417,7 @@ def update_prompt_name(): # check if name is null if name == "": return {"status": "error"} - prompts_collection.update_one( - {"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}} - ) + prompts_collection.update_one({"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}}) return {"status": "ok"} @@ -449,12 +427,15 @@ def get_api_keys(): keys = api_key_collection.find({"user": user}) list_keys = [] for key in keys: + source_name = ( + db.dereference(key["source"])["name"] if isinstance(key["source"], DBRef) else key["source"].split("/")[0] + ) list_keys.append( { "id": str(key["_id"]), "name": key["name"], "key": key["key"][:4] + "..." + key["key"][-4:], - "source": str(key["source"]), + "source": source_name, "prompt_id": key["prompt_id"], "chunks": key["chunks"], } @@ -466,23 +447,22 @@ def get_api_keys(): def create_api_key(): data = request.get_json() name = data["name"] - source = data["source"] prompt_id = data["prompt_id"] chunks = data["chunks"] key = str(uuid.uuid4()) user = "local" - if(ObjectId.is_valid(data["source"])): - source = DBRef("vectors",ObjectId(data["source"])) - resp = api_key_collection.insert_one( - { - "name": name, - "key": key, - "source": source, - "user": user, - "prompt_id": prompt_id, - "chunks": chunks, - } - ) + new_api_key = { + "name": name, + "key": key, + "user": user, + "prompt_id": prompt_id, + "chunks": chunks, + } + if "source" in data and ObjectId.is_valid(data["source"]): + new_api_key["source"] = DBRef("vectors", ObjectId(data["source"])) + if "retriever" in data: + new_api_key["retriever"] = data["retriever"] + resp = api_key_collection.insert_one(new_api_key) new_id = str(resp.inserted_id) return {"id": new_id, "key": key} @@ -509,9 +489,7 @@ def share_conversation(): conversation_id = data["conversation_id"] isPromptable = request.args.get("isPromptable").lower() == "true" - conversation = conversations_collection.find_one( - {"_id": ObjectId(conversation_id)} - ) + conversation = conversations_collection.find_one({"_id": ObjectId(conversation_id)}) current_n_queries = len(conversation["queries"]) ##generate binary representation of uuid @@ -527,7 +505,7 @@ def share_conversation(): { "prompt_id": prompt_id, "chunks": chunks, - "source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source, + "source": DBRef("vectors", ObjectId(source)) if ObjectId.is_valid(source) else source, "user": user, } ) @@ -536,9 +514,7 @@ def share_conversation(): api_uuid = pre_existing_api_document["key"] pre_existing = shared_conversations_collections.find_one( { - "conversation_id": DBRef( - "conversations", ObjectId(conversation_id) - ), + "conversation_id": DBRef("conversations", ObjectId(conversation_id)), "isPromptable": isPromptable, "first_n_queries": current_n_queries, "user": user, @@ -569,15 +545,13 @@ def share_conversation(): "api_key": api_uuid, } ) - return jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ) + return jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}) else: api_key_collection.insert_one( { "name": name, "key": api_uuid, - "source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source, + "source": DBRef("vectors", ObjectId(source)) if ObjectId.is_valid(source) else source, "user": user, "prompt_id": prompt_id, "chunks": chunks, @@ -598,9 +572,7 @@ def share_conversation(): ) ## Identifier as route parameter in frontend return ( - jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ), + jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}), 201, ) @@ -615,9 +587,7 @@ def share_conversation(): ) if pre_existing is not None: return ( - jsonify( - {"success": True, "identifier": str(pre_existing["uuid"].as_uuid())} - ), + jsonify({"success": True, "identifier": str(pre_existing["uuid"].as_uuid())}), 200, ) else: @@ -635,9 +605,7 @@ def share_conversation(): ) ## Identifier as route parameter in frontend return ( - jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ), + jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}), 201, ) except Exception as err: @@ -649,16 +617,10 @@ def share_conversation(): @user.route("/api/shared_conversation/", methods=["GET"]) def get_publicly_shared_conversations(identifier: str): try: - query_uuid = Binary.from_uuid( - uuid.UUID(identifier), UuidRepresentation.STANDARD - ) + query_uuid = Binary.from_uuid(uuid.UUID(identifier), UuidRepresentation.STANDARD) shared = shared_conversations_collections.find_one({"uuid": query_uuid}) conversation_queries = [] - if ( - shared - and "conversation_id" in shared - and isinstance(shared["conversation_id"], DBRef) - ): + if shared and "conversation_id" in shared and isinstance(shared["conversation_id"], DBRef): # Resolve the DBRef conversation_ref = shared["conversation_id"] conversation = db.dereference(conversation_ref) @@ -672,9 +634,7 @@ def get_publicly_shared_conversations(identifier: str): ), 404, ) - conversation_queries = conversation["queries"][ - : (shared["first_n_queries"]) - ] + conversation_queries = conversation["queries"][: (shared["first_n_queries"])] for query in conversation_queries: query.pop("sources") ## avoid exposing sources else: From 1eb168be55fdcd594c2548f9aa45c4617f778fd8 Mon Sep 17 00:00:00 2001 From: ManishMadan2882 Date: Sun, 11 Aug 2024 19:33:31 +0530 Subject: [PATCH 02/18] vector indexes to be named after mongo _id --- application/api/internal/routes.py | 10 +++++-- application/api/user/routes.py | 23 +++++++-------- application/retriever/classic_rag.py | 9 +----- application/worker.py | 43 +++++++++------------------- 4 files changed, 34 insertions(+), 51 deletions(-) diff --git a/application/api/internal/routes.py b/application/api/internal/routes.py index 6039ecdf..f4203822 100755 --- a/application/api/internal/routes.py +++ b/application/api/internal/routes.py @@ -3,7 +3,7 @@ import datetime from flask import Blueprint, request, send_from_directory from pymongo import MongoClient from werkzeug.utils import secure_filename - +from bson.objectid import ObjectId from application.core.settings import settings mongo = MongoClient(settings.MONGO_URI) @@ -35,7 +35,12 @@ def upload_index_files(): return {"status": "no name"} job_name = secure_filename(request.form["name"]) tokens = secure_filename(request.form["tokens"]) - save_dir = os.path.join(current_dir, "indexes", user, job_name) + """" + ObjectId serves as a dir name in application/indexes, + and for indexing the vector metadata in the collection + """ + _id = ObjectId() + save_dir = os.path.join(current_dir, "indexes", str(_id)) if settings.VECTOR_STORE == "faiss": if "file_faiss" not in request.files: print("No file part") @@ -58,6 +63,7 @@ def upload_index_files(): # create entry in vectors_collection vectors_collection.insert_one( { + "_id":_id, "user": user, "name": job_name, "language": job_name, diff --git a/application/api/user/routes.py b/application/api/user/routes.py index aab30469..7ce0b2e2 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -116,18 +116,17 @@ def delete_by_ids(): def delete_old(): """Delete old indexes.""" import shutil - - path = request.args.get("path") - dirs = path.split("/") - dirs_clean = [] - for i in range(0, len(dirs)): - dirs_clean.append(secure_filename(dirs[i])) - # check that path strats with indexes or vectors - - if dirs_clean[0] not in ["indexes", "vectors"]: - return {"status": "error"} - path_clean = "/".join(dirs_clean) - vectors_collection.delete_one({"name": dirs_clean[-1], "user": dirs_clean[-2]}) + name = request.args.get("name") + user = request.args.get("user") + doc = vectors_collection.find_one({ + "user":user, + "name":name + }) + print("user",user) + print("file",name) + if(doc is None): + return {"status":"not found"},404 + path_clean = doc["location"] if settings.VECTOR_STORE == "faiss": try: shutil.rmtree(os.path.join(current_dir, path_clean)) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 2b77db34..4a1aa5bc 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -40,14 +40,7 @@ class ClassicRAG(BaseRetriever): def _get_vectorstore(self, source): if "active_docs" in source: - if source["active_docs"].split("/")[0] == "default": - vectorstore = "" - elif source["active_docs"].split("/")[0] == "local": - vectorstore = "indexes/" + source["active_docs"] - else: - vectorstore = "vectors/" + source["active_docs"] - if source["active_docs"] == "default": - vectorstore = "" + vectorstore = "indexes/"+source["active_docs"] else: vectorstore = "" vectorstore = os.path.join("application", vectorstore) diff --git a/application/worker.py b/application/worker.py index bd1bc15a..b3258983 100755 --- a/application/worker.py +++ b/application/worker.py @@ -14,6 +14,7 @@ from application.parser.open_ai_func import call_openai_api from application.parser.schema.base import Document from application.parser.token_func import group_split + # Define a function to extract metadata from a given filename. def metadata_from_filename(title): store = "/".join(title.split("/")[1:3]) @@ -25,9 +26,7 @@ def generate_random_string(length): return "".join([string.ascii_letters[i % 52] for i in range(length)]) -current_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) +current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5): @@ -93,9 +92,7 @@ def ingest_worker(self, directory, formats, name_job, filename, user): print(full_path, file=sys.stderr) # check if API_URL env variable is set file_data = {"name": name_job, "file": filename, "user": user} - response = requests.get( - urljoin(settings.API_URL, "/api/download"), params=file_data - ) + response = requests.get(urljoin(settings.API_URL, "/api/download"), params=file_data) # check if file is in the response print(response, file=sys.stderr) file = response.content @@ -107,9 +104,7 @@ def ingest_worker(self, directory, formats, name_job, filename, user): # check if file is .zip and extract it if filename.endswith(".zip"): - extract_zip_recursive( - os.path.join(full_path, filename), full_path, 0, recursion_depth - ) + extract_zip_recursive(os.path.join(full_path, filename), full_path, 0, recursion_depth) self.update_state(state="PROGRESS", meta={"current": 1}) @@ -141,22 +136,16 @@ def ingest_worker(self, directory, formats, name_job, filename, user): # get files from outputs/inputs/index.faiss and outputs/inputs/index.pkl # and send them to the server (provide user and name in form) - file_data = {"name": name_job, "user": user, "tokens":tokens} + file_data = {"name": name_job, "user": user, "tokens": tokens} if settings.VECTOR_STORE == "faiss": files = { "file_faiss": open(full_path + "/index.faiss", "rb"), "file_pkl": open(full_path + "/index.pkl", "rb"), } - response = requests.post( - urljoin(settings.API_URL, "/api/upload_index"), files=files, data=file_data - ) - response = requests.get( - urljoin(settings.API_URL, "/api/delete_old?path=" + full_path) - ) + response = requests.post(urljoin(settings.API_URL, "/api/upload_index"), files=files, data=file_data) + response = requests.get(urljoin(settings.API_URL, "/api/delete_old?name=" + name_job + "&?user=" + user)) else: - response = requests.post( - urljoin(settings.API_URL, "/api/upload_index"), data=file_data - ) + response = requests.post(urljoin(settings.API_URL, "/api/upload_index"), data=file_data) # delete local shutil.rmtree(full_path) @@ -196,17 +185,15 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp"): self.update_state(state="PROGRESS", meta={"current": 100}) # Proceed with uploading and cleaning as in the original function - file_data = {"name": name_job, "user": user, "tokens":tokens} + file_data = {"name": name_job, "user": user, "tokens": tokens} if settings.VECTOR_STORE == "faiss": files = { "file_faiss": open(full_path + "/index.faiss", "rb"), "file_pkl": open(full_path + "/index.pkl", "rb"), } - - requests.post( - urljoin(settings.API_URL, "/api/upload_index"), files=files, data=file_data - ) - requests.get(urljoin(settings.API_URL, "/api/delete_old?path=" + full_path)) + + requests.post(urljoin(settings.API_URL, "/api/upload_index"), files=files, data=file_data) + requests.get(urljoin(settings.API_URL, "/api/delete_old?name=" + name_job + "&?user=" + user)) else: requests.post(urljoin(settings.API_URL, "/api/upload_index"), data=file_data) @@ -222,9 +209,7 @@ def count_tokens_docs(docs): for doc in docs: docs_content += doc.page_content - tokens, total_price = num_tokens_from_string( - string=docs_content, encoding_name="cl100k_base" - ) + tokens, total_price = num_tokens_from_string(string=docs_content, encoding_name="cl100k_base") # Here we print the number of tokens and the approx user cost with some visually appealing formatting. return tokens @@ -234,4 +219,4 @@ def num_tokens_from_string(string: str, encoding_name: str) -> int: encoding = tiktoken.get_encoding(encoding_name) num_tokens = len(encoding.encode(string)) total_price = (num_tokens / 1000) * 0.0004 - return num_tokens, total_price \ No newline at end of file + return num_tokens, total_price From dc4078d744f7b94bac6edefd48dcbcc9098a82c9 Mon Sep 17 00:00:00 2001 From: ManishMadan2882 Date: Sun, 11 Aug 2024 21:26:30 +0530 Subject: [PATCH 03/18] migration(fixes): retriver/sharing endpoints --- application/api/answer/routes.py | 16 +++++++++----- application/api/user/routes.py | 38 ++++++++++++++++++-------------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 85cc3afd..e2d7b6e8 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -82,9 +82,11 @@ def get_data_from_api_key(api_key): if "source" in data and isinstance(data["source"], DBRef): source_doc = db.dereference(data["source"]) - data["source"] = str(source_doc._id) + data["source"] = str(source_doc["_id"]) if "retriever" in source_doc: data["retriever"] = source_doc["retriever"] + else: + data["source"] = {} return data @@ -357,10 +359,14 @@ def api_answer(): data_key = get_data_from_api_key(data["api_key"]) chunks = int(data_key["chunks"]) prompt_id = data_key["prompt_id"] - source = data_key["source"] + source = {"active_docs": data_key["source"]} + retriever_name = data_key["retriever"] user_api_key = data["api_key"] elif "active_docs" in data: - source = data["active_docs"] + source = {"active_docs":data["active_docs"]} + user_api_key = None + else: + source = {} user_api_key = None prompt = get_prompt(prompt_id) @@ -411,10 +417,10 @@ def api_search(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) chunks = int(data_key["chunks"]) - source = data_key["source"] + source = {"active_docs":data_key["source"]} user_api_key = data_key["api_key"] elif "active_docs" in data: - source = data["active_docs"] + source = {"active_docs":data["active_docs"]} user_api_key = None else: source = {} diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 7ce0b2e2..84831a65 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -489,26 +489,31 @@ def share_conversation(): isPromptable = request.args.get("isPromptable").lower() == "true" conversation = conversations_collection.find_one({"_id": ObjectId(conversation_id)}) + if(conversation is None): + raise Exception("Conversation does not exist") current_n_queries = len(conversation["queries"]) ##generate binary representation of uuid explicit_binary = Binary.from_uuid(uuid.uuid4(), UuidRepresentation.STANDARD) if isPromptable: - source = "default" if "source" not in data else data["source"] prompt_id = "default" if "prompt_id" not in data else data["prompt_id"] chunks = "2" if "chunks" not in data else data["chunks"] name = conversation["name"] + "(shared)" - pre_existing_api_document = api_key_collection.find_one( - { + new_api_key_data = { "prompt_id": prompt_id, "chunks": chunks, - "source": DBRef("vectors", ObjectId(source)) if ObjectId.is_valid(source) else source, "user": user, } + if "source" in data and ObjectId.is_valid(data["source"]): + new_api_key_data["source"] = DBRef("vectors",ObjectId(data["source"])) + elif "retriever" in data: + new_api_key_data["retriever"] = data["retriever"] + + pre_existing_api_document = api_key_collection.find_one( + new_api_key_data ) - api_uuid = str(uuid.uuid4()) if pre_existing_api_document: api_uuid = pre_existing_api_document["key"] pre_existing = shared_conversations_collections.find_one( @@ -546,17 +551,16 @@ def share_conversation(): ) return jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}) else: - api_key_collection.insert_one( - { - "name": name, - "key": api_uuid, - "source": DBRef("vectors", ObjectId(source)) if ObjectId.is_valid(source) else source, - "user": user, - "prompt_id": prompt_id, - "chunks": chunks, - } - ) - shared_conversations_collections.insert_one( + + api_uuid = str(uuid.uuid4()) + new_api_key_data["key"] = api_uuid + new_api_key_data["name"] = name + if "source" in data and ObjectId.is_valid(data["source"]): + new_api_key_data["source"] = DBRef("vectors", ObjectId(data["source"])) + if "retriever" in data: + new_api_key_data["retriever"] = data["retriever"] + api_key_collection.insert_one(new_api_key_data) + shared_conversations_collections.insert_one( { "uuid": explicit_binary, "conversation_id": { @@ -568,7 +572,7 @@ def share_conversation(): "user": user, "api_key": api_uuid, } - ) + ) ## Identifier as route parameter in frontend return ( jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}), From 7e8dd6bba8f3fb18c2cd682c6c5117e11303da09 Mon Sep 17 00:00:00 2001 From: ManishMadan2882 Date: Mon, 12 Aug 2024 01:06:21 +0530 Subject: [PATCH 04/18] fix: get api keys endpoint --- application/api/user/routes.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 84831a65..7c6e979c 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -426,9 +426,17 @@ def get_api_keys(): keys = api_key_collection.find({"user": user}) list_keys = [] for key in keys: - source_name = ( - db.dereference(key["source"])["name"] if isinstance(key["source"], DBRef) else key["source"].split("/")[0] - ) + if "source" in key and isinstance(key["source"],DBRef): + source = db.dereference(key["source"]) + if source is None: + continue + else: + source_name = source["name"] + elif "retriever" in key: + source_name = key["retriever"] + else: + continue + list_keys.append( { "id": str(key["_id"]), From deeffbf77d1754b81d35b3353f44cef5f9d2f3ee Mon Sep 17 00:00:00 2001 From: ManishMadan2882 Date: Mon, 12 Aug 2024 15:50:16 +0530 Subject: [PATCH 05/18] fix(retriever):classic should not override --- application/api/answer/routes.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index e2d7b6e8..caca7c67 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -78,7 +78,7 @@ def get_data_from_api_key(api_key): raise Exception("Invalid API Key, please generate new key", 401) if "retriever" not in data: - data["retriever"] = "classic" + data["retriever"] = None if "source" in data and isinstance(data["source"], DBRef): source_doc = db.dereference(data["source"]) @@ -94,7 +94,7 @@ def get_retriever(source_id: str): doc = vectors_collection.find_one({"_id": ObjectId(source_id)}) if doc is None: raise Exception("Source document does not exist", 404) - retriever_name = "classic" if "retriever" not in doc else doc["retriever"] + retriever_name = None if "retriever" not in doc else doc["retriever"] return retriever_name @@ -255,12 +255,12 @@ def stream(): chunks = int(data_key["chunks"]) prompt_id = data_key["prompt_id"] source = {"active_docs": data_key["source"]} - retriever_name = data_key["retriever"] + retriever_name = data_key["retriever"] or retriever_name user_api_key = data["api_key"] elif "active_docs" in data: source = {"active_docs" : data["active_docs"]} - retriever_name = get_retriever(data["active_docs"]) + retriever_name = get_retriever(data["active_docs"]) or retriever_name user_api_key = None else: @@ -273,7 +273,7 @@ def stream(): retriever_name = source["active_docs"] """ prompt = get_prompt(prompt_id) - + retriever = RetrieverCreator.create_retriever( retriever_name, question=question, @@ -360,10 +360,11 @@ def api_answer(): chunks = int(data_key["chunks"]) prompt_id = data_key["prompt_id"] source = {"active_docs": data_key["source"]} - retriever_name = data_key["retriever"] + retriever_name = data_key["retriever"] or retriever_name user_api_key = data["api_key"] elif "active_docs" in data: source = {"active_docs":data["active_docs"]} + retriever_name = get_retriever(data["active_docs"]) or retriever_name user_api_key = None else: source = {} From 0891ef6d0ad54b1f1c0aa7406ece3fd85441ba12 Mon Sep 17 00:00:00 2001 From: ManishMadan2882 Date: Wed, 14 Aug 2024 17:15:20 +0530 Subject: [PATCH 06/18] frontend: adapting to migration --- frontend/src/components/Dropdown.tsx | 5 + .../src/conversation/conversationHandlers.ts | 102 +++++++----------- .../src/conversation/conversationModels.ts | 10 ++ frontend/src/modals/CreateAPIKeyModal.tsx | 70 ++++++------ frontend/src/models/misc.ts | 1 + frontend/src/settings/APIKeys.tsx | 3 +- 6 files changed, 94 insertions(+), 97 deletions(-) diff --git a/frontend/src/components/Dropdown.tsx b/frontend/src/components/Dropdown.tsx index 17516aaa..0353a191 100644 --- a/frontend/src/components/Dropdown.tsx +++ b/frontend/src/components/Dropdown.tsx @@ -26,6 +26,7 @@ function Dropdown({ | string | { label: string; value: string } | { value: number; description: string } + | { name: string; id: string; type: string } | null; onSelect: | ((value: string) => void) @@ -96,6 +97,10 @@ function Dropdown({ ? selectedValue.value + ` (${selectedValue.description})` : selectedValue.description }` + : selectedValue && + 'name' in selectedValue && + 'id' in selectedValue + ? `${selectedValue.name}` : placeholder ? placeholder : 'From URL'} diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts index 90bbc0a9..9e3d5d2c 100644 --- a/frontend/src/conversation/conversationHandlers.ts +++ b/frontend/src/conversation/conversationHandlers.ts @@ -1,32 +1,6 @@ import conversationService from '../api/services/conversationService'; import { Doc } from '../preferences/preferenceApi'; -import { Answer, FEEDBACK } from './conversationModels'; - -function getDocPath(selectedDocs: Doc | null): string { - let docPath = 'default'; - if (selectedDocs) { - let namePath = selectedDocs.name; - if (selectedDocs.language === namePath) { - namePath = '.project'; - } - if (selectedDocs.location === 'local') { - docPath = 'local' + '/' + selectedDocs.name + '/'; - } else if (selectedDocs.location === 'remote') { - docPath = - selectedDocs.language + - '/' + - namePath + - '/' + - selectedDocs.version + - '/' + - selectedDocs.model + - '/'; - } else if (selectedDocs.location === 'custom') { - docPath = selectedDocs.docLink; - } - } - return docPath; -} +import { Answer, FEEDBACK, RetrievalPayload } from './conversationModels'; export function handleFetchAnswer( question: string, @@ -54,23 +28,22 @@ export function handleFetchAnswer( title: any; } > { - const docPath = getDocPath(selectedDocs); history = history.map((item) => { return { prompt: item.prompt, response: item.response }; }); + const payload: RetrievalPayload = { + question: question, + history: JSON.stringify(history), + conversation_id: conversationId, + prompt_id: promptId, + chunks: chunks, + token_limit: token_limit, + }; + if (selectedDocs && 'id' in selectedDocs) + payload.active_docs = selectedDocs.id as string; + else payload.retriever = selectedDocs?.docLink as string; return conversationService - .answer( - { - question: question, - history: history, - active_docs: docPath, - conversation_id: conversationId, - prompt_id: promptId, - chunks: chunks, - token_limit: token_limit, - }, - signal, - ) + .answer(payload, signal) .then((response) => { if (response.ok) { return response.json(); @@ -101,24 +74,24 @@ export function handleFetchAnswerSteaming( token_limit: number, onEvent: (event: MessageEvent) => void, ): Promise { - const docPath = getDocPath(selectedDocs); history = history.map((item) => { return { prompt: item.prompt, response: item.response }; }); + const payload: RetrievalPayload = { + question: question, + history: JSON.stringify(history), + conversation_id: conversationId, + prompt_id: promptId, + chunks: chunks, + token_limit: token_limit, + }; + if (selectedDocs && 'id' in selectedDocs) + payload.active_docs = selectedDocs.id as string; + else payload.retriever = selectedDocs?.docLink as string; + return new Promise((resolve, reject) => { conversationService - .answerStream( - { - question: question, - active_docs: docPath, - history: JSON.stringify(history), - conversation_id: conversationId, - prompt_id: promptId, - chunks: chunks, - token_limit: token_limit, - }, - signal, - ) + .answerStream(payload, signal) .then((response) => { if (!response.body) throw Error('No response body'); @@ -175,16 +148,21 @@ export function handleSearch( chunks: string, token_limit: number, ) { - const docPath = getDocPath(selectedDocs); + history = history.map((item) => { + return { prompt: item.prompt, response: item.response }; + }); + const payload: RetrievalPayload = { + question: question, + history: JSON.stringify(history), + conversation_id: conversation_id, + chunks: chunks, + token_limit: token_limit, + }; + if (selectedDocs && 'id' in selectedDocs) + payload.active_docs = selectedDocs.id as string; + else payload.retriever = selectedDocs?.docLink as string; return conversationService - .search({ - question: question, - active_docs: docPath, - conversation_id, - history, - chunks: chunks, - token_limit: token_limit, - }) + .search(payload) .then((response) => response.json()) .then((data) => { return data; diff --git a/frontend/src/conversation/conversationModels.ts b/frontend/src/conversation/conversationModels.ts index 347a2521..bf86678b 100644 --- a/frontend/src/conversation/conversationModels.ts +++ b/frontend/src/conversation/conversationModels.ts @@ -31,3 +31,13 @@ export interface Query { conversationId?: string | null; title?: string | null; } +export interface RetrievalPayload { + question: string; + active_docs?: string; + retriever?: string; + history: string; + conversation_id: string | null; + prompt_id?: string | null; + chunks: string; + token_limit: number; +} diff --git a/frontend/src/modals/CreateAPIKeyModal.tsx b/frontend/src/modals/CreateAPIKeyModal.tsx index 2f67d83b..e59fd37e 100644 --- a/frontend/src/modals/CreateAPIKeyModal.tsx +++ b/frontend/src/modals/CreateAPIKeyModal.tsx @@ -22,8 +22,9 @@ export default function CreateAPIKeyModal({ const [APIKeyName, setAPIKeyName] = React.useState(''); const [sourcePath, setSourcePath] = React.useState<{ - label: string; - value: string; + name: string; + id: string; + type: string; } | null>(null); const [prompt, setPrompt] = React.useState<{ name: string; @@ -41,27 +42,17 @@ export default function CreateAPIKeyModal({ ? docs .filter((doc) => doc.model === embeddingsName) .map((doc: Doc) => { - let namePath = doc.name; - if (doc.language === namePath) { - namePath = '.project'; - } - let docPath = 'default'; - if (doc.location === 'local') { - docPath = 'local' + '/' + doc.name + '/'; - } else if (doc.location === 'remote') { - docPath = - doc.language + - '/' + - namePath + - '/' + - doc.version + - '/' + - doc.model + - '/'; + if ('id' in doc) { + return { + name: doc.name, + id: doc.id as string, + type: 'local', + }; } return { - label: doc.name, - value: docPath, + name: doc.name as string, + id: doc.docLink as string, + type: 'default', }; }) : []; @@ -107,9 +98,14 @@ export default function CreateAPIKeyModal({ - setSourcePath(selection) - } + onSelect={(selection: { + name: string; + id: string; + type: string; + }) => { + setSourcePath(selection); + console.log(selection); + }} options={extractDocPaths()} size="w-full" rounded="xl" @@ -142,16 +138,22 @@ export default function CreateAPIKeyModal({