diff --git a/application/app.py b/application/app.py index d9e6deeb..7350c498 100644 --- a/application/app.py +++ b/application/app.py @@ -2,6 +2,7 @@ import datetime import json import os import traceback +import asyncio import dotenv import requests @@ -97,6 +98,20 @@ mongo = MongoClient(app.config['MONGO_URI']) db = mongo["docsgpt"] vectors_collection = db["vectors"] +async def async_generate(chain, question, chat_history): + result = await chain.arun({"question": question, "chat_history": chat_history}) + return result + +def run_async_chain(chain, question, chat_history): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + result = {} + try: + answer = loop.run_until_complete(async_generate(chain, question, chat_history)) + finally: + loop.close() + result["answer"] = answer + return result @celery.task(bind=True) def ingest(self, directory, formats, name_job, filename, user): @@ -197,7 +212,9 @@ def api_answer(): combine_docs_chain=doc_chain, ) chat_history = [] - result = chain({"question": question, "chat_history": chat_history}) + #result = chain({"question": question, "chat_history": chat_history}) + # generate async with async generate method + result = run_async_chain(chain, question, chat_history) else: qa_chain = load_qa_chain(llm=llm, chain_type="map_reduce", combine_prompt=c_prompt, question_prompt=q_prompt)