From 1f649274d1f1a7f7f0d94d52b76602b166cdd1bf Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 5 Dec 2024 22:44:40 +0000 Subject: [PATCH] feat: tooling init --- application/cache.py | 18 ++--- application/llm/base.py | 14 ++-- application/llm/openai.py | 22 +++++-- application/requirements.txt | 2 +- application/retriever/classic_rag.py | 11 ++-- application/tools/agent.py | 98 ++++++++++++++++++++++++++++ application/tools/base.py | 20 ++++++ application/tools/cryptoprice.py | 73 +++++++++++++++++++++ application/tools/telegram.py | 79 ++++++++++++++++++++++ application/tools/tool_manager.py | 43 ++++++++++++ application/usage.py | 19 ++++-- application/utils.py | 16 ++++- 12 files changed, 383 insertions(+), 32 deletions(-) create mode 100644 application/tools/agent.py create mode 100644 application/tools/base.py create mode 100644 application/tools/cryptoprice.py create mode 100644 application/tools/telegram.py create mode 100644 application/tools/tool_manager.py diff --git a/application/cache.py b/application/cache.py index 33022e45..7239abac 100644 --- a/application/cache.py +++ b/application/cache.py @@ -5,6 +5,7 @@ import logging from threading import Lock from application.core.settings import settings from application.utils import get_hash +import sys logger = logging.getLogger(__name__) @@ -23,18 +24,19 @@ def get_redis_instance(): _redis_instance = None return _redis_instance -def gen_cache_key(*messages, model="docgpt"): +def gen_cache_key(messages, model="docgpt", tools=None): if not all(isinstance(msg, dict) for msg in messages): raise ValueError("All messages must be dictionaries.") - messages_str = json.dumps(list(messages), sort_keys=True) - combined = f"{model}_{messages_str}" + messages_str = json.dumps(messages) + tools_str = json.dumps(tools) if tools else "" + combined = f"{model}_{messages_str}_{tools_str}" cache_key = get_hash(combined) return cache_key def gen_cache(func): - def wrapper(self, model, messages, *args, **kwargs): + def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): try: - cache_key = gen_cache_key(*messages) + cache_key = gen_cache_key(messages, model, tools) redis_client = get_redis_instance() if redis_client: try: @@ -44,8 +46,8 @@ def gen_cache(func): except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") - result = func(self, model, messages, *args, **kwargs) - if redis_client: + result = func(self, model, messages, stream, tools, *args, **kwargs) + if redis_client and isinstance(result, str): try: redis_client.set(cache_key, result, ex=1800) except redis.ConnectionError as e: @@ -59,7 +61,7 @@ def gen_cache(func): def stream_cache(func): def wrapper(self, model, messages, stream, *args, **kwargs): - cache_key = gen_cache_key(*messages) + cache_key = gen_cache_key(messages) logger.info(f"Stream cache key: {cache_key}") redis_client = get_redis_instance() diff --git a/application/llm/base.py b/application/llm/base.py index 1caab5d3..b9b0e524 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -13,12 +13,12 @@ class BaseLLM(ABC): return method(self, *args, **kwargs) @abstractmethod - def _raw_gen(self, model, messages, stream, *args, **kwargs): + def _raw_gen(self, model, messages, stream, tools, *args, **kwargs): pass - def gen(self, model, messages, stream=False, *args, **kwargs): + def gen(self, model, messages, stream=False, tools=None, *args, **kwargs): decorators = [gen_token_usage, gen_cache] - return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) + return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs) @abstractmethod def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): @@ -26,4 +26,10 @@ class BaseLLM(ABC): def gen_stream(self, model, messages, stream=True, *args, **kwargs): decorators = [stream_cache, stream_token_usage] - return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) \ No newline at end of file + return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) + + def supports_tools(self): + return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools')) + + def _supports_tools(self): + raise NotImplementedError("Subclass must implement _supports_tools method") \ No newline at end of file diff --git a/application/llm/openai.py b/application/llm/openai.py index f85de6ea..cc2285a1 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -25,14 +25,20 @@ class OpenAILLM(BaseLLM): model, messages, stream=False, + tools=None, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs - ): - response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, **kwargs - ) - - return response.choices[0].message.content + ): + if tools: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, tools=tools, **kwargs + ) + return response.choices[0] + else: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) + return response.choices[0].message.content def _raw_gen_stream( self, @@ -40,6 +46,7 @@ class OpenAILLM(BaseLLM): model, messages, stream=True, + tools=None, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs ): @@ -52,6 +59,9 @@ class OpenAILLM(BaseLLM): # print(line.choices[0].delta.content, file=sys.stderr) if line.choices[0].delta.content is not None: yield line.choices[0].delta.content + + def _supports_tools(self): + return True class AzureOpenAILLM(OpenAILLM): diff --git a/application/requirements.txt b/application/requirements.txt index 2f28c2ea..c8f16d85 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -43,7 +43,7 @@ multidict==6.1.0 mypy-extensions==1.0.0 networkx==3.3 numpy==1.26.4 -openai==1.46.1 +openai==1.57.0 openapi-schema-validator==0.6.2 openapi-spec-validator==0.6.0 openapi3-parser==1.1.18 diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 42e318d2..4ac52bc5 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -2,6 +2,7 @@ from application.retriever.base import BaseRetriever from application.core.settings import settings from application.vectorstore.vector_creator import VectorCreator from application.llm.llm_creator import LLMCreator +from application.tools.agent import Agent from application.utils import num_tokens_from_string @@ -90,10 +91,12 @@ class ClassicRAG(BaseRetriever): ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key - ) - completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) + # llm = LLMCreator.create_llm( + # settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + # ) + # completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) + agent = Agent(llm_name=settings.LLM_NAME,gpt_model=self.gpt_model, api_key=settings.API_KEY, user_api_key=self.user_api_key) + completion = agent.gen(messages_combine) for line in completion: yield {"answer": str(line)} diff --git a/application/tools/agent.py b/application/tools/agent.py new file mode 100644 index 00000000..2df14442 --- /dev/null +++ b/application/tools/agent.py @@ -0,0 +1,98 @@ +from application.llm.llm_creator import LLMCreator +from application.core.settings import settings +from application.tools.tool_manager import ToolManager +import json + +tool_tg = { + "name": "telegram_send_message", + "description": "Send a notification to telegram about current chat", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Text to send in the notification" + } + }, + "required": ["text"], + "additionalProperties": False + } +} + +tool_crypto = { + "name": "cryptoprice_get", + "description": "Retrieve the price of a specified cryptocurrency in a given currency", + "parameters": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The cryptocurrency symbol (e.g. BTC)" + }, + "currency": { + "type": "string", + "description": "The currency in which you want the price (e.g. USD)" + } + }, + "required": ["symbol", "currency"], + "additionalProperties": False + } +} + +class Agent: + def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): + # Initialize the LLM with the provided parameters + self.llm = LLMCreator.create_llm(llm_name, api_key=api_key, user_api_key=user_api_key) + self.gpt_model = gpt_model + # Static tool configuration (to be replaced later) + self.tools = [ + { + "type": "function", + "function": tool_crypto + } + ] + self.tool_config = { + } + + def gen(self, messages): + # Generate initial response from the LLM + resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) + + if isinstance(resp, str): + # Yield the response if it's a string and exit + yield resp + return + + while resp.finish_reason == "tool_calls": + # Append the assistant's message to the conversation + messages.append(json.loads(resp.model_dump_json())['message']) + # Handle each tool call + tool_calls = resp.message.tool_calls + for call in tool_calls: + tm = ToolManager(config={}) + call_name = call.function.name + call_args = json.loads(call.function.arguments) + call_id = call.id + # Determine the tool name and load it + tool_name = call_name.split("_")[0] + tool = tm.load_tool(tool_name, tool_config=self.tool_config) + # Execute the tool's action + resp_tool = tool.execute_action(call_name, **call_args) + # Append the tool's response to the conversation + messages.append( + { + "role": "tool", + "content": str(resp_tool), + "tool_call_id": call_id + } + ) + # Generate a new response from the LLM after processing tools + resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) + + # If no tool calls are needed, generate the final response + if isinstance(resp, str): + yield resp + else: + completion = self.llm.gen_stream(model=self.gpt_model, messages=messages, tools=self.tools) + for line in completion: + yield line diff --git a/application/tools/base.py b/application/tools/base.py new file mode 100644 index 00000000..00cfee3a --- /dev/null +++ b/application/tools/base.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + +class Tool(ABC): + @abstractmethod + def execute_action(self, action_name: str, **kwargs): + pass + + @abstractmethod + def get_actions_metadata(self): + """ + Returns a list of JSON objects describing the actions supported by the tool. + """ + pass + + @abstractmethod + def get_config_requirements(self): + """ + Returns a dictionary describing the configuration requirements for the tool. + """ + pass diff --git a/application/tools/cryptoprice.py b/application/tools/cryptoprice.py new file mode 100644 index 00000000..d7cf61e1 --- /dev/null +++ b/application/tools/cryptoprice.py @@ -0,0 +1,73 @@ +from application.tools.base import Tool +import requests + +class CryptoPriceTool(Tool): + def __init__(self, config): + self.config = config + + def execute_action(self, action_name, **kwargs): + actions = { + "cryptoprice_get": self.get_price + } + + if action_name in actions: + return actions[action_name](**kwargs) + else: + raise ValueError(f"Unknown action: {action_name}") + + def get_price(self, symbol, currency): + """ + Fetches the current price of a given cryptocurrency symbol in the specified currency. + Example: + symbol = "BTC" + currency = "USD" + returns price in USD. + """ + url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}" + response = requests.get(url) + if response.status_code == 200: + data = response.json() + # data will be like {"USD": } if the call is successful + if currency.upper() in data: + return { + "status_code": response.status_code, + "price": data[currency.upper()], + "message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully." + } + else: + return { + "status_code": response.status_code, + "message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}." + } + else: + return { + "status_code": response.status_code, + "message": "Failed to retrieve price." + } + + def get_actions_metadata(self): + return [ + { + "name": "cryptoprice_get", + "description": "Retrieve the price of a specified cryptocurrency in a given currency", + "parameters": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The cryptocurrency symbol (e.g. BTC)" + }, + "currency": { + "type": "string", + "description": "The currency in which you want the price (e.g. USD)" + } + }, + "required": ["symbol", "currency"], + "additionalProperties": False + } + } + ] + + def get_config_requirements(self): + # No specific configuration needed for this tool as it just queries a public endpoint + return {} diff --git a/application/tools/telegram.py b/application/tools/telegram.py new file mode 100644 index 00000000..8210d8e7 --- /dev/null +++ b/application/tools/telegram.py @@ -0,0 +1,79 @@ +from application.tools.base import Tool +import requests + +class TelegramTool(Tool): + def __init__(self, config): + self.config = config + self.chat_id = config.get("chat_id", "142189016") + self.token = config.get("token", "YOUR_TG_TOKEN") + + def execute_action(self, action_name, **kwargs): + actions = { + "telegram_send_message": self.send_message, + "telegram_send_image": self.send_image + } + + if action_name in actions: + return actions[action_name](**kwargs) + else: + raise ValueError(f"Unknown action: {action_name}") + + def send_message(self, text): + print(f"Sending message: {text}") + url = f"https://api.telegram.org/bot{self.token}/sendMessage" + payload = {"chat_id": self.chat_id, "text": text} + response = requests.post(url, data=payload) + return {"status_code": response.status_code, "message": "Message sent"} + + def send_image(self, image_url): + print(f"Sending image: {image_url}") + url = f"https://api.telegram.org/bot{self.token}/sendPhoto" + payload = {"chat_id": self.chat_id, "photo": image_url} + response = requests.post(url, data=payload) + return {"status_code": response.status_code, "message": "Image sent"} + + def get_actions_metadata(self): + return [ + { + "name": "telegram_send_message", + "description": "Send a notification to telegram chat", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Text to send in the notification" + } + }, + "required": ["text"], + "additionalProperties": False + } + }, + { + "name": "telegram_send_image", + "description": "Send an image to the Telegram chat", + "parameters": { + "type": "object", + "properties": { + "image_url": { + "type": "string", + "description": "URL of the image to send" + } + }, + "required": ["image_url"], + "additionalProperties": False + } + } + ] + + def get_config_requirements(self): + return { + "chat_id": { + "type": "string", + "description": "Telegram chat ID to send messages to" + }, + "token": { + "type": "string", + "description": "Bot token for authentication" + } + } diff --git a/application/tools/tool_manager.py b/application/tools/tool_manager.py new file mode 100644 index 00000000..10231cb2 --- /dev/null +++ b/application/tools/tool_manager.py @@ -0,0 +1,43 @@ +import importlib +import inspect +import pkgutil +import os + +from application.tools.base import Tool + +class ToolManager: + def __init__(self, config): + self.config = config + self.tools = {} + self.load_tools() + + def load_tools(self): + tools_dir = os.path.dirname(__file__) + for finder, name, ispkg in pkgutil.iter_modules([tools_dir]): + if name == 'base' or name.startswith('__'): + continue + module = importlib.import_module(f'application.tools.{name}') + for member_name, obj in inspect.getmembers(module, inspect.isclass): + if issubclass(obj, Tool) and obj is not Tool: + tool_config = self.config.get(name, {}) + self.tools[name] = obj(tool_config) + + def load_tool(self, tool_name, tool_config): + self.config[tool_name] = tool_config + tools_dir = os.path.dirname(__file__) + module = importlib.import_module(f'application.tools.{tool_name}') + for member_name, obj in inspect.getmembers(module, inspect.isclass): + if issubclass(obj, Tool) and obj is not Tool: + return obj(tool_config) + + + def execute_action(self, tool_name, action_name, **kwargs): + if tool_name not in self.tools: + raise ValueError(f"Tool '{tool_name}' not loaded") + return self.tools[tool_name].execute_action(action_name, **kwargs) + + def get_all_actions_metadata(self): + metadata = [] + for tool in self.tools.values(): + metadata.extend(tool.get_actions_metadata()) + return metadata diff --git a/application/usage.py b/application/usage.py index e87ebe38..fe4cd50e 100644 --- a/application/usage.py +++ b/application/usage.py @@ -1,7 +1,7 @@ import sys from datetime import datetime from application.core.mongo_db import MongoDB -from application.utils import num_tokens_from_string +from application.utils import num_tokens_from_string, num_tokens_from_object_or_list mongo = MongoDB.get_client() db = mongo["docsgpt"] @@ -21,11 +21,16 @@ def update_token_usage(user_api_key, token_usage): def gen_token_usage(func): - def wrapper(self, model, messages, stream, **kwargs): + def wrapper(self, model, messages, stream, tools, **kwargs): for message in messages: - self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) - result = func(self, model, messages, stream, **kwargs) - self.token_usage["generated_tokens"] += num_tokens_from_string(result) + if message["content"]: + self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) + result = func(self, model, messages, stream, tools, **kwargs) + # check if result is a string + if isinstance(result, str): + self.token_usage["generated_tokens"] += num_tokens_from_string(result) + else: + self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(result) update_token_usage(self.user_api_key, self.token_usage) return result @@ -33,11 +38,11 @@ def gen_token_usage(func): def stream_token_usage(func): - def wrapper(self, model, messages, stream, **kwargs): + def wrapper(self, model, messages, stream, tools, **kwargs): for message in messages: self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) batch = [] - result = func(self, model, messages, stream, **kwargs) + result = func(self, model, messages, stream, tools, **kwargs) for r in result: batch.append(r) yield r diff --git a/application/utils.py b/application/utils.py index 1fc9e329..3b2eb9f3 100644 --- a/application/utils.py +++ b/application/utils.py @@ -15,9 +15,21 @@ def get_encoding(): def num_tokens_from_string(string: str) -> int: encoding = get_encoding() - num_tokens = len(encoding.encode(string)) - return num_tokens + if isinstance(string, str): + num_tokens = len(encoding.encode(string)) + return num_tokens + else: + return 0 +def num_tokens_from_object_or_list(thing): + if isinstance(thing, list): + return sum([num_tokens_from_object_or_list(x) for x in thing]) + elif isinstance(thing, dict): + return sum([num_tokens_from_object_or_list(x) for x in thing.values()]) + elif isinstance(thing, str): + return num_tokens_from_string(thing) + else: + return 0 def count_tokens_docs(docs): docs_content = ""