mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-12-03 02:23:14 +00:00
Merge pull request #1630 from siiddhantt/feat/show-tool-execution
feat: tool calls tracking
This commit is contained in:
@@ -88,9 +88,6 @@ def get_data_from_api_key(api_key):
|
||||
if data is None:
|
||||
raise Exception("Invalid API Key, please generate new key", 401)
|
||||
|
||||
if "retriever" not in data:
|
||||
data["retriever"] = None
|
||||
|
||||
if "source" in data and isinstance(data["source"], DBRef):
|
||||
source_doc = db.dereference(data["source"])
|
||||
data["source"] = str(source_doc["_id"])
|
||||
@@ -117,7 +114,9 @@ def is_azure_configured():
|
||||
)
|
||||
|
||||
|
||||
def save_conversation(conversation_id, question, response, source_log_docs, llm,index=None):
|
||||
def save_conversation(
|
||||
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None
|
||||
):
|
||||
if conversation_id is not None and index is not None:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
@@ -126,20 +125,14 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
|
||||
f"queries.{index}.prompt": question,
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.sources": source_log_docs,
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
##remove following queries from the array
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
{
|
||||
"$push":{
|
||||
"queries":{
|
||||
"$each":[],
|
||||
"$slice":index+1
|
||||
}
|
||||
}
|
||||
}
|
||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||
)
|
||||
elif conversation_id is not None and conversation_id != "None":
|
||||
conversations_collection.update_one(
|
||||
@@ -150,6 +143,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -169,11 +163,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
|
||||
"role": "user",
|
||||
"content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"system \n\nUser: "
|
||||
+ question
|
||||
+ "\n\n"
|
||||
+ "AI: "
|
||||
+ response,
|
||||
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -188,6 +178,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -208,12 +199,13 @@ def get_prompt(prompt_id):
|
||||
|
||||
|
||||
def complete_stream(
|
||||
question, retriever, conversation_id, user_api_key, isNoneDoc=False,index=None
|
||||
question, retriever, conversation_id, user_api_key, isNoneDoc=False, index=None
|
||||
):
|
||||
|
||||
try:
|
||||
try:
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
answer = retriever.gen()
|
||||
sources = retriever.search()
|
||||
for source in sources:
|
||||
@@ -222,6 +214,7 @@ def complete_stream(
|
||||
if len(sources) > 0:
|
||||
data = json.dumps({"type": "source", "source": sources})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
for line in answer:
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
@@ -229,6 +222,10 @@ def complete_stream(
|
||||
yield f"data: {data}\n\n"
|
||||
elif "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
@@ -239,7 +236,13 @@ def complete_stream(
|
||||
)
|
||||
if user_api_key is None:
|
||||
conversation_id = save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm,index
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
index,
|
||||
)
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
@@ -303,7 +306,7 @@ class Stream(Resource):
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"index":fields.Integer(
|
||||
"index": fields.Integer(
|
||||
required=False, description="The position where query is to be updated"
|
||||
),
|
||||
},
|
||||
@@ -315,22 +318,24 @@ class Stream(Resource):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
if "index" in data:
|
||||
required_fields = ["question","conversation_id"]
|
||||
required_fields = ["question", "conversation_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
|
||||
history = limit_chat_history(
|
||||
json.loads(data.get("history", [])), gpt_model=gpt_model
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
|
||||
index=data.get("index",None)
|
||||
|
||||
index = data.get("index", None)
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
@@ -367,7 +372,7 @@ class Stream(Resource):
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
|
||||
return Response(
|
||||
complete_stream(
|
||||
question=question,
|
||||
@@ -395,7 +400,7 @@ class Stream(Resource):
|
||||
)
|
||||
status_code = 400
|
||||
return Response(
|
||||
error_stream_generate('Unknown error occurred'),
|
||||
error_stream_generate("Unknown error occurred"),
|
||||
status=status_code,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
@@ -442,14 +447,16 @@ class Answer(Resource):
|
||||
@api.doc(description="Provide an answer based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
|
||||
history = limit_chat_history(
|
||||
json.loads(data.get("history", [])), gpt_model=gpt_model
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = int(data.get("chunks", 2))
|
||||
@@ -490,13 +497,16 @@ class Answer(Resource):
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
source_log_docs = []
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
for line in retriever.gen():
|
||||
if "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
elif "answer" in line:
|
||||
response_full += line["answer"]
|
||||
elif "tool_calls" in line:
|
||||
tool_calls.append(line["tool_calls"])
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in source_log_docs:
|
||||
@@ -509,7 +519,12 @@ class Answer(Resource):
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = str(
|
||||
save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
)
|
||||
)
|
||||
retriever_params = retriever.get_params()
|
||||
|
||||
@@ -1,26 +1,37 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
LLM_NAME: str = "docsgpt"
|
||||
MODEL_NAME: Optional[str] = None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
MODEL_NAME: Optional[str] = (
|
||||
None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
)
|
||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
MODEL_TOKEN_LIMITS: dict = {"gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5}
|
||||
MODEL_TOKEN_LIMITS: dict = {
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"claude-2": 1e5,
|
||||
"gemini-2.0-flash-exp": 1e6,
|
||||
}
|
||||
UPLOAD_FOLDER: str = "inputs"
|
||||
PARSE_PDF_AS_IMAGE: bool = False
|
||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
|
||||
VECTOR_STORE: str = (
|
||||
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
|
||||
)
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
@@ -28,12 +39,18 @@ class Settings(BaseSettings):
|
||||
API_URL: str = "http://localhost:7091" # backend url for celery worker
|
||||
|
||||
API_KEY: Optional[str] = None # LLM api key
|
||||
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
EMBEDDINGS_KEY: Optional[str] = (
|
||||
None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
)
|
||||
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
|
||||
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
|
||||
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
|
||||
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = (
|
||||
None # azure deployment name for embeddings
|
||||
)
|
||||
OPENAI_BASE_URL: Optional[str] = (
|
||||
None # openai base url for open ai compatable models
|
||||
)
|
||||
|
||||
# elasticsearch
|
||||
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
|
||||
@@ -68,12 +85,14 @@ class Settings(BaseSettings):
|
||||
|
||||
# Milvus vectorstore config
|
||||
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
||||
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
|
||||
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
|
||||
MILVUS_TOKEN: Optional[str] = ""
|
||||
|
||||
# LanceDB vectorstore config
|
||||
LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data
|
||||
LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors
|
||||
LANCEDB_TABLE_NAME: Optional[str] = (
|
||||
"docsgpts" # Name of the table to use for storing vectors
|
||||
)
|
||||
BRAVE_SEARCH_API_KEY: Optional[str] = None
|
||||
|
||||
FLASK_DEBUG_MODE: bool = False
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
@@ -15,6 +17,63 @@ class OpenAILLM(BaseLLM):
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
cleaned_messages = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
cleaned_messages.append(
|
||||
{"role": role, "content": item["text"]}
|
||||
)
|
||||
elif "function_call" in item:
|
||||
tool_call = {
|
||||
"id": item["function_call"]["call_id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item["function_call"]["name"],
|
||||
"arguments": json.dumps(
|
||||
item["function_call"]["args"]
|
||||
),
|
||||
},
|
||||
}
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tool_call],
|
||||
}
|
||||
)
|
||||
elif "function_response" in item:
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": item["function_response"][
|
||||
"call_id"
|
||||
],
|
||||
"content": json.dumps(
|
||||
item["function_response"]["response"]["result"]
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format: {item}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
@@ -25,9 +84,15 @@ class OpenAILLM(BaseLLM):
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
print(messages)
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, tools=tools, **kwargs
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
@@ -46,6 +111,7 @@ class OpenAILLM(BaseLLM):
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.tools.agent import Agent
|
||||
@@ -35,6 +37,12 @@ class ClassicRAG(BaseRetriever):
|
||||
)
|
||||
)
|
||||
self.user_api_key = user_api_key
|
||||
self.agent = Agent(
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.user_api_key,
|
||||
)
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
@@ -78,21 +86,42 @@ class ClassicRAG(BaseRetriever):
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
if "tool_calls" in i:
|
||||
for tool_call in i["tool_calls"]:
|
||||
call_id = tool_call.get("call_id")
|
||||
if call_id is None or call_id == "None":
|
||||
call_id = str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages_combine.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
# llm = LLMCreator.create_llm(
|
||||
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
# )
|
||||
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
agent = Agent(
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.user_api_key,
|
||||
)
|
||||
completion = agent.gen(messages_combine)
|
||||
completion = self.agent.gen(messages_combine)
|
||||
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
yield {"tool_calls": self.agent.tool_calls.copy()}
|
||||
|
||||
def search(self):
|
||||
return self._get_data()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ class Agent:
|
||||
# Static tool configuration (to be replaced later)
|
||||
self.tools = []
|
||||
self.tool_config = {}
|
||||
self.tool_calls = []
|
||||
|
||||
def _get_user_tools(self, user="local"):
|
||||
mongo = MongoDB.get_client()
|
||||
@@ -123,6 +124,16 @@ class Agent:
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
call_id = getattr(call, "id", None)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tool_data["name"],
|
||||
"call_id": call_id if call_id is not None else "None",
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
"result": result,
|
||||
}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
|
||||
def _simple_tool_agent(self, messages):
|
||||
@@ -134,7 +145,11 @@ class Agent:
|
||||
if isinstance(resp, str):
|
||||
yield resp
|
||||
return
|
||||
if hasattr(resp, "message") and hasattr(resp.message, "content"):
|
||||
if (
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
yield resp.message.content
|
||||
return
|
||||
|
||||
@@ -142,7 +157,11 @@ class Agent:
|
||||
|
||||
if isinstance(resp, str):
|
||||
yield resp
|
||||
elif hasattr(resp, "message") and hasattr(resp.message, "content"):
|
||||
elif (
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
yield resp.message.content
|
||||
else:
|
||||
completion = self.llm.gen_stream(
|
||||
@@ -154,6 +173,7 @@ class Agent:
|
||||
return
|
||||
|
||||
def gen(self, messages):
|
||||
self.tool_calls = []
|
||||
if self.llm.supports_tools():
|
||||
resp = self._simple_tool_agent(messages)
|
||||
for line in resp:
|
||||
|
||||
@@ -24,13 +24,28 @@ class OpenAILLMHandler(LLMHandler):
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, call
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": str(tool_response),
|
||||
"tool_call_id": call_id,
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": call.function.name,
|
||||
"args": call.function.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": call.function.name,
|
||||
"response": {"result": tool_response},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
messages.append(
|
||||
{
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import tiktoken
|
||||
import hashlib
|
||||
from flask import jsonify, make_response
|
||||
import re
|
||||
|
||||
import tiktoken
|
||||
from flask import jsonify, make_response
|
||||
|
||||
|
||||
_encoding = None
|
||||
|
||||
@@ -22,6 +23,7 @@ def num_tokens_from_string(string: str) -> int:
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def num_tokens_from_object_or_list(thing):
|
||||
if isinstance(thing, list):
|
||||
return sum([num_tokens_from_object_or_list(x) for x in thing])
|
||||
@@ -32,6 +34,7 @@ def num_tokens_from_object_or_list(thing):
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def count_tokens_docs(docs):
|
||||
docs_content = ""
|
||||
for doc in docs:
|
||||
@@ -59,6 +62,7 @@ def check_required_fields(data, required_fields):
|
||||
def get_hash(data):
|
||||
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
|
||||
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
"""
|
||||
Limits chat history based on token count.
|
||||
@@ -67,38 +71,41 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
from application.core.settings import settings
|
||||
|
||||
max_token_limit = (
|
||||
max_token_limit
|
||||
if max_token_limit and
|
||||
max_token_limit < settings.MODEL_TOKEN_LIMITS.get(
|
||||
gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
else settings.MODEL_TOKEN_LIMITS.get(
|
||||
gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
)
|
||||
|
||||
max_token_limit
|
||||
if max_token_limit
|
||||
and max_token_limit
|
||||
< settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
||||
else settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
||||
)
|
||||
|
||||
if not history:
|
||||
return []
|
||||
|
||||
tokens_current_history = 0
|
||||
|
||||
trimmed_history = []
|
||||
|
||||
tokens_current_history = 0
|
||||
|
||||
for message in reversed(history):
|
||||
tokens_batch = 0
|
||||
if "prompt" in message and "response" in message:
|
||||
tokens_batch = num_tokens_from_string(message["prompt"]) + num_tokens_from_string(
|
||||
message["response"]
|
||||
)
|
||||
if tokens_current_history + tokens_batch < max_token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
trimmed_history.insert(0, message)
|
||||
else:
|
||||
break
|
||||
tokens_batch += num_tokens_from_string(message["prompt"])
|
||||
tokens_batch += num_tokens_from_string(message["response"])
|
||||
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_call_string = f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}"
|
||||
tokens_batch += num_tokens_from_string(tool_call_string)
|
||||
|
||||
if tokens_current_history + tokens_batch < max_token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
trimmed_history.insert(0, message)
|
||||
else:
|
||||
break
|
||||
|
||||
return trimmed_history
|
||||
|
||||
|
||||
def validate_function_name(function_name):
|
||||
"""Validates if a function name matches the allowed pattern."""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
|
||||
return False
|
||||
return True
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user