diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index b122eac1..1b3c9b97 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -8,13 +8,14 @@ import traceback from pymongo import MongoClient from bson.objectid import ObjectId -from transformers import GPT2TokenizerFast +from application.utils import count_tokens from application.core.settings import settings from application.vectorstore.vector_creator import VectorCreator from application.llm.llm_creator import LLMCreator +from application.retriever.retriever_creator import RetrieverCreator from application.error import bad_request @@ -62,9 +63,6 @@ async def async_generate(chain, question, chat_history): return result -def count_tokens(string): - tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') - return len(tokenizer(string)['input_ids']) def run_async_chain(chain, question, chat_history): @@ -104,61 +102,11 @@ def get_vectorstore(data): def is_azure_configured(): return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME - -def complete_stream(question, docsearch, chat_history, prompt_id, conversation_id, chunks=2): - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) - if prompt_id == 'default': - prompt = chat_combine_template - elif prompt_id == 'creative': - prompt = chat_combine_creative - elif prompt_id == 'strict': - prompt = chat_combine_strict - else: - prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] - - if chunks == 0: - docs = [] - else: - docs = docsearch.search(question, k=chunks) - if settings.LLM_NAME == "llama.cpp": - docs = [docs[0]] - # join all page_content together with a newline - docs_together = "\n".join([doc.page_content for doc in docs]) - p_chat_combine = prompt.replace("{summaries}", docs_together) - messages_combine = [{"role": "system", "content": p_chat_combine}] - source_log_docs = [] - for doc in docs: - if doc.metadata: - source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) - else: - source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) - - if len(chat_history) > 1: - tokens_current_history = 0 - # count tokens in history - chat_history.reverse() - for i in chat_history: - if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: - tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) - messages_combine.append({"role": "user", "content": question}) - - response_full = "" - completion = llm.gen_stream(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, - messages=messages_combine) - for line in completion: - data = json.dumps({"answer": str(line)}) - response_full += str(line) - yield f"data: {data}\n\n" - - # save conversation to database +def save_conversation(conversation_id, question, response, source_log_docs, llm): if conversation_id is not None: conversations_collection.update_one( {"_id": ObjectId(conversation_id)}, - {"$push": {"queries": {"prompt": question, "response": response_full, "sources": source_log_docs}}}, + {"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}}, ) else: @@ -168,20 +116,50 @@ def complete_stream(question, docsearch, chat_history, prompt_id, conversation_i "words, respond ONLY with the summary, use the same " "language as the system \n\nUser: " + question + "\n\n" + "AI: " + - response_full}, + response}, {"role": "user", "content": "Summarise following conversation in no more than 3 words, " "respond ONLY with the summary, use the same language as the " "system"}] - completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, + completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30) conversation_id = conversations_collection.insert_one( {"user": "local", "date": datetime.datetime.utcnow(), "name": completion, - "queries": [{"prompt": question, "response": response_full, "sources": source_log_docs}]} + "queries": [{"prompt": question, "response": response, "sources": source_log_docs}]} ).inserted_id +def get_prompt(prompt_id): + if prompt_id == 'default': + prompt = chat_combine_template + elif prompt_id == 'creative': + prompt = chat_combine_creative + elif prompt_id == 'strict': + prompt = chat_combine_strict + else: + prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] + return prompt + + +def complete_stream(question, retriever, conversation_id): + + + response_full = "" + source_log_docs = [] + answer = retriever.gen() + for line in answer: + if "answer" in line: + response_full += str(line["answer"]) + data = json.dumps(line) + yield f"data: {data}\n\n" + elif "source" in line: + source_log_docs.append(line["source"]) + + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm) + # send data.type = "end" to indicate that the stream has ended as json data = json.dumps({"type": "id", "id": str(conversation_id)}) yield f"data: {data}\n\n" @@ -213,25 +191,26 @@ def stream(): chunks = int(data["chunks"]) else: chunks = 2 + + prompt = get_prompt(prompt_id) # check if active_docs is set if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) - vectorstore = get_vectorstore({"active_docs": data_key["source"]}) + source = {"active_docs": data_key["source"]} elif "active_docs" in data: - vectorstore = get_vectorstore({"active_docs": data["active_docs"]}) + source = {"active_docs": data["active_docs"]} else: - vectorstore = "" - docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY) + source = {} + + retriever = RetrieverCreator.create_retriever("classic", question=question, + source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model + ) return Response( - complete_stream(question, docsearch, - chat_history=history, - prompt_id=prompt_id, - conversation_id=conversation_id, - chunks=chunks), mimetype="text/event-stream" - ) + complete_stream(question=question, retriever=retriever, + conversation_id=conversation_id), mimetype="text/event-stream") @answer.route("/api/answer", methods=["POST"]) @@ -255,110 +234,35 @@ def api_answer(): chunks = int(data["chunks"]) else: chunks = 2 - - if prompt_id == 'default': - prompt = chat_combine_template - elif prompt_id == 'creative': - prompt = chat_combine_creative - elif prompt_id == 'strict': - prompt = chat_combine_strict - else: - prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] + + prompt = get_prompt(prompt_id) # use try and except to check for exception try: # check if the vectorstore is set if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) - vectorstore = get_vectorstore({"active_docs": data_key["source"]}) + source = {"active_docs": data_key["source"]} else: - vectorstore = get_vectorstore(data) - # loading the index and the store and the prompt template - # Note if you have used other embeddings than OpenAI, you need to change the embeddings - docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY) + source = {data} - - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) - - - - if chunks == 0: - docs = [] - else: - docs = docsearch.search(question, k=chunks) - # join all page_content together with a newline - docs_together = "\n".join([doc.page_content for doc in docs]) - p_chat_combine = prompt.replace("{summaries}", docs_together) - messages_combine = [{"role": "system", "content": p_chat_combine}] + retriever = RetrieverCreator.create_retriever("classic", question=question, + source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model + ) source_log_docs = [] - for doc in docs: - if doc.metadata: - source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) - else: - source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) - # join all page_content together with a newline + response_full = "" + for line in retriever.gen(): + if "source" in line: + source_log_docs.append(line["source"]) + elif "answer" in line: + response_full += line["answer"] + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + + result = {"answer": response_full, "sources": source_log_docs} + result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm) - if len(history) > 1: - tokens_current_history = 0 - # count tokens in history - history.reverse() - for i in history: - if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: - tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) - messages_combine.append({"role": "user", "content": question}) - - - completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, - messages=messages_combine) - - - result = {"answer": completion, "sources": source_log_docs} - logger.debug(result) - - # generate conversationId - if conversation_id is not None: - conversations_collection.update_one( - {"_id": ObjectId(conversation_id)}, - {"$push": {"queries": {"prompt": question, - "response": result["answer"], "sources": result['sources']}}}, - ) - - else: - # create new conversation - # generate summary - messages_summary = [ - {"role": "assistant", "content": "Summarise following conversation in no more than 3 words, " - "respond ONLY with the summary, use the same language as the system \n\n" - "User: " + question + "\n\n" + "AI: " + result["answer"]}, - {"role": "user", "content": "Summarise following conversation in no more than 3 words, " - "respond ONLY with the summary, use the same language as the system"} - ] - - completion = llm.gen( - model=gpt_model, - engine=settings.AZURE_DEPLOYMENT_NAME, - messages=messages_summary, - max_tokens=30 - ) - conversation_id = conversations_collection.insert_one( - {"user": "local", - "date": datetime.datetime.utcnow(), - "name": completion, - "queries": [{"prompt": question, "response": result["answer"], "sources": source_log_docs}]} - ).inserted_id - - result["conversation_id"] = str(conversation_id) - - # mock result - # result = { - # "answer": "The answer is 42", - # "sources": ["https://en.wikipedia.org/wiki/42_(number)", "https://en.wikipedia.org/wiki/42_(number)"] - # } return result except Exception as e: # print whole traceback @@ -375,20 +279,20 @@ def api_search(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) - vectorstore = data_key["source"] + source = {"active_docs": data_key["source"]} elif "active_docs" in data: - vectorstore = get_vectorstore({"active_docs": data["active_docs"]}) + source = {"active_docs": data["active_docs"]} else: - vectorstore = "" + source = {} if 'chunks' in data: chunks = int(data["chunks"]) else: chunks = 2 - docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY) - if chunks == 0: - docs = [] - else: - docs = docsearch.search(question, k=chunks) + + retriever = RetrieverCreator.create_retriever("classic", question=question, + source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model + ) + docs = retriever.search() source_log_docs = [] for doc in docs: @@ -396,6 +300,5 @@ def api_search(): source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) else: source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) - #yield f"data:{data}\n\n" return source_log_docs diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index a64d71e9..6b0d6467 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -10,7 +10,7 @@ class AnthropicLLM(BaseLLM): self.HUMAN_PROMPT = HUMAN_PROMPT self.AI_PROMPT = AI_PROMPT - def gen(self, model, messages, engine=None, max_tokens=300, stream=False, **kwargs): + def gen(self, model, messages, max_tokens=300, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Context \n {context} \n ### Question \n {user_question}" @@ -25,7 +25,7 @@ class AnthropicLLM(BaseLLM): ) return completion.completion - def gen_stream(self, model, messages, engine=None, max_tokens=300, **kwargs): + def gen_stream(self, model, messages, max_tokens=300, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Context \n {context} \n ### Question \n {user_question}" diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index e0c5dbad..d540a911 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -8,7 +8,7 @@ class DocsGPTAPILLM(BaseLLM): self.endpoint = "https://llm.docsgpt.co.uk" - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -24,7 +24,7 @@ class DocsGPTAPILLM(BaseLLM): return response_clean - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/llm/huggingface.py b/application/llm/huggingface.py index ef3b1fbc..554bee2f 100644 --- a/application/llm/huggingface.py +++ b/application/llm/huggingface.py @@ -29,7 +29,7 @@ class HuggingFaceLLM(BaseLLM): ) hf = HuggingFacePipeline(pipeline=pipe) - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -38,7 +38,7 @@ class HuggingFaceLLM(BaseLLM): return result.content - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.") diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index f18d4379..be34d4ff 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -12,7 +12,7 @@ class LlamaCpp(BaseLLM): llama = Llama(model_path=llm_name, n_ctx=2048) - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -24,7 +24,7 @@ class LlamaCpp(BaseLLM): return result['choices'][0]['text'].split('### Answer \n')[-1] - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/llm/openai.py b/application/llm/openai.py index a132399a..4b0ed25a 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -18,7 +18,7 @@ class OpenAILLM(BaseLLM): return openai - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs): response = self.client.chat.completions.create(model=model, messages=messages, stream=stream, @@ -26,7 +26,7 @@ class OpenAILLM(BaseLLM): return response.choices[0].message.content - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs): response = self.client.chat.completions.create(model=model, messages=messages, stream=stream, diff --git a/application/llm/premai.py b/application/llm/premai.py index 4bc8a898..5faa5fee 100644 --- a/application/llm/premai.py +++ b/application/llm/premai.py @@ -12,7 +12,7 @@ class PremAILLM(BaseLLM): self.api_key = api_key self.project_id = settings.PREMAI_PROJECT_ID - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): response = self.client.chat.completions.create(model=model, project_id=self.project_id, messages=messages, @@ -21,7 +21,7 @@ class PremAILLM(BaseLLM): return response.choices[0].message["content"] - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): response = self.client.chat.completions.create(model=model, project_id=self.project_id, messages=messages, diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 84ae09ad..b81f6385 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -74,7 +74,7 @@ class SagemakerAPILLM(BaseLLM): self.runtime = runtime - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -103,7 +103,7 @@ class SagemakerAPILLM(BaseLLM): print(result[0]['generated_text'], file=sys.stderr) return result[0]['generated_text'][len(prompt):] - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/retriever/__init__.py b/application/retriever/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/application/retriever/base.py b/application/retriever/base.py new file mode 100644 index 00000000..3bfaa5e1 --- /dev/null +++ b/application/retriever/base.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + + +class BaseRetriever(ABC): + def __init__(self): + pass + + @abstractmethod + def gen(self, *args, **kwargs): + pass diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py new file mode 100644 index 00000000..dc757946 --- /dev/null +++ b/application/retriever/classic_rag.py @@ -0,0 +1,83 @@ +import os +import json +from application.retriever.base import BaseRetriever +from application.core.settings import settings +from application.vectorstore.vector_creator import VectorCreator +from application.llm.llm_creator import LLMCreator + +from application.utils import count_tokens + + + +class ClassicRAG(BaseRetriever): + + def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + self.question = question + self.vectorstore = self._get_vectorstore(source=source) + self.chat_history = chat_history + self.prompt = prompt + self.chunks = chunks + self.gpt_model = gpt_model + + def _get_vectorstore(self, source): + if "active_docs" in source: + if source["active_docs"].split("/")[0] == "default": + vectorstore = "" + elif source["active_docs"].split("/")[0] == "local": + vectorstore = "indexes/" + source["active_docs"] + else: + vectorstore = "vectors/" + source["active_docs"] + if source["active_docs"] == "default": + vectorstore = "" + else: + vectorstore = "" + vectorstore = os.path.join("application", vectorstore) + return vectorstore + + + def _get_data(self): + if self.chunks == 0: + docs = [] + else: + docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY) + docs = docsearch.search(self.question, k=self.chunks) + if settings.LLM_NAME == "llama.cpp": + docs = [docs[0]] + return docs + + def gen(self): + docs = self._get_data() + + # join all page_content together with a newline + docs_together = "\n".join([doc.page_content for doc in docs]) + p_chat_combine = self.prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + for doc in docs: + if doc.metadata: + yield {"source": {"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}} + else: + yield {"source": {"title": doc.page_content, "text": doc.page_content}} + + if len(self.chat_history) > 1: + tokens_current_history = 0 + # count tokens in history + self.chat_history.reverse() + for i in self.chat_history: + if "prompt" in i and "response" in i: + tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) + if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_current_history += tokens_batch + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append({"role": "user", "content": self.question}) + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + + completion = llm.gen_stream(model=self.gpt_model, + messages=messages_combine) + for line in completion: + yield {"answer": str(line)} + + def search(self): + return self._get_data() + diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py new file mode 100644 index 00000000..9255f4e6 --- /dev/null +++ b/application/retriever/retriever_creator.py @@ -0,0 +1,15 @@ +from application.retriever.classic_rag import ClassicRAG + + + +class RetrieverCreator: + retievers = { + 'classic': ClassicRAG, + } + + @classmethod + def create_retriever(cls, type, *args, **kwargs): + retiever_class = cls.retievers.get(type.lower()) + if not retiever_class: + raise ValueError(f"No retievers class found for type {type}") + return retiever_class(*args, **kwargs) \ No newline at end of file diff --git a/application/utils.py b/application/utils.py new file mode 100644 index 00000000..ac98efc6 --- /dev/null +++ b/application/utils.py @@ -0,0 +1,6 @@ +from transformers import GPT2TokenizerFast + + +def count_tokens(string): + tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') + return len(tokenizer(string)['input_ids']) \ No newline at end of file