Merge pull request #1648 from siiddhantt/feat/agent-refactor-and-logging

feat: agent-retriever workflow + logging stack
This commit is contained in:
Alex
2025-03-06 09:32:18 -05:00
committed by GitHub
20 changed files with 752 additions and 307 deletions

View File

@@ -0,0 +1,14 @@
from application.agents.classic_agent import ClassicAgent
class AgentCreator:
agents = {
"classic": ClassicAgent,
}
@classmethod
def create_agent(cls, type, *args, **kwargs):
agent_class = cls.agents.get(type.lower())
if not agent_class:
raise ValueError(f"No agent class found for type {type}")
return agent_class(*args, **kwargs)

View File

@@ -1,23 +1,28 @@
from typing import Dict, Generator
from application.agents.llm_handler import get_llm_handler
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
from application.llm.llm_creator import LLMCreator
from application.tools.llm_handler import get_llm_handler
from application.tools.tool_action_parser import ToolActionParser
from application.tools.tool_manager import ToolManager
class Agent:
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
# Initialize the LLM with the provided parameters
class BaseAgent:
def __init__(self, endpoint, llm_name, gpt_model, api_key, user_api_key=None):
self.endpoint = endpoint
self.llm = LLMCreator.create_llm(
llm_name, api_key=api_key, user_api_key=user_api_key
)
self.llm_handler = get_llm_handler(llm_name)
self.gpt_model = gpt_model
# Static tool configuration (to be replaced later)
self.tools = []
self.tool_config = {}
self.tool_calls = []
def gen(self, *args, **kwargs) -> Generator[Dict, None, None]:
raise NotImplementedError('Method "gen" must be implemented in the child class')
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
@@ -135,50 +140,3 @@ class Agent:
self.tool_calls.append(tool_call_data)
return result, call_id
def _simple_tool_agent(self, messages):
tools_dict = self._get_user_tools()
self._prepare_tools(tools_dict)
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
if isinstance(resp, str):
yield resp
return
if (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield resp.message.content
return
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages)
if isinstance(resp, str):
yield resp
elif (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield resp.message.content
else:
completion = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=self.tools
)
for line in completion:
yield line
return
def gen(self, messages):
self.tool_calls = []
if self.llm.supports_tools():
resp = self._simple_tool_agent(messages)
for line in resp:
yield line
else:
resp = self.llm.gen_stream(model=self.gpt_model, messages=messages)
for line in resp:
yield line

View File

@@ -0,0 +1,135 @@
import uuid
from typing import Dict, Generator
from application.agents.base import BaseAgent
from application.logging import build_stack_data, log_activity, LogContext
from application.retriever.base import BaseRetriever
class ClassicAgent(BaseAgent):
def __init__(
self,
endpoint,
llm_name,
gpt_model,
api_key,
user_api_key=None,
prompt="",
chat_history=None,
):
super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key)
self.prompt = prompt
self.chat_history = chat_history if chat_history is not None else []
@log_activity()
def gen(
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
) -> Generator[Dict, None, None]:
yield from self._gen_inner(query, retriever, log_context)
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
retrieved_data = self._retriever_search(retriever, query, log_context)
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
if len(self.chat_history) > 0:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
call_id = tool_call.get("call_id")
if call_id is None or call_id == "None":
call_id = str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages_combine.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages_combine.append(
{"role": "tool", "content": [function_response_dict]}
)
messages_combine.append({"role": "user", "content": query})
tools_dict = self._get_user_tools()
self._prepare_tools(tools_dict)
resp = self._llm_gen(messages_combine, log_context)
if isinstance(resp, str):
yield {"answer": resp}
return
if (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield {"answer": resp.message.content}
return
resp = self._llm_handler(resp, tools_dict, messages_combine, log_context)
if isinstance(resp, str):
yield {"answer": resp}
elif (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield {"answer": resp.message.content}
else:
completion = self.llm.gen_stream(
model=self.gpt_model, messages=messages_combine, tools=self.tools
)
for line in completion:
if isinstance(line, str):
yield {"answer": line}
yield {"tool_calls": self.tool_calls.copy()}
def _retriever_search(self, retriever, query, log_context):
retrieved_data = retriever.search(query)
if log_context:
data = build_stack_data(retriever, exclude_attributes=["llm"])
log_context.stacks.append({"component": "retriever", "data": data})
return retrieved_data
def _llm_gen(self, messages_combine, log_context):
resp = self.llm.gen_stream(
model=self.gpt_model, messages=messages_combine, tools=self.tools
)
if log_context:
data = build_stack_data(self.llm)
log_context.stacks.append({"component": "llm", "data": data})
return resp
def _llm_handler(self, resp, tools_dict, messages_combine, log_context):
resp = self.llm_handler.handle_response(
self, resp, tools_dict, messages_combine
)
if log_context:
data = build_stack_data(self.llm_handler)
log_context.stacks.append({"component": "llm_handler", "data": data})
return resp

View File

@@ -0,0 +1,250 @@
import json
from abc import ABC, abstractmethod
from application.logging import build_stack_data
class LLMHandler(ABC):
def __init__(self):
self.llm_calls = []
self.tool_calls = []
@abstractmethod
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
pass
class OpenAILLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True):
if not stream:
while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"]
keys_to_remove = {"audio", "function_call", "refusal"}
filtered_data = {
k: v for k, v in message.items() if k not in keys_to_remove
}
messages.append(filtered_data)
tool_calls = resp.message.tool_calls
for call in tool_calls:
try:
self.tool_calls.append(call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
function_call_dict = {
"function_call": {
"name": call.function.name,
"args": call.function.arguments,
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": call.function.name,
"response": {"result": tool_response},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
except Exception as e:
messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call_id,
}
)
resp = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
return resp
else:
while True:
tool_calls = {}
for chunk in resp:
if isinstance(chunk, str):
return
else:
chunk_delta = chunk.delta
if (
hasattr(chunk_delta, "tool_calls")
and chunk_delta.tool_calls is not None
):
for tool_call in chunk_delta.tool_calls:
index = tool_call.index
if index not in tool_calls:
tool_calls[index] = {
"id": "",
"function": {"name": "", "arguments": ""},
}
current = tool_calls[index]
if tool_call.id:
current["id"] = tool_call.id
if tool_call.function.name:
current["function"][
"name"
] = tool_call.function.name
if tool_call.function.arguments:
current["function"][
"arguments"
] += tool_call.function.arguments
tool_calls[index] = current
if (
hasattr(chunk, "finish_reason")
and chunk.finish_reason == "tool_calls"
):
for index in sorted(tool_calls.keys()):
call = tool_calls[index]
try:
self.tool_calls.append(call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
function_call_dict = {
"function_call": {
"name": call["function"]["name"],
"args": call["function"]["arguments"],
"call_id": call["id"],
}
}
function_response_dict = {
"function_response": {
"name": call["function"]["name"],
"response": {"result": tool_response},
"call_id": call["id"],
}
}
messages.append(
{
"role": "assistant",
"content": [function_call_dict],
}
)
messages.append(
{
"role": "tool",
"content": [function_response_dict],
}
)
except Exception as e:
messages.append(
{
"role": "assistant",
"content": f"Error executing tool: {str(e)}",
}
)
tool_calls = {}
if (
hasattr(chunk, "finish_reason")
and chunk.finish_reason == "stop"
):
return
resp = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
class GoogleLLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True):
from google.genai import types
while True:
if not stream:
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
if response.candidates and response.candidates[0].content.parts:
tool_call_found = False
for part in response.candidates[0].content.parts:
if part.function_call:
tool_call_found = True
self.tool_calls.append(part.function_call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, part.function_call
)
function_response_part = types.Part.from_function_response(
name=part.function_call.name,
response={"result": tool_response},
)
messages.append(
{"role": "model", "content": [part.to_json_dict()]}
)
messages.append(
{
"role": "tool",
"content": [function_response_part.to_json_dict()],
}
)
if (
not tool_call_found
and response.candidates[0].content.parts
and response.candidates[0].content.parts[0].text
):
return response.candidates[0].content.parts[0].text
elif not tool_call_found:
return response.candidates[0].content.parts
else:
return response
else:
response = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
tool_call_found = False
for result in response:
if hasattr(result, "function_call"):
tool_call_found = True
self.tool_calls.append(result.function_call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, result.function_call
)
function_response_part = types.Part.from_function_response(
name=result.function_call.name,
response={"result": tool_response},
)
messages.append(
{"role": "model", "content": [result.to_json_dict()]}
)
messages.append(
{
"role": "tool",
"content": [function_response_part.to_json_dict()],
}
)
if not tool_call_found:
return response
def get_llm_handler(llm_type):
handlers = {
"openai": OpenAILLMHandler(),
"google": GoogleLLMHandler(),
}
return handlers.get(llm_type, OpenAILLMHandler())

View File

@@ -1,7 +1,7 @@
import json
import requests
from application.tools.base import Tool
from application.agents.tools.base import Tool
class APITool(Tool):
@@ -31,10 +31,27 @@ class APITool(Tool):
print(f"Making API call: {method} {url} with body: {body}")
response = requests.request(method, url, headers=headers, data=body)
response.raise_for_status()
try:
data = response.json()
except ValueError:
content_type = response.headers.get(
"Content-Type", "application/json"
).lower()
if "application/json" in content_type:
try:
data = response.json()
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
return {
"status_code": response.status_code,
"message": f"API call returned invalid JSON. Error: {e}",
"data": response.text,
}
elif "text/" in content_type or "application/xml" in content_type:
data = response.text
elif not response.content:
data = None
else:
print(f"Unsupported content type: {content_type}")
data = response.content
return {
"status_code": response.status_code,

View File

@@ -1,5 +1,5 @@
import requests
from application.tools.base import Tool
from application.agents.tools.base import Tool
class CryptoPriceTool(Tool):
@@ -31,7 +31,6 @@ class CryptoPriceTool(Tool):
response = requests.get(url)
if response.status_code == 200:
data = response.json()
# data will be like {"USD": <price>} if the call is successful
if currency.upper() in data:
return {
"status_code": response.status_code,

View File

@@ -1,5 +1,5 @@
import requests
from application.tools.base import Tool
from application.agents.tools.base import Tool
class TelegramTool(Tool):

View File

@@ -0,0 +1,42 @@
import json
import logging
logger = logging.getLogger(__name__)
class ToolActionParser:
def __init__(self, llm_type):
self.llm_type = llm_type
self.parsers = {
"OpenAILLM": self._parse_openai_llm,
"GoogleLLM": self._parse_google_llm,
}
def parse_args(self, call):
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
return parser(call)
def _parse_openai_llm(self, call):
if isinstance(call, dict):
try:
call_args = json.loads(call["function"]["arguments"])
tool_id = call["function"]["name"].split("_")[-1]
action_name = call["function"]["name"].rsplit("_", 1)[0]
except (KeyError, TypeError) as e:
logger.error(f"Error parsing OpenAI LLM call: {e}")
return None, None, None
else:
try:
call_args = json.loads(call.function.arguments)
tool_id = call.function.name.split("_")[-1]
action_name = call.function.name.rsplit("_", 1)[0]
except (AttributeError, TypeError) as e:
logger.error(f"Error parsing OpenAI LLM call: {e}")
return None, None, None
return tool_id, action_name, call_args
def _parse_google_llm(self, call):
call_args = call.args
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
return tool_id, action_name, call_args

View File

@@ -3,7 +3,7 @@ import inspect
import os
import pkgutil
from application.tools.base import Tool
from application.agents.tools.base import Tool
class ToolManager:
@@ -13,13 +13,11 @@ class ToolManager:
self.load_tools()
def load_tools(self):
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
tools_dir = os.path.join(os.path.dirname(__file__))
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
if name == "base" or name.startswith("__"):
continue
module = importlib.import_module(
f"application.tools.implementations.{name}"
)
module = importlib.import_module(f"application.agents.tools.{name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
tool_config = self.config.get(name, {})
@@ -27,9 +25,7 @@ class ToolManager:
def load_tool(self, tool_name, tool_config):
self.config[tool_name] = tool_config
module = importlib.import_module(
f"application.tools.implementations.{tool_name}"
)
module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
return obj(tool_config)

View File

@@ -1,15 +1,16 @@
import asyncio
import datetime
import json
import logging
import os
import traceback
import logging
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, make_response, request, Response
from flask_restx import fields, Namespace, Resource
from application.agents.agent_creator import AgentCreator
from application.core.mongo_db import MongoDB
from application.core.settings import settings
@@ -206,6 +207,7 @@ def get_prompt(prompt_id):
def complete_stream(
question,
agent,
retriever,
conversation_id,
user_api_key,
@@ -217,8 +219,8 @@ def complete_stream(
response_full = ""
source_log_docs = []
tool_calls = []
answer = retriever.gen()
sources = retriever.search()
answer = agent.gen(query=question, retriever=retriever)
sources = retriever.search(question)
for source in sources:
if "text" in source:
source["text"] = source["text"][:100].strip() + "..."
@@ -384,9 +386,20 @@ class Stream(Resource):
prompt = get_prompt(prompt_id)
if "isNoneDoc" in data and data["isNoneDoc"] is True:
chunks = 0
agent = AgentCreator.create_agent(
settings.AGENT_NAME,
endpoint="stream",
llm_name=settings.LLM_NAME,
gpt_model=gpt_model,
api_key=settings.API_KEY,
user_api_key=user_api_key,
prompt=prompt,
chat_history=history,
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
source=source,
chat_history=history,
prompt=prompt,
@@ -399,6 +412,7 @@ class Stream(Resource):
return Response(
complete_stream(
question=question,
agent=agent,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,

View File

@@ -1,9 +1,9 @@
import datetime
import json
import math
import os
import shutil
import uuid
import json
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
@@ -12,12 +12,13 @@ from flask import Blueprint, current_app, jsonify, make_response, redirect, requ
from flask_restx import fields, inputs, Namespace, Resource
from werkzeug.utils import secure_filename
from application.agents.tools.tool_manager import ToolManager
from application.api.user.tasks import ingest, ingest_remote
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.extensions import api
from application.tools.tool_manager import ToolManager
from application.tts.google_tts import GoogleTTS
from application.utils import check_required_fields, validate_function_name
from application.vectorstore.vector_creator import VectorCreator
@@ -449,22 +450,21 @@ class UploadRemote(Resource):
return missing_fields
try:
config = json.loads(data["data"])
source_data = None
config = json.loads(data["data"])
source_data = None
if data["source"] == "github":
if data["source"] == "github":
source_data = config.get("repo_url")
elif data["source"] in ["crawler", "url"]:
elif data["source"] in ["crawler", "url"]:
source_data = config.get("url")
elif data["source"] == "reddit":
source_data = config
elif data["source"] == "reddit":
source_data = config
task = ingest_remote.delay(
task = ingest_remote.delay(
source_data=source_data,
job_name=data["name"],
user=data["user"],
loader=data["source"]
loader=data["source"],
)
except Exception as err:
current_app.logger.error(f"Error uploading remote source: {err}")
@@ -1932,11 +1932,14 @@ class UpdateTool(Resource):
for action_name in list(data["config"]["actions"].keys()):
if not validate_function_name(action_name):
return make_response(
jsonify({
"success": False,
"message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.",
"param": "tools[].function.name"
}), 400
jsonify(
{
"success": False,
"message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.",
"param": "tools[].function.name",
}
),
400,
)
update_data["config"] = data["config"]
if "status" in data:

View File

@@ -32,6 +32,7 @@ class Settings(BaseSettings):
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
)
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
AGENT_NAME: str = "classic"
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"

View File

@@ -152,7 +152,15 @@ class GoogleLLM(BaseLLM):
config=config,
)
for chunk in response:
if chunk.text is not None:
if hasattr(chunk, "candidates") and chunk.candidates:
for candidate in chunk.candidates:
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if part.function_call:
yield part
elif part.text:
yield part.text
elif hasattr(chunk, "text"):
yield chunk.text
def _supports_tools(self):

View File

@@ -111,13 +111,24 @@ class OpenAILLM(BaseLLM):
**kwargs,
):
messages = self._clean_messages_openai(messages)
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
if tools:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content
else:
yield line.choices[0]
def _supports_tools(self):
return True

151
application/logging.py Normal file
View File

@@ -0,0 +1,151 @@
import datetime
import functools
import inspect
import logging
import uuid
from typing import Any, Callable, Dict, Generator, List
from application.core.mongo_db import MongoDB
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class LogContext:
def __init__(self, endpoint, activity_id, user, api_key, query):
self.endpoint = endpoint
self.activity_id = activity_id
self.user = user
self.api_key = api_key
self.query = query
self.stacks = []
def build_stack_data(
obj: Any,
include_attributes: List[str] = None,
exclude_attributes: List[str] = None,
custom_data: Dict = None,
) -> Dict:
data = {}
if include_attributes is None:
include_attributes = []
for name, value in inspect.getmembers(obj):
if (
not name.startswith("_")
and not inspect.ismethod(value)
and not inspect.isfunction(value)
):
include_attributes.append(name)
for attr_name in include_attributes:
if exclude_attributes and attr_name in exclude_attributes:
continue
try:
attr_value = getattr(obj, attr_name)
if attr_value is not None:
if isinstance(attr_value, (int, float, str, bool)):
data[attr_name] = attr_value
elif isinstance(attr_value, list):
if all(isinstance(item, dict) for item in attr_value):
data[attr_name] = attr_value
elif all(hasattr(item, "__dict__") for item in attr_value):
data[attr_name] = [item.__dict__ for item in attr_value]
else:
data[attr_name] = [str(item) for item in attr_value]
elif isinstance(attr_value, dict):
data[attr_name] = {k: str(v) for k, v in attr_value.items()}
else:
data[attr_name] = str(attr_value)
except AttributeError:
pass
if custom_data:
data.update(custom_data)
return data
def log_activity() -> Callable:
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
activity_id = str(uuid.uuid4())
data = build_stack_data(args[0])
endpoint = data.get("endpoint", "")
user = data.get("user", "local")
api_key = data.get("user_api_key", "")
query = kwargs.get("query", getattr(args[0], "query", ""))
context = LogContext(endpoint, activity_id, user, api_key, query)
kwargs["log_context"] = context
logging.info(
f"Starting activity: {endpoint} - {activity_id} - User: {user}"
)
generator = func(*args, **kwargs)
yield from _consume_and_log(generator, context)
return wrapper
return decorator
def _consume_and_log(generator: Generator, context: "LogContext"):
try:
for item in generator:
yield item
except Exception as e:
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")
context.stacks.append({"component": "error", "data": {"message": str(e)}})
_log_to_mongodb(
endpoint=context.endpoint,
activity_id=context.activity_id,
user=context.user,
api_key=context.api_key,
query=context.query,
stacks=context.stacks,
level="error",
)
raise
finally:
_log_to_mongodb(
endpoint=context.endpoint,
activity_id=context.activity_id,
user=context.user,
api_key=context.api_key,
query=context.query,
stacks=context.stacks,
level="info",
)
def _log_to_mongodb(
endpoint: str,
activity_id: str,
user: str,
api_key: str,
query: str,
stacks: List[Dict],
level: str,
) -> None:
try:
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
user_logs_collection = db["stack_logs"]
log_entry = {
"endpoint": endpoint,
"id": activity_id,
"level": level,
"user": user,
"api_key": api_key,
"query": query,
"stacks": stacks,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
user_logs_collection.insert_one(log_entry)
logging.debug(f"Logged activity to MongoDB: {activity_id}")
except Exception as e:
logging.error(f"Failed to log to MongoDB: {e}")

View File

@@ -1,28 +1,25 @@
import uuid
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.retriever.base import BaseRetriever
from application.tools.agent import Agent
from application.vectorstore.vector_creator import VectorCreator
class ClassicRAG(BaseRetriever):
def __init__(
self,
question,
source,
chat_history,
prompt,
chat_history=None,
prompt="",
chunks=2,
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
llm_name=settings.LLM_NAME,
api_key=settings.API_KEY,
):
self.question = question
self.vectorstore = source["active_docs"] if "active_docs" in source else None
self.chat_history = chat_history
self.original_question = ""
self.chat_history = chat_history if chat_history is not None else []
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
@@ -37,12 +34,41 @@ class ClassicRAG(BaseRetriever):
)
)
self.user_api_key = user_api_key
self.agent = Agent(
llm_name=settings.LLM_NAME,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
self.llm_name = llm_name
self.api_key = api_key
self.llm = LLMCreator.create_llm(
self.llm_name, api_key=self.api_key, user_api_key=self.user_api_key
)
self.question = self._rephrase_query()
self.vectorstore = source["active_docs"] if "active_docs" in source else None
def _rephrase_query(self):
if (
not self.original_question
or not self.chat_history
or self.chat_history == []
):
return self.original_question
prompt = f"""Given the following conversation history:
{self.chat_history}
Rephrase the following user question to be a standalone search query
that captures all relevant context from the conversation:
"""
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": self.original_question},
]
try:
rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages)
print(f"Rephrased query: {rephrased_query}")
return rephrased_query if rephrased_query else self.original_question
except Exception as e:
print(f"Error rephrasing query: {e}")
return self.original_question
def _get_data(self):
if self.chunks == 0:
@@ -69,68 +95,20 @@ class ClassicRAG(BaseRetriever):
return docs
def gen(self):
docs = self._get_data()
def gen():
pass
# join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 0:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
call_id = tool_call.get("call_id")
if call_id is None or call_id == "None":
call_id = str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages_combine.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages_combine.append(
{"role": "tool", "content": [function_response_dict]}
)
messages_combine.append({"role": "user", "content": self.question})
completion = self.agent.gen(messages_combine)
for line in completion:
yield {"answer": str(line)}
yield {"tool_calls": self.agent.tool_calls.copy()}
def search(self):
def search(self, query: str = ""):
if query:
self.original_question = query
self.question = self._rephrase_query()
return self._get_data()
def get_params(self):
return {
"question": self.question,
"question": self.original_question,
"rephrased_question": self.question,
"source": self.vectorstore,
"chat_history": self.chat_history,
"prompt": self.prompt,
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,

View File

@@ -1,112 +0,0 @@
import json
from abc import ABC, abstractmethod
class LLMHandler(ABC):
@abstractmethod
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
pass
class OpenAILLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages):
while resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"]
keys_to_remove = {"audio", "function_call", "refusal"}
filtered_data = {
k: v for k, v in message.items() if k not in keys_to_remove
}
messages.append(filtered_data)
tool_calls = resp.message.tool_calls
for call in tool_calls:
try:
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
function_call_dict = {
"function_call": {
"name": call.function.name,
"args": call.function.arguments,
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": call.function.name,
"response": {"result": tool_response},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
except Exception as e:
messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call_id,
}
)
resp = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
return resp
class GoogleLLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages):
from google.genai import types
while True:
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
if response.candidates and response.candidates[0].content.parts:
tool_call_found = False
for part in response.candidates[0].content.parts:
if part.function_call:
tool_call_found = True
tool_response, call_id = agent._execute_tool_action(
tools_dict, part.function_call
)
function_response_part = types.Part.from_function_response(
name=part.function_call.name,
response={"result": tool_response},
)
messages.append(
{"role": "model", "content": [part.to_json_dict()]}
)
messages.append(
{
"role": "tool",
"content": [function_response_part.to_json_dict()],
}
)
if (
not tool_call_found
and response.candidates[0].content.parts
and response.candidates[0].content.parts[0].text
):
return response.candidates[0].content.parts[0].text
elif not tool_call_found:
return response.candidates[0].content.parts
else:
return response
def get_llm_handler(llm_type):
handlers = {
"openai": OpenAILLMHandler(),
"google": GoogleLLMHandler(),
}
return handlers.get(llm_type, OpenAILLMHandler())

View File

@@ -1,26 +0,0 @@
import json
class ToolActionParser:
def __init__(self, llm_type):
self.llm_type = llm_type
self.parsers = {
"OpenAILLM": self._parse_openai_llm,
"GoogleLLM": self._parse_google_llm,
}
def parse_args(self, call):
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
return parser(call)
def _parse_openai_llm(self, call):
call_args = json.loads(call.function.arguments)
tool_id = call.function.name.split("_")[-1]
action_name = call.function.name.rsplit("_", 1)[0]
return tool_id, action_name, call_args
def _parse_google_llm(self, call):
call_args = call.args
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
return tool_id, action_name, call_args

View File

@@ -1,7 +1,8 @@
import sys
from datetime import datetime
from application.core.mongo_db import MongoDB
from application.utils import num_tokens_from_string, num_tokens_from_object_or_list
from application.utils import num_tokens_from_object_or_list, num_tokens_from_string
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
@@ -24,13 +25,16 @@ def gen_token_usage(func):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
if message["content"]:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
self.token_usage["prompt_tokens"] += num_tokens_from_string(
message["content"]
)
result = func(self, model, messages, stream, tools, **kwargs)
# check if result is a string
if isinstance(result, str):
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
else:
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(result)
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
result
)
update_token_usage(self.user_api_key, self.token_usage)
return result
@@ -40,7 +44,9 @@ def gen_token_usage(func):
def stream_token_usage(func):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
self.token_usage["prompt_tokens"] += num_tokens_from_string(
message["content"]
)
batch = []
result = func(self, model, messages, stream, tools, **kwargs)
for r in result: