Add support for setting the number of chunks processed per query

This commit is contained in:
Alex
2024-03-22 14:50:56 +00:00
parent add2db5b7a
commit ed08123550
7 changed files with 76 additions and 9 deletions

View File

@@ -98,7 +98,7 @@ def is_azure_configured():
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conversation_id):
def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conversation_id, chunks=2):
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)
if prompt_id == 'default':
@@ -109,8 +109,11 @@ def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conve
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
docs = docsearch.search(question, k=2)
if chunks == 0:
docs = []
else:
docs = docsearch.search(question, k=chunks)
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
# join all page_content together with a newline
@@ -193,6 +196,10 @@ def stream():
prompt_id = data["prompt_id"]
else:
prompt_id = 'default'
if 'chunks' in data:
chunks = int(data["chunks"])
else:
chunks = 2
# check if active_docs is set
@@ -214,7 +221,8 @@ def stream():
complete_stream(question, docsearch,
chat_history=history, api_key=api_key,
prompt_id=prompt_id,
conversation_id=conversation_id), mimetype="text/event-stream"
conversation_id=conversation_id,
chunks=chunks), mimetype="text/event-stream"
)
@@ -240,6 +248,10 @@ def api_answer():
prompt_id = data["prompt_id"]
else:
prompt_id = 'default'
if 'chunks' in data:
chunks = int(data["chunks"])
else:
chunks = 2
if prompt_id == 'default':
prompt = chat_combine_template
@@ -263,7 +275,10 @@ def api_answer():
docs = docsearch.search(question, k=2)
if chunks == 0:
docs = []
else:
docs = docsearch.search(question, k=chunks)
# join all page_content together with a newline
docs_together = "\n".join([doc.page_content for doc in docs])
p_chat_combine = prompt.replace("{summaries}", docs_together)
@@ -362,9 +377,15 @@ def api_search():
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
else:
vectorstore = ""
if 'chunks' in data:
chunks = int(data["chunks"])
else:
chunks = 2
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key)
docs = docsearch.search(question, k=2)
if chunks == 0:
docs = []
else:
docs = docsearch.search(question, k=chunks)
source_log_docs = []
for doc in docs: