diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 873a0ad7..e873a1cf 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -9,6 +9,7 @@ import traceback from pymongo import MongoClient from bson.objectid import ObjectId +from bson.dbref import DBRef from application.core.settings import settings from application.llm.llm_creator import LLMCreator @@ -20,7 +21,7 @@ logger = logging.getLogger(__name__) mongo = MongoClient(settings.MONGO_URI) db = mongo["docsgpt"] conversations_collection = db["conversations"] -vectors_collection = db["vectors"] +sources_collection = db["sources"] prompts_collection = db["prompts"] api_key_collection = db["api_keys"] user_logs_collection = db["user_logs"] @@ -37,9 +38,7 @@ if settings.MODEL_NAME: # in case there is particular model name configured gpt_model = settings.MODEL_NAME # load the prompts -current_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) +current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: chat_combine_template = f.read() @@ -75,35 +74,34 @@ def run_async_chain(chain, question, chat_history): def get_data_from_api_key(api_key): data = api_key_collection.find_one({"key": api_key}) - # # Raise custom exception if the API key is not found if data is None: raise Exception("Invalid API Key, please generate new key", 401) + + if "retriever" not in data: + data["retriever"] = None + + if "source" in data and isinstance(data["source"], DBRef): + source_doc = db.dereference(data["source"]) + data["source"] = str(source_doc["_id"]) + if "retriever" in source_doc: + data["retriever"] = source_doc["retriever"] + else: + data["source"] = {} return data -def get_vectorstore(data): - if "active_docs" in data: - if data["active_docs"].split("/")[0] == "default": - vectorstore = "" - elif data["active_docs"].split("/")[0] == "local": - vectorstore = "indexes/" + data["active_docs"] - else: - vectorstore = "vectors/" + data["active_docs"] - if data["active_docs"] == "default": - vectorstore = "" - else: - vectorstore = "" - vectorstore = os.path.join("application", vectorstore) - return vectorstore +def get_retriever(source_id: str): + doc = sources_collection.find_one({"_id": ObjectId(source_id)}) + if doc is None: + raise Exception("Source document does not exist", 404) + retriever_name = None if "retriever" not in doc else doc["retriever"] + return retriever_name + def is_azure_configured(): - return ( - settings.OPENAI_API_BASE - and settings.OPENAI_API_VERSION - and settings.AZURE_DEPLOYMENT_NAME - ) + return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME def save_conversation(conversation_id, question, response, source_log_docs, llm): @@ -263,33 +261,33 @@ def stream(): else: token_limit = settings.DEFAULT_MAX_HISTORY - # check if active_docs or api_key is set + ## retriever can be "brave_search, duckduck_search or classic" + retriever_name = data["retriever"] if "retriever" in data else "classic" + # check if active_docs or api_key is set if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) chunks = int(data_key["chunks"]) prompt_id = data_key["prompt_id"] source = {"active_docs": data_key["source"]} + retriever_name = data_key["retriever"] or retriever_name user_api_key = data["api_key"] + elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} + source = {"active_docs" : data["active_docs"]} + retriever_name = get_retriever(data["active_docs"]) or retriever_name user_api_key = None + else: source = {} user_api_key = None - if source["active_docs"].split("/")[0] in ["default", "local"]: - retriever_name = "classic" - else: - retriever_name = source["active_docs"] - - current_app.logger.info( - f"/stream - request_data: {data}, source: {source}", - extra={"data": json.dumps({"request_data": data, "source": source})}, + current_app.logger.info(f"/stream - request_data: {data}, source: {source}", + extra={"data": json.dumps({"request_data": data, "source": source})} ) prompt = get_prompt(prompt_id) - + retriever = RetrieverCreator.create_retriever( retriever_name, question=question, @@ -369,6 +367,10 @@ def api_answer(): else: token_limit = settings.DEFAULT_MAX_HISTORY + ## retriever can be brave_search, duckduck_search or classic + retriever_name = data["retriever"] if "retriever" in data else "classic" + + # use try and except to check for exception try: # check if the vectorstore is set if "api_key" in data: @@ -376,15 +378,15 @@ def api_answer(): chunks = int(data_key["chunks"]) prompt_id = data_key["prompt_id"] source = {"active_docs": data_key["source"]} + retriever_name = data_key["retriever"] or retriever_name user_api_key = data["api_key"] - else: - source = data + elif "active_docs" in data: + source = {"active_docs":data["active_docs"]} + retriever_name = get_retriever(data["active_docs"]) or retriever_name user_api_key = None - - if source["active_docs"].split("/")[0] in ["default", "local"]: - retriever_name = "classic" else: - retriever_name = source["active_docs"] + source = {} + user_api_key = None prompt = get_prompt(prompt_id) @@ -421,8 +423,8 @@ def api_answer(): ) result = {"answer": response_full, "sources": source_log_docs} - result["conversation_id"] = save_conversation( - conversation_id, question, response_full, source_log_docs, llm + result["conversation_id"] = str( + save_conversation(conversation_id, question, response_full, source_log_docs, llm) ) retriever_params = retriever.get_params() user_logs_collection.insert_one( @@ -459,19 +461,19 @@ def api_search(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) chunks = int(data_key["chunks"]) - source = {"active_docs": data_key["source"]} - user_api_key = data["api_key"] + source = {"active_docs":data_key["source"]} + user_api_key = data_key["api_key"] elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} + source = {"active_docs":data["active_docs"]} user_api_key = None else: source = {} user_api_key = None - if source["active_docs"].split("/")[0] in ["default", "local"]: - retriever_name = "classic" + if "retriever" in data: + retriever_name = data["retriever"] else: - retriever_name = source["active_docs"] + retriever_name = "classic" if "token_limit" in data: token_limit = data["token_limit"] else: diff --git a/application/api/internal/routes.py b/application/api/internal/routes.py index 6039ecdf..cea6c8ca 100755 --- a/application/api/internal/routes.py +++ b/application/api/internal/routes.py @@ -3,13 +3,13 @@ import datetime from flask import Blueprint, request, send_from_directory from pymongo import MongoClient from werkzeug.utils import secure_filename - +from bson.objectid import ObjectId from application.core.settings import settings mongo = MongoClient(settings.MONGO_URI) db = mongo["docsgpt"] conversations_collection = db["conversations"] -vectors_collection = db["vectors"] +sources_collection = db["sources"] current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -35,7 +35,12 @@ def upload_index_files(): return {"status": "no name"} job_name = secure_filename(request.form["name"]) tokens = secure_filename(request.form["tokens"]) - save_dir = os.path.join(current_dir, "indexes", user, job_name) + retriever = secure_filename(request.form["retriever"]) + id = secure_filename(request.form["id"]) + type = secure_filename(request.form["type"]) + remote_data = secure_filename(request.form["remote_data"]) if "remote_data" in request.form else None + + save_dir = os.path.join(current_dir, "indexes", str(id)) if settings.VECTOR_STORE == "faiss": if "file_faiss" not in request.files: print("No file part") @@ -55,17 +60,19 @@ def upload_index_files(): os.makedirs(save_dir) file_faiss.save(os.path.join(save_dir, "index.faiss")) file_pkl.save(os.path.join(save_dir, "index.pkl")) - # create entry in vectors_collection - vectors_collection.insert_one( + # create entry in sources_collection + sources_collection.insert_one( { + "_id": ObjectId(id), "user": user, "name": job_name, "language": job_name, - "location": save_dir, "date": datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S"), "model": settings.EMBEDDINGS_NAME, - "type": "local", - "tokens": tokens + "type": type, + "tokens": tokens, + "retriever": retriever, + "remote_data": remote_data } ) return {"status": "ok"} \ No newline at end of file diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 1b86135c..0f72be97 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -20,7 +20,7 @@ from application.vectorstore.vector_creator import VectorCreator mongo = MongoClient(settings.MONGO_URI) db = mongo["docsgpt"] conversations_collection = db["conversations"] -vectors_collection = db["vectors"] +sources_collection = db["sources"] prompts_collection = db["prompts"] feedback_collection = db["feedback"] api_key_collection = db["api_keys"] @@ -30,9 +30,7 @@ user_logs_collection = db["user_logs"] user = Blueprint("user", __name__) -current_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) +current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) def generate_minute_range(start_date, end_date): @@ -83,9 +81,7 @@ def get_conversations(): conversations = conversations_collection.find().sort("date", -1).limit(30) list_conversations = [] for conversation in conversations: - list_conversations.append( - {"id": str(conversation["_id"]), "name": conversation["name"]} - ) + list_conversations.append({"id": str(conversation["_id"]), "name": conversation["name"]}) # list_conversations = [{"id": "default", "name": "default"}, {"id": "jeff", "name": "jeff"}] @@ -116,15 +112,10 @@ def api_feedback(): question = data["question"] answer = data["answer"] feedback = data["feedback"] - - feedback_collection.insert_one( - { - "question": question, - "answer": answer, - "feedback": feedback, - "timestamp": datetime.datetime.now(datetime.timezone.utc), - } - ) + new_doc = {"question": question, "answer": answer, "feedback": feedback, "timestamp": datetime.datetime.now(datetime.timezone.utc)} + if "api_key" in data: + new_doc["api_key"] = data["api_key"] + feedback_collection.insert_one(new_doc) return {"status": "ok"} @@ -137,7 +128,7 @@ def delete_by_ids(): return {"status": "error"} if settings.VECTOR_STORE == "faiss": - result = vectors_collection.delete_index(ids=ids) + result = sources_collection.delete_index(ids=ids) if result: return {"status": "ok"} return {"status": "error"} @@ -147,28 +138,24 @@ def delete_by_ids(): def delete_old(): """Delete old indexes.""" import shutil - - path = request.args.get("path") - dirs = path.split("/") - dirs_clean = [] - for i in range(0, len(dirs)): - dirs_clean.append(secure_filename(dirs[i])) - # check that path strats with indexes or vectors - - if dirs_clean[0] not in ["indexes", "vectors"]: - return {"status": "error"} - path_clean = "/".join(dirs_clean) - vectors_collection.delete_one({"name": dirs_clean[-1], "user": dirs_clean[-2]}) + source_id = request.args.get("source_id") + doc = sources_collection.find_one({ + "_id": ObjectId(source_id), + "user": "local", + }) + if(doc is None): + return {"status":"not found"},404 if settings.VECTOR_STORE == "faiss": try: - shutil.rmtree(os.path.join(current_dir, path_clean)) + shutil.rmtree(os.path.join(current_dir, str(doc["_id"]))) except FileNotFoundError: pass else: - vetorstore = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, path=os.path.join(current_dir, path_clean) - ) + vetorstore = VectorCreator.create_vectorstore(settings.VECTOR_STORE, source_id=str(doc["_id"])) vetorstore.delete_index() + sources_collection.delete_one({ + "_id": ObjectId(source_id), + }) return {"status": "ok"} @@ -202,9 +189,7 @@ def upload_file(): file.save(os.path.join(temp_dir, filename)) # Use shutil.make_archive to zip the temp directory - zip_path = shutil.make_archive( - base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir - ) + zip_path = shutil.make_archive(base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir) final_filename = os.path.basename(zip_path) # Clean up the temporary directory after zipping @@ -246,9 +231,7 @@ def upload_remote(): source_data = request.form["data"] if source_data: - task = ingest_remote.delay( - source_data=source_data, job_name=job_name, user=user, loader=source - ) + task = ingest_remote.delay(source_data=source_data, job_name=job_name, user=user, loader=source) task_id = task.id return {"status": "ok", "task_id": task_id} else: @@ -275,54 +258,36 @@ def combined_json(): data = [ { "name": "default", - "language": "default", - "version": "", - "description": "default", - "fullName": "default", "date": "default", - "docLink": "default", "model": settings.EMBEDDINGS_NAME, "location": "remote", "tokens": "", + "retriever": "classic", } ] # structure: name, language, version, description, fullName, date, docLink - # append data from vectors_collection in sorted order in descending order of date - for index in vectors_collection.find({"user": user}).sort("date", -1): + # append data from sources_collection in sorted order in descending order of date + for index in sources_collection.find({"user": user}).sort("date", -1): data.append( { + "id": str(index["_id"]), "name": index["name"], - "language": index["language"], - "version": "", - "description": index["name"], - "fullName": index["name"], "date": index["date"], - "docLink": index["location"], "model": settings.EMBEDDINGS_NAME, "location": "local", "tokens": index["tokens"] if ("tokens" in index.keys()) else "", + "retriever": index["retriever"] if ("retriever" in index.keys()) else "classic", } ) - if settings.VECTOR_STORE == "faiss": - data_remote = requests.get( - "https://d3dg1063dc54p9.cloudfront.net/combined.json" - ).json() - for index in data_remote: - index["location"] = "remote" - data.append(index) if "duckduck_search" in settings.RETRIEVERS_ENABLED: data.append( { "name": "DuckDuckGo Search", - "language": "en", - "version": "", - "description": "duckduck_search", - "fullName": "DuckDuckGo Search", "date": "duckduck_search", - "docLink": "duckduck_search", "model": settings.EMBEDDINGS_NAME, "location": "custom", "tokens": "", + "retriever": "duckduck_search", } ) if "brave_search" in settings.RETRIEVERS_ENABLED: @@ -330,14 +295,11 @@ def combined_json(): { "name": "Brave Search", "language": "en", - "version": "", - "description": "brave_search", - "fullName": "Brave Search", "date": "brave_search", - "docLink": "brave_search", "model": settings.EMBEDDINGS_NAME, "location": "custom", "tokens": "", + "retriever": "brave_search", } ) @@ -346,39 +308,13 @@ def combined_json(): @user.route("/api/docs_check", methods=["POST"]) def check_docs(): - # check if docs exist in a vectorstore folder data = request.get_json() - # split docs on / and take first part - if data["docs"].split("/")[0] == "local": - return {"status": "exists"} + vectorstore = "vectors/" + secure_filename(data["docs"]) - base_path = "https://raw.githubusercontent.com/arc53/DocsHUB/main/" if os.path.exists(vectorstore) or data["docs"] == "default": return {"status": "exists"} else: - file_url = urlparse(base_path + vectorstore + "index.faiss") - - if ( - file_url.scheme in ["https"] - and file_url.netloc == "raw.githubusercontent.com" - and file_url.path.startswith("/arc53/DocsHUB/main/") - ): - r = requests.get(file_url.geturl()) - if r.status_code != 200: - return {"status": "null"} - else: - if not os.path.exists(vectorstore): - os.makedirs(vectorstore) - with open(vectorstore + "index.faiss", "wb") as f: - f.write(r.content) - - r = requests.get(base_path + vectorstore + "index.pkl") - with open(vectorstore + "index.pkl", "wb") as f: - f.write(r.content) - else: - return {"status": "null"} - - return {"status": "loaded"} + return {"status": "not found"} @user.route("/api/create_prompt", methods=["POST"]) @@ -409,9 +345,7 @@ def get_prompts(): list_prompts.append({"id": "creative", "name": "creative", "type": "public"}) list_prompts.append({"id": "strict", "name": "strict", "type": "public"}) for prompt in prompts: - list_prompts.append( - {"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"} - ) + list_prompts.append({"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"}) return jsonify(list_prompts) @@ -420,21 +354,15 @@ def get_prompts(): def get_single_prompt(): prompt_id = request.args.get("id") if prompt_id == "default": - with open( - os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r" - ) as f: + with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: chat_combine_template = f.read() return jsonify({"content": chat_combine_template}) elif prompt_id == "creative": - with open( - os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r" - ) as f: + with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f: chat_reduce_creative = f.read() return jsonify({"content": chat_reduce_creative}) elif prompt_id == "strict": - with open( - os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r" - ) as f: + with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f: chat_reduce_strict = f.read() return jsonify({"content": chat_reduce_strict}) @@ -463,9 +391,7 @@ def update_prompt_name(): # check if name is null if name == "": return {"status": "error"} - prompts_collection.update_one( - {"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}} - ) + prompts_collection.update_one({"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}}) return {"status": "ok"} @@ -475,12 +401,23 @@ def get_api_keys(): keys = api_key_collection.find({"user": user}) list_keys = [] for key in keys: + if "source" in key and isinstance(key["source"],DBRef): + source = db.dereference(key["source"]) + if source is None: + continue + else: + source_name = source["name"] + elif "retriever" in key: + source_name = key["retriever"] + else: + continue + list_keys.append( { "id": str(key["_id"]), "name": key["name"], "key": key["key"][:4] + "..." + key["key"][-4:], - "source": key["source"], + "source": source_name, "prompt_id": key["prompt_id"], "chunks": key["chunks"], } @@ -492,21 +429,22 @@ def get_api_keys(): def create_api_key(): data = request.get_json() name = data["name"] - source = data["source"] prompt_id = data["prompt_id"] chunks = data["chunks"] key = str(uuid.uuid4()) user = "local" - resp = api_key_collection.insert_one( - { - "name": name, - "key": key, - "source": source, - "user": user, - "prompt_id": prompt_id, - "chunks": chunks, - } - ) + new_api_key = { + "name": name, + "key": key, + "user": user, + "prompt_id": prompt_id, + "chunks": chunks, + } + if "source" in data and ObjectId.is_valid(data["source"]): + new_api_key["source"] = DBRef("sources", ObjectId(data["source"])) + if "retriever" in data: + new_api_key["retriever"] = data["retriever"] + resp = api_key_collection.insert_one(new_api_key) new_id = str(resp.inserted_id) return {"id": new_id, "key": key} @@ -533,36 +471,37 @@ def share_conversation(): conversation_id = data["conversation_id"] isPromptable = request.args.get("isPromptable").lower() == "true" - conversation = conversations_collection.find_one( - {"_id": ObjectId(conversation_id)} - ) + conversation = conversations_collection.find_one({"_id": ObjectId(conversation_id)}) + if(conversation is None): + raise Exception("Conversation does not exist") current_n_queries = len(conversation["queries"]) ##generate binary representation of uuid explicit_binary = Binary.from_uuid(uuid.uuid4(), UuidRepresentation.STANDARD) if isPromptable: - source = "default" if "source" not in data else data["source"] prompt_id = "default" if "prompt_id" not in data else data["prompt_id"] chunks = "2" if "chunks" not in data else data["chunks"] name = conversation["name"] + "(shared)" - pre_existing_api_document = api_key_collection.find_one( - { + new_api_key_data = { "prompt_id": prompt_id, "chunks": chunks, - "source": source, "user": user, } + if "source" in data and ObjectId.is_valid(data["source"]): + new_api_key_data["source"] = DBRef("sources",ObjectId(data["source"])) + elif "retriever" in data: + new_api_key_data["retriever"] = data["retriever"] + + pre_existing_api_document = api_key_collection.find_one( + new_api_key_data ) - api_uuid = str(uuid.uuid4()) if pre_existing_api_document: api_uuid = pre_existing_api_document["key"] pre_existing = shared_conversations_collections.find_one( { - "conversation_id": DBRef( - "conversations", ObjectId(conversation_id) - ), + "conversation_id": DBRef("conversations", ObjectId(conversation_id)), "isPromptable": isPromptable, "first_n_queries": current_n_queries, "user": user, @@ -593,21 +532,18 @@ def share_conversation(): "api_key": api_uuid, } ) - return jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ) + return jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}) else: - api_key_collection.insert_one( - { - "name": name, - "key": api_uuid, - "source": source, - "user": user, - "prompt_id": prompt_id, - "chunks": chunks, - } - ) - shared_conversations_collections.insert_one( + + api_uuid = str(uuid.uuid4()) + new_api_key_data["key"] = api_uuid + new_api_key_data["name"] = name + if "source" in data and ObjectId.is_valid(data["source"]): + new_api_key_data["source"] = DBRef("sources", ObjectId(data["source"])) + if "retriever" in data: + new_api_key_data["retriever"] = data["retriever"] + api_key_collection.insert_one(new_api_key_data) + shared_conversations_collections.insert_one( { "uuid": explicit_binary, "conversation_id": { @@ -619,12 +555,10 @@ def share_conversation(): "user": user, "api_key": api_uuid, } - ) + ) ## Identifier as route parameter in frontend return ( - jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ), + jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}), 201, ) @@ -639,9 +573,7 @@ def share_conversation(): ) if pre_existing is not None: return ( - jsonify( - {"success": True, "identifier": str(pre_existing["uuid"].as_uuid())} - ), + jsonify({"success": True, "identifier": str(pre_existing["uuid"].as_uuid())}), 200, ) else: @@ -659,9 +591,7 @@ def share_conversation(): ) ## Identifier as route parameter in frontend return ( - jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ), + jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}), 201, ) except Exception as err: @@ -673,16 +603,10 @@ def share_conversation(): @user.route("/api/shared_conversation/", methods=["GET"]) def get_publicly_shared_conversations(identifier: str): try: - query_uuid = Binary.from_uuid( - uuid.UUID(identifier), UuidRepresentation.STANDARD - ) + query_uuid = Binary.from_uuid(uuid.UUID(identifier), UuidRepresentation.STANDARD) shared = shared_conversations_collections.find_one({"uuid": query_uuid}) conversation_queries = [] - if ( - shared - and "conversation_id" in shared - and isinstance(shared["conversation_id"], DBRef) - ): + if shared and "conversation_id" in shared and isinstance(shared["conversation_id"], DBRef): # Resolve the DBRef conversation_ref = shared["conversation_id"] conversation = db.dereference(conversation_ref) @@ -696,9 +620,7 @@ def get_publicly_shared_conversations(identifier: str): ), 404, ) - conversation_queries = conversation["queries"][ - : (shared["first_n_queries"]) - ] + conversation_queries = conversation["queries"][: (shared["first_n_queries"])] for query in conversation_queries: query.pop("sources") ## avoid exposing sources else: diff --git a/application/core/settings.py b/application/core/settings.py index bbd62fe4..e6173be4 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -18,7 +18,7 @@ class Settings(BaseSettings): DEFAULT_MAX_HISTORY: int = 150 MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5} UPLOAD_FOLDER: str = "inputs" - VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" + VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search API_URL: str = "http://localhost:7091" # backend url for celery worker @@ -62,6 +62,11 @@ class Settings(BaseSettings): QDRANT_PATH: Optional[str] = None QDRANT_DISTANCE_FUNC: str = "Cosine" + # Milvus vectorstore config + MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt" + MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default + MILVUS_TOKEN: Optional[str] = "" + BRAVE_SEARCH_API_KEY: Optional[str] = None FLASK_DEBUG_MODE: bool = False diff --git a/application/parser/open_ai_func.py b/application/parser/open_ai_func.py index c58e8059..84f92db9 100755 --- a/application/parser/open_ai_func.py +++ b/application/parser/open_ai_func.py @@ -11,12 +11,14 @@ from retry import retry @retry(tries=10, delay=60) -def store_add_texts_with_retry(store, i): +def store_add_texts_with_retry(store, i, id): + # add source_id to the metadata + i.metadata["source_id"] = str(id) store.add_texts([i.page_content], metadatas=[i.metadata]) # store_pine.add_texts([i.page_content], metadatas=[i.metadata]) -def call_openai_api(docs, folder_name, task_status): +def call_openai_api(docs, folder_name, id, task_status): # Function to create a vector store from the documents and save it to disk if not os.path.exists(f"{folder_name}"): @@ -32,13 +34,13 @@ def call_openai_api(docs, folder_name, task_status): store = VectorCreator.create_vectorstore( settings.VECTOR_STORE, docs_init=docs_init, - path=f"{folder_name}", + source_id=f"{folder_name}", embeddings_key=os.getenv("EMBEDDINGS_KEY"), ) else: store = VectorCreator.create_vectorstore( settings.VECTOR_STORE, - path=f"{folder_name}", + source_id=str(id), embeddings_key=os.getenv("EMBEDDINGS_KEY"), ) # Uncomment for MPNet embeddings @@ -57,7 +59,7 @@ def call_openai_api(docs, folder_name, task_status): task_status.update_state( state="PROGRESS", meta={"current": int((c1 / s1) * 100)} ) - store_add_texts_with_retry(store, i) + store_add_texts_with_retry(store, i, id) except Exception as e: print(e) print("Error on ", i) diff --git a/application/requirements.txt b/application/requirements.txt index b793934b..0bf59d29 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -9,13 +9,15 @@ EbookLib==0.18 elasticsearch==8.14.0 escodegen==1.0.11 esprima==4.0.1 -Flask==3.0.1 -faiss-cpu==1.8.0 +Flask==3.0.3 +faiss-cpu==1.8.0.post1 gunicorn==23.0.0 html2text==2020.1.16 javalang==0.13.0 -langchain==0.1.4 -langchain-openai==0.0.5 +langchain==0.2.16 +langchain-community==0.2.16 +langchain-core==0.2.38 +langchain-openai==0.1.23 openapi3_parser==1.1.16 pandas==2.2.2 pydantic_settings==2.4.0 @@ -26,9 +28,9 @@ qdrant-client==1.11.0 redis==5.0.1 Requests==2.32.0 retry==0.9.2 -sentence-transformers -tiktoken +sentence-transformers==3.0.1 +tiktoken==0.7.0 torch -tqdm==4.66.3 -transformers==4.44.0 -Werkzeug==3.0.3 +tqdm==4.66.5 +transformers==4.44.2 +Werkzeug==3.0.4 diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 88827188..b87b5852 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,4 +1,3 @@ -import os from application.retriever.base import BaseRetriever from application.core.settings import settings from application.vectorstore.vector_creator import VectorCreator @@ -21,7 +20,7 @@ class ClassicRAG(BaseRetriever): user_api_key=None, ): self.question = question - self.vectorstore = self._get_vectorstore(source=source) + self.vectorstore = source['active_docs'] if 'active_docs' in source else None self.chat_history = chat_history self.prompt = prompt self.chunks = chunks @@ -38,21 +37,6 @@ class ClassicRAG(BaseRetriever): ) self.user_api_key = user_api_key - def _get_vectorstore(self, source): - if "active_docs" in source: - if source["active_docs"].split("/")[0] == "default": - vectorstore = "" - elif source["active_docs"].split("/")[0] == "local": - vectorstore = "indexes/" + source["active_docs"] - else: - vectorstore = "vectors/" + source["active_docs"] - if source["active_docs"] == "default": - vectorstore = "" - else: - vectorstore = "" - vectorstore = os.path.join("application", vectorstore) - return vectorstore - def _get_data(self): if self.chunks == 0: docs = [] diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py index ad071401..07be373d 100644 --- a/application/retriever/retriever_creator.py +++ b/application/retriever/retriever_creator.py @@ -5,15 +5,16 @@ from application.retriever.brave_search import BraveRetSearch class RetrieverCreator: - retievers = { + retrievers = { 'classic': ClassicRAG, 'duckduck_search': DuckDuckSearch, - 'brave_search': BraveRetSearch + 'brave_search': BraveRetSearch, + 'default': ClassicRAG } @classmethod def create_retriever(cls, type, *args, **kwargs): - retiever_class = cls.retievers.get(type.lower()) + retiever_class = cls.retrievers.get(type.lower()) if not retiever_class: raise ValueError(f"No retievers class found for type {type}") return retiever_class(*args, **kwargs) \ No newline at end of file diff --git a/application/vectorstore/base.py b/application/vectorstore/base.py index 522ef4fa..9c76b89f 100644 --- a/application/vectorstore/base.py +++ b/application/vectorstore/base.py @@ -1,13 +1,30 @@ from abc import ABC, abstractmethod import os -from langchain_community.embeddings import ( - HuggingFaceEmbeddings, - CohereEmbeddings, - HuggingFaceInstructEmbeddings, -) +from sentence_transformers import SentenceTransformer from langchain_openai import OpenAIEmbeddings from application.core.settings import settings +class EmbeddingsWrapper: + def __init__(self, model_name, *args, **kwargs): + self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs) + self.dimension = self.model.get_sentence_embedding_dimension() + + def embed_query(self, query: str): + return self.model.encode(query).tolist() + + def embed_documents(self, documents: list): + return self.model.encode(documents).tolist() + + def __call__(self, text): + if isinstance(text, str): + return self.embed_query(text) + elif isinstance(text, list): + return self.embed_documents(text) + else: + raise ValueError("Input must be a string or a list of strings") + + + class EmbeddingsSingleton: _instances = {} @@ -23,16 +40,15 @@ class EmbeddingsSingleton: def _create_instance(embeddings_name, *args, **kwargs): embeddings_factory = { "openai_text-embedding-ada-002": OpenAIEmbeddings, - "huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings, - "huggingface_sentence-transformers-all-mpnet-base-v2": HuggingFaceEmbeddings, - "huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings, - "cohere_medium": CohereEmbeddings + "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"), + "huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"), + "huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"), } - if embeddings_name not in embeddings_factory: - raise ValueError(f"Invalid embeddings_name: {embeddings_name}") - - return embeddings_factory[embeddings_name](*args, **kwargs) + if embeddings_name in embeddings_factory: + return embeddings_factory[embeddings_name](*args, **kwargs) + else: + return EmbeddingsWrapper(embeddings_name, *args, **kwargs) class BaseVectorStore(ABC): def __init__(self): @@ -58,22 +74,14 @@ class BaseVectorStore(ABC): embeddings_name, openai_api_key=embeddings_key ) - elif embeddings_name == "cohere_medium": - embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - cohere_api_key=embeddings_key - ) elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2": if os.path.exists("./model/all-mpnet-base-v2"): embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - model_name="./model/all-mpnet-base-v2", - model_kwargs={"device": "cpu"} + embeddings_name="./model/all-mpnet-base-v2", ) else: embedding_instance = EmbeddingsSingleton.get_instance( embeddings_name, - model_kwargs={"device": "cpu"} ) else: embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name) diff --git a/application/vectorstore/elasticsearch.py b/application/vectorstore/elasticsearch.py index bb28d5ce..e393e4a5 100644 --- a/application/vectorstore/elasticsearch.py +++ b/application/vectorstore/elasticsearch.py @@ -9,9 +9,9 @@ import elasticsearch class ElasticsearchStore(BaseVectorStore): _es_connection = None # Class attribute to hold the Elasticsearch connection - def __init__(self, path, embeddings_key, index_name=settings.ELASTIC_INDEX): + def __init__(self, source_id, embeddings_key, index_name=settings.ELASTIC_INDEX): super().__init__() - self.path = path.replace("application/indexes/", "").rstrip("/") + self.source_id = source_id.replace("application/indexes/", "").rstrip("/") self.embeddings_key = embeddings_key self.index_name = index_name @@ -81,7 +81,7 @@ class ElasticsearchStore(BaseVectorStore): embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key) vector = embeddings.embed_query(question) knn = { - "filter": [{"match": {"metadata.store.keyword": self.path}}], + "filter": [{"match": {"metadata.source_id.keyword": self.source_id}}], "field": "vector", "k": k, "num_candidates": 100, @@ -100,7 +100,7 @@ class ElasticsearchStore(BaseVectorStore): } } ], - "filter": [{"match": {"metadata.store.keyword": self.path}}], + "filter": [{"match": {"metadata.source_id.keyword": self.source_id}}], } }, "rank": {"rrf": {}}, @@ -209,5 +209,4 @@ class ElasticsearchStore(BaseVectorStore): def delete_index(self): self._es_connection.delete_by_query(index=self.index_name, query={"match": { - "metadata.store.keyword": self.path}},) - + "metadata.source_id.keyword": self.source_id}},) diff --git a/application/vectorstore/faiss.py b/application/vectorstore/faiss.py index 8e8f3b8e..a8839cd2 100644 --- a/application/vectorstore/faiss.py +++ b/application/vectorstore/faiss.py @@ -1,12 +1,22 @@ from langchain_community.vectorstores import FAISS from application.vectorstore.base import BaseVectorStore from application.core.settings import settings +import os + +def get_vectorstore(path): + if path: + vectorstore = "indexes/"+path + vectorstore = os.path.join("application", vectorstore) + else: + vectorstore = os.path.join("application") + + return vectorstore class FaissStore(BaseVectorStore): - def __init__(self, path, embeddings_key, docs_init=None): + def __init__(self, source_id, embeddings_key, docs_init=None): super().__init__() - self.path = path + self.path = get_vectorstore(source_id) embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key) if docs_init: self.docsearch = FAISS.from_documents( @@ -14,7 +24,8 @@ class FaissStore(BaseVectorStore): ) else: self.docsearch = FAISS.load_local( - self.path, embeddings + self.path, embeddings, + allow_dangerous_deserialization=True ) self.assert_embedding_dimensions(embeddings) @@ -37,10 +48,10 @@ class FaissStore(BaseVectorStore): """ if settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2": try: - word_embedding_dimension = embeddings.client[1].word_embedding_dimension + word_embedding_dimension = embeddings.dimension except AttributeError as e: - raise AttributeError("word_embedding_dimension not found in embeddings.client[1]") from e + raise AttributeError("'dimension' attribute not found in embeddings instance. Make sure the embeddings object is properly initialized.") from e docsearch_index_dimension = self.docsearch.index.d if word_embedding_dimension != docsearch_index_dimension: - raise ValueError(f"word_embedding_dimension ({word_embedding_dimension}) " + - f"!= docsearch_index_word_embedding_dimension ({docsearch_index_dimension})") \ No newline at end of file + raise ValueError(f"Embedding dimension mismatch: embeddings.dimension ({word_embedding_dimension}) " + + f"!= docsearch index dimension ({docsearch_index_dimension})") \ No newline at end of file diff --git a/application/vectorstore/milvus.py b/application/vectorstore/milvus.py new file mode 100644 index 00000000..9871991e --- /dev/null +++ b/application/vectorstore/milvus.py @@ -0,0 +1,37 @@ +from typing import List, Optional +from uuid import uuid4 + + +from application.core.settings import settings +from application.vectorstore.base import BaseVectorStore + + +class MilvusStore(BaseVectorStore): + def __init__(self, path: str = "", embeddings_key: str = "embeddings"): + super().__init__() + from langchain_milvus import Milvus + + connection_args = { + "uri": settings.MILVUS_URI, + "token": settings.MILVUS_TOKEN, + } + self._docsearch = Milvus( + embedding_function=self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key), + collection_name=settings.MILVUS_COLLECTION_NAME, + connection_args=connection_args, + ) + self._path = path + + def search(self, question, k=2, *args, **kwargs): + return self._docsearch.similarity_search(query=question, k=k, filter={"path": self._path} *args, **kwargs) + + def add_texts(self, texts: List[str], metadatas: Optional[List[dict]], *args, **kwargs): + ids = [str(uuid4()) for _ in range(len(texts))] + + return self._docsearch.add_texts(texts=texts, metadatas=metadatas, ids=ids, *args, **kwargs) + + def save_local(self, *args, **kwargs): + pass + + def delete_index(self, *args, **kwargs): + pass diff --git a/application/vectorstore/mongodb.py b/application/vectorstore/mongodb.py index 337fc41f..32bca489 100644 --- a/application/vectorstore/mongodb.py +++ b/application/vectorstore/mongodb.py @@ -5,7 +5,7 @@ from application.vectorstore.document_class import Document class MongoDBVectorStore(BaseVectorStore): def __init__( self, - path: str = "", + source_id: str = "", embeddings_key: str = "embeddings", collection: str = "documents", index_name: str = "vector_search_index", @@ -18,7 +18,7 @@ class MongoDBVectorStore(BaseVectorStore): self._embedding_key = embedding_key self._embeddings_key = embeddings_key self._mongo_uri = settings.MONGO_URI - self._path = path.replace("application/indexes/", "").rstrip("/") + self._source_id = source_id.replace("application/indexes/", "").rstrip("/") self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key) try: @@ -46,7 +46,7 @@ class MongoDBVectorStore(BaseVectorStore): "numCandidates": k * 10, "index": self._index_name, "filter": { - "store": {"$eq": self._path} + "source_id": {"$eq": self._source_id} } } } @@ -123,4 +123,4 @@ class MongoDBVectorStore(BaseVectorStore): return result_ids def delete_index(self, *args, **kwargs): - self._collection.delete_many({"store": self._path}) \ No newline at end of file + self._collection.delete_many({"source_id": self._source_id}) \ No newline at end of file diff --git a/application/vectorstore/qdrant.py b/application/vectorstore/qdrant.py index 482d06a1..3f94505f 100644 --- a/application/vectorstore/qdrant.py +++ b/application/vectorstore/qdrant.py @@ -5,12 +5,12 @@ from qdrant_client import models class QdrantStore(BaseVectorStore): - def __init__(self, path: str = "", embeddings_key: str = "embeddings"): + def __init__(self, source_id: str = "", embeddings_key: str = "embeddings"): self._filter = models.Filter( must=[ models.FieldCondition( - key="metadata.store", - match=models.MatchValue(value=path.replace("application/indexes/", "").rstrip("/")), + key="metadata.source_id", + match=models.MatchValue(value=source_id.replace("application/indexes/", "").rstrip("/")), ) ] ) diff --git a/application/vectorstore/vector_creator.py b/application/vectorstore/vector_creator.py index 27b38645..259fa31f 100644 --- a/application/vectorstore/vector_creator.py +++ b/application/vectorstore/vector_creator.py @@ -1,5 +1,6 @@ from application.vectorstore.faiss import FaissStore from application.vectorstore.elasticsearch import ElasticsearchStore +from application.vectorstore.milvus import MilvusStore from application.vectorstore.mongodb import MongoDBVectorStore from application.vectorstore.qdrant import QdrantStore @@ -10,6 +11,7 @@ class VectorCreator: "elasticsearch": ElasticsearchStore, "mongodb": MongoDBVectorStore, "qdrant": QdrantStore, + "milvus": MilvusStore, } @classmethod diff --git a/application/worker.py b/application/worker.py index c315f916..15603908 100755 --- a/application/worker.py +++ b/application/worker.py @@ -6,6 +6,7 @@ from urllib.parse import urljoin import logging import requests +from bson.objectid import ObjectId from application.core.settings import settings from application.parser.file.bulk import SimpleDirectoryReader @@ -16,10 +17,10 @@ from application.parser.token_func import group_split from application.utils import count_tokens_docs + # Define a function to extract metadata from a given filename. def metadata_from_filename(title): - store = "/".join(title.split("/")[1:3]) - return {"title": title, "store": store} + return {"title": title} # Define a function to generate a random string of a given length. @@ -27,9 +28,7 @@ def generate_random_string(length): return "".join([string.ascii_letters[i % 52] for i in range(length)]) -current_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) +current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5): @@ -60,7 +59,7 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5): # Define the main function for ingesting and processing documents. -def ingest_worker(self, directory, formats, name_job, filename, user): +def ingest_worker(self, directory, formats, name_job, filename, user, retriever="classic"): """ Ingest and process documents. @@ -71,6 +70,7 @@ def ingest_worker(self, directory, formats, name_job, filename, user): name_job (str): Name of the job for this ingestion task. filename (str): Name of the file to be ingested. user (str): Identifier for the user initiating the ingestion. + retriever (str): Type of retriever to use for processing the documents. Returns: dict: Information about the completed ingestion task, including input parameters and a "limited" flag. @@ -106,9 +106,7 @@ def ingest_worker(self, directory, formats, name_job, filename, user): # check if file is .zip and extract it if filename.endswith(".zip"): - extract_zip_recursive( - os.path.join(full_path, filename), full_path, 0, recursion_depth - ) + extract_zip_recursive(os.path.join(full_path, filename), full_path, 0, recursion_depth) self.update_state(state="PROGRESS", meta={"current": 1}) @@ -129,8 +127,9 @@ def ingest_worker(self, directory, formats, name_job, filename, user): ) docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs] + id = ObjectId() - call_openai_api(docs, full_path, self) + call_openai_api(docs, full_path, id, self) tokens = count_tokens_docs(docs) self.update_state(state="PROGRESS", meta={"current": 100}) @@ -140,22 +139,15 @@ def ingest_worker(self, directory, formats, name_job, filename, user): # get files from outputs/inputs/index.faiss and outputs/inputs/index.pkl # and send them to the server (provide user and name in form) - file_data = {"name": name_job, "user": user, "tokens":tokens} + file_data = {"name": name_job, "user": user, "tokens": tokens, "retriever": retriever, "id": str(id), 'type': 'local'} if settings.VECTOR_STORE == "faiss": files = { "file_faiss": open(full_path + "/index.faiss", "rb"), "file_pkl": open(full_path + "/index.pkl", "rb"), } - response = requests.post( - urljoin(settings.API_URL, "/api/upload_index"), files=files, data=file_data - ) - response = requests.get( - urljoin(settings.API_URL, "/api/delete_old?path=" + full_path) - ) + response = requests.post(urljoin(settings.API_URL, "/api/upload_index"), files=files, data=file_data) else: - response = requests.post( - urljoin(settings.API_URL, "/api/upload_index"), data=file_data - ) + response = requests.post(urljoin(settings.API_URL, "/api/upload_index"), data=file_data) # delete local shutil.rmtree(full_path) @@ -170,7 +162,7 @@ def ingest_worker(self, directory, formats, name_job, filename, user): } -def remote_worker(self, source_data, name_job, user, loader, directory="temp"): +def remote_worker(self, source_data, name_job, user, loader, directory="temp", retriever="classic"): token_check = True min_tokens = 150 max_tokens = 1250 @@ -191,25 +183,24 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp"): token_check=token_check, ) # docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs] - call_openai_api(docs, full_path, self) tokens = count_tokens_docs(docs) + id = ObjectId() + call_openai_api(docs, full_path, id, self) self.update_state(state="PROGRESS", meta={"current": 100}) # Proceed with uploading and cleaning as in the original function - file_data = {"name": name_job, "user": user, "tokens":tokens} + file_data = {"name": name_job, "user": user, "tokens": tokens, "retriever": retriever, + "id": str(id), 'type': loader, 'remote_data': source_data} if settings.VECTOR_STORE == "faiss": files = { "file_faiss": open(full_path + "/index.faiss", "rb"), "file_pkl": open(full_path + "/index.pkl", "rb"), } - - requests.post( - urljoin(settings.API_URL, "/api/upload_index"), files=files, data=file_data - ) - requests.get(urljoin(settings.API_URL, "/api/delete_old?path=" + full_path)) + + requests.post(urljoin(settings.API_URL, "/api/upload_index"), files=files, data=file_data) else: requests.post(urljoin(settings.API_URL, "/api/upload_index"), data=file_data) shutil.rmtree(full_path) - return {"urls": source_data, "name_job": name_job, "user": user, "limited": False} \ No newline at end of file + return {"urls": source_data, "name_job": name_job, "user": user, "limited": False} diff --git a/docker-compose-azure.yaml b/docker-compose-azure.yaml index 70a16808..601831e5 100644 --- a/docker-compose-azure.yaml +++ b/docker-compose-azure.yaml @@ -1,5 +1,3 @@ -version: "3.9" - services: frontend: build: ./frontend diff --git a/docker-compose-dev.yaml b/docker-compose-dev.yaml index f68e4e07..8a3e75c4 100644 --- a/docker-compose-dev.yaml +++ b/docker-compose-dev.yaml @@ -1,5 +1,3 @@ -version: "3.9" - services: redis: diff --git a/docker-compose-local.yaml b/docker-compose-local.yaml index 3aebe8b5..74bf0101 100644 --- a/docker-compose-local.yaml +++ b/docker-compose-local.yaml @@ -1,5 +1,3 @@ -version: "3.9" - services: frontend: build: ./frontend diff --git a/docker-compose-mock.yaml b/docker-compose-mock.yaml index a5c7419b..b4a917c9 100644 --- a/docker-compose-mock.yaml +++ b/docker-compose-mock.yaml @@ -1,5 +1,3 @@ -version: "3.9" - services: frontend: build: ./frontend diff --git a/docker-compose.yaml b/docker-compose.yaml index 7008b53d..05c8c059 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,5 +1,3 @@ -version: "3.9" - services: frontend: build: ./frontend diff --git a/extensions/react-widget/package-lock.json b/extensions/react-widget/package-lock.json index 2ad80282..610909de 100644 --- a/extensions/react-widget/package-lock.json +++ b/extensions/react-widget/package-lock.json @@ -1,12 +1,12 @@ { "name": "docsgpt", - "version": "0.4.1", + "version": "0.4.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "docsgpt", - "version": "0.4.1", + "version": "0.4.2", "license": "Apache-2.0", "dependencies": { "@babel/plugin-transform-flow-strip-types": "^7.23.3", diff --git a/extensions/react-widget/package.json b/extensions/react-widget/package.json index d12b76c4..d449d0a3 100644 --- a/extensions/react-widget/package.json +++ b/extensions/react-widget/package.json @@ -1,6 +1,6 @@ { "name": "docsgpt", - "version": "0.4.1", + "version": "0.4.2", "private": false, "description": "DocsGPT 🦖 is an innovative open-source tool designed to simplify the retrieval of information from project documentation using advanced GPT models 🤖.", "source": "./src/index.html", diff --git a/extensions/react-widget/publish.sh b/extensions/react-widget/publish.sh index 0441d50c..c4545d85 100755 --- a/extensions/react-widget/publish.sh +++ b/extensions/react-widget/publish.sh @@ -2,6 +2,7 @@ ## chmod +x publish.sh - to upgrade ownership set -e cat package.json >> package_copy.json +cat package-lock.json >> package-lock_copy.json publish_package() { PACKAGE_NAME=$1 BUILD_COMMAND=$2 @@ -24,6 +25,9 @@ publish_package() { # Publish to npm npm publish + # Clean up + mv package_copy.json package.json + mv package-lock_copy.json package-lock.json echo "Published ${PACKAGE_NAME}" } @@ -33,7 +37,7 @@ publish_package "docsgpt" "build" # Publish docsgpt-react package publish_package "docsgpt-react" "build:react" -# Clean up -mv package_copy.json package.json + rm -rf package_copy.json +rm -rf package-lock_copy.json echo "---Process completed---" \ No newline at end of file diff --git a/extensions/react-widget/src/assets/dislike.svg b/extensions/react-widget/src/assets/dislike.svg new file mode 100644 index 00000000..ec1d24c2 --- /dev/null +++ b/extensions/react-widget/src/assets/dislike.svg @@ -0,0 +1,4 @@ + + + + diff --git a/extensions/react-widget/src/assets/like.svg b/extensions/react-widget/src/assets/like.svg new file mode 100644 index 00000000..c49604ed --- /dev/null +++ b/extensions/react-widget/src/assets/like.svg @@ -0,0 +1,4 @@ + + + + diff --git a/extensions/react-widget/src/components/DocsGPTWidget.tsx b/extensions/react-widget/src/components/DocsGPTWidget.tsx index bc6adb6e..83defbcf 100644 --- a/extensions/react-widget/src/components/DocsGPTWidget.tsx +++ b/extensions/react-widget/src/components/DocsGPTWidget.tsx @@ -1,11 +1,13 @@ "use client"; -import React from 'react' +import React, { useRef } from 'react' import DOMPurify from 'dompurify'; import styled, { keyframes, createGlobalStyle } from 'styled-components'; import { PaperPlaneIcon, RocketIcon, ExclamationTriangleIcon, Cross2Icon } from '@radix-ui/react-icons'; -import { MESSAGE_TYPE, Query, Status, WidgetProps } from '../types/index'; -import { fetchAnswerStreaming } from '../requests/streamingApi'; +import { FEEDBACK, MESSAGE_TYPE, Query, Status, WidgetProps } from '../types/index'; +import { fetchAnswerStreaming, sendFeedback } from '../requests/streamingApi'; import { ThemeProvider } from 'styled-components'; +import Like from "../assets/like.svg" +import Dislike from "../assets/dislike.svg" import MarkdownIt from 'markdown-it'; const themes = { dark: { @@ -63,6 +65,10 @@ const GlobalStyles = createGlobalStyle` background-color: #646464; color: #fff !important; } +.response code { + white-space: pre-wrap !important; + line-break: loose !important; +} `; const Overlay = styled.div` position: fixed; @@ -195,12 +201,24 @@ const Conversation = styled.div<{ size: string }>` width:${props => props.size === 'large' ? '90vw' : props.size === 'medium' ? '60vw' : '400px'} !important; } `; - +const Feedback = styled.div` + background-color: transparent; + font-weight: normal; + gap: 12px; + display: flex; + padding: 6px; + clear: both; +`; const MessageBubble = styled.div<{ type: MESSAGE_TYPE }>` - display: flex; + display: block; font-size: 16px; - justify-content: ${props => props.type === 'QUESTION' ? 'flex-end' : 'flex-start'}; - margin: 0.5rem; + position: relative; + width: 100%;; + float: right; + margin: 0rem; + &:hover ${Feedback} * { + visibility: visible !important; + } `; const Message = styled.div<{ type: MESSAGE_TYPE }>` background: ${props => props.type === 'QUESTION' ? @@ -208,6 +226,7 @@ const Message = styled.div<{ type: MESSAGE_TYPE }>` props.theme.secondary.bg}; color: ${props => props.type === 'ANSWER' ? props.theme.primary.text : '#fff'}; border: none; + float: ${props => props.type === 'QUESTION' ? 'right' : 'left'}; max-width: ${props => props.type === 'ANSWER' ? '100%' : '80'}; overflow: auto; margin: 4px; @@ -315,6 +334,7 @@ const HeroDescription = styled.p` font-size: 14px; line-height: 1.5; `; + const Hero = ({ title, description, theme }: { title: string, description: string, theme: string }) => { return ( <> @@ -345,7 +365,8 @@ export const DocsGPTWidget = ({ size = 'small', theme = 'dark', buttonIcon = 'https://d3dg1063dc54p9.cloudfront.net/widget/message.svg', - buttonBg = 'linear-gradient(to bottom right, #5AF0EC, #E80D9D)' + buttonBg = 'linear-gradient(to bottom right, #5AF0EC, #E80D9D)', + collectFeedback = true }: WidgetProps) => { const [prompt, setPrompt] = React.useState(''); const [status, setStatus] = React.useState('idle'); @@ -353,6 +374,7 @@ export const DocsGPTWidget = ({ const [conversationId, setConversationId] = React.useState(null) const [open, setOpen] = React.useState(false) const [eventInterrupt, setEventInterrupt] = React.useState(false); //click or scroll by user while autoScrolling + const isBubbleHovered = useRef(false) const endMessageRef = React.useRef(null); const md = new MarkdownIt(); @@ -376,6 +398,36 @@ export const DocsGPTWidget = ({ !eventInterrupt && scrollToBottom(endMessageRef.current); }, [queries.length, queries[queries.length - 1]?.response]); + async function handleFeedback(feedback: FEEDBACK, index: number) { + let query = queries[index] + if (!query.response) + return; + if (query.feedback != feedback) { + sendFeedback({ + question: query.prompt, + answer: query.response, + feedback: feedback, + apikey: apiKey + }, apiHost) + .then(res => { + if (res.status == 200) { + query.feedback = feedback; + setQueries((prev: Query[]) => { + return prev.map((q, i) => (i === index ? query : q)); + }); + } + }) + .catch(err => console.log("Connection failed",err)) + } + else { + delete query.feedback; + setQueries((prev: Query[]) => { + return prev.map((q, i) => (i === index ? query : q)); + }); + + } + } + async function stream(question: string) { setStatus('loading') try { @@ -473,7 +525,7 @@ export const DocsGPTWidget = ({ } { - query.response ? + query.response ? { isBubbleHovered.current = true }} type='ANSWER'> + + {collectFeedback && + + handleFeedback("LIKE", index)} /> + handleFeedback("DISLIKE", index)} /> + } :
{ @@ -518,7 +588,7 @@ export const DocsGPTWidget = ({ type='text' placeholder="What do you want to do?" /> + disabled={prompt.trim().length == 0 || status !== 'idle'}> diff --git a/extensions/react-widget/src/requests/streamingApi.ts b/extensions/react-widget/src/requests/streamingApi.ts index b594915f..9cb9fddc 100644 --- a/extensions/react-widget/src/requests/streamingApi.ts +++ b/extensions/react-widget/src/requests/streamingApi.ts @@ -1,3 +1,4 @@ +import { FEEDBACK } from "@/types"; interface HistoryItem { prompt: string; response?: string; @@ -11,6 +12,12 @@ interface FetchAnswerStreamingProps { apiHost?: string; onEvent?: (event: MessageEvent) => void; } +interface FeedbackPayload { + question: string; + answer: string; + apikey: string; + feedback: FEEDBACK; +} export function fetchAnswerStreaming({ question = '', apiKey = '', @@ -20,12 +27,12 @@ export function fetchAnswerStreaming({ onEvent = () => { console.log("Event triggered, but no handler provided."); } }: FetchAnswerStreamingProps): Promise { return new Promise((resolve, reject) => { - const body= { + const body = { question: question, history: JSON.stringify(history), conversation_id: conversationId, model: 'default', - api_key:apiKey + api_key: apiKey }; fetch(apiHost + '/stream', { method: 'POST', @@ -80,4 +87,20 @@ export function fetchAnswerStreaming({ reject(error); }); }); -} \ No newline at end of file +} + + +export const sendFeedback = (payload: FeedbackPayload,apiHost:string): Promise => { + return fetch(`${apiHost}/api/feedback`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + question: payload.question, + answer: payload.answer, + feedback: payload.feedback, + api_key:payload.apikey + }), + }); +}; \ No newline at end of file diff --git a/extensions/react-widget/src/types/index.ts b/extensions/react-widget/src/types/index.ts index cb46f06b..a55b6342 100644 --- a/extensions/react-widget/src/types/index.ts +++ b/extensions/react-widget/src/types/index.ts @@ -23,4 +23,5 @@ export interface WidgetProps { theme?:THEME, buttonIcon?:string; buttonBg?:string; + collectFeedback?:boolean } \ No newline at end of file diff --git a/frontend/src/Navigation.tsx b/frontend/src/Navigation.tsx index 6514ba41..dbb3d4e6 100644 --- a/frontend/src/Navigation.tsx +++ b/frontend/src/Navigation.tsx @@ -24,9 +24,9 @@ import ConversationTile from './conversation/ConversationTile'; import { useDarkTheme, useMediaQuery, useOutsideAlerter } from './hooks'; import useDefaultDocument from './hooks/useDefaultDocument'; import DeleteConvModal from './modals/DeleteConvModal'; -import { ActiveState } from './models/misc'; +import { ActiveState, Doc } from './models/misc'; import APIKeyModal from './preferences/APIKeyModal'; -import { Doc, getConversations, getDocs } from './preferences/preferenceApi'; +import { getConversations, getDocs } from './preferences/preferenceApi'; import { selectApiKeyStatus, selectConversationId, @@ -124,10 +124,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { }; const handleDeleteClick = (doc: Doc) => { - const docPath = `indexes/local/${doc.name}`; - userService - .deletePath(docPath) + .deletePath(doc.id ?? '') .then(() => { return getDocs(); }) diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts index 2b67a58c..742056b5 100644 --- a/frontend/src/api/endpoints.ts +++ b/frontend/src/api/endpoints.ts @@ -10,7 +10,7 @@ const endpoints = { DELETE_PROMPT: '/api/delete_prompt', UPDATE_PROMPT: '/api/update_prompt', SINGLE_PROMPT: (id: string) => `/api/get_single_prompt?id=${id}`, - DELETE_PATH: (docPath: string) => `/api/delete_old?path=${docPath}`, + DELETE_PATH: (docPath: string) => `/api/delete_old?source_id=${docPath}`, TASK_STATUS: (task_id: string) => `/api/task_status?task_id=${task_id}`, MESSAGE_ANALYTICS: '/api/get_message_analytics', TOKEN_ANALYTICS: '/api/get_token_analytics', diff --git a/frontend/src/components/Dropdown.tsx b/frontend/src/components/Dropdown.tsx index c5961aaa..3daa3911 100644 --- a/frontend/src/components/Dropdown.tsx +++ b/frontend/src/components/Dropdown.tsx @@ -27,6 +27,7 @@ function Dropdown({ | string | { label: string; value: string } | { value: number; description: string } + | { name: string; id: string; type: string } | null; onSelect: | ((value: string) => void) diff --git a/frontend/src/components/SourceDropdown.tsx b/frontend/src/components/SourceDropdown.tsx index ce130b4d..d5146da5 100644 --- a/frontend/src/components/SourceDropdown.tsx +++ b/frontend/src/components/SourceDropdown.tsx @@ -1,7 +1,7 @@ import React from 'react'; import Trash from '../assets/trash.svg'; import Arrow2 from '../assets/dropdown-arrow.svg'; -import { Doc } from '../preferences/preferenceApi'; +import { Doc } from '../models/misc'; import { useDispatch } from 'react-redux'; import { useTranslation } from 'react-i18next'; type Props = { @@ -63,9 +63,6 @@ function SourceDropdown({

{selectedDocs?.name || 'None'}

-

- {selectedDocs?.version} -

(); - const endMessageRef = useRef(null); + const conversationRef = useRef(null); const inputRef = useRef(null); const [isDarkTheme] = useDarkTheme(); const [hasScrolledToLast, setHasScrolledToLast] = useState(true); @@ -58,26 +58,6 @@ export default function Conversation() { fetchStream.current && fetchStream.current.abort(); }, [conversationId]); - useEffect(() => { - const observerCallback: IntersectionObserverCallback = (entries) => { - entries.forEach((entry) => { - setHasScrolledToLast(entry.isIntersecting); - }); - }; - - const observer = new IntersectionObserver(observerCallback, { - root: null, - threshold: [1, 0.8], - }); - if (endMessageRef.current) { - observer.observe(endMessageRef.current); - } - - return () => { - observer.disconnect(); - }; - }, [endMessageRef.current]); - useEffect(() => { if (queries.length) { queries[queries.length - 1].error && setLastQueryReturnedErr(true); @@ -86,10 +66,16 @@ export default function Conversation() { }, [queries[queries.length - 1]]); const scrollIntoView = () => { - endMessageRef?.current?.scrollIntoView({ - behavior: 'smooth', - block: 'start', - }); + if (!conversationRef?.current || eventInterrupt) return; + + if (status === 'idle' || !queries[queries.length - 1].response) { + conversationRef.current.scrollTo({ + behavior: 'smooth', + top: conversationRef.current.scrollHeight, + }); + } else { + conversationRef.current.scrollTop = conversationRef.current.scrollHeight; + } }; const handleQuestion = ({ @@ -143,7 +129,6 @@ export default function Conversation() { if (query.response) { responseView = ( )}
) : ( - + {children} ); diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts index 4e87678b..eeefd0f8 100644 --- a/frontend/src/conversation/conversationHandlers.ts +++ b/frontend/src/conversation/conversationHandlers.ts @@ -1,32 +1,6 @@ import conversationService from '../api/services/conversationService'; -import { Doc } from '../preferences/preferenceApi'; -import { Answer, FEEDBACK } from './conversationModels'; - -function getDocPath(selectedDocs: Doc | null): string { - let docPath = 'default'; - if (selectedDocs) { - let namePath = selectedDocs.name; - if (selectedDocs.language === namePath) { - namePath = '.project'; - } - if (selectedDocs.location === 'local') { - docPath = 'local' + '/' + selectedDocs.name + '/'; - } else if (selectedDocs.location === 'remote') { - docPath = - selectedDocs.language + - '/' + - namePath + - '/' + - selectedDocs.version + - '/' + - selectedDocs.model + - '/'; - } else if (selectedDocs.location === 'custom') { - docPath = selectedDocs.docLink; - } - } - return docPath; -} +import { Doc } from '../models/misc'; +import { Answer, FEEDBACK, RetrievalPayload } from './conversationModels'; export function handleFetchAnswer( question: string, @@ -54,23 +28,22 @@ export function handleFetchAnswer( title: any; } > { - const docPath = getDocPath(selectedDocs); history = history.map((item) => { return { prompt: item.prompt, response: item.response }; }); + const payload: RetrievalPayload = { + question: question, + history: JSON.stringify(history), + conversation_id: conversationId, + prompt_id: promptId, + chunks: chunks, + token_limit: token_limit, + }; + if (selectedDocs && 'id' in selectedDocs) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs?.retriever as string; return conversationService - .answer( - { - question: question, - history: history, - active_docs: docPath, - conversation_id: conversationId, - prompt_id: promptId, - chunks: chunks, - token_limit: token_limit, - }, - signal, - ) + .answer(payload, signal) .then((response) => { if (response.ok) { return response.json(); @@ -101,16 +74,27 @@ export function handleFetchAnswerSteaming( token_limit: number, onEvent: (event: MessageEvent) => void, ): Promise { - const docPath = getDocPath(selectedDocs); history = history.map((item) => { return { prompt: item.prompt, response: item.response }; }); + const payload: RetrievalPayload = { + question: question, + history: JSON.stringify(history), + conversation_id: conversationId, + prompt_id: promptId, + chunks: chunks, + token_limit: token_limit, + }; + if (selectedDocs && 'id' in selectedDocs) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs?.retriever as string; + return new Promise((resolve, reject) => { conversationService .answerStream( { question: question, - active_docs: docPath, + active_docs: selectedDocs?.id as string, history: JSON.stringify(history), conversation_id: conversationId, prompt_id: promptId, @@ -176,11 +160,23 @@ export function handleSearch( chunks: string, token_limit: number, ) { - const docPath = getDocPath(selectedDocs); + history = history.map((item) => { + return { prompt: item.prompt, response: item.response }; + }); + const payload: RetrievalPayload = { + question: question, + history: JSON.stringify(history), + conversation_id: conversation_id, + chunks: chunks, + token_limit: token_limit, + }; + if (selectedDocs && 'id' in selectedDocs) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs?.retriever as string; return conversationService .search({ question: question, - active_docs: docPath, + active_docs: selectedDocs?.id as string, conversation_id, history, chunks: chunks, diff --git a/frontend/src/conversation/conversationModels.ts b/frontend/src/conversation/conversationModels.ts index 347a2521..bf86678b 100644 --- a/frontend/src/conversation/conversationModels.ts +++ b/frontend/src/conversation/conversationModels.ts @@ -31,3 +31,13 @@ export interface Query { conversationId?: string | null; title?: string | null; } +export interface RetrievalPayload { + question: string; + active_docs?: string; + retriever?: string; + history: string; + conversation_id: string | null; + prompt_id?: string | null; + chunks: string; + token_limit: number; +} diff --git a/frontend/src/hooks/useDefaultDocument.ts b/frontend/src/hooks/useDefaultDocument.ts index 46c10473..37374ce0 100644 --- a/frontend/src/hooks/useDefaultDocument.ts +++ b/frontend/src/hooks/useDefaultDocument.ts @@ -1,7 +1,8 @@ import React from 'react'; import { useDispatch, useSelector } from 'react-redux'; -import { Doc, getDocs } from '../preferences/preferenceApi'; +import { getDocs } from '../preferences/preferenceApi'; +import { Doc } from '../models/misc'; import { selectSelectedDocs, setSelectedDocs, diff --git a/frontend/src/modals/CreateAPIKeyModal.tsx b/frontend/src/modals/CreateAPIKeyModal.tsx index 2f67d83b..71d86330 100644 --- a/frontend/src/modals/CreateAPIKeyModal.tsx +++ b/frontend/src/modals/CreateAPIKeyModal.tsx @@ -22,8 +22,9 @@ export default function CreateAPIKeyModal({ const [APIKeyName, setAPIKeyName] = React.useState(''); const [sourcePath, setSourcePath] = React.useState<{ - label: string; - value: string; + name: string; + id: string; + type: string; } | null>(null); const [prompt, setPrompt] = React.useState<{ name: string; @@ -41,27 +42,17 @@ export default function CreateAPIKeyModal({ ? docs .filter((doc) => doc.model === embeddingsName) .map((doc: Doc) => { - let namePath = doc.name; - if (doc.language === namePath) { - namePath = '.project'; - } - let docPath = 'default'; - if (doc.location === 'local') { - docPath = 'local' + '/' + doc.name + '/'; - } else if (doc.location === 'remote') { - docPath = - doc.language + - '/' + - namePath + - '/' + - doc.version + - '/' + - doc.model + - '/'; + if ('id' in doc) { + return { + name: doc.name, + id: doc.id as string, + type: 'local', + }; } return { - label: doc.name, - value: docPath, + name: doc.name, + id: doc.id ?? 'default', + type: doc.type ?? 'default', }; }) : []; @@ -107,9 +98,14 @@ export default function CreateAPIKeyModal({ - setSourcePath(selection) - } + onSelect={(selection: { + name: string; + id: string; + type: string; + }) => { + setSourcePath(selection); + console.log(selection); + }} options={extractDocPaths()} size="w-full" rounded="xl" @@ -142,16 +138,22 @@ export default function CreateAPIKeyModal({