diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index bccffb66..6a4a0929 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -18,7 +18,7 @@ from application.error import bad_request from application.extensions import api from application.llm.llm_creator import LLMCreator from application.retriever.retriever_creator import RetrieverCreator -from application.utils import check_required_fields +from application.utils import check_required_fields, limit_chat_history logger = logging.getLogger(__name__) @@ -324,8 +324,7 @@ class Stream(Resource): try: question = data["question"] - history = str(data.get("history", [])) - history = str(json.loads(history)) + history = str(limit_chat_history(json.loads(data.get("history", [])))) conversation_id = data.get("conversation_id") prompt_id = data.get("prompt_id", "default") diff --git a/application/utils.py b/application/utils.py index 1fc9e329..8b5ddf2c 100644 --- a/application/utils.py +++ b/application/utils.py @@ -46,3 +46,17 @@ def check_required_fields(data, required_fields): def get_hash(data): return hashlib.md5(data.encode()).hexdigest() +def limit_chat_history(history,max_token_limit = 500): + + cumulative_token_count = 0 + trimmed_history = [] + + for i in reversed(history): + + if("prompt" in i and "response" in i): + cumulative_token_count += num_tokens_from_string(i["prompt"] + i["response"]) + if(cumulative_token_count > max_token_limit): + break + trimmed_history.insert(0,i) + + return trimmed_history \ No newline at end of file