fix: linting issue

This commit is contained in:
Siddhant Rai
2024-09-11 18:01:23 +05:30
parent 72e68a163c
commit dbf2cabd38
2 changed files with 128 additions and 69 deletions

View File

@@ -38,7 +38,9 @@ if settings.MODEL_NAME: # in case there is particular model name configured
gpt_model = settings.MODEL_NAME
# load the prompts
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
chat_combine_template = f.read()
@@ -99,9 +101,12 @@ def get_retriever(source_id: str):
return retriever_name
def is_azure_configured():
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
def save_conversation(conversation_id, question, response, source_log_docs, llm):
@@ -274,7 +279,7 @@ def stream():
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs" : data["active_docs"]}
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
@@ -282,12 +287,13 @@ def stream():
source = {}
user_api_key = None
current_app.logger.info(f"/stream - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})}
current_app.logger.info(
f"/stream - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
prompt = get_prompt(prompt_id)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
@@ -381,7 +387,7 @@ def api_answer():
retriever_name = data_key["retriever"] or retriever_name
user_api_key = data["api_key"]
elif "active_docs" in data:
source = {"active_docs":data["active_docs"]}
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
else:
@@ -424,7 +430,9 @@ def api_answer():
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
save_conversation(conversation_id, question, response_full, source_log_docs, llm)
save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
)
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
@@ -461,10 +469,10 @@ def api_search():
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key["chunks"])
source = {"active_docs":data_key["source"]}
source = {"active_docs": data_key["source"]}
user_api_key = data_key["api_key"]
elif "active_docs" in data:
source = {"active_docs":data["active_docs"]}
source = {"active_docs": data["active_docs"]}
user_api_key = None
else:
source = {}