mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: tooling init
This commit is contained in:
@@ -5,6 +5,7 @@ import logging
|
||||
from threading import Lock
|
||||
from application.core.settings import settings
|
||||
from application.utils import get_hash
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,18 +24,19 @@ def get_redis_instance():
|
||||
_redis_instance = None
|
||||
return _redis_instance
|
||||
|
||||
def gen_cache_key(*messages, model="docgpt"):
|
||||
def gen_cache_key(messages, model="docgpt", tools=None):
|
||||
if not all(isinstance(msg, dict) for msg in messages):
|
||||
raise ValueError("All messages must be dictionaries.")
|
||||
messages_str = json.dumps(list(messages), sort_keys=True)
|
||||
combined = f"{model}_{messages_str}"
|
||||
messages_str = json.dumps(messages)
|
||||
tools_str = json.dumps(tools) if tools else ""
|
||||
combined = f"{model}_{messages_str}_{tools_str}"
|
||||
cache_key = get_hash(combined)
|
||||
return cache_key
|
||||
|
||||
def gen_cache(func):
|
||||
def wrapper(self, model, messages, *args, **kwargs):
|
||||
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
|
||||
try:
|
||||
cache_key = gen_cache_key(*messages)
|
||||
cache_key = gen_cache_key(messages, model, tools)
|
||||
redis_client = get_redis_instance()
|
||||
if redis_client:
|
||||
try:
|
||||
@@ -44,8 +46,8 @@ def gen_cache(func):
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Redis connection error: {e}")
|
||||
|
||||
result = func(self, model, messages, *args, **kwargs)
|
||||
if redis_client:
|
||||
result = func(self, model, messages, stream, tools, *args, **kwargs)
|
||||
if redis_client and isinstance(result, str):
|
||||
try:
|
||||
redis_client.set(cache_key, result, ex=1800)
|
||||
except redis.ConnectionError as e:
|
||||
@@ -59,7 +61,7 @@ def gen_cache(func):
|
||||
|
||||
def stream_cache(func):
|
||||
def wrapper(self, model, messages, stream, *args, **kwargs):
|
||||
cache_key = gen_cache_key(*messages)
|
||||
cache_key = gen_cache_key(messages)
|
||||
logger.info(f"Stream cache key: {cache_key}")
|
||||
|
||||
redis_client = get_redis_instance()
|
||||
|
||||
@@ -13,12 +13,12 @@ class BaseLLM(ABC):
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _raw_gen(self, model, messages, stream, *args, **kwargs):
|
||||
def _raw_gen(self, model, messages, stream, tools, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def gen(self, model, messages, stream=False, *args, **kwargs):
|
||||
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
|
||||
decorators = [gen_token_usage, gen_cache]
|
||||
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
|
||||
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
|
||||
@@ -27,3 +27,9 @@ class BaseLLM(ABC):
|
||||
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
|
||||
decorators = [stream_cache, stream_token_usage]
|
||||
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
|
||||
|
||||
def supports_tools(self):
|
||||
return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools'))
|
||||
|
||||
def _supports_tools(self):
|
||||
raise NotImplementedError("Subclass must implement _supports_tools method")
|
||||
@@ -25,14 +25,20 @@ class OpenAILLM(BaseLLM):
|
||||
model,
|
||||
messages,
|
||||
stream=False,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs
|
||||
):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, tools=tools, **kwargs
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
def _raw_gen_stream(
|
||||
self,
|
||||
@@ -40,6 +46,7 @@ class OpenAILLM(BaseLLM):
|
||||
model,
|
||||
messages,
|
||||
stream=True,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs
|
||||
):
|
||||
@@ -53,6 +60,9 @@ class OpenAILLM(BaseLLM):
|
||||
if line.choices[0].delta.content is not None:
|
||||
yield line.choices[0].delta.content
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
|
||||
|
||||
class AzureOpenAILLM(OpenAILLM):
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ multidict==6.1.0
|
||||
mypy-extensions==1.0.0
|
||||
networkx==3.3
|
||||
numpy==1.26.4
|
||||
openai==1.46.1
|
||||
openai==1.57.0
|
||||
openapi-schema-validator==0.6.2
|
||||
openapi-spec-validator==0.6.0
|
||||
openapi3-parser==1.1.18
|
||||
|
||||
@@ -2,6 +2,7 @@ from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.tools.agent import Agent
|
||||
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
@@ -90,10 +91,12 @@ class ClassicRAG(BaseRetriever):
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
)
|
||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
# llm = LLMCreator.create_llm(
|
||||
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
# )
|
||||
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
agent = Agent(llm_name=settings.LLM_NAME,gpt_model=self.gpt_model, api_key=settings.API_KEY, user_api_key=self.user_api_key)
|
||||
completion = agent.gen(messages_combine)
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
|
||||
98
application/tools/agent.py
Normal file
98
application/tools/agent.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.core.settings import settings
|
||||
from application.tools.tool_manager import ToolManager
|
||||
import json
|
||||
|
||||
tool_tg = {
|
||||
"name": "telegram_send_message",
|
||||
"description": "Send a notification to telegram about current chat",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to send in the notification"
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
|
||||
tool_crypto = {
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "The cryptocurrency symbol (e.g. BTC)"
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "The currency in which you want the price (e.g. USD)"
|
||||
}
|
||||
},
|
||||
"required": ["symbol", "currency"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
|
||||
class Agent:
|
||||
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
|
||||
# Initialize the LLM with the provided parameters
|
||||
self.llm = LLMCreator.create_llm(llm_name, api_key=api_key, user_api_key=user_api_key)
|
||||
self.gpt_model = gpt_model
|
||||
# Static tool configuration (to be replaced later)
|
||||
self.tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": tool_crypto
|
||||
}
|
||||
]
|
||||
self.tool_config = {
|
||||
}
|
||||
|
||||
def gen(self, messages):
|
||||
# Generate initial response from the LLM
|
||||
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
|
||||
|
||||
if isinstance(resp, str):
|
||||
# Yield the response if it's a string and exit
|
||||
yield resp
|
||||
return
|
||||
|
||||
while resp.finish_reason == "tool_calls":
|
||||
# Append the assistant's message to the conversation
|
||||
messages.append(json.loads(resp.model_dump_json())['message'])
|
||||
# Handle each tool call
|
||||
tool_calls = resp.message.tool_calls
|
||||
for call in tool_calls:
|
||||
tm = ToolManager(config={})
|
||||
call_name = call.function.name
|
||||
call_args = json.loads(call.function.arguments)
|
||||
call_id = call.id
|
||||
# Determine the tool name and load it
|
||||
tool_name = call_name.split("_")[0]
|
||||
tool = tm.load_tool(tool_name, tool_config=self.tool_config)
|
||||
# Execute the tool's action
|
||||
resp_tool = tool.execute_action(call_name, **call_args)
|
||||
# Append the tool's response to the conversation
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": str(resp_tool),
|
||||
"tool_call_id": call_id
|
||||
}
|
||||
)
|
||||
# Generate a new response from the LLM after processing tools
|
||||
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
|
||||
|
||||
# If no tool calls are needed, generate the final response
|
||||
if isinstance(resp, str):
|
||||
yield resp
|
||||
else:
|
||||
completion = self.llm.gen_stream(model=self.gpt_model, messages=messages, tools=self.tools)
|
||||
for line in completion:
|
||||
yield line
|
||||
20
application/tools/base.py
Normal file
20
application/tools/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class Tool(ABC):
|
||||
@abstractmethod
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_actions_metadata(self):
|
||||
"""
|
||||
Returns a list of JSON objects describing the actions supported by the tool.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_config_requirements(self):
|
||||
"""
|
||||
Returns a dictionary describing the configuration requirements for the tool.
|
||||
"""
|
||||
pass
|
||||
73
application/tools/cryptoprice.py
Normal file
73
application/tools/cryptoprice.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from application.tools.base import Tool
|
||||
import requests
|
||||
|
||||
class CryptoPriceTool(Tool):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {
|
||||
"cryptoprice_get": self.get_price
|
||||
}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def get_price(self, symbol, currency):
|
||||
"""
|
||||
Fetches the current price of a given cryptocurrency symbol in the specified currency.
|
||||
Example:
|
||||
symbol = "BTC"
|
||||
currency = "USD"
|
||||
returns price in USD.
|
||||
"""
|
||||
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
# data will be like {"USD": <price>} if the call is successful
|
||||
if currency.upper() in data:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"price": data[currency.upper()],
|
||||
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully."
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}."
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": "Failed to retrieve price."
|
||||
}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "The cryptocurrency symbol (e.g. BTC)"
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "The currency in which you want the price (e.g. USD)"
|
||||
}
|
||||
},
|
||||
"required": ["symbol", "currency"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
# No specific configuration needed for this tool as it just queries a public endpoint
|
||||
return {}
|
||||
79
application/tools/telegram.py
Normal file
79
application/tools/telegram.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from application.tools.base import Tool
|
||||
import requests
|
||||
|
||||
class TelegramTool(Tool):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.chat_id = config.get("chat_id", "142189016")
|
||||
self.token = config.get("token", "YOUR_TG_TOKEN")
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {
|
||||
"telegram_send_message": self.send_message,
|
||||
"telegram_send_image": self.send_image
|
||||
}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def send_message(self, text):
|
||||
print(f"Sending message: {text}")
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||
payload = {"chat_id": self.chat_id, "text": text}
|
||||
response = requests.post(url, data=payload)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def send_image(self, image_url):
|
||||
print(f"Sending image: {image_url}")
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||
payload = {"chat_id": self.chat_id, "photo": image_url}
|
||||
response = requests.post(url, data=payload)
|
||||
return {"status_code": response.status_code, "message": "Image sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "telegram_send_message",
|
||||
"description": "Send a notification to telegram chat",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to send in the notification"
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "telegram_send_image",
|
||||
"description": "Send an image to the Telegram chat",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {
|
||||
"type": "string",
|
||||
"description": "URL of the image to send"
|
||||
}
|
||||
},
|
||||
"required": ["image_url"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Telegram chat ID to send messages to"
|
||||
},
|
||||
"token": {
|
||||
"type": "string",
|
||||
"description": "Bot token for authentication"
|
||||
}
|
||||
}
|
||||
43
application/tools/tool_manager.py
Normal file
43
application/tools/tool_manager.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
import os
|
||||
|
||||
from application.tools.base import Tool
|
||||
|
||||
class ToolManager:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.tools = {}
|
||||
self.load_tools()
|
||||
|
||||
def load_tools(self):
|
||||
tools_dir = os.path.dirname(__file__)
|
||||
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
|
||||
if name == 'base' or name.startswith('__'):
|
||||
continue
|
||||
module = importlib.import_module(f'application.tools.{name}')
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
def load_tool(self, tool_name, tool_config):
|
||||
self.config[tool_name] = tool_config
|
||||
tools_dir = os.path.dirname(__file__)
|
||||
module = importlib.import_module(f'application.tools.{tool_name}')
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
return obj(tool_config)
|
||||
|
||||
|
||||
def execute_action(self, tool_name, action_name, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||
|
||||
def get_all_actions_metadata(self):
|
||||
metadata = []
|
||||
for tool in self.tools.values():
|
||||
metadata.extend(tool.get_actions_metadata())
|
||||
return metadata
|
||||
@@ -1,7 +1,7 @@
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.utils import num_tokens_from_string
|
||||
from application.utils import num_tokens_from_string, num_tokens_from_object_or_list
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo["docsgpt"]
|
||||
@@ -21,11 +21,16 @@ def update_token_usage(user_api_key, token_usage):
|
||||
|
||||
|
||||
def gen_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, **kwargs):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
for message in messages:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
|
||||
result = func(self, model, messages, stream, **kwargs)
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
|
||||
if message["content"]:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
# check if result is a string
|
||||
if isinstance(result, str):
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
|
||||
else:
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(result)
|
||||
update_token_usage(self.user_api_key, self.token_usage)
|
||||
return result
|
||||
|
||||
@@ -33,11 +38,11 @@ def gen_token_usage(func):
|
||||
|
||||
|
||||
def stream_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, **kwargs):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
for message in messages:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
|
||||
batch = []
|
||||
result = func(self, model, messages, stream, **kwargs)
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
for r in result:
|
||||
batch.append(r)
|
||||
yield r
|
||||
|
||||
@@ -15,9 +15,21 @@ def get_encoding():
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
encoding = get_encoding()
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
||||
if isinstance(string, str):
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
||||
else:
|
||||
return 0
|
||||
|
||||
def num_tokens_from_object_or_list(thing):
|
||||
if isinstance(thing, list):
|
||||
return sum([num_tokens_from_object_or_list(x) for x in thing])
|
||||
elif isinstance(thing, dict):
|
||||
return sum([num_tokens_from_object_or_list(x) for x in thing.values()])
|
||||
elif isinstance(thing, str):
|
||||
return num_tokens_from_string(thing)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def count_tokens_docs(docs):
|
||||
docs_content = ""
|
||||
|
||||
Reference in New Issue
Block a user