diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index f076285d..85cc3afd 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -77,23 +77,23 @@ def get_data_from_api_key(api_key): if data is None: raise Exception("Invalid API Key, please generate new key", 401) - if isinstance(data["source"], DBRef): - source_id = db.dereference(data["source"])["_id"] - data["source"] = get_source(source_id) + if "retriever" not in data: + data["retriever"] = "classic" + if "source" in data and isinstance(data["source"], DBRef): + source_doc = db.dereference(data["source"]) + data["source"] = str(source_doc._id) + if "retriever" in source_doc: + data["retriever"] = source_doc["retriever"] return data -def get_source(active_doc): - if ObjectId.is_valid(active_doc): - doc = vectors_collection.find_one({"_id": ObjectId(active_doc)}) - if doc is None: - raise Exception("Source document does not exist", 404) - print("res", doc) - source = {"active_docs": "/".join(doc["location"].split("/")[-2:])} - else: - source = {"active_docs": active_doc} - return source +def get_retriever(source_id: str): + doc = vectors_collection.find_one({"_id": ObjectId(source_id)}) + if doc is None: + raise Exception("Source document does not exist", 404) + retriever_name = "classic" if "retriever" not in doc else doc["retriever"] + return retriever_name def get_vectorstore(data): @@ -244,25 +244,31 @@ def stream(): else: token_limit = settings.DEFAULT_MAX_HISTORY - # check if active_docs or api_key is set + ## retriever can be "brave_search, duckduck_search or classic" + retriever_name = data["retriever"] if "retriever" in data else "classic" + # check if active_docs or api_key is set if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) chunks = int(data_key["chunks"]) prompt_id = data_key["prompt_id"] - source = data_key["source"] + source = {"active_docs": data_key["source"]} + retriever_name = data_key["retriever"] user_api_key = data["api_key"] + elif "active_docs" in data: - source = get_source(data["active_docs"]) + source = {"active_docs" : data["active_docs"]} + retriever_name = get_retriever(data["active_docs"]) user_api_key = None + else: source = {} user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + """ if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": retriever_name = "classic" else: - retriever_name = source["active_docs"] + retriever_name = source["active_docs"] """ prompt = get_prompt(prompt_id) @@ -341,6 +347,9 @@ def api_answer(): else: token_limit = settings.DEFAULT_MAX_HISTORY + ## retriever can be brave_search, duckduck_search or classic + retriever_name = data["retriever"] if "retriever" in data else "classic" + # use try and except to check for exception try: # check if the vectorstore is set @@ -350,15 +359,10 @@ def api_answer(): prompt_id = data_key["prompt_id"] source = data_key["source"] user_api_key = data["api_key"] - else: - source = get_source(data["active_docs"]) + elif "active_docs" in data: + source = data["active_docs"] user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": - retriever_name = "classic" - else: - retriever_name = source["active_docs"] - prompt = get_prompt(prompt_id) retriever = RetrieverCreator.create_retriever( @@ -410,16 +414,16 @@ def api_search(): source = data_key["source"] user_api_key = data_key["api_key"] elif "active_docs" in data: - source = get_source(data["active_docs"]) + source = data["active_docs"] user_api_key = None else: source = {} user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": - retriever_name = "classic" + if "retriever" in data: + retriever_name = data["retriever"] else: - retriever_name = source["active_docs"] + retriever_name = "classic" if "token_limit" in data: token_limit = data["token_limit"] else: diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 06bab591..aab30469 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -25,9 +25,7 @@ shared_conversations_collections = db["shared_conversations"] user = Blueprint("user", __name__) -current_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) +current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @user.route("/api/delete_conversation", methods=["POST"]) @@ -57,9 +55,7 @@ def get_conversations(): conversations = conversations_collection.find().sort("date", -1).limit(30) list_conversations = [] for conversation in conversations: - list_conversations.append( - {"id": str(conversation["_id"]), "name": conversation["name"]} - ) + list_conversations.append({"id": str(conversation["_id"]), "name": conversation["name"]}) # list_conversations = [{"id": "default", "name": "default"}, {"id": "jeff", "name": "jeff"}] @@ -138,9 +134,7 @@ def delete_old(): except FileNotFoundError: pass else: - vetorstore = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, path=os.path.join(current_dir, path_clean) - ) + vetorstore = VectorCreator.create_vectorstore(settings.VECTOR_STORE, path=os.path.join(current_dir, path_clean)) vetorstore.delete_index() return {"status": "ok"} @@ -175,9 +169,7 @@ def upload_file(): file.save(os.path.join(temp_dir, filename)) # Use shutil.make_archive to zip the temp directory - zip_path = shutil.make_archive( - base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir - ) + zip_path = shutil.make_archive(base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir) final_filename = os.path.basename(zip_path) # Clean up the temporary directory after zipping @@ -219,9 +211,7 @@ def upload_remote(): source_data = request.form["data"] if source_data: - task = ingest_remote.delay( - source_data=source_data, job_name=job_name, user=user, loader=source - ) + task = ingest_remote.delay(source_data=source_data, job_name=job_name, user=user, loader=source) task_id = task.id return {"status": "ok", "task_id": task_id} else: @@ -264,7 +254,7 @@ def combined_json(): for index in vectors_collection.find({"user": user}).sort("date", -1): data.append( { - "id":str(index["_id"]), + "id": str(index["_id"]), "name": index["name"], "language": index["language"], "version": "", @@ -278,9 +268,7 @@ def combined_json(): } ) if settings.VECTOR_STORE == "faiss": - data_remote = requests.get( - "https://d3dg1063dc54p9.cloudfront.net/combined.json" - ).json() + data_remote = requests.get("https://d3dg1063dc54p9.cloudfront.net/combined.json").json() for index in data_remote: index["location"] = "remote" data.append(index) @@ -383,9 +371,7 @@ def get_prompts(): list_prompts.append({"id": "creative", "name": "creative", "type": "public"}) list_prompts.append({"id": "strict", "name": "strict", "type": "public"}) for prompt in prompts: - list_prompts.append( - {"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"} - ) + list_prompts.append({"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"}) return jsonify(list_prompts) @@ -394,21 +380,15 @@ def get_prompts(): def get_single_prompt(): prompt_id = request.args.get("id") if prompt_id == "default": - with open( - os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r" - ) as f: + with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: chat_combine_template = f.read() return jsonify({"content": chat_combine_template}) elif prompt_id == "creative": - with open( - os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r" - ) as f: + with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f: chat_reduce_creative = f.read() return jsonify({"content": chat_reduce_creative}) elif prompt_id == "strict": - with open( - os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r" - ) as f: + with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f: chat_reduce_strict = f.read() return jsonify({"content": chat_reduce_strict}) @@ -437,9 +417,7 @@ def update_prompt_name(): # check if name is null if name == "": return {"status": "error"} - prompts_collection.update_one( - {"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}} - ) + prompts_collection.update_one({"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}}) return {"status": "ok"} @@ -449,12 +427,15 @@ def get_api_keys(): keys = api_key_collection.find({"user": user}) list_keys = [] for key in keys: + source_name = ( + db.dereference(key["source"])["name"] if isinstance(key["source"], DBRef) else key["source"].split("/")[0] + ) list_keys.append( { "id": str(key["_id"]), "name": key["name"], "key": key["key"][:4] + "..." + key["key"][-4:], - "source": str(key["source"]), + "source": source_name, "prompt_id": key["prompt_id"], "chunks": key["chunks"], } @@ -466,23 +447,22 @@ def get_api_keys(): def create_api_key(): data = request.get_json() name = data["name"] - source = data["source"] prompt_id = data["prompt_id"] chunks = data["chunks"] key = str(uuid.uuid4()) user = "local" - if(ObjectId.is_valid(data["source"])): - source = DBRef("vectors",ObjectId(data["source"])) - resp = api_key_collection.insert_one( - { - "name": name, - "key": key, - "source": source, - "user": user, - "prompt_id": prompt_id, - "chunks": chunks, - } - ) + new_api_key = { + "name": name, + "key": key, + "user": user, + "prompt_id": prompt_id, + "chunks": chunks, + } + if "source" in data and ObjectId.is_valid(data["source"]): + new_api_key["source"] = DBRef("vectors", ObjectId(data["source"])) + if "retriever" in data: + new_api_key["retriever"] = data["retriever"] + resp = api_key_collection.insert_one(new_api_key) new_id = str(resp.inserted_id) return {"id": new_id, "key": key} @@ -509,9 +489,7 @@ def share_conversation(): conversation_id = data["conversation_id"] isPromptable = request.args.get("isPromptable").lower() == "true" - conversation = conversations_collection.find_one( - {"_id": ObjectId(conversation_id)} - ) + conversation = conversations_collection.find_one({"_id": ObjectId(conversation_id)}) current_n_queries = len(conversation["queries"]) ##generate binary representation of uuid @@ -527,7 +505,7 @@ def share_conversation(): { "prompt_id": prompt_id, "chunks": chunks, - "source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source, + "source": DBRef("vectors", ObjectId(source)) if ObjectId.is_valid(source) else source, "user": user, } ) @@ -536,9 +514,7 @@ def share_conversation(): api_uuid = pre_existing_api_document["key"] pre_existing = shared_conversations_collections.find_one( { - "conversation_id": DBRef( - "conversations", ObjectId(conversation_id) - ), + "conversation_id": DBRef("conversations", ObjectId(conversation_id)), "isPromptable": isPromptable, "first_n_queries": current_n_queries, "user": user, @@ -569,15 +545,13 @@ def share_conversation(): "api_key": api_uuid, } ) - return jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ) + return jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}) else: api_key_collection.insert_one( { "name": name, "key": api_uuid, - "source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source, + "source": DBRef("vectors", ObjectId(source)) if ObjectId.is_valid(source) else source, "user": user, "prompt_id": prompt_id, "chunks": chunks, @@ -598,9 +572,7 @@ def share_conversation(): ) ## Identifier as route parameter in frontend return ( - jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ), + jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}), 201, ) @@ -615,9 +587,7 @@ def share_conversation(): ) if pre_existing is not None: return ( - jsonify( - {"success": True, "identifier": str(pre_existing["uuid"].as_uuid())} - ), + jsonify({"success": True, "identifier": str(pre_existing["uuid"].as_uuid())}), 200, ) else: @@ -635,9 +605,7 @@ def share_conversation(): ) ## Identifier as route parameter in frontend return ( - jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ), + jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}), 201, ) except Exception as err: @@ -649,16 +617,10 @@ def share_conversation(): @user.route("/api/shared_conversation/", methods=["GET"]) def get_publicly_shared_conversations(identifier: str): try: - query_uuid = Binary.from_uuid( - uuid.UUID(identifier), UuidRepresentation.STANDARD - ) + query_uuid = Binary.from_uuid(uuid.UUID(identifier), UuidRepresentation.STANDARD) shared = shared_conversations_collections.find_one({"uuid": query_uuid}) conversation_queries = [] - if ( - shared - and "conversation_id" in shared - and isinstance(shared["conversation_id"], DBRef) - ): + if shared and "conversation_id" in shared and isinstance(shared["conversation_id"], DBRef): # Resolve the DBRef conversation_ref = shared["conversation_id"] conversation = db.dereference(conversation_ref) @@ -672,9 +634,7 @@ def get_publicly_shared_conversations(identifier: str): ), 404, ) - conversation_queries = conversation["queries"][ - : (shared["first_n_queries"]) - ] + conversation_queries = conversation["queries"][: (shared["first_n_queries"])] for query in conversation_queries: query.pop("sources") ## avoid exposing sources else: