mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
112 lines
3.1 KiB
Python
112 lines
3.1 KiB
Python
import hashlib
|
|
import re
|
|
|
|
import tiktoken
|
|
from flask import jsonify, make_response
|
|
|
|
|
|
_encoding = None
|
|
|
|
|
|
def get_encoding():
|
|
global _encoding
|
|
if _encoding is None:
|
|
_encoding = tiktoken.get_encoding("cl100k_base")
|
|
return _encoding
|
|
|
|
|
|
def num_tokens_from_string(string: str) -> int:
|
|
encoding = get_encoding()
|
|
if isinstance(string, str):
|
|
num_tokens = len(encoding.encode(string))
|
|
return num_tokens
|
|
else:
|
|
return 0
|
|
|
|
|
|
def num_tokens_from_object_or_list(thing):
|
|
if isinstance(thing, list):
|
|
return sum([num_tokens_from_object_or_list(x) for x in thing])
|
|
elif isinstance(thing, dict):
|
|
return sum([num_tokens_from_object_or_list(x) for x in thing.values()])
|
|
elif isinstance(thing, str):
|
|
return num_tokens_from_string(thing)
|
|
else:
|
|
return 0
|
|
|
|
|
|
def count_tokens_docs(docs):
|
|
docs_content = ""
|
|
for doc in docs:
|
|
docs_content += doc.page_content
|
|
|
|
tokens = num_tokens_from_string(docs_content)
|
|
return tokens
|
|
|
|
|
|
def check_required_fields(data, required_fields):
|
|
missing_fields = [field for field in required_fields if field not in data]
|
|
if missing_fields:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": f"Missing fields: {', '.join(missing_fields)}",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
return None
|
|
|
|
|
|
def get_hash(data):
|
|
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
|
|
|
|
|
|
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
|
"""
|
|
Limits chat history based on token count.
|
|
Returns a list of messages that fit within the token limit.
|
|
"""
|
|
from application.core.settings import settings
|
|
|
|
max_token_limit = (
|
|
max_token_limit
|
|
if max_token_limit
|
|
and max_token_limit
|
|
< settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
|
else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
|
)
|
|
|
|
if not history:
|
|
return []
|
|
|
|
trimmed_history = []
|
|
tokens_current_history = 0
|
|
|
|
for message in reversed(history):
|
|
tokens_batch = 0
|
|
if "prompt" in message and "response" in message:
|
|
tokens_batch += num_tokens_from_string(message["prompt"])
|
|
tokens_batch += num_tokens_from_string(message["response"])
|
|
|
|
if "tool_calls" in message:
|
|
for tool_call in message["tool_calls"]:
|
|
tool_call_string = f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}"
|
|
tokens_batch += num_tokens_from_string(tool_call_string)
|
|
|
|
if tokens_current_history + tokens_batch < max_token_limit:
|
|
tokens_current_history += tokens_batch
|
|
trimmed_history.insert(0, message)
|
|
else:
|
|
break
|
|
|
|
return trimmed_history
|
|
|
|
|
|
def validate_function_name(function_name):
|
|
"""Validates if a function name matches the allowed pattern."""
|
|
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
|
|
return False
|
|
return True
|