store only local docs as location

This commit is contained in:
ManishMadan2882
2024-08-09 18:27:54 +05:30
parent f9dbaa9407
commit 3c6fd365fb
2 changed files with 72 additions and 108 deletions

View File

@@ -77,23 +77,23 @@ def get_data_from_api_key(api_key):
if data is None:
raise Exception("Invalid API Key, please generate new key", 401)
if isinstance(data["source"], DBRef):
source_id = db.dereference(data["source"])["_id"]
data["source"] = get_source(source_id)
if "retriever" not in data:
data["retriever"] = "classic"
if "source" in data and isinstance(data["source"], DBRef):
source_doc = db.dereference(data["source"])
data["source"] = str(source_doc._id)
if "retriever" in source_doc:
data["retriever"] = source_doc["retriever"]
return data
def get_source(active_doc):
if ObjectId.is_valid(active_doc):
doc = vectors_collection.find_one({"_id": ObjectId(active_doc)})
if doc is None:
raise Exception("Source document does not exist", 404)
print("res", doc)
source = {"active_docs": "/".join(doc["location"].split("/")[-2:])}
else:
source = {"active_docs": active_doc}
return source
def get_retriever(source_id: str):
doc = vectors_collection.find_one({"_id": ObjectId(source_id)})
if doc is None:
raise Exception("Source document does not exist", 404)
retriever_name = "classic" if "retriever" not in doc else doc["retriever"]
return retriever_name
def get_vectorstore(data):
@@ -244,25 +244,31 @@ def stream():
else:
token_limit = settings.DEFAULT_MAX_HISTORY
# check if active_docs or api_key is set
## retriever can be "brave_search, duckduck_search or classic"
retriever_name = data["retriever"] if "retriever" in data else "classic"
# check if active_docs or api_key is set
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key["chunks"])
prompt_id = data_key["prompt_id"]
source = data_key["source"]
source = {"active_docs": data_key["source"]}
retriever_name = data_key["retriever"]
user_api_key = data["api_key"]
elif "active_docs" in data:
source = get_source(data["active_docs"])
source = {"active_docs" : data["active_docs"]}
retriever_name = get_retriever(data["active_docs"])
user_api_key = None
else:
source = {}
user_api_key = None
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
""" if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
else:
retriever_name = source["active_docs"]
retriever_name = source["active_docs"] """
prompt = get_prompt(prompt_id)
@@ -341,6 +347,9 @@ def api_answer():
else:
token_limit = settings.DEFAULT_MAX_HISTORY
## retriever can be brave_search, duckduck_search or classic
retriever_name = data["retriever"] if "retriever" in data else "classic"
# use try and except to check for exception
try:
# check if the vectorstore is set
@@ -350,15 +359,10 @@ def api_answer():
prompt_id = data_key["prompt_id"]
source = data_key["source"]
user_api_key = data["api_key"]
else:
source = get_source(data["active_docs"])
elif "active_docs" in data:
source = data["active_docs"]
user_api_key = None
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
else:
retriever_name = source["active_docs"]
prompt = get_prompt(prompt_id)
retriever = RetrieverCreator.create_retriever(
@@ -410,16 +414,16 @@ def api_search():
source = data_key["source"]
user_api_key = data_key["api_key"]
elif "active_docs" in data:
source = get_source(data["active_docs"])
source = data["active_docs"]
user_api_key = None
else:
source = {}
user_api_key = None
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
if "retriever" in data:
retriever_name = data["retriever"]
else:
retriever_name = source["active_docs"]
retriever_name = "classic"
if "token_limit" in data:
token_limit = data["token_limit"]
else:

View File

@@ -25,9 +25,7 @@ shared_conversations_collections = db["shared_conversations"]
user = Blueprint("user", __name__)
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@user.route("/api/delete_conversation", methods=["POST"])
@@ -57,9 +55,7 @@ def get_conversations():
conversations = conversations_collection.find().sort("date", -1).limit(30)
list_conversations = []
for conversation in conversations:
list_conversations.append(
{"id": str(conversation["_id"]), "name": conversation["name"]}
)
list_conversations.append({"id": str(conversation["_id"]), "name": conversation["name"]})
# list_conversations = [{"id": "default", "name": "default"}, {"id": "jeff", "name": "jeff"}]
@@ -138,9 +134,7 @@ def delete_old():
except FileNotFoundError:
pass
else:
vetorstore = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, path=os.path.join(current_dir, path_clean)
)
vetorstore = VectorCreator.create_vectorstore(settings.VECTOR_STORE, path=os.path.join(current_dir, path_clean))
vetorstore.delete_index()
return {"status": "ok"}
@@ -175,9 +169,7 @@ def upload_file():
file.save(os.path.join(temp_dir, filename))
# Use shutil.make_archive to zip the temp directory
zip_path = shutil.make_archive(
base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir
)
zip_path = shutil.make_archive(base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir)
final_filename = os.path.basename(zip_path)
# Clean up the temporary directory after zipping
@@ -219,9 +211,7 @@ def upload_remote():
source_data = request.form["data"]
if source_data:
task = ingest_remote.delay(
source_data=source_data, job_name=job_name, user=user, loader=source
)
task = ingest_remote.delay(source_data=source_data, job_name=job_name, user=user, loader=source)
task_id = task.id
return {"status": "ok", "task_id": task_id}
else:
@@ -264,7 +254,7 @@ def combined_json():
for index in vectors_collection.find({"user": user}).sort("date", -1):
data.append(
{
"id":str(index["_id"]),
"id": str(index["_id"]),
"name": index["name"],
"language": index["language"],
"version": "",
@@ -278,9 +268,7 @@ def combined_json():
}
)
if settings.VECTOR_STORE == "faiss":
data_remote = requests.get(
"https://d3dg1063dc54p9.cloudfront.net/combined.json"
).json()
data_remote = requests.get("https://d3dg1063dc54p9.cloudfront.net/combined.json").json()
for index in data_remote:
index["location"] = "remote"
data.append(index)
@@ -383,9 +371,7 @@ def get_prompts():
list_prompts.append({"id": "creative", "name": "creative", "type": "public"})
list_prompts.append({"id": "strict", "name": "strict", "type": "public"})
for prompt in prompts:
list_prompts.append(
{"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"}
)
list_prompts.append({"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"})
return jsonify(list_prompts)
@@ -394,21 +380,15 @@ def get_prompts():
def get_single_prompt():
prompt_id = request.args.get("id")
if prompt_id == "default":
with open(
os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r"
) as f:
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
chat_combine_template = f.read()
return jsonify({"content": chat_combine_template})
elif prompt_id == "creative":
with open(
os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r"
) as f:
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
chat_reduce_creative = f.read()
return jsonify({"content": chat_reduce_creative})
elif prompt_id == "strict":
with open(
os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r"
) as f:
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
chat_reduce_strict = f.read()
return jsonify({"content": chat_reduce_strict})
@@ -437,9 +417,7 @@ def update_prompt_name():
# check if name is null
if name == "":
return {"status": "error"}
prompts_collection.update_one(
{"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}}
)
prompts_collection.update_one({"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}})
return {"status": "ok"}
@@ -449,12 +427,15 @@ def get_api_keys():
keys = api_key_collection.find({"user": user})
list_keys = []
for key in keys:
source_name = (
db.dereference(key["source"])["name"] if isinstance(key["source"], DBRef) else key["source"].split("/")[0]
)
list_keys.append(
{
"id": str(key["_id"]),
"name": key["name"],
"key": key["key"][:4] + "..." + key["key"][-4:],
"source": str(key["source"]),
"source": source_name,
"prompt_id": key["prompt_id"],
"chunks": key["chunks"],
}
@@ -466,23 +447,22 @@ def get_api_keys():
def create_api_key():
data = request.get_json()
name = data["name"]
source = data["source"]
prompt_id = data["prompt_id"]
chunks = data["chunks"]
key = str(uuid.uuid4())
user = "local"
if(ObjectId.is_valid(data["source"])):
source = DBRef("vectors",ObjectId(data["source"]))
resp = api_key_collection.insert_one(
{
"name": name,
"key": key,
"source": source,
"user": user,
"prompt_id": prompt_id,
"chunks": chunks,
}
)
new_api_key = {
"name": name,
"key": key,
"user": user,
"prompt_id": prompt_id,
"chunks": chunks,
}
if "source" in data and ObjectId.is_valid(data["source"]):
new_api_key["source"] = DBRef("vectors", ObjectId(data["source"]))
if "retriever" in data:
new_api_key["retriever"] = data["retriever"]
resp = api_key_collection.insert_one(new_api_key)
new_id = str(resp.inserted_id)
return {"id": new_id, "key": key}
@@ -509,9 +489,7 @@ def share_conversation():
conversation_id = data["conversation_id"]
isPromptable = request.args.get("isPromptable").lower() == "true"
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}
)
conversation = conversations_collection.find_one({"_id": ObjectId(conversation_id)})
current_n_queries = len(conversation["queries"])
##generate binary representation of uuid
@@ -527,7 +505,7 @@ def share_conversation():
{
"prompt_id": prompt_id,
"chunks": chunks,
"source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source,
"source": DBRef("vectors", ObjectId(source)) if ObjectId.is_valid(source) else source,
"user": user,
}
)
@@ -536,9 +514,7 @@ def share_conversation():
api_uuid = pre_existing_api_document["key"]
pre_existing = shared_conversations_collections.find_one(
{
"conversation_id": DBRef(
"conversations", ObjectId(conversation_id)
),
"conversation_id": DBRef("conversations", ObjectId(conversation_id)),
"isPromptable": isPromptable,
"first_n_queries": current_n_queries,
"user": user,
@@ -569,15 +545,13 @@ def share_conversation():
"api_key": api_uuid,
}
)
return jsonify(
{"success": True, "identifier": str(explicit_binary.as_uuid())}
)
return jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())})
else:
api_key_collection.insert_one(
{
"name": name,
"key": api_uuid,
"source": DBRef("vectors",ObjectId(source)) if ObjectId.is_valid(source) else source,
"source": DBRef("vectors", ObjectId(source)) if ObjectId.is_valid(source) else source,
"user": user,
"prompt_id": prompt_id,
"chunks": chunks,
@@ -598,9 +572,7 @@ def share_conversation():
)
## Identifier as route parameter in frontend
return (
jsonify(
{"success": True, "identifier": str(explicit_binary.as_uuid())}
),
jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}),
201,
)
@@ -615,9 +587,7 @@ def share_conversation():
)
if pre_existing is not None:
return (
jsonify(
{"success": True, "identifier": str(pre_existing["uuid"].as_uuid())}
),
jsonify({"success": True, "identifier": str(pre_existing["uuid"].as_uuid())}),
200,
)
else:
@@ -635,9 +605,7 @@ def share_conversation():
)
## Identifier as route parameter in frontend
return (
jsonify(
{"success": True, "identifier": str(explicit_binary.as_uuid())}
),
jsonify({"success": True, "identifier": str(explicit_binary.as_uuid())}),
201,
)
except Exception as err:
@@ -649,16 +617,10 @@ def share_conversation():
@user.route("/api/shared_conversation/<string:identifier>", methods=["GET"])
def get_publicly_shared_conversations(identifier: str):
try:
query_uuid = Binary.from_uuid(
uuid.UUID(identifier), UuidRepresentation.STANDARD
)
query_uuid = Binary.from_uuid(uuid.UUID(identifier), UuidRepresentation.STANDARD)
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
conversation_queries = []
if (
shared
and "conversation_id" in shared
and isinstance(shared["conversation_id"], DBRef)
):
if shared and "conversation_id" in shared and isinstance(shared["conversation_id"], DBRef):
# Resolve the DBRef
conversation_ref = shared["conversation_id"]
conversation = db.dereference(conversation_ref)
@@ -672,9 +634,7 @@ def get_publicly_shared_conversations(identifier: str):
),
404,
)
conversation_queries = conversation["queries"][
: (shared["first_n_queries"])
]
conversation_queries = conversation["queries"][: (shared["first_n_queries"])]
for query in conversation_queries:
query.pop("sources") ## avoid exposing sources
else: