feat: tooling init

This commit is contained in:
Alex
2024-12-05 22:44:40 +00:00
parent 4443bc77fd
commit 1f649274d1
12 changed files with 383 additions and 32 deletions

View File

@@ -5,6 +5,7 @@ import logging
from threading import Lock
from application.core.settings import settings
from application.utils import get_hash
import sys
logger = logging.getLogger(__name__)
@@ -23,18 +24,19 @@ def get_redis_instance():
_redis_instance = None
return _redis_instance
def gen_cache_key(*messages, model="docgpt"):
def gen_cache_key(messages, model="docgpt", tools=None):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
messages_str = json.dumps(list(messages), sort_keys=True)
combined = f"{model}_{messages_str}"
messages_str = json.dumps(messages)
tools_str = json.dumps(tools) if tools else ""
combined = f"{model}_{messages_str}_{tools_str}"
cache_key = get_hash(combined)
return cache_key
def gen_cache(func):
def wrapper(self, model, messages, *args, **kwargs):
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
try:
cache_key = gen_cache_key(*messages)
cache_key = gen_cache_key(messages, model, tools)
redis_client = get_redis_instance()
if redis_client:
try:
@@ -44,8 +46,8 @@ def gen_cache(func):
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
result = func(self, model, messages, *args, **kwargs)
if redis_client:
result = func(self, model, messages, stream, tools, *args, **kwargs)
if redis_client and isinstance(result, str):
try:
redis_client.set(cache_key, result, ex=1800)
except redis.ConnectionError as e:
@@ -59,7 +61,7 @@ def gen_cache(func):
def stream_cache(func):
def wrapper(self, model, messages, stream, *args, **kwargs):
cache_key = gen_cache_key(*messages)
cache_key = gen_cache_key(messages)
logger.info(f"Stream cache key: {cache_key}")
redis_client = get_redis_instance()

View File

@@ -13,12 +13,12 @@ class BaseLLM(ABC):
return method(self, *args, **kwargs)
@abstractmethod
def _raw_gen(self, model, messages, stream, *args, **kwargs):
def _raw_gen(self, model, messages, stream, tools, *args, **kwargs):
pass
def gen(self, model, messages, stream=False, *args, **kwargs):
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs)
@abstractmethod
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
@@ -27,3 +27,9 @@ class BaseLLM(ABC):
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
decorators = [stream_cache, stream_token_usage]
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
def supports_tools(self):
return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools'))
def _supports_tools(self):
raise NotImplementedError("Subclass must implement _supports_tools method")

View File

@@ -25,14 +25,20 @@ class OpenAILLM(BaseLLM):
model,
messages,
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self,
@@ -40,6 +46,7 @@ class OpenAILLM(BaseLLM):
model,
messages,
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
@@ -53,6 +60,9 @@ class OpenAILLM(BaseLLM):
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content
def _supports_tools(self):
return True
class AzureOpenAILLM(OpenAILLM):

View File

@@ -43,7 +43,7 @@ multidict==6.1.0
mypy-extensions==1.0.0
networkx==3.3
numpy==1.26.4
openai==1.46.1
openai==1.57.0
openapi-schema-validator==0.6.2
openapi-spec-validator==0.6.0
openapi3-parser==1.1.18

View File

@@ -2,6 +2,7 @@ from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
from application.tools.agent import Agent
from application.utils import num_tokens_from_string
@@ -90,10 +91,12 @@ class ClassicRAG(BaseRetriever):
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
# llm = LLMCreator.create_llm(
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
# )
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
agent = Agent(llm_name=settings.LLM_NAME,gpt_model=self.gpt_model, api_key=settings.API_KEY, user_api_key=self.user_api_key)
completion = agent.gen(messages_combine)
for line in completion:
yield {"answer": str(line)}

View File

@@ -0,0 +1,98 @@
from application.llm.llm_creator import LLMCreator
from application.core.settings import settings
from application.tools.tool_manager import ToolManager
import json
tool_tg = {
"name": "telegram_send_message",
"description": "Send a notification to telegram about current chat",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "Text to send in the notification"
}
},
"required": ["text"],
"additionalProperties": False
}
}
tool_crypto = {
"name": "cryptoprice_get",
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
"parameters": {
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "The cryptocurrency symbol (e.g. BTC)"
},
"currency": {
"type": "string",
"description": "The currency in which you want the price (e.g. USD)"
}
},
"required": ["symbol", "currency"],
"additionalProperties": False
}
}
class Agent:
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
# Initialize the LLM with the provided parameters
self.llm = LLMCreator.create_llm(llm_name, api_key=api_key, user_api_key=user_api_key)
self.gpt_model = gpt_model
# Static tool configuration (to be replaced later)
self.tools = [
{
"type": "function",
"function": tool_crypto
}
]
self.tool_config = {
}
def gen(self, messages):
# Generate initial response from the LLM
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
if isinstance(resp, str):
# Yield the response if it's a string and exit
yield resp
return
while resp.finish_reason == "tool_calls":
# Append the assistant's message to the conversation
messages.append(json.loads(resp.model_dump_json())['message'])
# Handle each tool call
tool_calls = resp.message.tool_calls
for call in tool_calls:
tm = ToolManager(config={})
call_name = call.function.name
call_args = json.loads(call.function.arguments)
call_id = call.id
# Determine the tool name and load it
tool_name = call_name.split("_")[0]
tool = tm.load_tool(tool_name, tool_config=self.tool_config)
# Execute the tool's action
resp_tool = tool.execute_action(call_name, **call_args)
# Append the tool's response to the conversation
messages.append(
{
"role": "tool",
"content": str(resp_tool),
"tool_call_id": call_id
}
)
# Generate a new response from the LLM after processing tools
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
# If no tool calls are needed, generate the final response
if isinstance(resp, str):
yield resp
else:
completion = self.llm.gen_stream(model=self.gpt_model, messages=messages, tools=self.tools)
for line in completion:
yield line

20
application/tools/base.py Normal file
View File

@@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
class Tool(ABC):
@abstractmethod
def execute_action(self, action_name: str, **kwargs):
pass
@abstractmethod
def get_actions_metadata(self):
"""
Returns a list of JSON objects describing the actions supported by the tool.
"""
pass
@abstractmethod
def get_config_requirements(self):
"""
Returns a dictionary describing the configuration requirements for the tool.
"""
pass

View File

@@ -0,0 +1,73 @@
from application.tools.base import Tool
import requests
class CryptoPriceTool(Tool):
def __init__(self, config):
self.config = config
def execute_action(self, action_name, **kwargs):
actions = {
"cryptoprice_get": self.get_price
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def get_price(self, symbol, currency):
"""
Fetches the current price of a given cryptocurrency symbol in the specified currency.
Example:
symbol = "BTC"
currency = "USD"
returns price in USD.
"""
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
# data will be like {"USD": <price>} if the call is successful
if currency.upper() in data:
return {
"status_code": response.status_code,
"price": data[currency.upper()],
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully."
}
else:
return {
"status_code": response.status_code,
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}."
}
else:
return {
"status_code": response.status_code,
"message": "Failed to retrieve price."
}
def get_actions_metadata(self):
return [
{
"name": "cryptoprice_get",
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
"parameters": {
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "The cryptocurrency symbol (e.g. BTC)"
},
"currency": {
"type": "string",
"description": "The currency in which you want the price (e.g. USD)"
}
},
"required": ["symbol", "currency"],
"additionalProperties": False
}
}
]
def get_config_requirements(self):
# No specific configuration needed for this tool as it just queries a public endpoint
return {}

View File

@@ -0,0 +1,79 @@
from application.tools.base import Tool
import requests
class TelegramTool(Tool):
def __init__(self, config):
self.config = config
self.chat_id = config.get("chat_id", "142189016")
self.token = config.get("token", "YOUR_TG_TOKEN")
def execute_action(self, action_name, **kwargs):
actions = {
"telegram_send_message": self.send_message,
"telegram_send_image": self.send_image
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def send_message(self, text):
print(f"Sending message: {text}")
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": self.chat_id, "text": text}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Message sent"}
def send_image(self, image_url):
print(f"Sending image: {image_url}")
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": self.chat_id, "photo": image_url}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Image sent"}
def get_actions_metadata(self):
return [
{
"name": "telegram_send_message",
"description": "Send a notification to telegram chat",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "Text to send in the notification"
}
},
"required": ["text"],
"additionalProperties": False
}
},
{
"name": "telegram_send_image",
"description": "Send an image to the Telegram chat",
"parameters": {
"type": "object",
"properties": {
"image_url": {
"type": "string",
"description": "URL of the image to send"
}
},
"required": ["image_url"],
"additionalProperties": False
}
}
]
def get_config_requirements(self):
return {
"chat_id": {
"type": "string",
"description": "Telegram chat ID to send messages to"
},
"token": {
"type": "string",
"description": "Bot token for authentication"
}
}

View File

@@ -0,0 +1,43 @@
import importlib
import inspect
import pkgutil
import os
from application.tools.base import Tool
class ToolManager:
def __init__(self, config):
self.config = config
self.tools = {}
self.load_tools()
def load_tools(self):
tools_dir = os.path.dirname(__file__)
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
if name == 'base' or name.startswith('__'):
continue
module = importlib.import_module(f'application.tools.{name}')
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)
def load_tool(self, tool_name, tool_config):
self.config[tool_name] = tool_config
tools_dir = os.path.dirname(__file__)
module = importlib.import_module(f'application.tools.{tool_name}')
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
return obj(tool_config)
def execute_action(self, tool_name, action_name, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
return self.tools[tool_name].execute_action(action_name, **kwargs)
def get_all_actions_metadata(self):
metadata = []
for tool in self.tools.values():
metadata.extend(tool.get_actions_metadata())
return metadata

View File

@@ -1,7 +1,7 @@
import sys
from datetime import datetime
from application.core.mongo_db import MongoDB
from application.utils import num_tokens_from_string
from application.utils import num_tokens_from_string, num_tokens_from_object_or_list
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
@@ -21,11 +21,16 @@ def update_token_usage(user_api_key, token_usage):
def gen_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
result = func(self, model, messages, stream, **kwargs)
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
if message["content"]:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
result = func(self, model, messages, stream, tools, **kwargs)
# check if result is a string
if isinstance(result, str):
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
else:
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(result)
update_token_usage(self.user_api_key, self.token_usage)
return result
@@ -33,11 +38,11 @@ def gen_token_usage(func):
def stream_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
batch = []
result = func(self, model, messages, stream, **kwargs)
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r

View File

@@ -15,9 +15,21 @@ def get_encoding():
def num_tokens_from_string(string: str) -> int:
encoding = get_encoding()
num_tokens = len(encoding.encode(string))
return num_tokens
if isinstance(string, str):
num_tokens = len(encoding.encode(string))
return num_tokens
else:
return 0
def num_tokens_from_object_or_list(thing):
if isinstance(thing, list):
return sum([num_tokens_from_object_or_list(x) for x in thing])
elif isinstance(thing, dict):
return sum([num_tokens_from_object_or_list(x) for x in thing.values()])
elif isinstance(thing, str):
return num_tokens_from_string(thing)
else:
return 0
def count_tokens_docs(docs):
docs_content = ""