diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index f076285d..15feeb3c 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) mongo = MongoClient(settings.MONGO_URI) db = mongo["docsgpt"] conversations_collection = db["conversations"] -vectors_collection = db["vectors"] +sources_collection = db["sources"] prompts_collection = db["prompts"] api_key_collection = db["api_keys"] answer = Blueprint("answer", __name__) @@ -77,40 +77,27 @@ def get_data_from_api_key(api_key): if data is None: raise Exception("Invalid API Key, please generate new key", 401) - if isinstance(data["source"], DBRef): - source_id = db.dereference(data["source"])["_id"] - data["source"] = get_source(source_id) + if "retriever" not in data: + data["retriever"] = None + if "source" in data and isinstance(data["source"], DBRef): + source_doc = db.dereference(data["source"]) + data["source"] = str(source_doc["_id"]) + if "retriever" in source_doc: + data["retriever"] = source_doc["retriever"] + else: + data["source"] = {} return data -def get_source(active_doc): - if ObjectId.is_valid(active_doc): - doc = vectors_collection.find_one({"_id": ObjectId(active_doc)}) - if doc is None: - raise Exception("Source document does not exist", 404) - print("res", doc) - source = {"active_docs": "/".join(doc["location"].split("/")[-2:])} - else: - source = {"active_docs": active_doc} - return source +def get_retriever(source_id: str): + doc = sources_collection.find_one({"_id": ObjectId(source_id)}) + if doc is None: + raise Exception("Source document does not exist", 404) + retriever_name = None if "retriever" not in doc else doc["retriever"] + return retriever_name -def get_vectorstore(data): - if "active_docs" in data: - if data["active_docs"].split("/")[0] == "default": - vectorstore = "" - elif data["active_docs"].split("/")[0] == "local": - vectorstore = "indexes/" + data["active_docs"] - else: - vectorstore = "vectors/" + data["active_docs"] - if data["active_docs"] == "default": - vectorstore = "" - else: - vectorstore = "" - vectorstore = os.path.join("application", vectorstore) - return vectorstore - def is_azure_configured(): return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME @@ -244,28 +231,34 @@ def stream(): else: token_limit = settings.DEFAULT_MAX_HISTORY - # check if active_docs or api_key is set + ## retriever can be "brave_search, duckduck_search or classic" + retriever_name = data["retriever"] if "retriever" in data else "classic" + # check if active_docs or api_key is set if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) chunks = int(data_key["chunks"]) prompt_id = data_key["prompt_id"] - source = data_key["source"] + source = {"active_docs": data_key["source"]} + retriever_name = data_key["retriever"] or retriever_name user_api_key = data["api_key"] + elif "active_docs" in data: - source = get_source(data["active_docs"]) + source = {"active_docs" : data["active_docs"]} + retriever_name = get_retriever(data["active_docs"]) or retriever_name user_api_key = None + else: source = {} user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + """ if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": retriever_name = "classic" else: - retriever_name = source["active_docs"] + retriever_name = source["active_docs"] """ prompt = get_prompt(prompt_id) - + retriever = RetrieverCreator.create_retriever( retriever_name, question=question, @@ -341,6 +334,9 @@ def api_answer(): else: token_limit = settings.DEFAULT_MAX_HISTORY + ## retriever can be brave_search, duckduck_search or classic + retriever_name = data["retriever"] if "retriever" in data else "classic" + # use try and except to check for exception try: # check if the vectorstore is set @@ -348,16 +344,16 @@ def api_answer(): data_key = get_data_from_api_key(data["api_key"]) chunks = int(data_key["chunks"]) prompt_id = data_key["prompt_id"] - source = data_key["source"] + source = {"active_docs": data_key["source"]} + retriever_name = data_key["retriever"] or retriever_name user_api_key = data["api_key"] - else: - source = get_source(data["active_docs"]) + elif "active_docs" in data: + source = {"active_docs":data["active_docs"]} + retriever_name = get_retriever(data["active_docs"]) or retriever_name user_api_key = None - - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": - retriever_name = "classic" else: - retriever_name = source["active_docs"] + source = {} + user_api_key = None prompt = get_prompt(prompt_id) @@ -407,19 +403,19 @@ def api_search(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) chunks = int(data_key["chunks"]) - source = data_key["source"] + source = {"active_docs":data_key["source"]} user_api_key = data_key["api_key"] elif "active_docs" in data: - source = get_source(data["active_docs"]) + source = {"active_docs":data["active_docs"]} user_api_key = None else: source = {} user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": - retriever_name = "classic" + if "retriever" in data: + retriever_name = data["retriever"] else: - retriever_name = source["active_docs"] + retriever_name = "classic" if "token_limit" in data: token_limit = data["token_limit"] else: diff --git a/application/api/internal/routes.py b/application/api/internal/routes.py index 6039ecdf..cea6c8ca 100755 --- a/application/api/internal/routes.py +++ b/application/api/internal/routes.py @@ -3,13 +3,13 @@ import datetime from flask import Blueprint, request, send_from_directory from pymongo import MongoClient from werkzeug.utils import secure_filename - +from bson.objectid import ObjectId from application.core.settings import settings mongo = MongoClient(settings.MONGO_URI) db = mongo["docsgpt"] conversations_collection = db["conversations"] -vectors_collection = db["vectors"] +sources_collection = db["sources"] current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -35,7 +35,12 @@ def upload_index_files(): return {"status": "no name"} job_name = secure_filename(request.form["name"]) tokens = secure_filename(request.form["tokens"]) - save_dir = os.path.join(current_dir, "indexes", user, job_name) + retriever = secure_filename(request.form["retriever"]) + id = secure_filename(request.form["id"]) + type = secure_filename(request.form["type"]) + remote_data = secure_filename(request.form["remote_data"]) if "remote_data" in request.form else None + + save_dir = os.path.join(current_dir, "indexes", str(id)) if settings.VECTOR_STORE == "faiss": if "file_faiss" not in request.files: print("No file part") @@ -55,17 +60,19 @@ def upload_index_files(): os.makedirs(save_dir) file_faiss.save(os.path.join(save_dir, "index.faiss")) file_pkl.save(os.path.join(save_dir, "index.pkl")) - # create entry in vectors_collection - vectors_collection.insert_one( + # create entry in sources_collection + sources_collection.insert_one( { + "_id": ObjectId(id), "user": user, "name": job_name, "language": job_name, - "location": save_dir, "date": datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S"), "model": settings.EMBEDDINGS_NAME, - "type": "local", - "tokens": tokens + "type": type, + "tokens": tokens, + "retriever": retriever, + "remote_data": remote_data } ) return {"status": "ok"} \ No newline at end of file diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 06bab591..73023a89 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -2,8 +2,6 @@ import os import uuid import shutil from flask import Blueprint, request, jsonify -from urllib.parse import urlparse -import requests from pymongo import MongoClient from bson.objectid import ObjectId from bson.binary import Binary, UuidRepresentation @@ -17,7 +15,7 @@ from application.vectorstore.vector_creator import VectorCreator mongo = MongoClient(settings.MONGO_URI) db = mongo["docsgpt"] conversations_collection = db["conversations"] -vectors_collection = db["vectors"] +sources_collection = db["sources"] prompts_collection = db["prompts"] feedback_collection = db["feedback"] api_key_collection = db["api_keys"] @@ -25,9 +23,7 @@ shared_conversations_collections = db["shared_conversations"] user = Blueprint("user", __name__) -current_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) +current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @user.route("/api/delete_conversation", methods=["POST"]) @@ -57,9 +53,7 @@ def get_conversations(): conversations = conversations_collection.find().sort("date", -1).limit(30) list_conversations = [] for conversation in conversations: - list_conversations.append( - {"id": str(conversation["_id"]), "name": conversation["name"]} - ) + list_conversations.append({"id": str(conversation["_id"]), "name": conversation["name"]}) # list_conversations = [{"id": "default", "name": "default"}, {"id": "jeff", "name": "jeff"}] @@ -110,7 +104,7 @@ def delete_by_ids(): return {"status": "error"} if settings.VECTOR_STORE == "faiss": - result = vectors_collection.delete_index(ids=ids) + result = sources_collection.delete_index(ids=ids) if result: return {"status": "ok"} return {"status": "error"} @@ -120,28 +114,24 @@ def delete_by_ids(): def delete_old(): """Delete old indexes.""" import shutil - - path = request.args.get("path") - dirs = path.split("/") - dirs_clean = [] - for i in range(0, len(dirs)): - dirs_clean.append(secure_filename(dirs[i])) - # check that path strats with indexes or vectors - - if dirs_clean[0] not in ["indexes", "vectors"]: - return {"status": "error"} - path_clean = "/".join(dirs_clean) - vectors_collection.delete_one({"name": dirs_clean[-1], "user": dirs_clean[-2]}) + source_id = request.args.get("source_id") + doc = sources_collection.find_one({ + "_id": ObjectId(source_id), + "user": "local", + }) + if(doc is None): + return {"status":"not found"},404 if settings.VECTOR_STORE == "faiss": try: - shutil.rmtree(os.path.join(current_dir, path_clean)) + shutil.rmtree(os.path.join(current_dir, str(doc["_id"]))) except FileNotFoundError: pass else: - vetorstore = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, path=os.path.join(current_dir, path_clean) - ) + vetorstore = VectorCreator.create_vectorstore(settings.VECTOR_STORE, source_id=str(doc["_id"])) vetorstore.delete_index() + sources_collection.delete_one({ + "_id": ObjectId(source_id), + }) return {"status": "ok"} @@ -175,9 +165,7 @@ def upload_file(): file.save(os.path.join(temp_dir, filename)) # Use shutil.make_archive to zip the temp directory - zip_path = shutil.make_archive( - base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir - ) + zip_path = shutil.make_archive(base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir) final_filename = os.path.basename(zip_path) # Clean up the temporary directory after zipping @@ -219,9 +207,7 @@ def upload_remote(): source_data = request.form["data"] if source_data: - task = ingest_remote.delay( - source_data=source_data, job_name=job_name, user=user, loader=source - ) + task = ingest_remote.delay(source_data=source_data, job_name=job_name, user=user, loader=source) task_id = task.id return {"status": "ok", "task_id": task_id} else: @@ -248,55 +234,36 @@ def combined_json(): data = [ { "name": "default", - "language": "default", - "version": "", - "description": "default", - "fullName": "default", "date": "default", - "docLink": "default", "model": settings.EMBEDDINGS_NAME, "location": "remote", "tokens": "", + "retriever": "classic", } ] # structure: name, language, version, description, fullName, date, docLink - # append data from vectors_collection in sorted order in descending order of date - for index in vectors_collection.find({"user": user}).sort("date", -1): + # append data from sources_collection in sorted order in descending order of date + for index in sources_collection.find({"user": user}).sort("date", -1): data.append( { - "id":str(index["_id"]), + "id": str(index["_id"]), "name": index["name"], - "language": index["language"], - "version": "", - "description": index["name"], - "fullName": index["name"], "date": index["date"], - "docLink": index["location"], "model": settings.EMBEDDINGS_NAME, "location": "local", "tokens": index["tokens"] if ("tokens" in index.keys()) else "", + "retriever": index["retriever"] if ("retriever" in index.keys()) else "classic", } ) - if settings.VECTOR_STORE == "faiss": - data_remote = requests.get( - "https://d3dg1063dc54p9.cloudfront.net/combined.json" - ).json() - for index in data_remote: - index["location"] = "remote" - data.append(index) if "duckduck_search" in settings.RETRIEVERS_ENABLED: data.append( { "name": "DuckDuckGo Search", - "language": "en", - "version": "", - "description": "duckduck_search", - "fullName": "DuckDuckGo Search", "date": "duckduck_search", - "docLink": "duckduck_search", "model": settings.EMBEDDINGS_NAME, "location": "custom", "tokens": "", + "retriever": "duckduck_search", } ) if "brave_search" in settings.RETRIEVERS_ENABLED: @@ -304,14 +271,11 @@ def combined_json(): { "name": "Brave Search", "language": "en", - "version": "", - "description": "brave_search", - "fullName": "Brave Search", "date": "brave_search", - "docLink": "brave_search", "model": settings.EMBEDDINGS_NAME, "location": "custom", "tokens": "", + "retriever": "brave_search", } ) @@ -320,39 +284,13 @@ def combined_json(): @user.route("/api/docs_check", methods=["POST"]) def check_docs(): - # check if docs exist in a vectorstore folder data = request.get_json() - # split docs on / and take first part - if data["docs"].split("/")[0] == "local": - return {"status": "exists"} + vectorstore = "vectors/" + secure_filename(data["docs"]) - base_path = "https://raw.githubusercontent.com/arc53/DocsHUB/main/" if os.path.exists(vectorstore) or data["docs"] == "default": return {"status": "exists"} else: - file_url = urlparse(base_path + vectorstore + "index.faiss") - - if ( - file_url.scheme in ["https"] - and file_url.netloc == "raw.githubusercontent.com" - and file_url.path.startswith("/arc53/DocsHUB/main/") - ): - r = requests.get(file_url.geturl()) - if r.status_code != 200: - return {"status": "null"} - else: - if not os.path.exists(vectorstore): - os.makedirs(vectorstore) - with open(vectorstore + "index.faiss", "wb") as f: - f.write(r.content) - - r = requests.get(base_path + vectorstore + "index.pkl") - with open(vectorstore + "index.pkl", "wb") as f: - f.write(r.content) - else: - return {"status": "null"} - - return {"status": "loaded"} + return {"status": "not found"} @user.route("/api/create_prompt", methods=["POST"]) @@ -383,9 +321,7 @@ def get_prompts(): list_prompts.append({"id": "creative", "name": "creative", "type": "public"}) list_prompts.append({"id": "strict", "name": "strict", "type": "public"}) for prompt in prompts: - list_prompts.append( - {"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"} - ) + list_prompts.append({"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"}) return jsonify(list_prompts) @@ -394,21 +330,15 @@ def get_prompts(): def get_single_prompt(): prompt_id = request.args.get("id") if prompt_id == "default": - with open( - os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r" - ) as f: + with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: chat_combine_template = f.read() return jsonify({"content": chat_combine_template}) elif prompt_id == "creative": - with open( - os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r" - ) as f: + with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f: chat_reduce_creative = f.read() return jsonify({"content": chat_reduce_creative}) elif prompt_id == "strict": - with open( - os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r" - ) as f: + with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f: chat_reduce_strict = f.read() return jsonify({"content": chat_reduce_strict}) @@ -437,9 +367,7 @@ def update_prompt_name(): # check if name is null if name == "": return {"status": "error"} - prompts_collection.update_one( - {"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}} - ) + prompts_collection.update_one({"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}}) return {"status": "ok"} @@ -449,12 +377,23 @@ def get_api_keys(): keys = api_key_collection.find({"user": user}) list_keys = [] for key in keys: + if "source" in key and isinstance(key["source"],DBRef): + source = db.dereference(key["source"]) + if source is None: + continue + else: + source_name = source["name"] + elif "retriever" in key: + source_name = key["retriever"] + else: + continue + list_keys.append( { "id": str(key["_id"]), "name": key["name"], "key": key["key"][:4] + "..." + key["key"][-4:], - "source": str(key["source"]), + "source": source_name, "prompt_id": key["prompt_id"], "chunks": key["chunks"], } @@ -466,23 +405,22 @@ def get_api_keys(): def create_api_key(): data = request.get_json() name = data["name"] - source = data["source"] prompt_id = data["prompt_id"] chunks = data["chunks"] key = str(uuid.uuid4()) user = "local" - if(ObjectId.is_valid(data["source"])): - source = DBRef("vectors",ObjectId(data["source"])) - resp = api_key_collection.insert_one( - { - "name": name, - "key": key, - "source": source, - "user": user, - "prompt_id": prompt_id, - "chunks": chunks, - } - ) + new_api_key = { + "name": name, + "key": key, + "user": user, + "prompt_id": prompt_id, + "chunks": chunks, + } + if "source" in data and ObjectId.is_valid(data["source"]): + new_api_key["source"] = DBRef("sources", ObjectId(data["source"])) + if "retriever" in data: + new_api_key["retriever"] = data["retriever"] + resp = api_key_collection.insert_one(new_api_key) new_id = str(resp.inserted_id) return {"id": new_id, "key": key} @@ -509,36 +447,37 @@ def share_conversation(): conversation_id = data["conversation_id"] isPromptable = request.args.get("isPromptable").lower() == "true" - conversation = conversations_collection.find_one( - {"_id": ObjectId(conversation_id)} - ) + conversation = conversations_collection.find_one({"_id": ObjectId(conversation_id)}) + if(conversation is None): + raise Exception("Conversation does not exist") current_n_queries = len(conversation["queries"]) ##generate binary representation of uuid explicit_binary = Binary.from_uuid(uuid.uuid4(), UuidRepresentation.STANDARD) if isPromptable: - source = "default" if "source" not in data else data["source"] prompt_id = "default" if "prompt_id" not in data else data["prompt_id"] chunks = "2" if "chunks" not in data else data["chunks"] name = conversation["name"] + "(shared)" - pre_existing_api_document = api_key_collection.find_one( - { + new_api_key_data = { "prompt_id": prompt_id, "chunks": chunks, - "source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source, "user": user, } + if "source" in data and ObjectId.is_valid(data["source"]): + new_api_key_data["source"] = DBRef("sources",ObjectId(data["source"])) + elif "retriever" in data: + new_api_key_data["retriever"] = data["retriever"] + + pre_existing_api_document = api_key_collection.find_one( + new_api_key_data ) - api_uuid = str(uuid.uuid4()) if pre_existing_api_document: api_uuid = pre_existing_api_document["key"] pre_existing = shared_conversations_collections.find_one( { - "conversation_id": DBRef( - "conversations", ObjectId(conversation_id) - ), + "conversation_id": DBRef("conversations", ObjectId(conversation_id)), "isPromptable": isPromptable, "first_n_queries": current_n_queries, "user": user, @@ -569,21 +508,18 @@ def share_conversation(): "api_key": api_uuid, } ) - return jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ) + return jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}) else: - api_key_collection.insert_one( - { - "name": name, - "key": api_uuid, - "source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source, - "user": user, - "prompt_id": prompt_id, - "chunks": chunks, - } - ) - shared_conversations_collections.insert_one( + + api_uuid = str(uuid.uuid4()) + new_api_key_data["key"] = api_uuid + new_api_key_data["name"] = name + if "source" in data and ObjectId.is_valid(data["source"]): + new_api_key_data["source"] = DBRef("sources", ObjectId(data["source"])) + if "retriever" in data: + new_api_key_data["retriever"] = data["retriever"] + api_key_collection.insert_one(new_api_key_data) + shared_conversations_collections.insert_one( { "uuid": explicit_binary, "conversation_id": { @@ -595,12 +531,10 @@ def share_conversation(): "user": user, "api_key": api_uuid, } - ) + ) ## Identifier as route parameter in frontend return ( - jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ), + jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}), 201, ) @@ -615,9 +549,7 @@ def share_conversation(): ) if pre_existing is not None: return ( - jsonify( - {"success": True, "identifier": str(pre_existing["uuid"].as_uuid())} - ), + jsonify({"success": True, "identifier": str(pre_existing["uuid"].as_uuid())}), 200, ) else: @@ -635,9 +567,7 @@ def share_conversation(): ) ## Identifier as route parameter in frontend return ( - jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ), + jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}), 201, ) except Exception as err: @@ -649,16 +579,10 @@ def share_conversation(): @user.route("/api/shared_conversation/", methods=["GET"]) def get_publicly_shared_conversations(identifier: str): try: - query_uuid = Binary.from_uuid( - uuid.UUID(identifier), UuidRepresentation.STANDARD - ) + query_uuid = Binary.from_uuid(uuid.UUID(identifier), UuidRepresentation.STANDARD) shared = shared_conversations_collections.find_one({"uuid": query_uuid}) conversation_queries = [] - if ( - shared - and "conversation_id" in shared - and isinstance(shared["conversation_id"], DBRef) - ): + if shared and "conversation_id" in shared and isinstance(shared["conversation_id"], DBRef): # Resolve the DBRef conversation_ref = shared["conversation_id"] conversation = db.dereference(conversation_ref) @@ -672,9 +596,7 @@ def get_publicly_shared_conversations(identifier: str): ), 404, ) - conversation_queries = conversation["queries"][ - : (shared["first_n_queries"]) - ] + conversation_queries = conversation["queries"][: (shared["first_n_queries"])] for query in conversation_queries: query.pop("sources") ## avoid exposing sources else: diff --git a/application/parser/open_ai_func.py b/application/parser/open_ai_func.py index c58e8059..84f92db9 100755 --- a/application/parser/open_ai_func.py +++ b/application/parser/open_ai_func.py @@ -11,12 +11,14 @@ from retry import retry @retry(tries=10, delay=60) -def store_add_texts_with_retry(store, i): +def store_add_texts_with_retry(store, i, id): + # add source_id to the metadata + i.metadata["source_id"] = str(id) store.add_texts([i.page_content], metadatas=[i.metadata]) # store_pine.add_texts([i.page_content], metadatas=[i.metadata]) -def call_openai_api(docs, folder_name, task_status): +def call_openai_api(docs, folder_name, id, task_status): # Function to create a vector store from the documents and save it to disk if not os.path.exists(f"{folder_name}"): @@ -32,13 +34,13 @@ def call_openai_api(docs, folder_name, task_status): store = VectorCreator.create_vectorstore( settings.VECTOR_STORE, docs_init=docs_init, - path=f"{folder_name}", + source_id=f"{folder_name}", embeddings_key=os.getenv("EMBEDDINGS_KEY"), ) else: store = VectorCreator.create_vectorstore( settings.VECTOR_STORE, - path=f"{folder_name}", + source_id=str(id), embeddings_key=os.getenv("EMBEDDINGS_KEY"), ) # Uncomment for MPNet embeddings @@ -57,7 +59,7 @@ def call_openai_api(docs, folder_name, task_status): task_status.update_state( state="PROGRESS", meta={"current": int((c1 / s1) * 100)} ) - store_add_texts_with_retry(store, i) + store_add_texts_with_retry(store, i, id) except Exception as e: print(e) print("Error on ", i) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 2b77db34..499a4b7e 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,4 +1,3 @@ -import os from application.retriever.base import BaseRetriever from application.core.settings import settings from application.vectorstore.vector_creator import VectorCreator @@ -21,7 +20,7 @@ class ClassicRAG(BaseRetriever): user_api_key=None, ): self.question = question - self.vectorstore = self._get_vectorstore(source=source) + self.vectorstore = source['active_docs'] if 'active_docs' in source else None self.chat_history = chat_history self.prompt = prompt self.chunks = chunks @@ -38,21 +37,6 @@ class ClassicRAG(BaseRetriever): ) self.user_api_key = user_api_key - def _get_vectorstore(self, source): - if "active_docs" in source: - if source["active_docs"].split("/")[0] == "default": - vectorstore = "" - elif source["active_docs"].split("/")[0] == "local": - vectorstore = "indexes/" + source["active_docs"] - else: - vectorstore = "vectors/" + source["active_docs"] - if source["active_docs"] == "default": - vectorstore = "" - else: - vectorstore = "" - vectorstore = os.path.join("application", vectorstore) - return vectorstore - def _get_data(self): if self.chunks == 0: docs = [] diff --git a/application/vectorstore/elasticsearch.py b/application/vectorstore/elasticsearch.py index bb28d5ce..e393e4a5 100644 --- a/application/vectorstore/elasticsearch.py +++ b/application/vectorstore/elasticsearch.py @@ -9,9 +9,9 @@ import elasticsearch class ElasticsearchStore(BaseVectorStore): _es_connection = None # Class attribute to hold the Elasticsearch connection - def __init__(self, path, embeddings_key, index_name=settings.ELASTIC_INDEX): + def __init__(self, source_id, embeddings_key, index_name=settings.ELASTIC_INDEX): super().__init__() - self.path = path.replace("application/indexes/", "").rstrip("/") + self.source_id = source_id.replace("application/indexes/", "").rstrip("/") self.embeddings_key = embeddings_key self.index_name = index_name @@ -81,7 +81,7 @@ class ElasticsearchStore(BaseVectorStore): embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key) vector = embeddings.embed_query(question) knn = { - "filter": [{"match": {"metadata.store.keyword": self.path}}], + "filter": [{"match": {"metadata.source_id.keyword": self.source_id}}], "field": "vector", "k": k, "num_candidates": 100, @@ -100,7 +100,7 @@ class ElasticsearchStore(BaseVectorStore): } } ], - "filter": [{"match": {"metadata.store.keyword": self.path}}], + "filter": [{"match": {"metadata.source_id.keyword": self.source_id}}], } }, "rank": {"rrf": {}}, @@ -209,5 +209,4 @@ class ElasticsearchStore(BaseVectorStore): def delete_index(self): self._es_connection.delete_by_query(index=self.index_name, query={"match": { - "metadata.store.keyword": self.path}},) - + "metadata.source_id.keyword": self.source_id}},) diff --git a/application/vectorstore/faiss.py b/application/vectorstore/faiss.py index 8e8f3b8e..b504ebf8 100644 --- a/application/vectorstore/faiss.py +++ b/application/vectorstore/faiss.py @@ -1,12 +1,22 @@ from langchain_community.vectorstores import FAISS from application.vectorstore.base import BaseVectorStore from application.core.settings import settings +import os + +def get_vectorstore(path): + if path: + vectorstore = "indexes/"+path + vectorstore = os.path.join("application", vectorstore) + else: + vectorstore = os.path.join("application") + + return vectorstore class FaissStore(BaseVectorStore): - def __init__(self, path, embeddings_key, docs_init=None): + def __init__(self, source_id, embeddings_key, docs_init=None): super().__init__() - self.path = path + self.path = get_vectorstore(source_id) embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key) if docs_init: self.docsearch = FAISS.from_documents( diff --git a/application/vectorstore/mongodb.py b/application/vectorstore/mongodb.py index 337fc41f..32bca489 100644 --- a/application/vectorstore/mongodb.py +++ b/application/vectorstore/mongodb.py @@ -5,7 +5,7 @@ from application.vectorstore.document_class import Document class MongoDBVectorStore(BaseVectorStore): def __init__( self, - path: str = "", + source_id: str = "", embeddings_key: str = "embeddings", collection: str = "documents", index_name: str = "vector_search_index", @@ -18,7 +18,7 @@ class MongoDBVectorStore(BaseVectorStore): self._embedding_key = embedding_key self._embeddings_key = embeddings_key self._mongo_uri = settings.MONGO_URI - self._path = path.replace("application/indexes/", "").rstrip("/") + self._source_id = source_id.replace("application/indexes/", "").rstrip("/") self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key) try: @@ -46,7 +46,7 @@ class MongoDBVectorStore(BaseVectorStore): "numCandidates": k * 10, "index": self._index_name, "filter": { - "store": {"$eq": self._path} + "source_id": {"$eq": self._source_id} } } } @@ -123,4 +123,4 @@ class MongoDBVectorStore(BaseVectorStore): return result_ids def delete_index(self, *args, **kwargs): - self._collection.delete_many({"store": self._path}) \ No newline at end of file + self._collection.delete_many({"source_id": self._source_id}) \ No newline at end of file diff --git a/application/vectorstore/qdrant.py b/application/vectorstore/qdrant.py index 482d06a1..3f94505f 100644 --- a/application/vectorstore/qdrant.py +++ b/application/vectorstore/qdrant.py @@ -5,12 +5,12 @@ from qdrant_client import models class QdrantStore(BaseVectorStore): - def __init__(self, path: str = "", embeddings_key: str = "embeddings"): + def __init__(self, source_id: str = "", embeddings_key: str = "embeddings"): self._filter = models.Filter( must=[ models.FieldCondition( - key="metadata.store", - match=models.MatchValue(value=path.replace("application/indexes/", "").rstrip("/")), + key="metadata.source_id", + match=models.MatchValue(value=source_id.replace("application/indexes/", "").rstrip("/")), ) ] ) diff --git a/application/worker.py b/application/worker.py index bd1bc15a..40e66431 100755 --- a/application/worker.py +++ b/application/worker.py @@ -6,6 +6,7 @@ import tiktoken from urllib.parse import urljoin import requests +from bson.objectid import ObjectId from application.core.settings import settings from application.parser.file.bulk import SimpleDirectoryReader @@ -14,10 +15,10 @@ from application.parser.open_ai_func import call_openai_api from application.parser.schema.base import Document from application.parser.token_func import group_split + # Define a function to extract metadata from a given filename. def metadata_from_filename(title): - store = "/".join(title.split("/")[1:3]) - return {"title": title, "store": store} + return {"title": title} # Define a function to generate a random string of a given length. @@ -25,9 +26,7 @@ def generate_random_string(length): return "".join([string.ascii_letters[i % 52] for i in range(length)]) -current_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) +current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5): @@ -58,7 +57,7 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5): # Define the main function for ingesting and processing documents. -def ingest_worker(self, directory, formats, name_job, filename, user): +def ingest_worker(self, directory, formats, name_job, filename, user, retriever="classic"): """ Ingest and process documents. @@ -69,6 +68,7 @@ def ingest_worker(self, directory, formats, name_job, filename, user): name_job (str): Name of the job for this ingestion task. filename (str): Name of the file to be ingested. user (str): Identifier for the user initiating the ingestion. + retriever (str): Type of retriever to use for processing the documents. Returns: dict: Information about the completed ingestion task, including input parameters and a "limited" flag. @@ -93,9 +93,7 @@ def ingest_worker(self, directory, formats, name_job, filename, user): print(full_path, file=sys.stderr) # check if API_URL env variable is set file_data = {"name": name_job, "file": filename, "user": user} - response = requests.get( - urljoin(settings.API_URL, "/api/download"), params=file_data - ) + response = requests.get(urljoin(settings.API_URL, "/api/download"), params=file_data) # check if file is in the response print(response, file=sys.stderr) file = response.content @@ -107,9 +105,7 @@ def ingest_worker(self, directory, formats, name_job, filename, user): # check if file is .zip and extract it if filename.endswith(".zip"): - extract_zip_recursive( - os.path.join(full_path, filename), full_path, 0, recursion_depth - ) + extract_zip_recursive(os.path.join(full_path, filename), full_path, 0, recursion_depth) self.update_state(state="PROGRESS", meta={"current": 1}) @@ -130,8 +126,9 @@ def ingest_worker(self, directory, formats, name_job, filename, user): ) docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs] + id = ObjectId() - call_openai_api(docs, full_path, self) + call_openai_api(docs, full_path, id, self) tokens = count_tokens_docs(docs) self.update_state(state="PROGRESS", meta={"current": 100}) @@ -141,22 +138,15 @@ def ingest_worker(self, directory, formats, name_job, filename, user): # get files from outputs/inputs/index.faiss and outputs/inputs/index.pkl # and send them to the server (provide user and name in form) - file_data = {"name": name_job, "user": user, "tokens":tokens} + file_data = {"name": name_job, "user": user, "tokens": tokens, "retriever": retriever, "id": str(id), 'type': 'local'} if settings.VECTOR_STORE == "faiss": files = { "file_faiss": open(full_path + "/index.faiss", "rb"), "file_pkl": open(full_path + "/index.pkl", "rb"), } - response = requests.post( - urljoin(settings.API_URL, "/api/upload_index"), files=files, data=file_data - ) - response = requests.get( - urljoin(settings.API_URL, "/api/delete_old?path=" + full_path) - ) + response = requests.post(urljoin(settings.API_URL, "/api/upload_index"), files=files, data=file_data) else: - response = requests.post( - urljoin(settings.API_URL, "/api/upload_index"), data=file_data - ) + response = requests.post(urljoin(settings.API_URL, "/api/upload_index"), data=file_data) # delete local shutil.rmtree(full_path) @@ -171,7 +161,7 @@ def ingest_worker(self, directory, formats, name_job, filename, user): } -def remote_worker(self, source_data, name_job, user, loader, directory="temp"): +def remote_worker(self, source_data, name_job, user, loader, directory="temp", retriever="classic"): token_check = True min_tokens = 150 max_tokens = 1250 @@ -191,22 +181,21 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp"): token_check=token_check, ) # docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs] - call_openai_api(docs, full_path, self) tokens = count_tokens_docs(docs) + id = ObjectId() + call_openai_api(docs, full_path, id, self) self.update_state(state="PROGRESS", meta={"current": 100}) # Proceed with uploading and cleaning as in the original function - file_data = {"name": name_job, "user": user, "tokens":tokens} + file_data = {"name": name_job, "user": user, "tokens": tokens, "retriever": retriever, + "id": str(id), 'type': loader, 'remote_data': source_data} if settings.VECTOR_STORE == "faiss": files = { "file_faiss": open(full_path + "/index.faiss", "rb"), "file_pkl": open(full_path + "/index.pkl", "rb"), } - - requests.post( - urljoin(settings.API_URL, "/api/upload_index"), files=files, data=file_data - ) - requests.get(urljoin(settings.API_URL, "/api/delete_old?path=" + full_path)) + + requests.post(urljoin(settings.API_URL, "/api/upload_index"), files=files, data=file_data) else: requests.post(urljoin(settings.API_URL, "/api/upload_index"), data=file_data) @@ -222,9 +211,7 @@ def count_tokens_docs(docs): for doc in docs: docs_content += doc.page_content - tokens, total_price = num_tokens_from_string( - string=docs_content, encoding_name="cl100k_base" - ) + tokens, total_price = num_tokens_from_string(string=docs_content, encoding_name="cl100k_base") # Here we print the number of tokens and the approx user cost with some visually appealing formatting. return tokens @@ -234,4 +221,4 @@ def num_tokens_from_string(string: str, encoding_name: str) -> int: encoding = tiktoken.get_encoding(encoding_name) num_tokens = len(encoding.encode(string)) total_price = (num_tokens / 1000) * 0.0004 - return num_tokens, total_price \ No newline at end of file + return num_tokens, total_price diff --git a/frontend/src/Navigation.tsx b/frontend/src/Navigation.tsx index cbfe5d95..4a617970 100644 --- a/frontend/src/Navigation.tsx +++ b/frontend/src/Navigation.tsx @@ -23,9 +23,9 @@ import { import ConversationTile from './conversation/ConversationTile'; import { useDarkTheme, useMediaQuery, useOutsideAlerter } from './hooks'; import DeleteConvModal from './modals/DeleteConvModal'; -import { ActiveState } from './models/misc'; +import { ActiveState, Doc } from './models/misc'; import APIKeyModal from './preferences/APIKeyModal'; -import { Doc, getConversations, getDocs } from './preferences/preferenceApi'; +import { getConversations, getDocs } from './preferences/preferenceApi'; import { selectApiKeyStatus, selectConversationId, @@ -124,10 +124,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { }; const handleDeleteClick = (doc: Doc) => { - const docPath = `indexes/local/${doc.name}`; - userService - .deletePath(docPath) + .deletePath(doc.id ?? '') .then(() => { return getDocs(); }) diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts index af2fb920..c06ac3d2 100644 --- a/frontend/src/api/endpoints.ts +++ b/frontend/src/api/endpoints.ts @@ -10,7 +10,7 @@ const endpoints = { DELETE_PROMPT: '/api/delete_prompt', UPDATE_PROMPT: '/api/update_prompt', SINGLE_PROMPT: (id: string) => `/api/get_single_prompt?id=${id}`, - DELETE_PATH: (docPath: string) => `/api/delete_old?path=${docPath}`, + DELETE_PATH: (docPath: string) => `/api/delete_old?source_id=${docPath}`, TASK_STATUS: (task_id: string) => `/api/task_status?task_id=${task_id}`, }, CONVERSATION: { diff --git a/frontend/src/components/Dropdown.tsx b/frontend/src/components/Dropdown.tsx index 17516aaa..0353a191 100644 --- a/frontend/src/components/Dropdown.tsx +++ b/frontend/src/components/Dropdown.tsx @@ -26,6 +26,7 @@ function Dropdown({ | string | { label: string; value: string } | { value: number; description: string } + | { name: string; id: string; type: string } | null; onSelect: | ((value: string) => void) @@ -96,6 +97,10 @@ function Dropdown({ ? selectedValue.value + ` (${selectedValue.description})` : selectedValue.description }` + : selectedValue && + 'name' in selectedValue && + 'id' in selectedValue + ? `${selectedValue.name}` : placeholder ? placeholder : 'From URL'} diff --git a/frontend/src/components/SourceDropdown.tsx b/frontend/src/components/SourceDropdown.tsx index ce130b4d..d5146da5 100644 --- a/frontend/src/components/SourceDropdown.tsx +++ b/frontend/src/components/SourceDropdown.tsx @@ -1,7 +1,7 @@ import React from 'react'; import Trash from '../assets/trash.svg'; import Arrow2 from '../assets/dropdown-arrow.svg'; -import { Doc } from '../preferences/preferenceApi'; +import { Doc } from '../models/misc'; import { useDispatch } from 'react-redux'; import { useTranslation } from 'react-i18next'; type Props = { @@ -63,9 +63,6 @@ function SourceDropdown({

{selectedDocs?.name || 'None'}

-

- {selectedDocs?.version} -

{ - const docPath = getDocPath(selectedDocs); history = history.map((item) => { return { prompt: item.prompt, response: item.response }; }); + const payload: RetrievalPayload = { + question: question, + history: JSON.stringify(history), + conversation_id: conversationId, + prompt_id: promptId, + chunks: chunks, + token_limit: token_limit, + }; + if (selectedDocs && 'id' in selectedDocs) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs?.retriever as string; return conversationService - .answer( - { - question: question, - history: history, - active_docs: docPath, - conversation_id: conversationId, - prompt_id: promptId, - chunks: chunks, - token_limit: token_limit, - }, - signal, - ) + .answer(payload, signal) .then((response) => { if (response.ok) { return response.json(); @@ -101,24 +74,24 @@ export function handleFetchAnswerSteaming( token_limit: number, onEvent: (event: MessageEvent) => void, ): Promise { - const docPath = getDocPath(selectedDocs); history = history.map((item) => { return { prompt: item.prompt, response: item.response }; }); + const payload: RetrievalPayload = { + question: question, + history: JSON.stringify(history), + conversation_id: conversationId, + prompt_id: promptId, + chunks: chunks, + token_limit: token_limit, + }; + if (selectedDocs && 'id' in selectedDocs) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs?.retriever as string; + return new Promise((resolve, reject) => { conversationService - .answerStream( - { - question: question, - active_docs: docPath, - history: JSON.stringify(history), - conversation_id: conversationId, - prompt_id: promptId, - chunks: chunks, - token_limit: token_limit, - }, - signal, - ) + .answerStream(payload, signal) .then((response) => { if (!response.body) throw Error('No response body'); @@ -175,16 +148,21 @@ export function handleSearch( chunks: string, token_limit: number, ) { - const docPath = getDocPath(selectedDocs); + history = history.map((item) => { + return { prompt: item.prompt, response: item.response }; + }); + const payload: RetrievalPayload = { + question: question, + history: JSON.stringify(history), + conversation_id: conversation_id, + chunks: chunks, + token_limit: token_limit, + }; + if (selectedDocs && 'id' in selectedDocs) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs?.retriever as string; return conversationService - .search({ - question: question, - active_docs: docPath, - conversation_id, - history, - chunks: chunks, - token_limit: token_limit, - }) + .search(payload) .then((response) => response.json()) .then((data) => { return data; diff --git a/frontend/src/conversation/conversationModels.ts b/frontend/src/conversation/conversationModels.ts index 347a2521..bf86678b 100644 --- a/frontend/src/conversation/conversationModels.ts +++ b/frontend/src/conversation/conversationModels.ts @@ -31,3 +31,13 @@ export interface Query { conversationId?: string | null; title?: string | null; } +export interface RetrievalPayload { + question: string; + active_docs?: string; + retriever?: string; + history: string; + conversation_id: string | null; + prompt_id?: string | null; + chunks: string; + token_limit: number; +} diff --git a/frontend/src/modals/CreateAPIKeyModal.tsx b/frontend/src/modals/CreateAPIKeyModal.tsx index 2f67d83b..71d86330 100644 --- a/frontend/src/modals/CreateAPIKeyModal.tsx +++ b/frontend/src/modals/CreateAPIKeyModal.tsx @@ -22,8 +22,9 @@ export default function CreateAPIKeyModal({ const [APIKeyName, setAPIKeyName] = React.useState(''); const [sourcePath, setSourcePath] = React.useState<{ - label: string; - value: string; + name: string; + id: string; + type: string; } | null>(null); const [prompt, setPrompt] = React.useState<{ name: string; @@ -41,27 +42,17 @@ export default function CreateAPIKeyModal({ ? docs .filter((doc) => doc.model === embeddingsName) .map((doc: Doc) => { - let namePath = doc.name; - if (doc.language === namePath) { - namePath = '.project'; - } - let docPath = 'default'; - if (doc.location === 'local') { - docPath = 'local' + '/' + doc.name + '/'; - } else if (doc.location === 'remote') { - docPath = - doc.language + - '/' + - namePath + - '/' + - doc.version + - '/' + - doc.model + - '/'; + if ('id' in doc) { + return { + name: doc.name, + id: doc.id as string, + type: 'local', + }; } return { - label: doc.name, - value: docPath, + name: doc.name, + id: doc.id ?? 'default', + type: doc.type ?? 'default', }; }) : []; @@ -107,9 +98,14 @@ export default function CreateAPIKeyModal({ - setSourcePath(selection) - } + onSelect={(selection: { + name: string; + id: string; + type: string; + }) => { + setSourcePath(selection); + console.log(selection); + }} options={extractDocPaths()} size="w-full" rounded="xl" @@ -142,16 +138,22 @@ export default function CreateAPIKeyModal({