Merge branch 'main' into tool-use

This commit is contained in:
Alex
2024-12-20 17:29:41 +00:00
committed by GitHub
84 changed files with 4317 additions and 5522 deletions

View File

@@ -18,7 +18,7 @@ from application.error import bad_request
from application.extensions import api
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import check_required_fields
from application.utils import check_required_fields, limit_chat_history
logger = logging.getLogger(__name__)
@@ -118,8 +118,31 @@ def is_azure_configured():
)
def save_conversation(conversation_id, question, response, source_log_docs, llm):
if conversation_id is not None and conversation_id != "None":
def save_conversation(conversation_id, question, response, source_log_docs, llm,index=None):
if conversation_id is not None and index is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$set": {
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.sources": source_log_docs,
}
}
)
##remove following queries from the array
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$push":{
"queries":{
"$each":[],
"$slice":index+1
}
}
}
)
elif conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
@@ -141,17 +164,17 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
"language as the system",
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system",
"system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
},
]
@@ -186,7 +209,7 @@ def get_prompt(prompt_id):
def complete_stream(
question, retriever, conversation_id, user_api_key, isNoneDoc=False
question, retriever, conversation_id, user_api_key, isNoneDoc=False,index=None
):
try:
@@ -217,7 +240,7 @@ def complete_stream(
)
if user_api_key is None:
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
conversation_id, question, response_full, source_log_docs, llm,index
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
@@ -282,6 +305,9 @@ class Stream(Resource):
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index":fields.Integer(
required=False, description="The position where query is to be updated"
),
},
)
@@ -290,23 +316,23 @@ class Stream(Resource):
def post(self):
data = request.get_json()
required_fields = ["question"]
if "index" in data:
required_fields = ["question","conversation_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = data.get("history", [])
history = json.loads(history)
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
index=data.get("index",None)
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
@@ -343,7 +369,7 @@ class Stream(Resource):
gpt_model=gpt_model,
user_api_key=user_api_key,
)
return Response(
complete_stream(
question=question,
@@ -351,6 +377,7 @@ class Stream(Resource):
conversation_id=conversation_id,
user_api_key=user_api_key,
isNoneDoc=data.get("isNoneDoc"),
index=index,
),
mimetype="text/event-stream",
)
@@ -428,7 +455,7 @@ class Answer(Resource):
try:
question = data["question"]
history = data.get("history", [])
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))

View File

@@ -181,10 +181,12 @@ class SubmitFeedback(Resource):
"FeedbackModel",
{
"question": fields.String(
required=True, description="The user question"
required=False, description="The user question"
),
"answer": fields.String(required=True, description="The AI answer"),
"answer": fields.String(required=False, description="The AI answer"),
"feedback": fields.String(required=True, description="User feedback"),
"question_index":fields.Integer(required=True, description="The question number in that particular conversation"),
"conversation_id":fields.String(required=True, description="id of the particular conversation"),
"api_key": fields.String(description="Optional API key"),
},
)
@@ -194,23 +196,21 @@ class SubmitFeedback(Resource):
)
def post(self):
data = request.get_json()
required_fields = ["question", "answer", "feedback"]
required_fields = [ "feedback","conversation_id","question_index"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
new_doc = {
"question": data["question"],
"answer": data["answer"],
"feedback": data["feedback"],
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
if "api_key" in data:
new_doc["api_key"] = data["api_key"]
try:
feedback_collection.insert_one(new_doc)
conversations_collection.update_one(
{"_id": ObjectId(data["conversation_id"]), f"queries.{data['question_index']}": {"$exists": True}},
{
"$set": {
f"queries.{data['question_index']}.feedback": data["feedback"]
}
}
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@@ -253,13 +253,12 @@ class DeleteOldIndexes(Resource):
jsonify({"success": False, "message": "Missing required fields"}), 400
)
try:
doc = sources_collection.find_one(
doc = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": "local"}
)
if not doc:
)
if not doc:
return make_response(jsonify({"status": "not found"}), 404)
try:
if settings.VECTOR_STORE == "faiss":
shutil.rmtree(os.path.join(current_dir, "indexes", str(doc["_id"])))
else:
@@ -268,12 +267,12 @@ class DeleteOldIndexes(Resource):
)
vectorstore.delete_index()
sources_collection.delete_one({"_id": ObjectId(source_id)})
except FileNotFoundError:
pass
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
sources_collection.delete_one({"_id": ObjectId(source_id)})
return make_response(jsonify({"success": True}), 200)
@@ -344,6 +343,9 @@ class UploadFile(Resource):
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
job_name,
final_filename,
@@ -370,6 +372,9 @@ class UploadFile(Resource):
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
job_name,
final_filename,
@@ -478,11 +483,22 @@ class PaginatedSources(Resource):
sort_order = request.args.get("order", "desc") # Default to 'desc'
page = int(request.args.get("page", 1)) # Default to 1
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
# add .strip() to remove leading and trailing whitespaces
search_term = request.args.get(
"search", ""
).strip() # add search for filter documents
# Prepare
# Prepare query for filtering
query = {"user": user}
if search_term:
query["name"] = {
"$regex": search_term,
"$options": "i", # using case-insensitive search
}
total_documents = sources_collection.count_documents(query)
total_pages = max(1, math.ceil(total_documents / rows_per_page))
page = min(max(1, page), total_pages) # add this to make sure page inbound is within the range
sort_order = 1 if sort_order == "asc" else -1
skip = (page - 1) * rows_per_page
@@ -2082,3 +2098,4 @@ class DeleteTool(Resource):
return {"success": False, "error": str(err)}, 400
return {"success": True}, 200

View File

@@ -2,14 +2,22 @@ from celery import Celery
from application.core.settings import settings
from celery.signals import setup_logging
def make_celery(app_name=__name__):
celery = Celery(app_name, broker=settings.CELERY_BROKER_URL, backend=settings.CELERY_RESULT_BACKEND)
celery = Celery(
app_name,
broker=settings.CELERY_BROKER_URL,
backend=settings.CELERY_RESULT_BACKEND,
)
celery.conf.update(settings)
return celery
@setup_logging.connect
def config_loggers(*args, **kwargs):
from application.core.logging_config import setup_logging
setup_logging()
celery = make_celery()

View File

@@ -16,8 +16,9 @@ class Settings(BaseSettings):
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5}
MODEL_TOKEN_LIMITS: dict = {"gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5}
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search

View File

@@ -9,35 +9,25 @@ class DocsGPTAPILLM(BaseLLM):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
self.endpoint = "https://llm.docsgpt.co.uk"
self.endpoint = "https://llm.arc53.com"
def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
response = requests.post(
f"{self.endpoint}/answer", json={"prompt": prompt, "max_new_tokens": 30}
f"{self.endpoint}/answer", json={"messages": messages, "max_new_tokens": 30}
)
response_clean = response.json()["a"].replace("###", "")
return response_clean
def _raw_gen_stream(self, baseself, model, messages, stream=True, *args, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
# send prompt to endpoint /stream
response = requests.post(
f"{self.endpoint}/stream",
json={"prompt": prompt, "max_new_tokens": 256},
json={"messages": messages, "max_new_tokens": 256},
stream=True,
)
for line in response.iter_lines():
if line:
# data = json.loads(line)
data_str = line.decode("utf-8")
if data_str.startswith("data: "):
data = json.loads(data_str[6:])

View File

@@ -13,6 +13,7 @@ from application.parser.file.rst_parser import RstParser
from application.parser.file.tabular_parser import PandasCSVParser,ExcelParser
from application.parser.file.json_parser import JSONParser
from application.parser.file.pptx_parser import PPTXParser
from application.parser.file.image_parser import ImageParser
from application.parser.schema.base import Document
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = {
@@ -27,6 +28,9 @@ DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = {
".mdx": MarkdownParser(),
".json":JSONParser(),
".pptx":PPTXParser(),
".png": ImageParser(),
".jpg": ImageParser(),
".jpeg": ImageParser(),
}

View File

@@ -7,7 +7,8 @@ from pathlib import Path
from typing import Dict
from application.parser.file.base_parser import BaseParser
from application.core.settings import settings
import requests
class PDFParser(BaseParser):
"""PDF parser."""
@@ -18,6 +19,15 @@ class PDFParser(BaseParser):
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
if settings.PARSE_PDF_AS_IMAGE:
doc2md_service = "https://llm.arc53.com/doc2md"
# alternatively you can use local vision capable LLM
with open(file, "rb") as file_loaded:
files = {'file': file_loaded}
response = requests.post(doc2md_service, files=files)
data = response.json()["markdown"]
return data
try:
import PyPDF2
except ImportError:

View File

@@ -0,0 +1,27 @@
"""Image parser.
Contains parser for .png, .jpg, .jpeg files.
"""
from pathlib import Path
import requests
from typing import Dict, Union
from application.parser.file.base_parser import BaseParser
class ImageParser(BaseParser):
"""Image parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
doc2md_service = "https://llm.arc53.com/doc2md"
# alternatively you can use local vision capable LLM
with open(file, "rb") as file_loaded:
files = {'file': file_loaded}
response = requests.post(doc2md_service, files=files)
data = response.json()["markdown"]
return data

View File

@@ -91,6 +91,25 @@ class RstParser(BaseParser):
]
return rst_tups
def chunk_by_token_count(self, text: str, max_tokens: int = 100) -> List[str]:
"""Chunk text by token count."""
avg_token_length = 5
chunk_size = max_tokens * avg_token_length
chunks = []
for i in range(0, len(text), chunk_size):
chunk = text[i:i+chunk_size]
if i + chunk_size < len(text):
last_space = chunk.rfind(' ')
if last_space != -1:
chunk = chunk[:last_space]
chunks.append(chunk.strip())
return chunks
def remove_images(self, content: str) -> str:
pattern = r"\.\. image:: (.*)"
content = re.sub(pattern, "", content)
@@ -136,7 +155,7 @@ class RstParser(BaseParser):
return {}
def parse_tups(
self, filepath: Path, errors: str = "ignore"
self, filepath: Path, errors: str = "ignore",max_tokens: Optional[int] = 1000
) -> List[Tuple[Optional[str], str]]:
"""Parse file into tuples."""
with open(filepath, "r") as f:
@@ -156,6 +175,15 @@ class RstParser(BaseParser):
rst_tups = self.remove_whitespaces_excess(rst_tups)
if self._remove_characters_excess:
rst_tups = self.remove_characters_excess(rst_tups)
# Apply chunking if max_tokens is provided
if max_tokens is not None:
chunked_tups = []
for header, text in rst_tups:
chunks = self.chunk_by_token_count(text, max_tokens)
for idx, chunk in enumerate(chunks):
chunked_tups.append((f"{header} - Chunk {idx + 1}", chunk))
return chunked_tups
return rst_tups
def parse_file(

View File

@@ -1,66 +0,0 @@
import os
import javalang
def find_files(directory):
files_list = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.java'):
files_list.append(os.path.join(root, file))
return files_list
def extract_functions(file_path):
with open(file_path, "r") as file:
java_code = file.read()
methods = {}
tree = javalang.parse.parse(java_code)
for _, node in tree.filter(javalang.tree.MethodDeclaration):
method_name = node.name
start_line = node.position.line - 1
end_line = start_line
brace_count = 0
for line in java_code.splitlines()[start_line:]:
end_line += 1
brace_count += line.count("{") - line.count("}")
if brace_count == 0:
break
method_source_code = "\n".join(java_code.splitlines()[start_line:end_line])
methods[method_name] = method_source_code
return methods
def extract_classes(file_path):
with open(file_path, 'r') as file:
source_code = file.read()
classes = {}
tree = javalang.parse.parse(source_code)
for class_decl in tree.types:
class_name = class_decl.name
declarations = []
methods = []
for field_decl in class_decl.fields:
field_name = field_decl.declarators[0].name
field_type = field_decl.type.name
declarations.append(f"{field_type} {field_name}")
for method_decl in class_decl.methods:
methods.append(method_decl.name)
class_string = "Declarations: " + ", ".join(declarations) + "\n Method name: " + ", ".join(methods)
classes[class_name] = class_string
return classes
def extract_functions_and_classes(directory):
files = find_files(directory)
functions_dict = {}
classes_dict = {}
for file in files:
functions = extract_functions(file)
if functions:
functions_dict[file] = functions
classes = extract_classes(file)
if classes:
classes_dict[file] = classes
return functions_dict, classes_dict

View File

@@ -1,70 +0,0 @@
import os
import escodegen
import esprima
def find_files(directory):
files_list = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.js'):
files_list.append(os.path.join(root, file))
return files_list
def extract_functions(file_path):
with open(file_path, 'r') as file:
source_code = file.read()
functions = {}
tree = esprima.parseScript(source_code)
for node in tree.body:
if node.type == 'FunctionDeclaration':
func_name = node.id.name if node.id else '<anonymous>'
functions[func_name] = escodegen.generate(node)
elif node.type == 'VariableDeclaration':
for declaration in node.declarations:
if declaration.init and declaration.init.type == 'FunctionExpression':
func_name = declaration.id.name if declaration.id else '<anonymous>'
functions[func_name] = escodegen.generate(declaration.init)
elif node.type == 'ClassDeclaration':
for subnode in node.body.body:
if subnode.type == 'MethodDefinition':
func_name = subnode.key.name
functions[func_name] = escodegen.generate(subnode.value)
elif subnode.type == 'VariableDeclaration':
for declaration in subnode.declarations:
if declaration.init and declaration.init.type == 'FunctionExpression':
func_name = declaration.id.name if declaration.id else '<anonymous>'
functions[func_name] = escodegen.generate(declaration.init)
return functions
def extract_classes(file_path):
with open(file_path, 'r') as file:
source_code = file.read()
classes = {}
tree = esprima.parseScript(source_code)
for node in tree.body:
if node.type == 'ClassDeclaration':
class_name = node.id.name
function_names = []
for subnode in node.body.body:
if subnode.type == 'MethodDefinition':
function_names.append(subnode.key.name)
classes[class_name] = ", ".join(function_names)
return classes
def extract_functions_and_classes(directory):
files = find_files(directory)
functions_dict = {}
classes_dict = {}
for file in files:
functions = extract_functions(file)
if functions:
functions_dict[file] = functions
classes = extract_classes(file)
if classes:
classes_dict[file] = classes
return functions_dict, classes_dict

View File

@@ -1,121 +0,0 @@
import ast
import os
from pathlib import Path
import tiktoken
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
def find_files(directory):
files_list = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.py'):
files_list.append(os.path.join(root, file))
return files_list
def extract_functions(file_path):
with open(file_path, 'r') as file:
source_code = file.read()
functions = {}
tree = ast.parse(source_code)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
func_name = node.name
func_def = ast.get_source_segment(source_code, node)
functions[func_name] = func_def
return functions
def extract_classes(file_path):
with open(file_path, 'r') as file:
source_code = file.read()
classes = {}
tree = ast.parse(source_code)
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
class_name = node.name
function_names = []
for subnode in ast.walk(node):
if isinstance(subnode, ast.FunctionDef):
function_names.append(subnode.name)
classes[class_name] = ", ".join(function_names)
return classes
def extract_functions_and_classes(directory):
files = find_files(directory)
functions_dict = {}
classes_dict = {}
for file in files:
functions = extract_functions(file)
if functions:
functions_dict[file] = functions
classes = extract_classes(file)
if classes:
classes_dict[file] = classes
return functions_dict, classes_dict
def parse_functions(functions_dict, formats, dir):
c1 = len(functions_dict)
for i, (source, functions) in enumerate(functions_dict.items(), start=1):
print(f"Processing file {i}/{c1}")
source_w = source.replace(dir + "/", "").replace("." + formats, ".md")
subfolders = "/".join(source_w.split("/")[:-1])
Path(f"outputs/{subfolders}").mkdir(parents=True, exist_ok=True)
for j, (name, function) in enumerate(functions.items(), start=1):
print(f"Processing function {j}/{len(functions)}")
prompt = PromptTemplate(
input_variables=["code"],
template="Code: \n{code}, \nDocumentation: ",
)
llm = OpenAI(temperature=0)
response = llm(prompt.format(code=function))
mode = "a" if Path(f"outputs/{source_w}").exists() else "w"
with open(f"outputs/{source_w}", mode) as f:
f.write(
f"\n\n# Function name: {name} \n\nFunction: \n```\n{function}\n```, \nDocumentation: \n{response}")
def parse_classes(classes_dict, formats, dir):
c1 = len(classes_dict)
for i, (source, classes) in enumerate(classes_dict.items()):
print(f"Processing file {i + 1}/{c1}")
source_w = source.replace(dir + "/", "").replace("." + formats, ".md")
subfolders = "/".join(source_w.split("/")[:-1])
Path(f"outputs/{subfolders}").mkdir(parents=True, exist_ok=True)
for name, function_names in classes.items():
print(f"Processing Class {i + 1}/{c1}")
prompt = PromptTemplate(
input_variables=["class_name", "functions_names"],
template="Class name: {class_name} \nFunctions: {functions_names}, \nDocumentation: ",
)
llm = OpenAI(temperature=0)
response = llm(prompt.format(class_name=name, functions_names=function_names))
with open(f"outputs/{source_w}", "a" if Path(f"outputs/{source_w}").exists() else "w") as f:
f.write(f"\n\n# Class name: {name} \n\nFunctions: \n{function_names}, \nDocumentation: \n{response}")
def transform_to_docs(functions_dict, classes_dict, formats, dir):
docs_content = ''.join([str(key) + str(value) for key, value in functions_dict.items()])
docs_content += ''.join([str(key) + str(value) for key, value in classes_dict.items()])
num_tokens = len(tiktoken.get_encoding("cl100k_base").encode(docs_content))
total_price = ((num_tokens / 1000) * 0.02)
print(f"Number of Tokens = {num_tokens:,d}")
print(f"Approx Cost = ${total_price:,.2f}")
user_input = input("Price Okay? (Y/N)\n").lower()
if user_input == "y" or user_input == "":
if not Path("outputs").exists():
Path("outputs").mkdir()
parse_functions(functions_dict, formats, dir)
parse_classes(classes_dict, formats, dir)
print("All done!")
else:
print("The API was not called. No money was spent.")

View File

@@ -1,4 +1,4 @@
anthropic==0.34.2
anthropic==0.40.0
boto3==1.34.153
beautifulsoup4==4.12.3
celery==5.3.6
@@ -28,12 +28,12 @@ jsonschema==4.23.0
jsonschema-spec==0.2.4
jsonschema-specifications==2023.7.1
kombu==5.4.2
langchain==0.3.0
langchain-community==0.3.0
langchain-core==0.3.2
langchain==0.3.11
langchain-community==0.3.11
langchain-core==0.3.25
langchain-openai==0.2.0
langchain-text-splitters==0.3.0
langsmith==0.1.125
langsmith==0.2.3
lazy-object-proxy==1.10.0
lxml==5.3.0
markupsafe==2.1.5
@@ -73,17 +73,17 @@ referencing==0.30.2
regex==2024.9.11
requests==2.32.3
retry==0.9.2
sentence-transformers==3.0.1
sentence-transformers==3.3.1
tiktoken==0.7.0
tokenizers==0.19.1
tokenizers==0.21.0
torch==2.4.1
tqdm==4.66.5
transformers==4.44.2
transformers==4.47.0
typing-extensions==4.12.2
typing-inspect==0.9.0
tzdata==2024.2
urllib3==2.2.3
vine==5.1.0
wcwidth==0.2.13
werkzeug==3.0.4
werkzeug==3.1.3
yarl==1.11.1

View File

@@ -2,7 +2,6 @@ import json
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import num_tokens_from_string
from langchain_community.tools import BraveSearch
@@ -73,15 +72,8 @@ class BraveRetSearch(BaseRetriever):
yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)

View File

@@ -6,6 +6,7 @@ from application.utils import num_tokens_from_string
from application.vectorstore.vector_creator import VectorCreator
class ClassicRAG(BaseRetriever):
def __init__(
@@ -73,15 +74,8 @@ class ClassicRAG(BaseRetriever):
yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(
i["prompt"]
) + num_tokens_from_string(i["response"])
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
if "prompt" in i and "response" in i:
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
@@ -89,7 +83,6 @@ class ClassicRAG(BaseRetriever):
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
# llm = LLMCreator.create_llm(
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
# )

View File

@@ -1,7 +1,6 @@
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import num_tokens_from_string
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
@@ -89,16 +88,9 @@ class DuckDuckSearch(BaseRetriever):
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
if len(self.chat_history) > 1:
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
if "prompt" in i and "response" in i:
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)

View File

@@ -58,3 +58,40 @@ def check_required_fields(data, required_fields):
def get_hash(data):
return hashlib.md5(data.encode()).hexdigest()
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
"""
Limits chat history based on token count.
Returns a list of messages that fit within the token limit.
"""
from application.core.settings import settings
max_token_limit = (
max_token_limit
if max_token_limit and
max_token_limit < settings.MODEL_TOKEN_LIMITS.get(
gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
if not history:
return []
tokens_current_history = 0
trimmed_history = []
for message in reversed(history):
if "prompt" in message and "response" in message:
tokens_batch = num_tokens_from_string(message["prompt"]) + num_tokens_from_string(
message["response"]
)
if tokens_current_history + tokens_batch < max_token_limit:
tokens_current_history += tokens_batch
trimmed_history.insert(0, message)
else:
break
return trimmed_history