mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge pull request #915 from arc53/feat/retrievers-class
Update application files and fix LLM models, create new retriever class
This commit is contained in:
@@ -8,13 +8,12 @@ import traceback
|
|||||||
|
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from bson.objectid import ObjectId
|
from bson.objectid import ObjectId
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.vectorstore.vector_creator import VectorCreator
|
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
|
from application.retriever.retriever_creator import RetrieverCreator
|
||||||
from application.error import bad_request
|
from application.error import bad_request
|
||||||
|
|
||||||
|
|
||||||
@@ -62,9 +61,6 @@ async def async_generate(chain, question, chat_history):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def count_tokens(string):
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
|
|
||||||
return len(tokenizer(string)['input_ids'])
|
|
||||||
|
|
||||||
|
|
||||||
def run_async_chain(chain, question, chat_history):
|
def run_async_chain(chain, question, chat_history):
|
||||||
@@ -104,61 +100,11 @@ def get_vectorstore(data):
|
|||||||
def is_azure_configured():
|
def is_azure_configured():
|
||||||
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
||||||
|
|
||||||
|
def save_conversation(conversation_id, question, response, source_log_docs, llm):
|
||||||
def complete_stream(question, docsearch, chat_history, prompt_id, conversation_id, chunks=2):
|
if conversation_id is not None and conversation_id != "None":
|
||||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
|
||||||
if prompt_id == 'default':
|
|
||||||
prompt = chat_combine_template
|
|
||||||
elif prompt_id == 'creative':
|
|
||||||
prompt = chat_combine_creative
|
|
||||||
elif prompt_id == 'strict':
|
|
||||||
prompt = chat_combine_strict
|
|
||||||
else:
|
|
||||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
|
||||||
|
|
||||||
if chunks == 0:
|
|
||||||
docs = []
|
|
||||||
else:
|
|
||||||
docs = docsearch.search(question, k=chunks)
|
|
||||||
if settings.LLM_NAME == "llama.cpp":
|
|
||||||
docs = [docs[0]]
|
|
||||||
# join all page_content together with a newline
|
|
||||||
docs_together = "\n".join([doc.page_content for doc in docs])
|
|
||||||
p_chat_combine = prompt.replace("{summaries}", docs_together)
|
|
||||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
|
||||||
source_log_docs = []
|
|
||||||
for doc in docs:
|
|
||||||
if doc.metadata:
|
|
||||||
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
|
|
||||||
else:
|
|
||||||
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
|
|
||||||
|
|
||||||
if len(chat_history) > 1:
|
|
||||||
tokens_current_history = 0
|
|
||||||
# count tokens in history
|
|
||||||
chat_history.reverse()
|
|
||||||
for i in chat_history:
|
|
||||||
if "prompt" in i and "response" in i:
|
|
||||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
|
|
||||||
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
|
|
||||||
tokens_current_history += tokens_batch
|
|
||||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
|
||||||
messages_combine.append({"role": "system", "content": i["response"]})
|
|
||||||
messages_combine.append({"role": "user", "content": question})
|
|
||||||
|
|
||||||
response_full = ""
|
|
||||||
completion = llm.gen_stream(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
|
|
||||||
messages=messages_combine)
|
|
||||||
for line in completion:
|
|
||||||
data = json.dumps({"answer": str(line)})
|
|
||||||
response_full += str(line)
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
|
|
||||||
# save conversation to database
|
|
||||||
if conversation_id is not None:
|
|
||||||
conversations_collection.update_one(
|
conversations_collection.update_one(
|
||||||
{"_id": ObjectId(conversation_id)},
|
{"_id": ObjectId(conversation_id)},
|
||||||
{"$push": {"queries": {"prompt": question, "response": response_full, "sources": source_log_docs}}},
|
{"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}},
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -168,19 +114,50 @@ def complete_stream(question, docsearch, chat_history, prompt_id, conversation_i
|
|||||||
"words, respond ONLY with the summary, use the same "
|
"words, respond ONLY with the summary, use the same "
|
||||||
"language as the system \n\nUser: " + question + "\n\n" +
|
"language as the system \n\nUser: " + question + "\n\n" +
|
||||||
"AI: " +
|
"AI: " +
|
||||||
response_full},
|
response},
|
||||||
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
|
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
|
||||||
"respond ONLY with the summary, use the same language as the "
|
"respond ONLY with the summary, use the same language as the "
|
||||||
"system"}]
|
"system"}]
|
||||||
|
|
||||||
completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
|
completion = llm.gen(model=gpt_model,
|
||||||
messages=messages_summary, max_tokens=30)
|
messages=messages_summary, max_tokens=30)
|
||||||
conversation_id = conversations_collection.insert_one(
|
conversation_id = conversations_collection.insert_one(
|
||||||
{"user": "local",
|
{"user": "local",
|
||||||
"date": datetime.datetime.utcnow(),
|
"date": datetime.datetime.utcnow(),
|
||||||
"name": completion,
|
"name": completion,
|
||||||
"queries": [{"prompt": question, "response": response_full, "sources": source_log_docs}]}
|
"queries": [{"prompt": question, "response": response, "sources": source_log_docs}]}
|
||||||
).inserted_id
|
).inserted_id
|
||||||
|
return conversation_id
|
||||||
|
|
||||||
|
def get_prompt(prompt_id):
|
||||||
|
if prompt_id == 'default':
|
||||||
|
prompt = chat_combine_template
|
||||||
|
elif prompt_id == 'creative':
|
||||||
|
prompt = chat_combine_creative
|
||||||
|
elif prompt_id == 'strict':
|
||||||
|
prompt = chat_combine_strict
|
||||||
|
else:
|
||||||
|
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def complete_stream(question, retriever, conversation_id):
|
||||||
|
|
||||||
|
|
||||||
|
response_full = ""
|
||||||
|
source_log_docs = []
|
||||||
|
answer = retriever.gen()
|
||||||
|
for line in answer:
|
||||||
|
if "answer" in line:
|
||||||
|
response_full += str(line["answer"])
|
||||||
|
data = json.dumps(line)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
elif "source" in line:
|
||||||
|
source_log_docs.append(line["source"])
|
||||||
|
|
||||||
|
|
||||||
|
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
||||||
|
conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
|
||||||
|
|
||||||
# send data.type = "end" to indicate that the stream has ended as json
|
# send data.type = "end" to indicate that the stream has ended as json
|
||||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||||
@@ -213,25 +190,31 @@ def stream():
|
|||||||
chunks = int(data["chunks"])
|
chunks = int(data["chunks"])
|
||||||
else:
|
else:
|
||||||
chunks = 2
|
chunks = 2
|
||||||
|
|
||||||
|
prompt = get_prompt(prompt_id)
|
||||||
|
|
||||||
# check if active_docs is set
|
# check if active_docs is set
|
||||||
|
|
||||||
if "api_key" in data:
|
if "api_key" in data:
|
||||||
data_key = get_data_from_api_key(data["api_key"])
|
data_key = get_data_from_api_key(data["api_key"])
|
||||||
vectorstore = get_vectorstore({"active_docs": data_key["source"]})
|
source = {"active_docs": data_key["source"]}
|
||||||
elif "active_docs" in data:
|
elif "active_docs" in data:
|
||||||
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
|
source = {"active_docs": data["active_docs"]}
|
||||||
else:
|
else:
|
||||||
vectorstore = ""
|
source = {}
|
||||||
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
|
|
||||||
|
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
|
||||||
|
retriever_name = "classic"
|
||||||
|
else:
|
||||||
|
retriever_name = source['active_docs']
|
||||||
|
|
||||||
|
retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
|
||||||
|
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
|
||||||
|
)
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
complete_stream(question, docsearch,
|
complete_stream(question=question, retriever=retriever,
|
||||||
chat_history=history,
|
conversation_id=conversation_id), mimetype="text/event-stream")
|
||||||
prompt_id=prompt_id,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
chunks=chunks), mimetype="text/event-stream"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@answer.route("/api/answer", methods=["POST"])
|
@answer.route("/api/answer", methods=["POST"])
|
||||||
@@ -255,110 +238,40 @@ def api_answer():
|
|||||||
chunks = int(data["chunks"])
|
chunks = int(data["chunks"])
|
||||||
else:
|
else:
|
||||||
chunks = 2
|
chunks = 2
|
||||||
|
|
||||||
if prompt_id == 'default':
|
prompt = get_prompt(prompt_id)
|
||||||
prompt = chat_combine_template
|
|
||||||
elif prompt_id == 'creative':
|
|
||||||
prompt = chat_combine_creative
|
|
||||||
elif prompt_id == 'strict':
|
|
||||||
prompt = chat_combine_strict
|
|
||||||
else:
|
|
||||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
|
||||||
|
|
||||||
# use try and except to check for exception
|
# use try and except to check for exception
|
||||||
try:
|
try:
|
||||||
# check if the vectorstore is set
|
# check if the vectorstore is set
|
||||||
if "api_key" in data:
|
if "api_key" in data:
|
||||||
data_key = get_data_from_api_key(data["api_key"])
|
data_key = get_data_from_api_key(data["api_key"])
|
||||||
vectorstore = get_vectorstore({"active_docs": data_key["source"]})
|
source = {"active_docs": data_key["source"]}
|
||||||
else:
|
else:
|
||||||
vectorstore = get_vectorstore(data)
|
source = {data}
|
||||||
# loading the index and the store and the prompt template
|
|
||||||
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
|
|
||||||
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
|
|
||||||
|
|
||||||
|
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
|
||||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
retriever_name = "classic"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if chunks == 0:
|
|
||||||
docs = []
|
|
||||||
else:
|
else:
|
||||||
docs = docsearch.search(question, k=chunks)
|
retriever_name = source['active_docs']
|
||||||
# join all page_content together with a newline
|
|
||||||
docs_together = "\n".join([doc.page_content for doc in docs])
|
retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
|
||||||
p_chat_combine = prompt.replace("{summaries}", docs_together)
|
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
|
||||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
)
|
||||||
source_log_docs = []
|
source_log_docs = []
|
||||||
for doc in docs:
|
response_full = ""
|
||||||
if doc.metadata:
|
for line in retriever.gen():
|
||||||
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
|
if "source" in line:
|
||||||
else:
|
source_log_docs.append(line["source"])
|
||||||
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
|
elif "answer" in line:
|
||||||
# join all page_content together with a newline
|
response_full += line["answer"]
|
||||||
|
|
||||||
|
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
result = {"answer": response_full, "sources": source_log_docs}
|
||||||
|
result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
|
||||||
|
|
||||||
if len(history) > 1:
|
|
||||||
tokens_current_history = 0
|
|
||||||
# count tokens in history
|
|
||||||
history.reverse()
|
|
||||||
for i in history:
|
|
||||||
if "prompt" in i and "response" in i:
|
|
||||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
|
|
||||||
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
|
|
||||||
tokens_current_history += tokens_batch
|
|
||||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
|
||||||
messages_combine.append({"role": "system", "content": i["response"]})
|
|
||||||
messages_combine.append({"role": "user", "content": question})
|
|
||||||
|
|
||||||
|
|
||||||
completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
|
|
||||||
messages=messages_combine)
|
|
||||||
|
|
||||||
|
|
||||||
result = {"answer": completion, "sources": source_log_docs}
|
|
||||||
logger.debug(result)
|
|
||||||
|
|
||||||
# generate conversationId
|
|
||||||
if conversation_id is not None:
|
|
||||||
conversations_collection.update_one(
|
|
||||||
{"_id": ObjectId(conversation_id)},
|
|
||||||
{"$push": {"queries": {"prompt": question,
|
|
||||||
"response": result["answer"], "sources": result['sources']}}},
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# create new conversation
|
|
||||||
# generate summary
|
|
||||||
messages_summary = [
|
|
||||||
{"role": "assistant", "content": "Summarise following conversation in no more than 3 words, "
|
|
||||||
"respond ONLY with the summary, use the same language as the system \n\n"
|
|
||||||
"User: " + question + "\n\n" + "AI: " + result["answer"]},
|
|
||||||
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
|
|
||||||
"respond ONLY with the summary, use the same language as the system"}
|
|
||||||
]
|
|
||||||
|
|
||||||
completion = llm.gen(
|
|
||||||
model=gpt_model,
|
|
||||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
|
||||||
messages=messages_summary,
|
|
||||||
max_tokens=30
|
|
||||||
)
|
|
||||||
conversation_id = conversations_collection.insert_one(
|
|
||||||
{"user": "local",
|
|
||||||
"date": datetime.datetime.utcnow(),
|
|
||||||
"name": completion,
|
|
||||||
"queries": [{"prompt": question, "response": result["answer"], "sources": source_log_docs}]}
|
|
||||||
).inserted_id
|
|
||||||
|
|
||||||
result["conversation_id"] = str(conversation_id)
|
|
||||||
|
|
||||||
# mock result
|
|
||||||
# result = {
|
|
||||||
# "answer": "The answer is 42",
|
|
||||||
# "sources": ["https://en.wikipedia.org/wiki/42_(number)", "https://en.wikipedia.org/wiki/42_(number)"]
|
|
||||||
# }
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# print whole traceback
|
# print whole traceback
|
||||||
@@ -375,27 +288,24 @@ def api_search():
|
|||||||
|
|
||||||
if "api_key" in data:
|
if "api_key" in data:
|
||||||
data_key = get_data_from_api_key(data["api_key"])
|
data_key = get_data_from_api_key(data["api_key"])
|
||||||
vectorstore = data_key["source"]
|
source = {"active_docs": data_key["source"]}
|
||||||
elif "active_docs" in data:
|
elif "active_docs" in data:
|
||||||
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
|
source = {"active_docs": data["active_docs"]}
|
||||||
else:
|
else:
|
||||||
vectorstore = ""
|
source = {}
|
||||||
if 'chunks' in data:
|
if 'chunks' in data:
|
||||||
chunks = int(data["chunks"])
|
chunks = int(data["chunks"])
|
||||||
else:
|
else:
|
||||||
chunks = 2
|
chunks = 2
|
||||||
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
|
|
||||||
if chunks == 0:
|
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
|
||||||
docs = []
|
retriever_name = "classic"
|
||||||
else:
|
else:
|
||||||
docs = docsearch.search(question, k=chunks)
|
retriever_name = source['active_docs']
|
||||||
|
|
||||||
source_log_docs = []
|
retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
|
||||||
for doc in docs:
|
source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model
|
||||||
if doc.metadata:
|
)
|
||||||
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
|
docs = retriever.search()
|
||||||
else:
|
return docs
|
||||||
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
|
|
||||||
#yield f"data:{data}\n\n"
|
|
||||||
return source_log_docs
|
|
||||||
|
|
||||||
|
|||||||
@@ -251,6 +251,34 @@ def combined_json():
|
|||||||
for index in data_remote:
|
for index in data_remote:
|
||||||
index["location"] = "remote"
|
index["location"] = "remote"
|
||||||
data.append(index)
|
data.append(index)
|
||||||
|
if 'duckduck_search' in settings.RETRIEVERS_ENABLED:
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"name": "DuckDuckGo Search",
|
||||||
|
"language": "en",
|
||||||
|
"version": "",
|
||||||
|
"description": "duckduck_search",
|
||||||
|
"fullName": "DuckDuckGo Search",
|
||||||
|
"date": "duckduck_search",
|
||||||
|
"docLink": "duckduck_search",
|
||||||
|
"model": settings.EMBEDDINGS_NAME,
|
||||||
|
"location": "custom",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if 'brave_search' in settings.RETRIEVERS_ENABLED:
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"name": "Brave Search",
|
||||||
|
"language": "en",
|
||||||
|
"version": "",
|
||||||
|
"description": "brave_search",
|
||||||
|
"fullName": "Brave Search",
|
||||||
|
"date": "brave_search",
|
||||||
|
"docLink": "brave_search",
|
||||||
|
"model": settings.EMBEDDINGS_NAME,
|
||||||
|
"location": "custom",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify(data)
|
return jsonify(data)
|
||||||
|
|
||||||
@@ -269,10 +297,12 @@ def check_docs():
|
|||||||
else:
|
else:
|
||||||
file_url = urlparse(base_path + vectorstore + "index.faiss")
|
file_url = urlparse(base_path + vectorstore + "index.faiss")
|
||||||
|
|
||||||
if file_url.scheme in ['https'] and file_url.netloc == 'raw.githubusercontent.com' and file_url.path.startswith('/arc53/DocsHUB/main/'):
|
if (
|
||||||
|
file_url.scheme in ['https'] and
|
||||||
|
file_url.netloc == 'raw.githubusercontent.com' and
|
||||||
|
file_url.path.startswith('/arc53/DocsHUB/main/')
|
||||||
|
):
|
||||||
r = requests.get(file_url.geturl())
|
r = requests.get(file_url.geturl())
|
||||||
|
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
return {"status": "null"}
|
return {"status": "null"}
|
||||||
else:
|
else:
|
||||||
@@ -281,7 +311,6 @@ def check_docs():
|
|||||||
with open(vectorstore + "index.faiss", "wb") as f:
|
with open(vectorstore + "index.faiss", "wb") as f:
|
||||||
f.write(r.content)
|
f.write(r.content)
|
||||||
|
|
||||||
# download the store
|
|
||||||
r = requests.get(base_path + vectorstore + "index.pkl")
|
r = requests.get(base_path + vectorstore + "index.pkl")
|
||||||
with open(vectorstore + "index.pkl", "wb") as f:
|
with open(vectorstore + "index.pkl", "wb") as f:
|
||||||
f.write(r.content)
|
f.write(r.content)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__
|
|||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
LLM_NAME: str = "docsgpt"
|
LLM_NAME: str = "docsgpt"
|
||||||
MODEL_NAME: Optional[str] = None # when LLM_NAME is openai, MODEL_NAME can be e.g. gpt-4-turbo-preview or gpt-3.5-turbo
|
MODEL_NAME: Optional[str] = None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
|
||||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||||
@@ -18,6 +18,7 @@ class Settings(BaseSettings):
|
|||||||
TOKENS_MAX_HISTORY: int = 150
|
TOKENS_MAX_HISTORY: int = 150
|
||||||
UPLOAD_FOLDER: str = "inputs"
|
UPLOAD_FOLDER: str = "inputs"
|
||||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant"
|
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant"
|
||||||
|
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
|
||||||
|
|
||||||
API_URL: str = "http://localhost:7091" # backend url for celery worker
|
API_URL: str = "http://localhost:7091" # backend url for celery worker
|
||||||
|
|
||||||
@@ -59,6 +60,8 @@ class Settings(BaseSettings):
|
|||||||
QDRANT_PATH: Optional[str] = None
|
QDRANT_PATH: Optional[str] = None
|
||||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||||
|
|
||||||
|
BRAVE_SEARCH_API_KEY: Optional[str] = None
|
||||||
|
|
||||||
FLASK_DEBUG_MODE: bool = False
|
FLASK_DEBUG_MODE: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ class AnthropicLLM(BaseLLM):
|
|||||||
self.HUMAN_PROMPT = HUMAN_PROMPT
|
self.HUMAN_PROMPT = HUMAN_PROMPT
|
||||||
self.AI_PROMPT = AI_PROMPT
|
self.AI_PROMPT = AI_PROMPT
|
||||||
|
|
||||||
def gen(self, model, messages, engine=None, max_tokens=300, stream=False, **kwargs):
|
def gen(self, model, messages, max_tokens=300, stream=False, **kwargs):
|
||||||
context = messages[0]['content']
|
context = messages[0]['content']
|
||||||
user_question = messages[-1]['content']
|
user_question = messages[-1]['content']
|
||||||
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
||||||
@@ -25,7 +25,7 @@ class AnthropicLLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
return completion.completion
|
return completion.completion
|
||||||
|
|
||||||
def gen_stream(self, model, messages, engine=None, max_tokens=300, **kwargs):
|
def gen_stream(self, model, messages, max_tokens=300, **kwargs):
|
||||||
context = messages[0]['content']
|
context = messages[0]['content']
|
||||||
user_question = messages[-1]['content']
|
user_question = messages[-1]['content']
|
||||||
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ class DocsGPTAPILLM(BaseLLM):
|
|||||||
self.endpoint = "https://llm.docsgpt.co.uk"
|
self.endpoint = "https://llm.docsgpt.co.uk"
|
||||||
|
|
||||||
|
|
||||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
def gen(self, model, messages, stream=False, **kwargs):
|
||||||
context = messages[0]['content']
|
context = messages[0]['content']
|
||||||
user_question = messages[-1]['content']
|
user_question = messages[-1]['content']
|
||||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||||
@@ -24,7 +24,7 @@ class DocsGPTAPILLM(BaseLLM):
|
|||||||
|
|
||||||
return response_clean
|
return response_clean
|
||||||
|
|
||||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
def gen_stream(self, model, messages, stream=True, **kwargs):
|
||||||
context = messages[0]['content']
|
context = messages[0]['content']
|
||||||
user_question = messages[-1]['content']
|
user_question = messages[-1]['content']
|
||||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class HuggingFaceLLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
hf = HuggingFacePipeline(pipeline=pipe)
|
hf = HuggingFacePipeline(pipeline=pipe)
|
||||||
|
|
||||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
def gen(self, model, messages, stream=False, **kwargs):
|
||||||
context = messages[0]['content']
|
context = messages[0]['content']
|
||||||
user_question = messages[-1]['content']
|
user_question = messages[-1]['content']
|
||||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||||
@@ -38,7 +38,7 @@ class HuggingFaceLLM(BaseLLM):
|
|||||||
|
|
||||||
return result.content
|
return result.content
|
||||||
|
|
||||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
def gen_stream(self, model, messages, stream=True, **kwargs):
|
||||||
|
|
||||||
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")
|
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ class LlamaCpp(BaseLLM):
|
|||||||
|
|
||||||
llama = Llama(model_path=llm_name, n_ctx=2048)
|
llama = Llama(model_path=llm_name, n_ctx=2048)
|
||||||
|
|
||||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
def gen(self, model, messages, stream=False, **kwargs):
|
||||||
context = messages[0]['content']
|
context = messages[0]['content']
|
||||||
user_question = messages[-1]['content']
|
user_question = messages[-1]['content']
|
||||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||||
@@ -24,7 +24,7 @@ class LlamaCpp(BaseLLM):
|
|||||||
|
|
||||||
return result['choices'][0]['text'].split('### Answer \n')[-1]
|
return result['choices'][0]['text'].split('### Answer \n')[-1]
|
||||||
|
|
||||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
def gen_stream(self, model, messages, stream=True, **kwargs):
|
||||||
context = messages[0]['content']
|
context = messages[0]['content']
|
||||||
user_question = messages[-1]['content']
|
user_question = messages[-1]['content']
|
||||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class OpenAILLM(BaseLLM):
|
|||||||
|
|
||||||
return openai
|
return openai
|
||||||
|
|
||||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
def gen(self, model, messages, stream=False, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs):
|
||||||
response = self.client.chat.completions.create(model=model,
|
response = self.client.chat.completions.create(model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
@@ -26,7 +26,7 @@ class OpenAILLM(BaseLLM):
|
|||||||
|
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
def gen_stream(self, model, messages, stream=True, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs):
|
||||||
response = self.client.chat.completions.create(model=model,
|
response = self.client.chat.completions.create(model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ class PremAILLM(BaseLLM):
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.project_id = settings.PREMAI_PROJECT_ID
|
self.project_id = settings.PREMAI_PROJECT_ID
|
||||||
|
|
||||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
def gen(self, model, messages, stream=False, **kwargs):
|
||||||
response = self.client.chat.completions.create(model=model,
|
response = self.client.chat.completions.create(model=model,
|
||||||
project_id=self.project_id,
|
project_id=self.project_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@@ -21,7 +21,7 @@ class PremAILLM(BaseLLM):
|
|||||||
|
|
||||||
return response.choices[0].message["content"]
|
return response.choices[0].message["content"]
|
||||||
|
|
||||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
def gen_stream(self, model, messages, stream=True, **kwargs):
|
||||||
response = self.client.chat.completions.create(model=model,
|
response = self.client.chat.completions.create(model=model,
|
||||||
project_id=self.project_id,
|
project_id=self.project_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class SagemakerAPILLM(BaseLLM):
|
|||||||
self.runtime = runtime
|
self.runtime = runtime
|
||||||
|
|
||||||
|
|
||||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
def gen(self, model, messages, stream=False, **kwargs):
|
||||||
context = messages[0]['content']
|
context = messages[0]['content']
|
||||||
user_question = messages[-1]['content']
|
user_question = messages[-1]['content']
|
||||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||||
@@ -103,7 +103,7 @@ class SagemakerAPILLM(BaseLLM):
|
|||||||
print(result[0]['generated_text'], file=sys.stderr)
|
print(result[0]['generated_text'], file=sys.stderr)
|
||||||
return result[0]['generated_text'][len(prompt):]
|
return result[0]['generated_text'][len(prompt):]
|
||||||
|
|
||||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
def gen_stream(self, model, messages, stream=True, **kwargs):
|
||||||
context = messages[0]['content']
|
context = messages[0]['content']
|
||||||
user_question = messages[-1]['content']
|
user_question = messages[-1]['content']
|
||||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||||
|
|||||||
@@ -22,7 +22,10 @@ def group_documents(documents: List[Document], min_tokens: int, max_tokens: int)
|
|||||||
doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
|
doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
|
||||||
|
|
||||||
# Check if current group is empty or if the document can be added based on token count and matching metadata
|
# Check if current group is empty or if the document can be added based on token count and matching metadata
|
||||||
if current_group is None or (len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and doc_len < min_tokens and current_group.extra_info == doc.extra_info):
|
if (current_group is None or
|
||||||
|
(len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and
|
||||||
|
doc_len < min_tokens and
|
||||||
|
current_group.extra_info == doc.extra_info)):
|
||||||
if current_group is None:
|
if current_group is None:
|
||||||
current_group = doc # Use the document directly to retain its metadata
|
current_group = doc # Use the document directly to retain its metadata
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ boto3==1.34.6
|
|||||||
celery==5.3.6
|
celery==5.3.6
|
||||||
dataclasses_json==0.6.3
|
dataclasses_json==0.6.3
|
||||||
docx2txt==0.8
|
docx2txt==0.8
|
||||||
|
duckduckgo-search==5.3.0
|
||||||
EbookLib==0.18
|
EbookLib==0.18
|
||||||
elasticsearch==8.12.0
|
elasticsearch==8.12.0
|
||||||
escodegen==1.0.11
|
escodegen==1.0.11
|
||||||
|
|||||||
0
application/retriever/__init__.py
Normal file
0
application/retriever/__init__.py
Normal file
14
application/retriever/base.py
Normal file
14
application/retriever/base.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRetriever(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def gen(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
75
application/retriever/brave_search.py
Normal file
75
application/retriever/brave_search.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import json
|
||||||
|
from application.retriever.base import BaseRetriever
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.llm.llm_creator import LLMCreator
|
||||||
|
from application.utils import count_tokens
|
||||||
|
from langchain_community.tools import BraveSearch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BraveRetSearch(BaseRetriever):
|
||||||
|
|
||||||
|
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
|
||||||
|
self.question = question
|
||||||
|
self.source = source
|
||||||
|
self.chat_history = chat_history
|
||||||
|
self.prompt = prompt
|
||||||
|
self.chunks = chunks
|
||||||
|
self.gpt_model = gpt_model
|
||||||
|
|
||||||
|
def _get_data(self):
|
||||||
|
if self.chunks == 0:
|
||||||
|
docs = []
|
||||||
|
else:
|
||||||
|
search = BraveSearch.from_api_key(api_key=settings.BRAVE_SEARCH_API_KEY,
|
||||||
|
search_kwargs={"count": int(self.chunks)})
|
||||||
|
results = search.run(self.question)
|
||||||
|
results = json.loads(results)
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
for i in results:
|
||||||
|
try:
|
||||||
|
title = i['title']
|
||||||
|
link = i['link']
|
||||||
|
snippet = i['snippet']
|
||||||
|
docs.append({"text": snippet, "title": title, "link": link})
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
if settings.LLM_NAME == "llama.cpp":
|
||||||
|
docs = [docs[0]]
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def gen(self):
|
||||||
|
docs = self._get_data()
|
||||||
|
|
||||||
|
# join all page_content together with a newline
|
||||||
|
docs_together = "\n".join([doc["text"] for doc in docs])
|
||||||
|
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
||||||
|
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||||
|
for doc in docs:
|
||||||
|
yield {"source": doc}
|
||||||
|
|
||||||
|
if len(self.chat_history) > 1:
|
||||||
|
tokens_current_history = 0
|
||||||
|
# count tokens in history
|
||||||
|
self.chat_history.reverse()
|
||||||
|
for i in self.chat_history:
|
||||||
|
if "prompt" in i and "response" in i:
|
||||||
|
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
|
||||||
|
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
|
||||||
|
tokens_current_history += tokens_batch
|
||||||
|
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||||
|
messages_combine.append({"role": "system", "content": i["response"]})
|
||||||
|
messages_combine.append({"role": "user", "content": self.question})
|
||||||
|
|
||||||
|
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
||||||
|
|
||||||
|
completion = llm.gen_stream(model=self.gpt_model,
|
||||||
|
messages=messages_combine)
|
||||||
|
for line in completion:
|
||||||
|
yield {"answer": str(line)}
|
||||||
|
|
||||||
|
def search(self):
|
||||||
|
return self._get_data()
|
||||||
|
|
||||||
91
application/retriever/classic_rag.py
Normal file
91
application/retriever/classic_rag.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import os
|
||||||
|
from application.retriever.base import BaseRetriever
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
|
from application.llm.llm_creator import LLMCreator
|
||||||
|
|
||||||
|
from application.utils import count_tokens
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ClassicRAG(BaseRetriever):
|
||||||
|
|
||||||
|
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
|
||||||
|
self.question = question
|
||||||
|
self.vectorstore = self._get_vectorstore(source=source)
|
||||||
|
self.chat_history = chat_history
|
||||||
|
self.prompt = prompt
|
||||||
|
self.chunks = chunks
|
||||||
|
self.gpt_model = gpt_model
|
||||||
|
|
||||||
|
def _get_vectorstore(self, source):
|
||||||
|
if "active_docs" in source:
|
||||||
|
if source["active_docs"].split("/")[0] == "default":
|
||||||
|
vectorstore = ""
|
||||||
|
elif source["active_docs"].split("/")[0] == "local":
|
||||||
|
vectorstore = "indexes/" + source["active_docs"]
|
||||||
|
else:
|
||||||
|
vectorstore = "vectors/" + source["active_docs"]
|
||||||
|
if source["active_docs"] == "default":
|
||||||
|
vectorstore = ""
|
||||||
|
else:
|
||||||
|
vectorstore = ""
|
||||||
|
vectorstore = os.path.join("application", vectorstore)
|
||||||
|
return vectorstore
|
||||||
|
|
||||||
|
|
||||||
|
def _get_data(self):
|
||||||
|
if self.chunks == 0:
|
||||||
|
docs = []
|
||||||
|
else:
|
||||||
|
docsearch = VectorCreator.create_vectorstore(
|
||||||
|
settings.VECTOR_STORE,
|
||||||
|
self.vectorstore,
|
||||||
|
settings.EMBEDDINGS_KEY
|
||||||
|
)
|
||||||
|
docs_temp = docsearch.search(self.question, k=self.chunks)
|
||||||
|
docs = [
|
||||||
|
{
|
||||||
|
"title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content,
|
||||||
|
"text": i.page_content
|
||||||
|
}
|
||||||
|
for i in docs_temp
|
||||||
|
]
|
||||||
|
if settings.LLM_NAME == "llama.cpp":
|
||||||
|
docs = [docs[0]]
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def gen(self):
|
||||||
|
docs = self._get_data()
|
||||||
|
|
||||||
|
# join all page_content together with a newline
|
||||||
|
docs_together = "\n".join([doc["text"] for doc in docs])
|
||||||
|
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
||||||
|
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||||
|
for doc in docs:
|
||||||
|
yield {"source": doc}
|
||||||
|
|
||||||
|
if len(self.chat_history) > 1:
|
||||||
|
tokens_current_history = 0
|
||||||
|
# count tokens in history
|
||||||
|
self.chat_history.reverse()
|
||||||
|
for i in self.chat_history:
|
||||||
|
if "prompt" in i and "response" in i:
|
||||||
|
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
|
||||||
|
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
|
||||||
|
tokens_current_history += tokens_batch
|
||||||
|
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||||
|
messages_combine.append({"role": "system", "content": i["response"]})
|
||||||
|
messages_combine.append({"role": "user", "content": self.question})
|
||||||
|
|
||||||
|
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
||||||
|
|
||||||
|
completion = llm.gen_stream(model=self.gpt_model,
|
||||||
|
messages=messages_combine)
|
||||||
|
for line in completion:
|
||||||
|
yield {"answer": str(line)}
|
||||||
|
|
||||||
|
def search(self):
|
||||||
|
return self._get_data()
|
||||||
|
|
||||||
94
application/retriever/duckduck_search.py
Normal file
94
application/retriever/duckduck_search.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
from application.retriever.base import BaseRetriever
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.llm.llm_creator import LLMCreator
|
||||||
|
from application.utils import count_tokens
|
||||||
|
from langchain_community.tools import DuckDuckGoSearchResults
|
||||||
|
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DuckDuckSearch(BaseRetriever):
|
||||||
|
|
||||||
|
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
|
||||||
|
self.question = question
|
||||||
|
self.source = source
|
||||||
|
self.chat_history = chat_history
|
||||||
|
self.prompt = prompt
|
||||||
|
self.chunks = chunks
|
||||||
|
self.gpt_model = gpt_model
|
||||||
|
|
||||||
|
def _parse_lang_string(self, input_string):
|
||||||
|
result = []
|
||||||
|
current_item = ""
|
||||||
|
inside_brackets = False
|
||||||
|
for char in input_string:
|
||||||
|
if char == "[":
|
||||||
|
inside_brackets = True
|
||||||
|
elif char == "]":
|
||||||
|
inside_brackets = False
|
||||||
|
result.append(current_item)
|
||||||
|
current_item = ""
|
||||||
|
elif inside_brackets:
|
||||||
|
current_item += char
|
||||||
|
|
||||||
|
if inside_brackets:
|
||||||
|
result.append(current_item)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_data(self):
|
||||||
|
if self.chunks == 0:
|
||||||
|
docs = []
|
||||||
|
else:
|
||||||
|
wrapper = DuckDuckGoSearchAPIWrapper(max_results=self.chunks)
|
||||||
|
search = DuckDuckGoSearchResults(api_wrapper=wrapper)
|
||||||
|
results = search.run(self.question)
|
||||||
|
results = self._parse_lang_string(results)
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
for i in results:
|
||||||
|
try:
|
||||||
|
text = i.split("title:")[0]
|
||||||
|
title = i.split("title:")[1].split("link:")[0]
|
||||||
|
link = i.split("link:")[1]
|
||||||
|
docs.append({"text": text, "title": title, "link": link})
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
if settings.LLM_NAME == "llama.cpp":
|
||||||
|
docs = [docs[0]]
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def gen(self):
|
||||||
|
docs = self._get_data()
|
||||||
|
|
||||||
|
# join all page_content together with a newline
|
||||||
|
docs_together = "\n".join([doc["text"] for doc in docs])
|
||||||
|
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
||||||
|
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||||
|
for doc in docs:
|
||||||
|
yield {"source": doc}
|
||||||
|
|
||||||
|
if len(self.chat_history) > 1:
|
||||||
|
tokens_current_history = 0
|
||||||
|
# count tokens in history
|
||||||
|
self.chat_history.reverse()
|
||||||
|
for i in self.chat_history:
|
||||||
|
if "prompt" in i and "response" in i:
|
||||||
|
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
|
||||||
|
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
|
||||||
|
tokens_current_history += tokens_batch
|
||||||
|
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||||
|
messages_combine.append({"role": "system", "content": i["response"]})
|
||||||
|
messages_combine.append({"role": "user", "content": self.question})
|
||||||
|
|
||||||
|
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
||||||
|
|
||||||
|
completion = llm.gen_stream(model=self.gpt_model,
|
||||||
|
messages=messages_combine)
|
||||||
|
for line in completion:
|
||||||
|
yield {"answer": str(line)}
|
||||||
|
|
||||||
|
def search(self):
|
||||||
|
return self._get_data()
|
||||||
|
|
||||||
19
application/retriever/retriever_creator.py
Normal file
19
application/retriever/retriever_creator.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from application.retriever.classic_rag import ClassicRAG
|
||||||
|
from application.retriever.duckduck_search import DuckDuckSearch
|
||||||
|
from application.retriever.brave_search import BraveRetSearch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class RetrieverCreator:
|
||||||
|
retievers = {
|
||||||
|
'classic': ClassicRAG,
|
||||||
|
'duckduck_search': DuckDuckSearch,
|
||||||
|
'brave_search': BraveRetSearch
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_retriever(cls, type, *args, **kwargs):
|
||||||
|
retiever_class = cls.retievers.get(type.lower())
|
||||||
|
if not retiever_class:
|
||||||
|
raise ValueError(f"No retievers class found for type {type}")
|
||||||
|
return retiever_class(*args, **kwargs)
|
||||||
6
application/utils.py
Normal file
6
application/utils.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
def count_tokens(string):
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
|
||||||
|
return len(tokenizer(string)['input_ids'])
|
||||||
@@ -3,6 +3,33 @@ import { Doc } from '../preferences/preferenceApi';
|
|||||||
|
|
||||||
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
|
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
|
||||||
|
|
||||||
|
function getDocPath(selectedDocs: Doc | null): string {
|
||||||
|
let docPath = 'default';
|
||||||
|
|
||||||
|
if (selectedDocs) {
|
||||||
|
let namePath = selectedDocs.name;
|
||||||
|
if (selectedDocs.language === namePath) {
|
||||||
|
namePath = '.project';
|
||||||
|
}
|
||||||
|
if (selectedDocs.location === 'local') {
|
||||||
|
docPath = 'local' + '/' + selectedDocs.name + '/';
|
||||||
|
} else if (selectedDocs.location === 'remote') {
|
||||||
|
docPath =
|
||||||
|
selectedDocs.language +
|
||||||
|
'/' +
|
||||||
|
namePath +
|
||||||
|
'/' +
|
||||||
|
selectedDocs.version +
|
||||||
|
'/' +
|
||||||
|
selectedDocs.model +
|
||||||
|
'/';
|
||||||
|
} else if (selectedDocs.location === 'custom') {
|
||||||
|
docPath = selectedDocs.docLink;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return docPath;
|
||||||
|
}
|
||||||
export function fetchAnswerApi(
|
export function fetchAnswerApi(
|
||||||
question: string,
|
question: string,
|
||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
@@ -28,27 +55,7 @@ export function fetchAnswerApi(
|
|||||||
title: any;
|
title: any;
|
||||||
}
|
}
|
||||||
> {
|
> {
|
||||||
let docPath = 'default';
|
const docPath = getDocPath(selectedDocs);
|
||||||
|
|
||||||
if (selectedDocs) {
|
|
||||||
let namePath = selectedDocs.name;
|
|
||||||
if (selectedDocs.language === namePath) {
|
|
||||||
namePath = '.project';
|
|
||||||
}
|
|
||||||
if (selectedDocs.location === 'local') {
|
|
||||||
docPath = 'local' + '/' + selectedDocs.name + '/';
|
|
||||||
} else if (selectedDocs.location === 'remote') {
|
|
||||||
docPath =
|
|
||||||
selectedDocs.language +
|
|
||||||
'/' +
|
|
||||||
namePath +
|
|
||||||
'/' +
|
|
||||||
selectedDocs.version +
|
|
||||||
'/' +
|
|
||||||
selectedDocs.model +
|
|
||||||
'/';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//in history array remove all keys except prompt and response
|
//in history array remove all keys except prompt and response
|
||||||
history = history.map((item) => {
|
history = history.map((item) => {
|
||||||
return { prompt: item.prompt, response: item.response };
|
return { prompt: item.prompt, response: item.response };
|
||||||
@@ -98,27 +105,7 @@ export function fetchAnswerSteaming(
|
|||||||
chunks: string,
|
chunks: string,
|
||||||
onEvent: (event: MessageEvent) => void,
|
onEvent: (event: MessageEvent) => void,
|
||||||
): Promise<Answer> {
|
): Promise<Answer> {
|
||||||
let docPath = 'default';
|
const docPath = getDocPath(selectedDocs);
|
||||||
|
|
||||||
if (selectedDocs) {
|
|
||||||
let namePath = selectedDocs.name;
|
|
||||||
if (selectedDocs.language === namePath) {
|
|
||||||
namePath = '.project';
|
|
||||||
}
|
|
||||||
if (selectedDocs.location === 'local') {
|
|
||||||
docPath = 'local' + '/' + selectedDocs.name + '/';
|
|
||||||
} else if (selectedDocs.location === 'remote') {
|
|
||||||
docPath =
|
|
||||||
selectedDocs.language +
|
|
||||||
'/' +
|
|
||||||
namePath +
|
|
||||||
'/' +
|
|
||||||
selectedDocs.version +
|
|
||||||
'/' +
|
|
||||||
selectedDocs.model +
|
|
||||||
'/';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
history = history.map((item) => {
|
history = history.map((item) => {
|
||||||
return { prompt: item.prompt, response: item.response };
|
return { prompt: item.prompt, response: item.response };
|
||||||
@@ -195,31 +182,7 @@ export function searchEndpoint(
|
|||||||
history: Array<any> = [],
|
history: Array<any> = [],
|
||||||
chunks: string,
|
chunks: string,
|
||||||
) {
|
) {
|
||||||
/*
|
const docPath = getDocPath(selectedDocs);
|
||||||
"active_docs": "default",
|
|
||||||
"question": "Summarise",
|
|
||||||
"conversation_id": null,
|
|
||||||
"history": "[]" */
|
|
||||||
let docPath = 'default';
|
|
||||||
if (selectedDocs) {
|
|
||||||
let namePath = selectedDocs.name;
|
|
||||||
if (selectedDocs.language === namePath) {
|
|
||||||
namePath = '.project';
|
|
||||||
}
|
|
||||||
if (selectedDocs.location === 'local') {
|
|
||||||
docPath = 'local' + '/' + selectedDocs.name + '/';
|
|
||||||
} else if (selectedDocs.location === 'remote') {
|
|
||||||
docPath =
|
|
||||||
selectedDocs.language +
|
|
||||||
'/' +
|
|
||||||
namePath +
|
|
||||||
'/' +
|
|
||||||
selectedDocs.version +
|
|
||||||
'/' +
|
|
||||||
selectedDocs.model +
|
|
||||||
'/';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const body = {
|
const body = {
|
||||||
question: question,
|
question: question,
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class TestSagemakerAPILLM(unittest.TestCase):
|
|||||||
def test_gen(self):
|
def test_gen(self):
|
||||||
with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
|
with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
|
||||||
return_value=self.response) as mock_invoke_endpoint:
|
return_value=self.response) as mock_invoke_endpoint:
|
||||||
output = self.sagemaker.gen(None, None, self.messages)
|
output = self.sagemaker.gen(None, self.messages)
|
||||||
mock_invoke_endpoint.assert_called_once_with(
|
mock_invoke_endpoint.assert_called_once_with(
|
||||||
EndpointName=self.sagemaker.endpoint,
|
EndpointName=self.sagemaker.endpoint,
|
||||||
ContentType='application/json',
|
ContentType='application/json',
|
||||||
@@ -66,7 +66,7 @@ class TestSagemakerAPILLM(unittest.TestCase):
|
|||||||
def test_gen_stream(self):
|
def test_gen_stream(self):
|
||||||
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
|
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
|
||||||
return_value=self.response) as mock_invoke_endpoint:
|
return_value=self.response) as mock_invoke_endpoint:
|
||||||
output = list(self.sagemaker.gen_stream(None, None, self.messages))
|
output = list(self.sagemaker.gen_stream(None, self.messages))
|
||||||
mock_invoke_endpoint.assert_called_once_with(
|
mock_invoke_endpoint.assert_called_once_with(
|
||||||
EndpointName=self.sagemaker.endpoint,
|
EndpointName=self.sagemaker.endpoint,
|
||||||
ContentType='application/json',
|
ContentType='application/json',
|
||||||
|
|||||||
Reference in New Issue
Block a user