fix(retriever):classic should not override

This commit is contained in:
ManishMadan2882
2024-08-12 15:50:16 +05:30
parent 7e8dd6bba8
commit deeffbf77d

View File

@@ -78,7 +78,7 @@ def get_data_from_api_key(api_key):
raise Exception("Invalid API Key, please generate new key", 401)
if "retriever" not in data:
data["retriever"] = "classic"
data["retriever"] = None
if "source" in data and isinstance(data["source"], DBRef):
source_doc = db.dereference(data["source"])
@@ -94,7 +94,7 @@ def get_retriever(source_id: str):
doc = vectors_collection.find_one({"_id": ObjectId(source_id)})
if doc is None:
raise Exception("Source document does not exist", 404)
retriever_name = "classic" if "retriever" not in doc else doc["retriever"]
retriever_name = None if "retriever" not in doc else doc["retriever"]
return retriever_name
@@ -255,12 +255,12 @@ def stream():
chunks = int(data_key["chunks"])
prompt_id = data_key["prompt_id"]
source = {"active_docs": data_key["source"]}
retriever_name = data_key["retriever"]
retriever_name = data_key["retriever"] or retriever_name
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs" : data["active_docs"]}
retriever_name = get_retriever(data["active_docs"])
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
else:
@@ -273,7 +273,7 @@ def stream():
retriever_name = source["active_docs"] """
prompt = get_prompt(prompt_id)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
@@ -360,10 +360,11 @@ def api_answer():
chunks = int(data_key["chunks"])
prompt_id = data_key["prompt_id"]
source = {"active_docs": data_key["source"]}
retriever_name = data_key["retriever"]
retriever_name = data_key["retriever"] or retriever_name
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs":data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
else:
source = {}