diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index aba2b88e..c55ffe72 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -324,7 +324,7 @@ class Stream(Resource): try: question = data["question"] - history = str(limit_chat_history(json.loads(data.get("history", [])))) + history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model) conversation_id = data.get("conversation_id") prompt_id = data.get("prompt_id", "default") @@ -455,7 +455,7 @@ class Answer(Resource): try: question = data["question"] - history = str(limit_chat_history(json.loads(data.get("history", [])))) + history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model) conversation_id = data.get("conversation_id") prompt_id = data.get("prompt_id", "default") chunks = int(data.get("chunks", 2)) diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index 4601d352..3d9ae89e 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -73,6 +73,7 @@ class BraveRetSearch(BaseRetriever): if len(self.chat_history) > 1: for i in self.chat_history: + if "prompt" in i and "response" in i: messages_combine.append( {"role": "user", "content": i["prompt"]} ) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 75b2b576..8de625dd 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -73,6 +73,7 @@ class ClassicRAG(BaseRetriever): if len(self.chat_history) > 1: for i in self.chat_history: + if "prompt" in i and "response" in i: messages_combine.append( {"role": "user", "content": i["prompt"]} ) @@ -80,7 +81,7 @@ class ClassicRAG(BaseRetriever): {"role": "system", "content": i["response"]} ) messages_combine.append({"role": "user", "content": self.question}) - + llm = LLMCreator.create_llm( settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key ) diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index 80717e7d..fa19ead0 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -90,6 +90,7 @@ class DuckDuckSearch(BaseRetriever): if len(self.chat_history) > 1: for i in self.chat_history: + if "prompt" in i and "response" in i: messages_combine.append( {"role": "user", "content": i["prompt"]} ) diff --git a/application/utils.py b/application/utils.py index a96e0c9a..7099a20a 100644 --- a/application/utils.py +++ b/application/utils.py @@ -54,13 +54,16 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): from application.core.settings import settings max_token_limit = ( - max_token_limit - if max_token_limit - and max_token_limit < settings.MODEL_TOKEN_LIMITS.get( - gpt_model, settings.DEFAULT_MAX_HISTORY + max_token_limit + if max_token_limit and + max_token_limit < settings.MODEL_TOKEN_LIMITS.get( + gpt_model, settings.DEFAULT_MAX_HISTORY + ) + else settings.MODEL_TOKEN_LIMITS.get( + gpt_model, settings.DEFAULT_MAX_HISTORY + ) ) - else settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) - ) + if not history: return [] @@ -78,5 +81,5 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): trimmed_history.insert(0, message) else: break - + return trimmed_history