diff --git a/application/utils.py b/application/utils.py index 8b5ddf2c..a96e0c9a 100644 --- a/application/utils.py +++ b/application/utils.py @@ -46,17 +46,37 @@ def check_required_fields(data, required_fields): def get_hash(data): return hashlib.md5(data.encode()).hexdigest() -def limit_chat_history(history,max_token_limit = 500): - - cumulative_token_count = 0 +def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): + """ + Limits chat history based on token count. + Returns a list of messages that fit within the token limit. + """ + from application.core.settings import settings + + max_token_limit = ( + max_token_limit + if max_token_limit + and max_token_limit < settings.MODEL_TOKEN_LIMITS.get( + gpt_model, settings.DEFAULT_MAX_HISTORY + ) + else settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) + ) + + if not history: + return [] + + tokens_current_history = 0 trimmed_history = [] - for i in reversed(history): - - if("prompt" in i and "response" in i): - cumulative_token_count += num_tokens_from_string(i["prompt"] + i["response"]) - if(cumulative_token_count > max_token_limit): - break - trimmed_history.insert(0,i) - - return trimmed_history \ No newline at end of file + for message in reversed(history): + if "prompt" in message and "response" in message: + tokens_batch = num_tokens_from_string(message["prompt"]) + num_tokens_from_string( + message["response"] + ) + if tokens_current_history + tokens_batch < max_token_limit: + tokens_current_history += tokens_batch + trimmed_history.insert(0, message) + else: + break + + return trimmed_history