diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index e873a1cf..da9f2775 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -38,7 +38,9 @@ if settings.MODEL_NAME: # in case there is particular model name configured gpt_model = settings.MODEL_NAME # load the prompts -current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +current_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: chat_combine_template = f.read() @@ -99,9 +101,12 @@ def get_retriever(source_id: str): return retriever_name - def is_azure_configured(): - return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME + return ( + settings.OPENAI_API_BASE + and settings.OPENAI_API_VERSION + and settings.AZURE_DEPLOYMENT_NAME + ) def save_conversation(conversation_id, question, response, source_log_docs, llm): @@ -274,7 +279,7 @@ def stream(): user_api_key = data["api_key"] elif "active_docs" in data: - source = {"active_docs" : data["active_docs"]} + source = {"active_docs": data["active_docs"]} retriever_name = get_retriever(data["active_docs"]) or retriever_name user_api_key = None @@ -282,12 +287,13 @@ def stream(): source = {} user_api_key = None - current_app.logger.info(f"/stream - request_data: {data}, source: {source}", - extra={"data": json.dumps({"request_data": data, "source": source})} + current_app.logger.info( + f"/stream - request_data: {data}, source: {source}", + extra={"data": json.dumps({"request_data": data, "source": source})}, ) prompt = get_prompt(prompt_id) - + retriever = RetrieverCreator.create_retriever( retriever_name, question=question, @@ -381,7 +387,7 @@ def api_answer(): retriever_name = data_key["retriever"] or retriever_name user_api_key = data["api_key"] elif "active_docs" in data: - source = {"active_docs":data["active_docs"]} + source = {"active_docs": data["active_docs"]} retriever_name = get_retriever(data["active_docs"]) or retriever_name user_api_key = None else: @@ -424,7 +430,9 @@ def api_answer(): result = {"answer": response_full, "sources": source_log_docs} result["conversation_id"] = str( - save_conversation(conversation_id, question, response_full, source_log_docs, llm) + save_conversation( + conversation_id, question, response_full, source_log_docs, llm + ) ) retriever_params = retriever.get_params() user_logs_collection.insert_one( @@ -461,10 +469,10 @@ def api_search(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) chunks = int(data_key["chunks"]) - source = {"active_docs":data_key["source"]} + source = {"active_docs": data_key["source"]} user_api_key = data_key["api_key"] elif "active_docs" in data: - source = {"active_docs":data["active_docs"]} + source = {"active_docs": data["active_docs"]} user_api_key = None else: source = {} diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 0f72be97..5bdc5201 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -4,7 +4,6 @@ import shutil import uuid from urllib.parse import urlparse -import requests from bson.binary import Binary, UuidRepresentation from bson.dbref import DBRef from bson.objectid import ObjectId @@ -30,7 +29,9 @@ user_logs_collection = db["user_logs"] user = Blueprint("user", __name__) -current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +current_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) def generate_minute_range(start_date, end_date): @@ -81,7 +82,9 @@ def get_conversations(): conversations = conversations_collection.find().sort("date", -1).limit(30) list_conversations = [] for conversation in conversations: - list_conversations.append({"id": str(conversation["_id"]), "name": conversation["name"]}) + list_conversations.append( + {"id": str(conversation["_id"]), "name": conversation["name"]} + ) # list_conversations = [{"id": "default", "name": "default"}, {"id": "jeff", "name": "jeff"}] @@ -112,7 +115,12 @@ def api_feedback(): question = data["question"] answer = data["answer"] feedback = data["feedback"] - new_doc = {"question": question, "answer": answer, "feedback": feedback, "timestamp": datetime.datetime.now(datetime.timezone.utc)} + new_doc = { + "question": question, + "answer": answer, + "feedback": feedback, + "timestamp": datetime.datetime.now(datetime.timezone.utc), + } if "api_key" in data: new_doc["api_key"] = data["api_key"] feedback_collection.insert_one(new_doc) @@ -138,24 +146,31 @@ def delete_by_ids(): def delete_old(): """Delete old indexes.""" import shutil + source_id = request.args.get("source_id") - doc = sources_collection.find_one({ - "_id": ObjectId(source_id), - "user": "local", - }) - if(doc is None): - return {"status":"not found"},404 + doc = sources_collection.find_one( + { + "_id": ObjectId(source_id), + "user": "local", + } + ) + if doc is None: + return {"status": "not found"}, 404 if settings.VECTOR_STORE == "faiss": try: shutil.rmtree(os.path.join(current_dir, str(doc["_id"]))) except FileNotFoundError: pass else: - vetorstore = VectorCreator.create_vectorstore(settings.VECTOR_STORE, source_id=str(doc["_id"])) + vetorstore = VectorCreator.create_vectorstore( + settings.VECTOR_STORE, source_id=str(doc["_id"]) + ) vetorstore.delete_index() - sources_collection.delete_one({ - "_id": ObjectId(source_id), - }) + sources_collection.delete_one( + { + "_id": ObjectId(source_id), + } + ) return {"status": "ok"} @@ -189,7 +204,9 @@ def upload_file(): file.save(os.path.join(temp_dir, filename)) # Use shutil.make_archive to zip the temp directory - zip_path = shutil.make_archive(base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir) + zip_path = shutil.make_archive( + base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir + ) final_filename = os.path.basename(zip_path) # Clean up the temporary directory after zipping @@ -231,7 +248,9 @@ def upload_remote(): source_data = request.form["data"] if source_data: - task = ingest_remote.delay(source_data=source_data, job_name=job_name, user=user, loader=source) + task = ingest_remote.delay( + source_data=source_data, job_name=job_name, user=user, loader=source + ) task_id = task.id return {"status": "ok", "task_id": task_id} else: @@ -276,7 +295,9 @@ def combined_json(): "model": settings.EMBEDDINGS_NAME, "location": "local", "tokens": index["tokens"] if ("tokens" in index.keys()) else "", - "retriever": index["retriever"] if ("retriever" in index.keys()) else "classic", + "retriever": ( + index["retriever"] if ("retriever" in index.keys()) else "classic" + ), } ) if "duckduck_search" in settings.RETRIEVERS_ENABLED: @@ -345,7 +366,9 @@ def get_prompts(): list_prompts.append({"id": "creative", "name": "creative", "type": "public"}) list_prompts.append({"id": "strict", "name": "strict", "type": "public"}) for prompt in prompts: - list_prompts.append({"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"}) + list_prompts.append( + {"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"} + ) return jsonify(list_prompts) @@ -354,15 +377,21 @@ def get_prompts(): def get_single_prompt(): prompt_id = request.args.get("id") if prompt_id == "default": - with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: + with open( + os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r" + ) as f: chat_combine_template = f.read() return jsonify({"content": chat_combine_template}) elif prompt_id == "creative": - with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f: + with open( + os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r" + ) as f: chat_reduce_creative = f.read() return jsonify({"content": chat_reduce_creative}) elif prompt_id == "strict": - with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f: + with open( + os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r" + ) as f: chat_reduce_strict = f.read() return jsonify({"content": chat_reduce_strict}) @@ -391,7 +420,9 @@ def update_prompt_name(): # check if name is null if name == "": return {"status": "error"} - prompts_collection.update_one({"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}}) + prompts_collection.update_one( + {"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}} + ) return {"status": "ok"} @@ -401,7 +432,7 @@ def get_api_keys(): keys = api_key_collection.find({"user": user}) list_keys = [] for key in keys: - if "source" in key and isinstance(key["source"],DBRef): + if "source" in key and isinstance(key["source"], DBRef): source = db.dereference(key["source"]) if source is None: continue @@ -411,7 +442,7 @@ def get_api_keys(): source_name = key["retriever"] else: continue - + list_keys.append( { "id": str(key["_id"]), @@ -471,8 +502,10 @@ def share_conversation(): conversation_id = data["conversation_id"] isPromptable = request.args.get("isPromptable").lower() == "true" - conversation = conversations_collection.find_one({"_id": ObjectId(conversation_id)}) - if(conversation is None): + conversation = conversations_collection.find_one( + {"_id": ObjectId(conversation_id)} + ) + if conversation is None: raise Exception("Conversation does not exist") current_n_queries = len(conversation["queries"]) @@ -484,24 +517,24 @@ def share_conversation(): chunks = "2" if "chunks" not in data else data["chunks"] name = conversation["name"] + "(shared)" - new_api_key_data = { - "prompt_id": prompt_id, - "chunks": chunks, - "user": user, - } + new_api_key_data = { + "prompt_id": prompt_id, + "chunks": chunks, + "user": user, + } if "source" in data and ObjectId.is_valid(data["source"]): - new_api_key_data["source"] = DBRef("sources",ObjectId(data["source"])) + new_api_key_data["source"] = DBRef("sources", ObjectId(data["source"])) elif "retriever" in data: new_api_key_data["retriever"] = data["retriever"] - - pre_existing_api_document = api_key_collection.find_one( - new_api_key_data - ) + + pre_existing_api_document = api_key_collection.find_one(new_api_key_data) if pre_existing_api_document: api_uuid = pre_existing_api_document["key"] pre_existing = shared_conversations_collections.find_one( { - "conversation_id": DBRef("conversations", ObjectId(conversation_id)), + "conversation_id": DBRef( + "conversations", ObjectId(conversation_id) + ), "isPromptable": isPromptable, "first_n_queries": current_n_queries, "user": user, @@ -532,33 +565,39 @@ def share_conversation(): "api_key": api_uuid, } ) - return jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}) + return jsonify( + {"success": True, "identifier": str(explicit_binary.as_uuid())} + ) else: - + api_uuid = str(uuid.uuid4()) new_api_key_data["key"] = api_uuid new_api_key_data["name"] = name if "source" in data and ObjectId.is_valid(data["source"]): - new_api_key_data["source"] = DBRef("sources", ObjectId(data["source"])) + new_api_key_data["source"] = DBRef( + "sources", ObjectId(data["source"]) + ) if "retriever" in data: new_api_key_data["retriever"] = data["retriever"] api_key_collection.insert_one(new_api_key_data) shared_conversations_collections.insert_one( - { - "uuid": explicit_binary, - "conversation_id": { - "$ref": "conversations", - "$id": ObjectId(conversation_id), - }, - "isPromptable": isPromptable, - "first_n_queries": current_n_queries, - "user": user, - "api_key": api_uuid, - } - ) + { + "uuid": explicit_binary, + "conversation_id": { + "$ref": "conversations", + "$id": ObjectId(conversation_id), + }, + "isPromptable": isPromptable, + "first_n_queries": current_n_queries, + "user": user, + "api_key": api_uuid, + } + ) ## Identifier as route parameter in frontend return ( - jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}), + jsonify( + {"success": True, "identifier": str(explicit_binary.as_uuid())} + ), 201, ) @@ -573,7 +612,9 @@ def share_conversation(): ) if pre_existing is not None: return ( - jsonify({"success": True, "identifier": str(pre_existing["uuid"].as_uuid())}), + jsonify( + {"success": True, "identifier": str(pre_existing["uuid"].as_uuid())} + ), 200, ) else: @@ -591,7 +632,9 @@ def share_conversation(): ) ## Identifier as route parameter in frontend return ( - jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}), + jsonify( + {"success": True, "identifier": str(explicit_binary.as_uuid())} + ), 201, ) except Exception as err: @@ -603,10 +646,16 @@ def share_conversation(): @user.route("/api/shared_conversation/", methods=["GET"]) def get_publicly_shared_conversations(identifier: str): try: - query_uuid = Binary.from_uuid(uuid.UUID(identifier), UuidRepresentation.STANDARD) + query_uuid = Binary.from_uuid( + uuid.UUID(identifier), UuidRepresentation.STANDARD + ) shared = shared_conversations_collections.find_one({"uuid": query_uuid}) conversation_queries = [] - if shared and "conversation_id" in shared and isinstance(shared["conversation_id"], DBRef): + if ( + shared + and "conversation_id" in shared + and isinstance(shared["conversation_id"], DBRef) + ): # Resolve the DBRef conversation_ref = shared["conversation_id"] conversation = db.dereference(conversation_ref) @@ -620,7 +669,9 @@ def get_publicly_shared_conversations(identifier: str): ), 404, ) - conversation_queries = conversation["queries"][: (shared["first_n_queries"])] + conversation_queries = conversation["queries"][ + : (shared["first_n_queries"]) + ] for query in conversation_queries: query.pop("sources") ## avoid exposing sources else: