fix: linting issue

This commit is contained in:
Siddhant Rai
2024-09-11 18:01:23 +05:30
parent 72e68a163c
commit dbf2cabd38
2 changed files with 128 additions and 69 deletions

View File

@@ -4,7 +4,6 @@ import shutil
import uuid
from urllib.parse import urlparse
import requests
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
from bson.objectid import ObjectId
@@ -30,7 +29,9 @@ user_logs_collection = db["user_logs"]
user = Blueprint("user", __name__)
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
def generate_minute_range(start_date, end_date):
@@ -81,7 +82,9 @@ def get_conversations():
conversations = conversations_collection.find().sort("date", -1).limit(30)
list_conversations = []
for conversation in conversations:
list_conversations.append({"id": str(conversation["_id"]), "name": conversation["name"]})
list_conversations.append(
{"id": str(conversation["_id"]), "name": conversation["name"]}
)
# list_conversations = [{"id": "default", "name": "default"}, {"id": "jeff", "name": "jeff"}]
@@ -112,7 +115,12 @@ def api_feedback():
question = data["question"]
answer = data["answer"]
feedback = data["feedback"]
new_doc = {"question": question, "answer": answer, "feedback": feedback, "timestamp": datetime.datetime.now(datetime.timezone.utc)}
new_doc = {
"question": question,
"answer": answer,
"feedback": feedback,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
if "api_key" in data:
new_doc["api_key"] = data["api_key"]
feedback_collection.insert_one(new_doc)
@@ -138,24 +146,31 @@ def delete_by_ids():
def delete_old():
"""Delete old indexes."""
import shutil
source_id = request.args.get("source_id")
doc = sources_collection.find_one({
"_id": ObjectId(source_id),
"user": "local",
})
if(doc is None):
return {"status":"not found"},404
doc = sources_collection.find_one(
{
"_id": ObjectId(source_id),
"user": "local",
}
)
if doc is None:
return {"status": "not found"}, 404
if settings.VECTOR_STORE == "faiss":
try:
shutil.rmtree(os.path.join(current_dir, str(doc["_id"])))
except FileNotFoundError:
pass
else:
vetorstore = VectorCreator.create_vectorstore(settings.VECTOR_STORE, source_id=str(doc["_id"]))
vetorstore = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id=str(doc["_id"])
)
vetorstore.delete_index()
sources_collection.delete_one({
"_id": ObjectId(source_id),
})
sources_collection.delete_one(
{
"_id": ObjectId(source_id),
}
)
return {"status": "ok"}
@@ -189,7 +204,9 @@ def upload_file():
file.save(os.path.join(temp_dir, filename))
# Use shutil.make_archive to zip the temp directory
zip_path = shutil.make_archive(base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir)
zip_path = shutil.make_archive(
base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir
)
final_filename = os.path.basename(zip_path)
# Clean up the temporary directory after zipping
@@ -231,7 +248,9 @@ def upload_remote():
source_data = request.form["data"]
if source_data:
task = ingest_remote.delay(source_data=source_data, job_name=job_name, user=user, loader=source)
task = ingest_remote.delay(
source_data=source_data, job_name=job_name, user=user, loader=source
)
task_id = task.id
return {"status": "ok", "task_id": task_id}
else:
@@ -276,7 +295,9 @@ def combined_json():
"model": settings.EMBEDDINGS_NAME,
"location": "local",
"tokens": index["tokens"] if ("tokens" in index.keys()) else "",
"retriever": index["retriever"] if ("retriever" in index.keys()) else "classic",
"retriever": (
index["retriever"] if ("retriever" in index.keys()) else "classic"
),
}
)
if "duckduck_search" in settings.RETRIEVERS_ENABLED:
@@ -345,7 +366,9 @@ def get_prompts():
list_prompts.append({"id": "creative", "name": "creative", "type": "public"})
list_prompts.append({"id": "strict", "name": "strict", "type": "public"})
for prompt in prompts:
list_prompts.append({"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"})
list_prompts.append(
{"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"}
)
return jsonify(list_prompts)
@@ -354,15 +377,21 @@ def get_prompts():
def get_single_prompt():
prompt_id = request.args.get("id")
if prompt_id == "default":
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
with open(
os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r"
) as f:
chat_combine_template = f.read()
return jsonify({"content": chat_combine_template})
elif prompt_id == "creative":
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
with open(
os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r"
) as f:
chat_reduce_creative = f.read()
return jsonify({"content": chat_reduce_creative})
elif prompt_id == "strict":
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
with open(
os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r"
) as f:
chat_reduce_strict = f.read()
return jsonify({"content": chat_reduce_strict})
@@ -391,7 +420,9 @@ def update_prompt_name():
# check if name is null
if name == "":
return {"status": "error"}
prompts_collection.update_one({"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}})
prompts_collection.update_one(
{"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}}
)
return {"status": "ok"}
@@ -401,7 +432,7 @@ def get_api_keys():
keys = api_key_collection.find({"user": user})
list_keys = []
for key in keys:
if "source" in key and isinstance(key["source"],DBRef):
if "source" in key and isinstance(key["source"], DBRef):
source = db.dereference(key["source"])
if source is None:
continue
@@ -411,7 +442,7 @@ def get_api_keys():
source_name = key["retriever"]
else:
continue
list_keys.append(
{
"id": str(key["_id"]),
@@ -471,8 +502,10 @@ def share_conversation():
conversation_id = data["conversation_id"]
isPromptable = request.args.get("isPromptable").lower() == "true"
conversation = conversations_collection.find_one({"_id": ObjectId(conversation_id)})
if(conversation is None):
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}
)
if conversation is None:
raise Exception("Conversation does not exist")
current_n_queries = len(conversation["queries"])
@@ -484,24 +517,24 @@ def share_conversation():
chunks = "2" if "chunks" not in data else data["chunks"]
name = conversation["name"] + "(shared)"
new_api_key_data = {
"prompt_id": prompt_id,
"chunks": chunks,
"user": user,
}
new_api_key_data = {
"prompt_id": prompt_id,
"chunks": chunks,
"user": user,
}
if "source" in data and ObjectId.is_valid(data["source"]):
new_api_key_data["source"] = DBRef("sources",ObjectId(data["source"]))
new_api_key_data["source"] = DBRef("sources", ObjectId(data["source"]))
elif "retriever" in data:
new_api_key_data["retriever"] = data["retriever"]
pre_existing_api_document = api_key_collection.find_one(
new_api_key_data
)
pre_existing_api_document = api_key_collection.find_one(new_api_key_data)
if pre_existing_api_document:
api_uuid = pre_existing_api_document["key"]
pre_existing = shared_conversations_collections.find_one(
{
"conversation_id": DBRef("conversations", ObjectId(conversation_id)),
"conversation_id": DBRef(
"conversations", ObjectId(conversation_id)
),
"isPromptable": isPromptable,
"first_n_queries": current_n_queries,
"user": user,
@@ -532,33 +565,39 @@ def share_conversation():
"api_key": api_uuid,
}
)
return jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())})
return jsonify(
{"success": True, "identifier": str(explicit_binary.as_uuid())}
)
else:
api_uuid = str(uuid.uuid4())
new_api_key_data["key"] = api_uuid
new_api_key_data["name"] = name
if "source" in data and ObjectId.is_valid(data["source"]):
new_api_key_data["source"] = DBRef("sources", ObjectId(data["source"]))
new_api_key_data["source"] = DBRef(
"sources", ObjectId(data["source"])
)
if "retriever" in data:
new_api_key_data["retriever"] = data["retriever"]
api_key_collection.insert_one(new_api_key_data)
shared_conversations_collections.insert_one(
{
"uuid": explicit_binary,
"conversation_id": {
"$ref": "conversations",
"$id": ObjectId(conversation_id),
},
"isPromptable": isPromptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
}
)
{
"uuid": explicit_binary,
"conversation_id": {
"$ref": "conversations",
"$id": ObjectId(conversation_id),
},
"isPromptable": isPromptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
}
)
## Identifier as route parameter in frontend
return (
jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}),
jsonify(
{"success": True, "identifier": str(explicit_binary.as_uuid())}
),
201,
)
@@ -573,7 +612,9 @@ def share_conversation():
)
if pre_existing is not None:
return (
jsonify({"success": True, "identifier": str(pre_existing["uuid"].as_uuid())}),
jsonify(
{"success": True, "identifier": str(pre_existing["uuid"].as_uuid())}
),
200,
)
else:
@@ -591,7 +632,9 @@ def share_conversation():
)
## Identifier as route parameter in frontend
return (
jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}),
jsonify(
{"success": True, "identifier": str(explicit_binary.as_uuid())}
),
201,
)
except Exception as err:
@@ -603,10 +646,16 @@ def share_conversation():
@user.route("/api/shared_conversation/<string:identifier>", methods=["GET"])
def get_publicly_shared_conversations(identifier: str):
try:
query_uuid = Binary.from_uuid(uuid.UUID(identifier), UuidRepresentation.STANDARD)
query_uuid = Binary.from_uuid(
uuid.UUID(identifier), UuidRepresentation.STANDARD
)
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
conversation_queries = []
if shared and "conversation_id" in shared and isinstance(shared["conversation_id"], DBRef):
if (
shared
and "conversation_id" in shared
and isinstance(shared["conversation_id"], DBRef)
):
# Resolve the DBRef
conversation_ref = shared["conversation_id"]
conversation = db.dereference(conversation_ref)
@@ -620,7 +669,9 @@ def get_publicly_shared_conversations(identifier: str):
),
404,
)
conversation_queries = conversation["queries"][: (shared["first_n_queries"])]
conversation_queries = conversation["queries"][
: (shared["first_n_queries"])
]
for query in conversation_queries:
query.pop("sources") ## avoid exposing sources
else: