Merge pull request #1630 from siiddhantt/feat/show-tool-execution

feat: tool calls tracking
This commit is contained in:
Alex
2025-02-14 10:27:15 +00:00
committed by GitHub
18 changed files with 502 additions and 99 deletions

View File

@@ -88,9 +88,6 @@ def get_data_from_api_key(api_key):
if data is None:
raise Exception("Invalid API Key, please generate new key", 401)
if "retriever" not in data:
data["retriever"] = None
if "source" in data and isinstance(data["source"], DBRef):
source_doc = db.dereference(data["source"])
data["source"] = str(source_doc["_id"])
@@ -117,7 +114,9 @@ def is_azure_configured():
)
def save_conversation(conversation_id, question, response, source_log_docs, llm,index=None):
def save_conversation(
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None
):
if conversation_id is not None and index is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
@@ -126,20 +125,14 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.sources": source_log_docs,
f"queries.{index}.tool_calls": tool_calls,
}
}
},
)
##remove following queries from the array
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$push":{
"queries":{
"$each":[],
"$slice":index+1
}
}
}
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
elif conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
@@ -150,6 +143,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
"prompt": question,
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
}
}
},
@@ -169,11 +163,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
},
]
@@ -188,6 +178,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
"prompt": question,
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
}
],
}
@@ -208,12 +199,13 @@ def get_prompt(prompt_id):
def complete_stream(
question, retriever, conversation_id, user_api_key, isNoneDoc=False,index=None
question, retriever, conversation_id, user_api_key, isNoneDoc=False, index=None
):
try:
try:
response_full = ""
source_log_docs = []
tool_calls = []
answer = retriever.gen()
sources = retriever.search()
for source in sources:
@@ -222,6 +214,7 @@ def complete_stream(
if len(sources) > 0:
data = json.dumps({"type": "source", "source": sources})
yield f"data: {data}\n\n"
for line in answer:
if "answer" in line:
response_full += str(line["answer"])
@@ -229,6 +222,10 @@ def complete_stream(
yield f"data: {data}\n\n"
elif "source" in line:
source_log_docs.append(line["source"])
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
@@ -239,7 +236,13 @@ def complete_stream(
)
if user_api_key is None:
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm,index
conversation_id,
question,
response_full,
source_log_docs,
tool_calls,
llm,
index,
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
@@ -303,7 +306,7 @@ class Stream(Resource):
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index":fields.Integer(
"index": fields.Integer(
required=False, description="The position where query is to be updated"
),
},
@@ -315,22 +318,24 @@ class Stream(Resource):
data = request.get_json()
required_fields = ["question"]
if "index" in data:
required_fields = ["question","conversation_id"]
required_fields = ["question", "conversation_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
history = limit_chat_history(
json.loads(data.get("history", [])), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
index=data.get("index",None)
index = data.get("index", None)
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
@@ -367,7 +372,7 @@ class Stream(Resource):
gpt_model=gpt_model,
user_api_key=user_api_key,
)
return Response(
complete_stream(
question=question,
@@ -395,7 +400,7 @@ class Stream(Resource):
)
status_code = 400
return Response(
error_stream_generate('Unknown error occurred'),
error_stream_generate("Unknown error occurred"),
status=status_code,
mimetype="text/event-stream",
)
@@ -442,14 +447,16 @@ class Answer(Resource):
@api.doc(description="Provide an answer based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
history = limit_chat_history(
json.loads(data.get("history", [])), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))
@@ -490,13 +497,16 @@ class Answer(Resource):
user_api_key=user_api_key,
)
source_log_docs = []
response_full = ""
source_log_docs = []
tool_calls = []
for line in retriever.gen():
if "source" in line:
source_log_docs.append(line["source"])
elif "answer" in line:
response_full += line["answer"]
elif "tool_calls" in line:
tool_calls.append(line["tool_calls"])
if data.get("isNoneDoc"):
for doc in source_log_docs:
@@ -509,7 +519,12 @@ class Answer(Resource):
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
save_conversation(
conversation_id, question, response_full, source_log_docs, llm
conversation_id,
question,
response_full,
source_log_docs,
tool_calls,
llm,
)
)
retriever_params = retriever.get_params()

View File

@@ -1,26 +1,37 @@
import os
from pathlib import Path
from typing import Optional
import os
from pydantic_settings import BaseSettings
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
class Settings(BaseSettings):
LLM_NAME: str = "docsgpt"
MODEL_NAME: Optional[str] = None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
MODEL_NAME: Optional[str] = (
None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
)
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
MODEL_TOKEN_LIMITS: dict = {"gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5}
MODEL_TOKEN_LIMITS: dict = {
"gpt-4o-mini": 128000,
"gpt-3.5-turbo": 4096,
"claude-2": 1e5,
"gemini-2.0-flash-exp": 1e6,
}
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
)
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
@@ -28,12 +39,18 @@ class Settings(BaseSettings):
API_URL: str = "http://localhost:7091" # backend url for celery worker
API_KEY: Optional[str] = None # LLM api key
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = (
None # azure deployment name for embeddings
)
OPENAI_BASE_URL: Optional[str] = (
None # openai base url for open ai compatable models
)
# elasticsearch
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
@@ -68,12 +85,14 @@ class Settings(BaseSettings):
# Milvus vectorstore config
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
MILVUS_TOKEN: Optional[str] = ""
# LanceDB vectorstore config
LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data
LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors
LANCEDB_TABLE_NAME: Optional[str] = (
"docsgpts" # Name of the table to use for storing vectors
)
BRAVE_SEARCH_API_KEY: Optional[str] = None
FLASK_DEBUG_MODE: bool = False

View File

@@ -1,3 +1,5 @@
import json
from application.core.settings import settings
from application.llm.base import BaseLLM
@@ -15,6 +17,63 @@ class OpenAILLM(BaseLLM):
self.api_key = api_key
self.user_api_key = user_api_key
def _clean_messages_openai(self, messages):
cleaned_messages = []
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "model":
role = "assistant"
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
for item in content:
if "text" in item:
cleaned_messages.append(
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
tool_call = {
"id": item["function_call"]["call_id"],
"type": "function",
"function": {
"name": item["function_call"]["name"],
"arguments": json.dumps(
item["function_call"]["args"]
),
},
}
cleaned_messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
}
)
elif "function_response" in item:
cleaned_messages.append(
{
"role": "tool",
"tool_call_id": item["function_response"][
"call_id"
],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
def _raw_gen(
self,
baseself,
@@ -25,9 +84,15 @@ class OpenAILLM(BaseLLM):
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs,
):
messages = self._clean_messages_openai(messages)
print(messages)
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
model=model,
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
return response.choices[0]
else:
@@ -46,6 +111,7 @@ class OpenAILLM(BaseLLM):
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs,
):
messages = self._clean_messages_openai(messages)
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)

View File

@@ -1,3 +1,5 @@
import uuid
from application.core.settings import settings
from application.retriever.base import BaseRetriever
from application.tools.agent import Agent
@@ -35,6 +37,12 @@ class ClassicRAG(BaseRetriever):
)
)
self.user_api_key = user_api_key
self.agent = Agent(
llm_name=settings.LLM_NAME,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
)
def _get_data(self):
if self.chunks == 0:
@@ -78,21 +86,42 @@ class ClassicRAG(BaseRetriever):
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
call_id = tool_call.get("call_id")
if call_id is None or call_id == "None":
call_id = str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages_combine.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages_combine.append(
{"role": "tool", "content": [function_response_dict]}
)
messages_combine.append({"role": "user", "content": self.question})
# llm = LLMCreator.create_llm(
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
# )
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
agent = Agent(
llm_name=settings.LLM_NAME,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
)
completion = agent.gen(messages_combine)
completion = self.agent.gen(messages_combine)
for line in completion:
yield {"answer": str(line)}
yield {"tool_calls": self.agent.tool_calls.copy()}
def search(self):
return self._get_data()

View File

@@ -16,6 +16,7 @@ class Agent:
# Static tool configuration (to be replaced later)
self.tools = []
self.tool_config = {}
self.tool_calls = []
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
@@ -123,6 +124,16 @@ class Agent:
print(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
call_id = getattr(call, "id", None)
tool_call_data = {
"tool_name": tool_data["name"],
"call_id": call_id if call_id is not None else "None",
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": result,
}
self.tool_calls.append(tool_call_data)
return result, call_id
def _simple_tool_agent(self, messages):
@@ -134,7 +145,11 @@ class Agent:
if isinstance(resp, str):
yield resp
return
if hasattr(resp, "message") and hasattr(resp.message, "content"):
if (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield resp.message.content
return
@@ -142,7 +157,11 @@ class Agent:
if isinstance(resp, str):
yield resp
elif hasattr(resp, "message") and hasattr(resp.message, "content"):
elif (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield resp.message.content
else:
completion = self.llm.gen_stream(
@@ -154,6 +173,7 @@ class Agent:
return
def gen(self, messages):
self.tool_calls = []
if self.llm.supports_tools():
resp = self._simple_tool_agent(messages)
for line in resp:

View File

@@ -24,13 +24,28 @@ class OpenAILLMHandler(LLMHandler):
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
messages.append(
{
"role": "tool",
"content": str(tool_response),
"tool_call_id": call_id,
function_call_dict = {
"function_call": {
"name": call.function.name,
"args": call.function.arguments,
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": call.function.name,
"response": {"result": tool_response},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
except Exception as e:
messages.append(
{

View File

@@ -1,8 +1,9 @@
import tiktoken
import hashlib
from flask import jsonify, make_response
import re
import tiktoken
from flask import jsonify, make_response
_encoding = None
@@ -22,6 +23,7 @@ def num_tokens_from_string(string: str) -> int:
else:
return 0
def num_tokens_from_object_or_list(thing):
if isinstance(thing, list):
return sum([num_tokens_from_object_or_list(x) for x in thing])
@@ -32,6 +34,7 @@ def num_tokens_from_object_or_list(thing):
else:
return 0
def count_tokens_docs(docs):
docs_content = ""
for doc in docs:
@@ -59,6 +62,7 @@ def check_required_fields(data, required_fields):
def get_hash(data):
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
"""
Limits chat history based on token count.
@@ -67,38 +71,41 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
from application.core.settings import settings
max_token_limit = (
max_token_limit
if max_token_limit and
max_token_limit < settings.MODEL_TOKEN_LIMITS.get(
gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
max_token_limit
if max_token_limit
and max_token_limit
< settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
else settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
)
if not history:
return []
tokens_current_history = 0
trimmed_history = []
tokens_current_history = 0
for message in reversed(history):
tokens_batch = 0
if "prompt" in message and "response" in message:
tokens_batch = num_tokens_from_string(message["prompt"]) + num_tokens_from_string(
message["response"]
)
if tokens_current_history + tokens_batch < max_token_limit:
tokens_current_history += tokens_batch
trimmed_history.insert(0, message)
else:
break
tokens_batch += num_tokens_from_string(message["prompt"])
tokens_batch += num_tokens_from_string(message["response"])
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
tool_call_string = f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}"
tokens_batch += num_tokens_from_string(tool_call_string)
if tokens_current_history + tokens_batch < max_token_limit:
tokens_current_history += tokens_batch
trimmed_history.insert(0, message)
else:
break
return trimmed_history
def validate_function_name(function_name):
"""Validates if a function name matches the allowed pattern."""
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
return False
return True
return True