feat: tools frontend and endpoints refactor

This commit is contained in:
Siddhant Rai
2024-12-18 22:48:40 +05:30
parent f87ae429f4
commit f9a7db11eb
23 changed files with 1069 additions and 168 deletions

View File

@@ -1,9 +1,10 @@
from application.llm.llm_creator import LLMCreator
from application.core.settings import settings
from application.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
import json
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.tools.tool_manager import ToolManager
tool_tg = {
"name": "telegram_send_message",
"description": "Send a notification to telegram about current chat",
@@ -12,15 +13,15 @@ tool_tg = {
"properties": {
"text": {
"type": "string",
"description": "Text to send in the notification"
"description": "Text to send in the notification",
}
},
"required": ["text"],
"additionalProperties": False
}
"additionalProperties": False,
},
}
tool_crypto = {
tool_crypto = {
"name": "cryptoprice_get",
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
"parameters": {
@@ -28,33 +29,30 @@ tool_crypto = {
"properties": {
"symbol": {
"type": "string",
"description": "The cryptocurrency symbol (e.g. BTC)"
"description": "The cryptocurrency symbol (e.g. BTC)",
},
"currency": {
"type": "string",
"description": "The currency in which you want the price (e.g. USD)"
}
"description": "The currency in which you want the price (e.g. USD)",
},
},
"required": ["symbol", "currency"],
"additionalProperties": False
}
"additionalProperties": False,
},
}
class Agent:
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
# Initialize the LLM with the provided parameters
self.llm = LLMCreator.create_llm(llm_name, api_key=api_key, user_api_key=user_api_key)
self.llm = LLMCreator.create_llm(
llm_name, api_key=api_key, user_api_key=user_api_key
)
self.gpt_model = gpt_model
# Static tool configuration (to be replaced later)
self.tools = [
{
"type": "function",
"function": tool_crypto
}
]
self.tool_config = {
}
self.tools = [{"type": "function", "function": tool_crypto}]
self.tool_config = {}
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
@@ -65,19 +63,17 @@ class Agent:
tool.pop("_id")
user_tools = {tool["name"]: tool for tool in user_tools}
return user_tools
def _simple_tool_agent(self, messages):
tools_dict = self._get_user_tools()
# combine all tool_actions into one list
self.tools.extend([
{
"type": "function",
"function": tool_action
}
for tool in tools_dict.values()
for tool_action in tool["actions"]
])
self.tools.extend(
[
{"type": "function", "function": tool_action}
for tool in tools_dict.values()
for tool_action in tool["actions"]
]
)
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
@@ -88,7 +84,7 @@ class Agent:
while resp.finish_reason == "tool_calls":
# Append the assistant's message to the conversation
messages.append(json.loads(resp.model_dump_json())['message'])
messages.append(json.loads(resp.model_dump_json())["message"])
# Handle each tool call
tool_calls = resp.message.tool_calls
for call in tool_calls:
@@ -98,25 +94,27 @@ class Agent:
call_id = call.id
# Determine the tool name and load it
tool_name = call_name.split("_")[0]
tool = tm.load_tool(tool_name, tool_config=tools_dict[tool_name]['config'])
tool = tm.load_tool(
tool_name, tool_config=tools_dict[tool_name]["config"]
)
# Execute the tool's action
resp_tool = tool.execute_action(call_name, **call_args)
# Append the tool's response to the conversation
messages.append(
{
"role": "tool",
"content": str(resp_tool),
"tool_call_id": call_id
}
{"role": "tool", "content": str(resp_tool), "tool_call_id": call_id}
)
# Generate a new response from the LLM after processing tools
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
resp = self.llm.gen(
model=self.gpt_model, messages=messages, tools=self.tools
)
# If no tool calls are needed, generate the final response
if isinstance(resp, str):
yield resp
else:
completion = self.llm.gen_stream(model=self.gpt_model, messages=messages, tools=self.tools)
completion = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=self.tools
)
for line in completion:
yield line
@@ -127,4 +125,4 @@ class Agent:
else:
resp = self.llm.gen_stream(model=self.gpt_model, messages=messages)
for line in resp:
yield line
yield line